Initial Commit 🚀 🚀

This commit is contained in:
Abhishek Kumar 2025-09-09 14:37:32 +05:30
commit 4f2a629340
444 changed files with 76863 additions and 0 deletions

View file

View file

@ -0,0 +1,330 @@
from typing import Annotated, Optional
import httpx
from fastapi import Header, HTTPException
from loguru import logger
from pydantic import ValidationError
from api.constants import DEPLOYMENT_MODE, DOGRAH_MPS_SECRET_KEY, MPS_API_URL
from api.db import db_client
from api.db.models import UserModel
from api.schemas.user_configuration import UserConfiguration
from api.services.auth.stack_auth import stackauth
from api.services.configuration.registry import (
DograhLLMModel,
DograhSTTModel,
DograhTTSModel,
DograhVoice,
ServiceProviders,
)
async def get_user(
authorization: Annotated[str | None, Header()] = None,
) -> UserModel:
# ------------------------------------------------------------------
# Check if we're in OSS deployment mode
# ------------------------------------------------------------------
if DEPLOYMENT_MODE == "oss":
return await _handle_oss_auth(authorization)
# ------------------------------------------------------------------
# 1. Validate and fetch the authenticated Stack user
# ------------------------------------------------------------------
stack_user = await stackauth.get_user(authorization)
if stack_user is None:
raise HTTPException(status_code=401, detail="Unauthorized")
# ------------------------------------------------------------------
# 2. Ensure the user has a team (Stack "selected_team_id")
# ------------------------------------------------------------------
selected_team_id: str | None = stack_user.get("selected_team_id")
if not selected_team_id and stack_user.get("selected_team"):
selected_team_id = stack_user["selected_team"].get("id")
if not selected_team_id:
raise HTTPException(status_code=400, detail="No team selected")
# ------------------------------------------------------------------
# 3. Persist/Fetch the local User model
# ------------------------------------------------------------------
try:
user_model = await db_client.get_or_create_user_by_provider_id(stack_user["id"])
except Exception as e:
raise HTTPException(
status_code=500, detail=f"Error while creating user from database {e}"
)
# ------------------------------------------------------------------
# 4. Persist Organization (team) and mapping in local database
# ------------------------------------------------------------------
try:
(
organization,
org_was_created,
) = await db_client.get_or_create_organization_by_provider_id(
org_provider_id=selected_team_id, user_id=user_model.id
)
# Check if user's selected organization differs from the current organization
if user_model.selected_organization_id != organization.id:
await db_client.add_user_to_organization(user_model.id, organization.id)
# Update user's selected organization
await db_client.update_user_selected_organization(
user_model.id, organization.id
)
# Update the user_model object to reflect the change
user_model.selected_organization_id = organization.id
# Only create default configuration if organization was just created
# This prevents race conditions where multiple concurrent requests
# might try to create configurations
if org_was_created:
existing_cfg = await db_client.get_user_configurations(user_model.id)
if not (existing_cfg.llm or existing_cfg.tts or existing_cfg.stt):
mps_config = await create_user_configuration_with_mps_key(
user_model.id, organization.id, stack_user["id"]
)
if mps_config:
await db_client.update_user_configuration(
user_model.id, mps_config
)
except Exception as exc:
raise HTTPException(
status_code=500,
detail=f"Failed to map user to organization: {exc}",
)
return user_model
async def _handle_oss_auth(authorization: str | None) -> UserModel:
"""
Handle authentication for OSS deployment mode.
Uses the authorization token as provider_id and creates user/org if needed.
"""
if not authorization:
raise HTTPException(status_code=401, detail="Authorization header required")
# Remove "Bearer " prefix if present
token = (
authorization.replace("Bearer ", "")
if authorization.startswith("Bearer ")
else authorization
)
if not token:
raise HTTPException(status_code=401, detail="Invalid authorization token")
try:
# Use token as provider_id for OSS mode
user_model = await db_client.get_or_create_user_by_provider_id(
provider_id=token
)
# Create or get organization for OSS user
# Each OSS user gets their own organization using their token as org ID
organization = await db_client.get_or_create_organization_by_provider_id(
provider_id=f"org_{token}"
)
# Ensure user is mapped to their organization
if user_model.selected_organization_id != organization.id:
# add_user_to_organization now handles race conditions with ON CONFLICT DO NOTHING
await db_client.add_user_to_organization(user_model.id, organization.id)
await db_client.update_user_selected_organization(
user_model.id, organization.id
)
user_model.selected_organization_id = organization.id
return user_model
except Exception as e:
raise HTTPException(
status_code=500, detail=f"Error while handling OSS authentication: {e}"
)
async def get_user_optional(
authorization: Annotated[str | None, Header()] = None,
) -> UserModel | None:
"""
Same as get_user but returns None instead of raising 401 if unauthorized.
Useful for endpoints that need to work both with and without auth.
"""
try:
return await get_user(authorization)
except HTTPException as e:
if e.status_code == 401:
return None
raise
async def _handle_oss_auth(authorization: str | None) -> UserModel:
"""
Handle authentication for OSS deployment mode.
Uses the authorization token as provider_id and creates user/org if needed.
"""
if not authorization:
raise HTTPException(status_code=401, detail="Authorization header required")
# Remove "Bearer " prefix if present
token = (
authorization.replace("Bearer ", "")
if authorization.startswith("Bearer ")
else authorization
)
if not token:
raise HTTPException(status_code=401, detail="Invalid authorization token")
try:
# Use token as provider_id for OSS mode
user_model = await db_client.get_or_create_user_by_provider_id(
provider_id=token
)
# Create or get organization for OSS user
# Each OSS user gets their own organization using their token as org ID
(
organization,
org_was_created,
) = await db_client.get_or_create_organization_by_provider_id(
org_provider_id=f"org_{token}", user_id=user_model.id
)
# Ensure user is mapped to their organization
if user_model.selected_organization_id != organization.id:
# add_user_to_organization now handles race conditions with ON CONFLICT DO NOTHING
await db_client.add_user_to_organization(user_model.id, organization.id)
await db_client.update_user_selected_organization(
user_model.id, organization.id
)
user_model.selected_organization_id = organization.id
# Only create default configuration if organization was just created
# This prevents race conditions where multiple concurrent requests
# might try to create configurations
if org_was_created:
existing_cfg = await db_client.get_user_configurations(user_model.id)
if not (existing_cfg.llm or existing_cfg.tts or existing_cfg.stt):
mps_config = await create_user_configuration_with_mps_key(
user_model.id, organization.id, token
)
if mps_config:
await db_client.update_user_configuration(
user_model.id, mps_config
)
return user_model
except Exception as e:
raise HTTPException(
status_code=500, detail=f"Error while handling OSS authentication: {e}"
)
async def create_user_configuration_with_mps_key(
user_id: int, organization_id: int, user_provider_id: str
) -> Optional[UserConfiguration]:
"""Create user configuration using MPS service key.
Args:
user_id: The user's ID
organization_id: The organization's ID
user_provider_id: The user's provider ID (for created_by field)
Returns:
UserConfiguration with MPS-provided API keys or None if failed
"""
async with httpx.AsyncClient() as client:
# Use MPS API URL from constants
if DEPLOYMENT_MODE == "oss":
# For OSS mode, create a temporary service key without authentication
response = await client.post(
f"{MPS_API_URL}/api/v1/service-keys/",
json={
"name": f"Default Dograh Model Service Key",
"description": "Auto-generated key for OSS user",
"expires_in_days": 7, # Short-lived for OSS
"created_by": user_provider_id,
},
timeout=10.0,
)
else:
# For authenticated mode, use the secret key and organization ID
if not DOGRAH_MPS_SECRET_KEY:
logger.warning(
"Warning: DOGRAH_MPS_SECRET_KEY not set for authenticated mode"
)
raise ValidationError("Missing DOGRAH_MPS_SECRET_KEY in non oss mode")
response = await client.post(
f"{MPS_API_URL}/api/v1/service-keys/",
json={
"name": f"Default Dograh Model Service Key",
"description": f"Auto-generated key for organization {organization_id}",
"organization_id": organization_id,
"expires_in_days": 90, # Longer-lived for authenticated users
"created_by": user_provider_id,
},
headers={"X-Secret-Key": DOGRAH_MPS_SECRET_KEY},
timeout=10.0,
)
if response.status_code == 200:
data = response.json()
service_key = data.get("service_key")
if service_key:
# Create configuration JSON for storage in database
# The service_factory will use this to instantiate actual services
configuration = {
"llm": {
"provider": ServiceProviders.DOGRAH.value,
"api_key": service_key,
"model": DograhLLMModel.DEFAULT.value, # Default model
},
"tts": {
"provider": ServiceProviders.DOGRAH.value,
"api_key": service_key,
"model": DograhTTSModel.DEFAULT.value, # Default model
"voice": DograhVoice.DEFAULT.value, # Default voice
},
"stt": {
"provider": ServiceProviders.DOGRAH.value,
"api_key": service_key,
"model": DograhSTTModel.DEFAULT.value, # Default model
},
}
user_config = UserConfiguration(**configuration)
return user_config
else:
logger.warning(
f"Failed to get MPS service key: {response.status_code} - {response.text}"
)
async def get_superuser(
authorization: Annotated[str | None, Header()] = None,
) -> UserModel:
"""
Dependency to check if the authenticated user is a superuser.
Raises HTTPException if user is not authenticated or not a superuser.
"""
user = await get_user(authorization)
if not user.is_superuser:
raise HTTPException(
status_code=403, detail="Access denied. Superuser privileges required."
)
return user

View file

@ -0,0 +1,122 @@
import os
import aiohttp
class StackAuth:
def __init__(self):
self.project_id = os.environ.get("STACK_AUTH_PROJECT_ID")
self.secret_server_key = os.environ.get("STACK_SECRET_SERVER_KEY")
# ------------------------------------------------------------------
# Internal helpers
# ------------------------------------------------------------------
def _strip_bearer(self, access_token: str | None) -> str | None:
"""Remove the leading "Bearer " prefix from the token if present."""
if not access_token:
return None
if access_token.startswith("Bearer "):
return access_token.split(" ", 1)[1]
return access_token
async def get_user(self, access_token: str):
if not access_token:
return None
access_token = self._strip_bearer(access_token)
url = os.environ.get("STACK_AUTH_API_URL") + "/api/v1/users/me"
headers = {
"x-stack-access-type": "server",
"x-stack-project-id": self.project_id,
"x-stack-secret-server-key": self.secret_server_key,
"x-stack-access-token": access_token,
}
async with aiohttp.ClientSession() as session:
async with session.get(url, headers=headers) as response:
response = await response.json()
if "id" in response:
return response
else:
return None
async def impersonate(self, stack_user_id: str):
url = os.environ.get("STACK_AUTH_API_URL") + "/api/v1/auth/sessions"
headers = {
"x-stack-access-type": "server",
"x-stack-project-id": self.project_id,
"x-stack-secret-server-key": self.secret_server_key,
}
data = {
"user_id": stack_user_id,
"expires_in_millis": 3600000,
"is_impersonation": True,
}
async with aiohttp.ClientSession() as session:
async with session.post(url, headers=headers, json=data) as response:
response = await response.json()
return response
# ------------------------------------------------------------------
# Team & user management helpers
# ------------------------------------------------------------------
# async def create_team(
# self,
# access_token: str,
# display_name: str,
# profile_image_url: str | None = None,
# client_metadata: dict | None = None,
# ) -> dict:
# """Create a new team for the authenticated user and return the API response."""
# token = self._strip_bearer(access_token)
# if token is None:
# raise ValueError("Access token required to create team")
# url = os.environ.get("STACK_AUTH_API_URL") + "/api/v1/teams"
# headers = {
# "x-stack-access-type": "server",
# "x-stack-project-id": self.project_id,
# "x-stack-secret-server-key": self.secret_server_key,
# "x-stack-access-token": token,
# "Content-Type": "application/json",
# }
# payload: dict = {
# "display_name": display_name,
# "creator_user_id": "me",
# }
# if profile_image_url is not None:
# payload["profile_image_url"] = profile_image_url
# if client_metadata is not None:
# payload["client_metadata"] = client_metadata
# async with aiohttp.ClientSession() as session:
# async with session.post(url, headers=headers, json=payload) as response:
# return await response.json()
# async def update_user(self, access_token: str, data: dict) -> dict:
# """Patch the current user with supplied data and return the API response."""
# token = self._strip_bearer(access_token)
# if token is None:
# raise ValueError("Access token required to update user")
# url = os.environ.get("STACK_AUTH_API_URL") + "/api/v1/users/me"
# headers = {
# "x-stack-access-type": "server",
# "x-stack-project-id": self.project_id,
# "x-stack-secret-server-key": self.secret_server_key,
# "x-stack-access-token": token,
# "Content-Type": "application/json",
# }
# async with aiohttp.ClientSession() as session:
# async with session.patch(url, headers=headers, json=data) as response:
# return await response.json()
stackauth = StackAuth()

View file

@ -0,0 +1,5 @@
"""Campaign service package"""
from .rate_limiter import rate_limiter
__all__ = ["rate_limiter"]

View file

@ -0,0 +1,329 @@
import asyncio
import time
from datetime import UTC, datetime
from typing import Optional
from loguru import logger
from api.db import db_client
from api.db.models import QueuedRunModel, WorkflowRunModel
from api.enums import OrganizationConfigurationKey, WorkflowRunMode
from api.services.campaign.rate_limiter import rate_limiter
from api.services.telephony.twilio import TwilioService
class CampaignCallDispatcher:
"""Manages rate-limited and concurrent-limited call dispatching"""
def __init__(self):
self._twilio_service = None
self.default_concurrent_limit = 20
@property
def twilio_service(self):
"""Lazy initialization of TwilioService"""
if self._twilio_service is None:
self._twilio_service = TwilioService()
return self._twilio_service
async def get_org_concurrent_limit(self, organization_id: int) -> int:
"""Get the concurrent call limit for an organization."""
try:
config = await db_client.get_configuration(
organization_id,
OrganizationConfigurationKey.CONCURRENT_CALL_LIMIT.value,
)
if config and config.value:
return int(config.value["value"])
except Exception as e:
logger.warning(
f"Error getting concurrent limit for org {organization_id}: {e}"
)
return self.default_concurrent_limit
async def process_batch(self, campaign_id: int, batch_size: int = 10) -> int:
"""
Processes a batch of queued runs with priority for scheduled retries
Returns: number of processed runs
"""
# Get campaign details
campaign = await db_client.get_campaign_by_id(campaign_id)
if not campaign:
raise ValueError(f"Campaign {campaign_id} not found")
# Check if campaign is in running state
if campaign.state != "running":
logger.info(
f"Campaign {campaign_id} is not in running state: {campaign.state}"
)
return 0
# First, get any scheduled retries that are due
scheduled_runs = await db_client.get_scheduled_queued_runs(
campaign_id=campaign_id,
scheduled_before=datetime.now(UTC),
limit=batch_size,
)
remaining_slots = batch_size - len(scheduled_runs)
# Then get regular queued runs
regular_runs = []
if remaining_slots > 0:
regular_runs = await db_client.get_queued_runs(
campaign_id=campaign_id,
state="queued",
scheduled_for=False, # Exclude scheduled runs
limit=remaining_slots,
)
queued_runs = scheduled_runs + regular_runs
if not queued_runs:
logger.info(f"No more queued runs for campaign {campaign_id}")
return 0
processed_count = 0
for queued_run in queued_runs:
try:
# Apply rate limiting
await self.apply_rate_limit(
campaign.organization_id, campaign.rate_limit_per_second
)
# Dispatch the call
workflow_run = await self.dispatch_call(queued_run, campaign)
# Update queued run as processed
await db_client.update_queued_run(
queued_run_id=queued_run.id,
state="processed",
workflow_run_id=workflow_run.id,
processed_at=datetime.now(UTC),
)
processed_count += 1
# Update campaign processed count
await db_client.update_campaign(
campaign_id=campaign_id, processed_rows=campaign.processed_rows + 1
)
except Exception as e:
logger.warning(f"Error processing queued run {queued_run.id}: {e}")
# Mark the queued run as failed to prevent infinite retry loops
try:
await db_client.update_queued_run(
queued_run_id=queued_run.id,
state="failed",
processed_at=datetime.now(UTC),
)
logger.info(
f"Marked queued run {queued_run.id} as failed due to error: {e}"
)
except Exception as update_error:
logger.error(
f"Failed to mark queued run {queued_run.id} as failed: {update_error}"
)
return processed_count
async def dispatch_call(
self, queued_run: QueuedRunModel, campaign: any
) -> Optional[WorkflowRunModel]:
"""Creates workflow run and initiates call with concurrent limiting"""
# Get concurrent limit for organization
max_concurrent = await self.get_org_concurrent_limit(campaign.organization_id)
# Track wait time for alerting
wait_start = time.time()
slot_id = None
# Wait until we can acquire a concurrent slot
while True:
slot_id = await rate_limiter.try_acquire_concurrent_slot(
campaign.organization_id, max_concurrent
)
if slot_id:
break
# Check if we've been waiting too long
wait_time = time.time() - wait_start
if wait_time > 600: # 10 minutes
logger.error(
f"Waiting for concurrent slot for {wait_time:.1f}s, "
f"org: {campaign.organization_id}, campaign: {campaign.id}"
)
logger.debug(
f"Attempting to get a slot for {campaign.organization_id} {campaign.id}"
)
# Wait before retrying
await asyncio.sleep(1)
# Get workflow details
workflow = await db_client.get_workflow_by_id(campaign.workflow_id)
if not workflow:
# Release slot before raising
await rate_limiter.release_concurrent_slot(
campaign.organization_id, slot_id
)
raise ValueError(f"Workflow {campaign.workflow_id} not found")
# Merge context variables (queued_run context already includes retry info if applicable)
initial_context = {
**workflow.template_context_variables,
**queued_run.context_variables,
"campaign_id": campaign.id,
}
# Extract phone number
phone_number = queued_run.context_variables.get("phone_number")
if not phone_number:
# Release slot before raising
await rate_limiter.release_concurrent_slot(
campaign.organization_id, slot_id
)
raise ValueError(f"No phone number in queued run {queued_run.id}")
# Create workflow run with queued_run_id tracking
workflow_run_name = f"WR-CAMPAIGN-{campaign.id}-{queued_run.id}"
try:
workflow_run = await db_client.create_workflow_run(
name=workflow_run_name,
workflow_id=campaign.workflow_id,
mode=WorkflowRunMode.TWILIO.value,
user_id=campaign.created_by,
initial_context=initial_context,
campaign_id=campaign.id,
queued_run_id=queued_run.id, # Link to queued run for retry tracking
)
# Store slot_id mapping in Redis for cleanup later
await rate_limiter.store_workflow_slot_mapping(
workflow_run.id, campaign.organization_id, slot_id
)
except Exception as e:
# Release slot on error
await rate_limiter.release_concurrent_slot(
campaign.organization_id, slot_id
)
raise
# Add "retry" tag if this is a retry call
if queued_run.context_variables.get("is_retry"):
retry_reason = queued_run.context_variables.get("retry_reason", "unknown")
await db_client.update_workflow_run(
run_id=workflow_run.id,
gathered_context={
"call_tags": ["retry", f"retry_reason_{retry_reason}"]
},
)
# Initiate call via Twilio
try:
call_result = await self.twilio_service.initiate_call(
to_number=phone_number,
workflow_run_id=workflow_run.id,
organization_id=campaign.organization_id,
url_args={
"workflow_id": campaign.workflow_id,
"user_id": campaign.created_by,
"workflow_run_id": workflow_run.id,
"campaign_id": campaign.id,
},
)
logger.info(
f"Call initiated for workflow run {workflow_run.id}, SID: {call_result.get('sid')}"
)
except Exception as e:
logger.error(
f"Failed to initiate call for workflow run {workflow_run.id}: {e}"
)
# Update workflow run as failed
twilio_callback_logs = workflow_run.logs.get("twilio_status_callbacks", [])
twilio_callback_log = {
"status": "failed",
"timestamp": datetime.now(UTC).isoformat(),
"data": {"error": str(e)},
}
twilio_callback_logs.append(twilio_callback_log)
await db_client.update_workflow_run(
run_id=workflow_run.id,
is_completed=True,
gathered_context={
"error": str(e),
},
logs={
"twilio_status_callbacks": twilio_callback_logs,
},
)
# Release concurrent slot on failure
mapping = await rate_limiter.get_workflow_slot_mapping(workflow_run.id)
if mapping:
org_id, slot_id = mapping
await rate_limiter.release_concurrent_slot(org_id, slot_id)
await rate_limiter.delete_workflow_slot_mapping(workflow_run.id)
raise
return workflow_run
async def apply_rate_limit(self, organization_id: int, rate_limit: int) -> None:
"""
Enforces rate limiting - waits if necessary to comply with rate limit
Example usage:
```
# This will wait up to 1 second if needed to respect rate limit
await self.apply_rate_limit(org_id, 1) # 1 call per second
await twilio.initiate_call(...) # Now safe to call
```
"""
max_wait = 1.0 # Maximum time to wait for a slot
start_time = time.time()
while True:
# Try to acquire token
if await rate_limiter.acquire_token(organization_id, rate_limit):
return # Got permission to proceed
# Check how long to wait
wait_time = await rate_limiter.get_next_available_slot(
organization_id, rate_limit
)
# Don't wait forever
if time.time() - start_time + wait_time > max_wait:
raise TimeoutError("Rate limit timeout - try again later")
# Wait for next available slot
await asyncio.sleep(wait_time)
async def release_call_slot(self, workflow_run_id: int) -> bool:
"""
Release concurrent slot when a call completes.
Called by Twilio webhooks or workflow completion handlers.
"""
mapping = await rate_limiter.get_workflow_slot_mapping(workflow_run_id)
if mapping:
org_id, slot_id = mapping
success = await rate_limiter.release_concurrent_slot(org_id, slot_id)
if success:
await rate_limiter.delete_workflow_slot_mapping(workflow_run_id)
logger.info(
f"Released concurrent slot for workflow run {workflow_run_id}"
)
return success
return False
# Global instance
campaign_call_dispatcher = CampaignCallDispatcher()

View file

@ -0,0 +1,258 @@
"""Campaign event protocol for orchestrator communication.
Defines message formats and helpers for campaign event publishing and handling.
"""
import json
from dataclasses import asdict, dataclass
from enum import Enum
from typing import Any, Dict, Optional
class CampaignEventType(str, Enum):
"""Types of campaign events."""
# Batch processing events
BATCH_COMPLETED = "batch_completed"
BATCH_FAILED = "batch_failed"
# Sync events
SYNC_STARTED = "sync_started"
SYNC_COMPLETED = "sync_completed"
SYNC_FAILED = "sync_failed"
# Campaign lifecycle events
CAMPAIGN_STARTED = "campaign_started"
CAMPAIGN_PAUSED = "campaign_paused"
CAMPAIGN_RESUMED = "campaign_resumed"
CAMPAIGN_COMPLETED = "campaign_completed"
CAMPAIGN_FAILED = "campaign_failed"
# Retry events
RETRY_NEEDED = "retry_needed"
RETRY_SCHEDULED = "retry_scheduled"
RETRY_FAILED = "retry_failed"
class RetryReason(str, Enum):
"""Reasons for retry."""
BUSY = "busy"
NO_ANSWER = "no_answer"
VOICEMAIL = "voicemail"
FAILED = "failed"
ERROR = "error"
@dataclass
class BaseCampaignEvent:
"""Base class for all campaign events."""
type: str
campaign_id: int = 0
timestamp: Optional[str] = None
def __post_init__(self):
if self.timestamp is None:
from datetime import UTC, datetime
self.timestamp = datetime.now(UTC).isoformat()
def to_json(self) -> str:
return json.dumps(asdict(self))
@classmethod
def from_json(cls, data: str):
return cls(**json.loads(data))
@dataclass
class BatchCompletedEvent(BaseCampaignEvent):
"""Event sent when a batch processing completes."""
type: str = CampaignEventType.BATCH_COMPLETED
processed_count: int = 0
failed_count: int = 0
batch_size: int = 0
metadata: Optional[Dict[str, Any]] = None
def __post_init__(self):
super().__post_init__()
if self.metadata is None:
self.metadata = {}
@dataclass
class BatchFailedEvent(BaseCampaignEvent):
"""Event sent when a batch processing fails."""
type: str = CampaignEventType.BATCH_FAILED
error: str = ""
processed_count: int = 0
metadata: Optional[Dict[str, Any]] = None
def __post_init__(self):
super().__post_init__()
if self.metadata is None:
self.metadata = {}
@dataclass
class SyncStartedEvent(BaseCampaignEvent):
"""Event sent when campaign source sync starts."""
type: str = CampaignEventType.SYNC_STARTED
source_type: str = ""
source_id: str = ""
@dataclass
class SyncCompletedEvent(BaseCampaignEvent):
"""Event sent when campaign source sync completes."""
type: str = CampaignEventType.SYNC_COMPLETED
total_rows: int = 0
source_type: str = ""
source_id: str = ""
metadata: Optional[Dict[str, Any]] = None
def __post_init__(self):
super().__post_init__()
if self.metadata is None:
self.metadata = {}
@dataclass
class SyncFailedEvent(BaseCampaignEvent):
"""Event sent when campaign source sync fails."""
type: str = CampaignEventType.SYNC_FAILED
error: str = ""
source_type: str = ""
source_id: str = ""
@dataclass
class CampaignStartedEvent(BaseCampaignEvent):
"""Event sent when a campaign starts."""
type: str = CampaignEventType.CAMPAIGN_STARTED
workflow_id: int = 0
total_rows: Optional[int] = None
@dataclass
class CampaignPausedEvent(BaseCampaignEvent):
"""Event sent when a campaign is paused."""
type: str = CampaignEventType.CAMPAIGN_PAUSED
processed_rows: int = 0
failed_rows: int = 0
@dataclass
class CampaignResumedEvent(BaseCampaignEvent):
"""Event sent when a campaign is resumed."""
type: str = CampaignEventType.CAMPAIGN_RESUMED
processed_rows: int = 0
failed_rows: int = 0
@dataclass
class CampaignCompletedEvent(BaseCampaignEvent):
"""Event sent when a campaign completes."""
type: str = CampaignEventType.CAMPAIGN_COMPLETED
total_rows: int = 0
processed_rows: int = 0
failed_rows: int = 0
duration_seconds: Optional[float] = None
@dataclass
class CampaignFailedEvent(BaseCampaignEvent):
"""Event sent when a campaign fails."""
type: str = CampaignEventType.CAMPAIGN_FAILED
error: str = ""
processed_rows: int = 0
failed_rows: int = 0
@dataclass
class RetryNeededEvent(BaseCampaignEvent):
"""Event sent when a call needs retry."""
type: str = CampaignEventType.RETRY_NEEDED
workflow_run_id: int = 0
queued_run_id: int = 0
reason: str = "" # RetryReason value
metadata: Optional[Dict[str, Any]] = None
def __post_init__(self):
super().__post_init__()
if self.metadata is None:
self.metadata = {}
@dataclass
class RetryScheduledEvent(BaseCampaignEvent):
"""Event sent when a retry is scheduled."""
type: str = CampaignEventType.RETRY_SCHEDULED
queued_run_id: int = 0
retry_run_id: int = 0
retry_count: int = 0
scheduled_for: str = "" # ISO timestamp
reason: str = "" # RetryReason value
@dataclass
class RetryFailedEvent(BaseCampaignEvent):
"""Event sent when max retries reached."""
type: str = CampaignEventType.RETRY_FAILED
queued_run_id: int = 0
retry_count: int = 0
last_reason: str = "" # RetryReason value
def parse_campaign_event(data: str) -> Any:
"""Parse a campaign event message."""
try:
parsed = json.loads(data)
event_type = parsed.get("type")
# Map event types to their classes
event_class_map = {
CampaignEventType.BATCH_COMPLETED: BatchCompletedEvent,
CampaignEventType.BATCH_FAILED: BatchFailedEvent,
CampaignEventType.SYNC_STARTED: SyncStartedEvent,
CampaignEventType.SYNC_COMPLETED: SyncCompletedEvent,
CampaignEventType.SYNC_FAILED: SyncFailedEvent,
CampaignEventType.CAMPAIGN_STARTED: CampaignStartedEvent,
CampaignEventType.CAMPAIGN_PAUSED: CampaignPausedEvent,
CampaignEventType.CAMPAIGN_RESUMED: CampaignResumedEvent,
CampaignEventType.CAMPAIGN_COMPLETED: CampaignCompletedEvent,
CampaignEventType.CAMPAIGN_FAILED: CampaignFailedEvent,
CampaignEventType.RETRY_NEEDED: RetryNeededEvent,
CampaignEventType.RETRY_SCHEDULED: RetryScheduledEvent,
CampaignEventType.RETRY_FAILED: RetryFailedEvent,
}
event_class = event_class_map.get(event_type)
if event_class:
return event_class(**parsed)
# Unknown event type
from loguru import logger
logger.warning(f"Unknown campaign event type: {event_type}")
return None
except Exception as e:
from loguru import logger
logger.error(f"Failed to parse campaign event: {e}, data: {data}")
return None

View file

@ -0,0 +1,121 @@
"""Campaign event publisher for orchestrator communication.
Handles publishing of campaign events to Redis pub/sub channels.
"""
from typing import Dict, Optional
import redis.asyncio as aioredis
from loguru import logger
from api.constants import REDIS_URL
from api.enums import RedisChannel
from api.services.campaign.campaign_event_protocol import (
BatchCompletedEvent,
CampaignCompletedEvent,
RetryNeededEvent,
SyncCompletedEvent,
)
class CampaignEventPublisher:
"""Helper class for publishing campaign events."""
def __init__(self, redis_client):
self.redis = redis_client
async def publish_batch_completed(
self,
campaign_id: int,
processed_count: int,
failed_count: int = 0,
batch_size: int = 0,
metadata: Optional[Dict] = None,
):
"""Publish batch completed event."""
event = BatchCompletedEvent(
campaign_id=campaign_id,
processed_count=processed_count,
failed_count=failed_count,
batch_size=batch_size,
metadata=metadata,
)
await self.redis.publish(RedisChannel.CAMPAIGN_EVENTS.value, event.to_json())
async def publish_sync_completed(
self,
campaign_id: int,
total_rows: int,
source_type: str = "",
source_id: str = "",
metadata: Optional[Dict] = None,
):
"""Publish sync completed event."""
event = SyncCompletedEvent(
campaign_id=campaign_id,
total_rows=total_rows,
source_type=source_type,
source_id=source_id,
metadata=metadata,
)
await self.redis.publish(RedisChannel.CAMPAIGN_EVENTS.value, event.to_json())
async def publish_retry_needed(
self,
workflow_run_id: int,
reason: str,
campaign_id: Optional[int] = None,
queued_run_id: Optional[int] = None,
metadata: Optional[Dict] = None,
):
"""Publish retry needed event."""
event = RetryNeededEvent(
campaign_id=campaign_id or 0,
workflow_run_id=workflow_run_id,
queued_run_id=queued_run_id or 0,
reason=reason,
metadata=metadata or {},
)
await self.redis.publish(RedisChannel.CAMPAIGN_EVENTS.value, event.to_json())
logger.info(
f"Published retry event for workflow_run {workflow_run_id}, "
f"reason: {reason}, campaign: {campaign_id}"
)
async def publish_campaign_completed(
self,
campaign_id: int,
total_rows: int,
processed_rows: int,
failed_rows: int,
duration_seconds: Optional[float] = None,
):
"""Publish campaign completed event."""
event = CampaignCompletedEvent(
campaign_id=campaign_id,
total_rows=total_rows,
processed_rows=processed_rows,
failed_rows=failed_rows,
duration_seconds=duration_seconds,
)
await self.redis.publish(RedisChannel.CAMPAIGN_EVENTS.value, event.to_json())
# Global publisher instance with lazy Redis connection
async def get_campaign_event_publisher() -> CampaignEventPublisher:
"""Get or create the campaign event publisher."""
global _campaign_publisher
global _campaign_redis_client
if "_campaign_publisher" not in globals():
_campaign_redis_client = await aioredis.from_url(
REDIS_URL, decode_responses=True
)
_campaign_publisher = CampaignEventPublisher(_campaign_redis_client)
return _campaign_publisher

View file

@ -0,0 +1,563 @@
"""Campaign Orchestrator Service.
This service ensures continuous campaign processing by listening to events
and scheduling batches immediately upon completion. It also monitors campaigns
for final completion after 1 hour of inactivity and handles retry events.
"""
from api.logging_config import setup_logging
logging_queue_listener = setup_logging()
import asyncio
import signal
from datetime import UTC, datetime, timedelta
from typing import Dict
import redis.asyncio as aioredis
from loguru import logger
from api.constants import REDIS_URL
from api.db import db_client
from api.db.models import CampaignModel, QueuedRunModel
from api.enums import RedisChannel
from api.services.campaign.campaign_event_protocol import (
CampaignCompletedEvent,
CampaignEventType,
RetryNeededEvent,
parse_campaign_event,
)
from api.tasks.arq import enqueue_job
from api.tasks.function_names import FunctionNames
class CampaignOrchestrator:
"""Orchestrates campaign processing, retry handling, and completion detection."""
def __init__(self, redis_client: aioredis.Redis):
self.redis = redis_client
self.completion_check_interval = 60 # 1 minute
self.completion_timeout = 3600 # 1 hour
self._processing_locks: Dict[int, datetime] = {} # prevent duplicate scheduling
self._last_activity: Dict[
int, datetime
] = {} # track last activity per campaign
self._batch_in_progress: Dict[
int, datetime
] = {} # track batches that have been scheduled but not completed
self._running = False
self._pubsub = None
async def run(self):
"""Main service with two concurrent tasks."""
self._running = True
logger.info("Campaign Orchestrator starting...")
try:
# Task 1: Listen for events and react immediately
event_task = asyncio.create_task(self._listen_for_events())
# Task 2: Periodically check for stale campaigns
completion_task = asyncio.create_task(self._monitor_completion())
# Wait for both tasks
await asyncio.gather(event_task, completion_task)
except asyncio.CancelledError:
logger.info("Campaign Orchestrator cancelled")
raise
except Exception as e:
logger.error(f"Campaign Orchestrator error: {e}")
raise
finally:
await self.shutdown()
async def _listen_for_events(self):
"""Listen for campaign events and react immediately."""
self._pubsub = self.redis.pubsub()
await self._pubsub.subscribe(RedisChannel.CAMPAIGN_EVENTS.value)
logger.info(f"Subscribed to {RedisChannel.CAMPAIGN_EVENTS.value} channel")
async for message in self._pubsub.listen():
if not self._running:
break
if message["type"] == "message":
try:
event = parse_campaign_event(message["data"])
if event:
await self._handle_event(event)
else:
logger.error(
f"Failed to parse campaign event: {message['data']}"
)
except Exception as e:
logger.error(f"Error handling campaign event: {e}")
async def _handle_event(self, event):
"""Handle campaign events including retry events."""
# Handle RetryNeededEvent
if isinstance(event, RetryNeededEvent):
await self._handle_retry_event(event)
return
# All events should have campaign_id
if not hasattr(event, "campaign_id") or not event.campaign_id:
logger.warning(f"Event missing campaign_id: {type(event).__name__}")
return
campaign_id = event.campaign_id
event_type = event.type
logger.debug(f"campaign_id: {campaign_id} - Received event: {event_type}")
if event_type == CampaignEventType.BATCH_COMPLETED:
# Clear the batch in progress flag
if campaign_id in self._batch_in_progress:
del self._batch_in_progress[campaign_id]
logger.debug(
f"campaign_id: {campaign_id} - Batch completed, cleared in-progress flag"
)
# Immediately schedule next batch
await self._schedule_next_batch(campaign_id)
self._last_activity[campaign_id] = datetime.now(UTC)
elif event_type == CampaignEventType.SYNC_COMPLETED:
# Start processing after sync
logger.info(
f"campaign_id: {campaign_id} - Sync completed, starting processing"
)
await self._schedule_next_batch(campaign_id)
self._last_activity[campaign_id] = datetime.now(UTC)
async def _handle_retry_event(self, event: RetryNeededEvent):
"""Process retry event and schedule if eligible (from campaign_retry_manager)."""
# Check retry eligibility
campaign_id = event.campaign_id
if not campaign_id:
logger.debug("Skipping non-campaign retry event")
return
# Get campaign configuration
campaign = await db_client.get_campaign_by_id(campaign_id)
if not campaign:
logger.error(f"campaign_id: {campaign_id} - Campaign not found")
return
retry_config = campaign.retry_config or {}
if not retry_config.get("enabled", True):
logger.info(f"campaign_id: {campaign_id} - Retry disabled")
return
# Check if this reason should be retried
reason = event.reason
if reason == "busy" and not retry_config.get("retry_on_busy", True):
logger.info(f"campaign_id: {campaign_id} - Skipping retry for busy signal")
return
if reason == "no_answer" and not retry_config.get("retry_on_no_answer", True):
logger.info(f"campaign_id: {campaign_id} - Skipping retry for no-answer")
return
if reason == "voicemail" and not retry_config.get("retry_on_voicemail", True):
logger.info(f"campaign_id: {campaign_id} - Skipping retry for voicemail")
return
# Get the original queued run
queued_run = await db_client.get_queued_run_by_id(event.queued_run_id)
if not queued_run:
logger.error(
f"campaign_id: {campaign_id} - Queued run {event.queued_run_id} not found"
)
return
max_retries = retry_config.get("max_retries", 1)
if queued_run.retry_count >= max_retries:
await self._mark_final_failure(queued_run, reason)
logger.info(
f"campaign_id: {campaign_id} - Max retries ({max_retries}) reached for queued run {queued_run.id}"
)
return
# Create scheduled retry entry
retry_delay = retry_config.get("retry_delay_seconds", 120)
await self._schedule_retry(queued_run, reason, retry_delay)
# Update last activity
self._last_activity[campaign_id] = datetime.now(UTC)
async def _schedule_retry(
self, original_run: QueuedRunModel, reason: str, delay_seconds: int
):
"""Create a new queued run for retry."""
campaign_id = original_run.campaign_id
# Create retry context
retry_context = {
**original_run.context_variables,
"is_retry": True,
"retry_attempt": original_run.retry_count + 1,
"retry_reason": reason,
}
logger.debug(
f"campaign_id: {campaign_id} - Scheduling retry for {reason} in {delay_seconds}s, "
f"retry attempt {original_run.retry_count + 1}"
)
# Create retry entry with unique source_uuid
retry_run = await db_client.create_queued_run(
campaign_id=campaign_id,
source_uuid=f"{original_run.source_uuid}_retry_{original_run.retry_count + 1}",
context_variables=retry_context,
state="queued",
retry_count=original_run.retry_count + 1,
parent_queued_run_id=original_run.id,
scheduled_for=datetime.now(UTC) + timedelta(seconds=delay_seconds),
retry_reason=reason,
)
logger.info(
f"campaign_id: {campaign_id} - Scheduled retry {retry_run.id} for {reason} in {delay_seconds}s, "
f"retry attempt {retry_run.retry_count}"
)
async def _mark_final_failure(self, queued_run: QueuedRunModel, reason: str):
"""Mark a queued run as finally failed after max retries."""
campaign_id = queued_run.campaign_id
# Update the campaign's failed_rows counter
campaign = await db_client.get_campaign_by_id(campaign_id)
if campaign:
await db_client.update_campaign(
campaign_id=campaign_id, failed_rows=campaign.failed_rows + 1
)
logger.info(
f"campaign_id: {campaign_id} - Queued run {queued_run.id} finally failed after max retries, "
f"last reason: {reason}"
)
async def _schedule_next_batch(self, campaign_id: int):
"""Schedule next batch immediately if work available."""
# Prevent duplicate scheduling with in-memory lock
if campaign_id in self._processing_locks:
lock_time = self._processing_locks[campaign_id]
if (datetime.now(UTC) - lock_time).total_seconds() < 5:
logger.debug(
f"campaign_id: {campaign_id} - Batch already scheduled recently"
)
return
# Set lock
self._processing_locks[campaign_id] = datetime.now(UTC)
try:
# Check campaign status
campaign = await db_client.get_campaign_by_id(campaign_id)
if not campaign:
logger.error(f"campaign_id: {campaign_id} - Campaign not found")
return
if campaign.state not in ["running", "syncing"]:
logger.info(
f"campaign_id: {campaign_id} - Campaign not in running state: {campaign.state}"
)
return
# Check for available work (queued runs + due retries)
has_work = await self._has_pending_work(campaign_id)
if has_work:
# Schedule batch immediately
await enqueue_job(
FunctionNames.PROCESS_CAMPAIGN_BATCH,
campaign_id,
10, # batch_size
)
logger.info(f"campaign_id: {campaign_id} - Scheduled next batch")
# Set batch in progress flag
self._batch_in_progress[campaign_id] = datetime.now(UTC)
# Update database
await db_client.update_campaign(
campaign_id=campaign_id,
last_batch_scheduled_at=datetime.now(UTC),
last_activity_at=datetime.now(UTC),
)
else:
logger.info(
f"campaign_id: {campaign_id} - No pending work to process, "
f"campaign may complete or wait for retries"
)
except Exception as e:
logger.error(f"campaign_id: {campaign_id} - Error scheduling batch: {e}")
finally:
# Release lock after a short delay
asyncio.create_task(self._release_lock_after_delay(campaign_id, 5))
async def _release_lock_after_delay(self, campaign_id: int, delay: int):
"""Release processing lock after delay."""
await asyncio.sleep(delay)
if campaign_id in self._processing_locks:
del self._processing_locks[campaign_id]
logger.debug(f"campaign_id: {campaign_id} - Released processing lock")
async def _monitor_completion(self):
"""Periodically check for campaigns that should be marked complete."""
while self._running:
try:
await self._check_stale_campaigns()
except Exception as e:
logger.error(f"Completion monitoring failed: {e}")
await asyncio.sleep(self.completion_check_interval)
async def _check_stale_campaigns(self):
"""Check all running campaigns for completion or orphaned work."""
logger.debug("Checking for stale campaigns...")
campaigns = await db_client.get_campaigns_by_status(statuses=["running"])
for campaign in campaigns:
try:
campaign_id = campaign.id
# Check if batch is stuck (initiated > 5 minutes ago but no completion)
if campaign_id in self._batch_in_progress:
batch_start_time = self._batch_in_progress[campaign_id]
time_since_batch_start = (
datetime.now(UTC) - batch_start_time
).total_seconds()
if time_since_batch_start > 300: # 5 minutes
logger.warning(
f"campaign_id: {campaign_id} - Batch stuck for {time_since_batch_start:.0f}s, "
f"clearing flag and checking for more work"
)
del self._batch_in_progress[campaign_id]
# Check if there's work to be done
if await self._has_pending_work(campaign_id):
logger.info(
f"campaign_id: {campaign_id} - Found pending work after stuck batch, "
f"scheduling new batch"
)
await self._schedule_next_batch(campaign_id)
continue
# Check for orphaned work (e.g., newly created retries with no batch in progress)
if campaign_id not in self._batch_in_progress:
has_work = await self._has_pending_work(campaign_id)
if has_work:
logger.info(
f"campaign_id: {campaign_id} - Found orphaned work (likely new retries), "
f"scheduling batch to process"
)
await self._schedule_next_batch(campaign_id)
continue
# Check if campaign should be marked complete
if await self._should_mark_complete(campaign):
await self._complete_campaign(campaign)
except Exception as e:
logger.error(
f"campaign_id: {campaign.id} - Completion check failed: {e}"
)
async def _should_mark_complete(self, campaign: CampaignModel) -> bool:
"""Check if campaign has no activity for 1 hour."""
campaign_id = campaign.id
# Don't mark complete if batch is in progress
if campaign_id in self._batch_in_progress:
logger.debug(
f"campaign_id: {campaign_id} - Batch in progress, not marking complete"
)
return False
# Check for any pending work
has_work = await self._has_pending_work(campaign_id)
if has_work:
return False
# Check in-memory last activity
last_activity = self._last_activity.get(campaign_id)
if not last_activity:
# Fall back to database
last_activity = campaign.last_activity_at
if not last_activity:
# No activity recorded, use last batch scheduled time
last_activity = campaign.last_batch_scheduled_at
if not last_activity:
# No activity at all, use started_at
last_activity = campaign.started_at
if last_activity:
time_since = datetime.now(UTC) - last_activity
if time_since.total_seconds() < self.completion_timeout:
return False
logger.info(
f"campaign_id: {campaign_id} - No activity for {self.completion_timeout}s, "
f"marking complete"
)
return True
async def _has_pending_work(self, campaign_id: int) -> bool:
"""Check if campaign has any work to do."""
# Check queued runs
queued_count = await db_client.get_queued_runs_count(
campaign_id=campaign_id, states=["queued"]
)
if queued_count > 0:
logger.debug(f"campaign_id: {campaign_id} - Has {queued_count} queued runs")
return True
# Check scheduled retries that are due
scheduled_count = await db_client.get_scheduled_runs_count(
campaign_id=campaign_id, scheduled_before=datetime.now(UTC)
)
if scheduled_count > 0:
logger.debug(
f"campaign_id: {campaign_id} - Has {scheduled_count} scheduled retries due"
)
return True
return False
async def _complete_campaign(self, campaign: CampaignModel):
"""Mark campaign as complete."""
campaign_id = campaign.id
try:
# Double-check no pending work
if await self._has_pending_work(campaign_id):
logger.info(
f"campaign_id: {campaign_id} - Found pending work, not completing"
)
return
# Update campaign status
await db_client.update_campaign(
campaign_id=campaign_id,
state="completed",
completed_at=datetime.now(UTC),
)
logger.info(f"campaign_id: {campaign_id} - Campaign marked as completed")
# Publish completion event using typed event
completion_event = CampaignCompletedEvent(
campaign_id=campaign_id,
total_rows=campaign.total_rows or 0,
processed_rows=campaign.processed_rows,
failed_rows=campaign.failed_rows,
)
# Calculate duration if started_at is available
if campaign.started_at:
duration = (datetime.now(UTC) - campaign.started_at).total_seconds()
completion_event.duration_seconds = duration
await self.redis.publish(
RedisChannel.CAMPAIGN_EVENTS.value, completion_event.to_json()
)
# Clean up in-memory state
if campaign_id in self._last_activity:
del self._last_activity[campaign_id]
if campaign_id in self._processing_locks:
del self._processing_locks[campaign_id]
if campaign_id in self._batch_in_progress:
del self._batch_in_progress[campaign_id]
except Exception as e:
logger.error(
f"campaign_id: {campaign_id} - Failed to complete campaign: {e}"
)
async def shutdown(self):
"""Clean shutdown of the orchestrator."""
logger.info("Campaign Orchestrator shutting down...")
self._running = False
if self._pubsub:
try:
await self._pubsub.unsubscribe(RedisChannel.CAMPAIGN_EVENTS.value)
await self._pubsub.close()
except Exception as e:
logger.error(f"Error closing pubsub: {e}")
logger.info("Campaign Orchestrator shutdown complete")
async def main():
"""Main entry point for Campaign Orchestrator service."""
# Setup Redis connection
redis = await aioredis.from_url(REDIS_URL, decode_responses=True)
# Create and run orchestrator
orchestrator = CampaignOrchestrator(redis)
# Create a shutdown event for clean coordination
shutdown_event = asyncio.Event()
# Setup signal handlers
loop = asyncio.get_event_loop()
def signal_handler(signum):
logger.info(f"Received shutdown signal {signum}")
shutdown_event.set()
for sig in (signal.SIGTERM, signal.SIGINT):
loop.add_signal_handler(sig, lambda s=sig: signal_handler(s))
# Run orchestrator with shutdown monitoring
orchestrator_task = asyncio.create_task(orchestrator.run())
shutdown_task = asyncio.create_task(shutdown_event.wait())
try:
# Wait for either orchestrator to complete or shutdown signal
done, _ = await asyncio.wait(
[orchestrator_task, shutdown_task], return_when=asyncio.FIRST_COMPLETED
)
# If shutdown was triggered, stop the orchestrator
if shutdown_task in done:
logger.info("Shutdown signal received, stopping orchestrator...")
orchestrator._running = False
# Wait for orchestrator to finish gracefully
try:
await asyncio.wait_for(orchestrator_task, timeout=5.0)
except asyncio.TimeoutError:
logger.warning("Orchestrator shutdown timeout, cancelling...")
orchestrator_task.cancel()
try:
await orchestrator_task
except asyncio.CancelledError:
pass
except KeyboardInterrupt:
logger.info("Keyboard interrupt received")
finally:
# Ensure clean shutdown
await orchestrator.shutdown()
await redis.aclose()
logger.info("Campaign Orchestrator service stopped")
if __name__ == "__main__":
asyncio.run(main())

View file

@ -0,0 +1,256 @@
import time
import uuid
from typing import Optional
import redis.asyncio as aioredis
from loguru import logger
from api.constants import REDIS_URL
class RateLimiter:
"""Sliding window rate limiter to enforce strict per-second limits and concurrent call limits"""
def __init__(self):
self.redis_client: Optional[aioredis.Redis] = None
self.stale_call_timeout = 1800 # 30 minutes in seconds
async def _get_redis(self) -> aioredis.Redis:
"""Get or create Redis connection"""
if self.redis_client is None:
self.redis_client = await aioredis.from_url(
REDIS_URL, decode_responses=True
)
return self.redis_client
async def acquire_token(self, organization_id: int, rate_limit: int = 1) -> bool:
"""
Enforces strict rate limit: max N calls per rolling second window
Returns True if allowed, False if rate limited
"""
redis_client = await self._get_redis()
key = f"rate_limit:{organization_id}"
now = time.time()
window_start = now - 1.0 # 1 second sliding window
# Lua script for atomic sliding window operation
lua_script = """
local key = KEYS[1]
local now = tonumber(ARGV[1])
local window_start = tonumber(ARGV[2])
local max_requests = tonumber(ARGV[3])
-- Remove timestamps older than window
redis.call('ZREMRANGEBYSCORE', key, 0, window_start)
-- Count requests in current window
local current_requests = redis.call('ZCARD', key)
if current_requests < max_requests then
-- Add current timestamp
redis.call('ZADD', key, now, now)
redis.call('EXPIRE', key, 2) -- Expire after 2 seconds
return 1
else
return 0
end
"""
try:
result = await redis_client.eval(
lua_script, 1, key, now, window_start, rate_limit
)
return bool(result)
except Exception as e:
logger.error(f"Rate limiter error: {e}")
# On error, be conservative and deny
return False
async def get_next_available_slot(
self, organization_id: int, rate_limit: int = 1
) -> float:
"""
Returns seconds until next available slot
Useful for implementing retry with backoff
"""
redis_client = await self._get_redis()
key = f"rate_limit:{organization_id}"
try:
# Get oldest timestamp in current window
oldest = await redis_client.zrange(key, 0, 0, withscores=True)
if not oldest:
return 0.0 # Can call immediately
oldest_time = oldest[0][1]
next_available = oldest_time + 1.0 # 1 second after oldest
wait_time = max(0, next_available - time.time())
return wait_time
except Exception as e:
logger.error(f"Rate limiter get_next_available_slot error: {e}")
return 1.0 # Default wait time on error
async def try_acquire_concurrent_slot(
self, organization_id: int, max_concurrent: int = 20
) -> Optional[str]:
"""
Try to acquire a concurrent call slot.
Returns a unique slot_id if successful, None if limit reached.
"""
redis_client = await self._get_redis()
concurrent_key = f"concurrent_calls:{organization_id}"
now = time.time()
stale_cutoff = now - self.stale_call_timeout
# Lua script for atomic operation
lua_script = """
local key = KEYS[1]
local now = tonumber(ARGV[1])
local max_concurrent = tonumber(ARGV[2])
local stale_cutoff = tonumber(ARGV[3])
local slot_id = ARGV[4]
-- Remove stale entries (older than 30 minutes)
redis.call('ZREMRANGEBYSCORE', key, 0, stale_cutoff)
-- Get current count
local current_count = redis.call('ZCARD', key)
if current_count < max_concurrent then
-- Add new slot
redis.call('ZADD', key, now, slot_id)
redis.call('EXPIRE', key, 3600) -- Expire after 1 hour
return slot_id
else
return nil
end
"""
# Generate unique slot ID (timestamp + random component)
slot_id = f"{int(now * 1000)}_{uuid.uuid4().hex[:8]}"
try:
result = await redis_client.eval(
lua_script,
1,
concurrent_key,
now,
max_concurrent,
stale_cutoff,
slot_id,
)
return result
except Exception as e:
logger.error(f"Concurrent limiter error: {e}")
return None
async def release_concurrent_slot(self, organization_id: int, slot_id: str) -> bool:
"""
Release a concurrent call slot.
Returns True if slot was released, False otherwise.
"""
if not slot_id:
return False
redis_client = await self._get_redis()
concurrent_key = f"concurrent_calls:{organization_id}"
try:
removed = await redis_client.zrem(concurrent_key, slot_id)
if removed:
logger.debug(
f"Released concurrent slot {slot_id} for org {organization_id}"
)
return bool(removed)
except Exception as e:
logger.error(f"Error releasing concurrent slot: {e}")
return False
async def get_concurrent_count(self, organization_id: int) -> int:
"""
Get current number of active concurrent calls for an organization.
Automatically cleans up stale entries.
"""
redis_client = await self._get_redis()
concurrent_key = f"concurrent_calls:{organization_id}"
try:
# Clean up stale entries first
stale_cutoff = time.time() - self.stale_call_timeout
await redis_client.zremrangebyscore(concurrent_key, 0, stale_cutoff)
# Get current count
count = await redis_client.zcard(concurrent_key)
return count
except Exception as e:
logger.error(f"Error getting concurrent count: {e}")
return 0
async def store_workflow_slot_mapping(
self, workflow_run_id: int, organization_id: int, slot_id: str
) -> bool:
"""
Store the mapping between workflow_run_id and its concurrent slot.
Used for cleanup when calls complete.
"""
redis_client = await self._get_redis()
mapping_key = f"workflow_slot_mapping:{workflow_run_id}"
try:
# Store as a hash with TTL
await redis_client.hset(
mapping_key, mapping={"org_id": organization_id, "slot_id": slot_id}
)
# Set expiry to match stale timeout
await redis_client.expire(mapping_key, self.stale_call_timeout)
return True
except Exception as e:
logger.error(f"Error storing workflow slot mapping: {e}")
return False
async def get_workflow_slot_mapping(
self, workflow_run_id: int
) -> Optional[tuple[int, str]]:
"""
Get the concurrent slot mapping for a workflow run.
Returns (organization_id, slot_id) tuple or None if not found.
"""
redis_client = await self._get_redis()
mapping_key = f"workflow_slot_mapping:{workflow_run_id}"
try:
mapping = await redis_client.hgetall(mapping_key)
if mapping and "org_id" in mapping and "slot_id" in mapping:
return (int(mapping["org_id"]), mapping["slot_id"])
return None
except Exception as e:
logger.error(f"Error getting workflow slot mapping: {e}")
return None
async def delete_workflow_slot_mapping(self, workflow_run_id: int) -> bool:
"""
Delete the workflow slot mapping after releasing the slot.
"""
redis_client = await self._get_redis()
mapping_key = f"workflow_slot_mapping:{workflow_run_id}"
try:
deleted = await redis_client.delete(mapping_key)
return bool(deleted)
except Exception as e:
logger.error(f"Error deleting workflow slot mapping: {e}")
return False
async def close(self):
"""Close Redis connection"""
if self.redis_client:
await self.redis_client.close()
self.redis_client = None
# Global rate limiter instance
rate_limiter = RateLimiter()

View file

@ -0,0 +1,122 @@
from datetime import UTC, datetime
from typing import Any, Dict
from loguru import logger
from api.db import db_client
from api.tasks.arq import enqueue_job
from api.tasks.function_names import FunctionNames
class CampaignRunnerService:
"""Orchestrates campaign execution"""
async def start_campaign(self, campaign_id: int) -> None:
"""Entry point - updates state to 'syncing' and enqueues sync task"""
# Get campaign
campaign = await db_client.get_campaign_by_id(campaign_id)
if not campaign:
raise ValueError(f"Campaign {campaign_id} not found")
if campaign.state != "created":
raise ValueError(
f"Campaign must be in 'created' state to start, current state: {campaign.state}"
)
# Update campaign state to syncing
await db_client.update_campaign(
campaign_id=campaign_id,
state="syncing",
started_at=datetime.now(UTC),
source_sync_status="in_progress",
)
# Enqueue the sync task
await enqueue_job(FunctionNames.SYNC_CAMPAIGN_SOURCE, campaign_id)
logger.info(f"Campaign {campaign_id} started, syncing source data")
async def pause_campaign(self, campaign_id: int) -> None:
"""Pauses active campaign processing"""
campaign = await db_client.get_campaign_by_id(campaign_id)
if not campaign:
raise ValueError(f"Campaign {campaign_id} not found")
if campaign.state not in ["running", "syncing"]:
raise ValueError(
f"Campaign must be in 'running' or 'syncing' state to pause, current state: {campaign.state}"
)
# Update state to paused
await db_client.update_campaign(campaign_id=campaign_id, state="paused")
logger.info(f"Campaign {campaign_id} paused")
async def resume_campaign(self, campaign_id: int) -> None:
"""Resumes paused campaign"""
campaign = await db_client.get_campaign_by_id(campaign_id)
if not campaign:
raise ValueError(f"Campaign {campaign_id} not found")
if campaign.state != "paused":
raise ValueError(
f"Campaign must be in 'paused' state to resume, current state: {campaign.state}"
)
# Update state to running
await db_client.update_campaign(campaign_id=campaign_id, state="running")
# Enqueue process batch task to continue processing
await enqueue_job(FunctionNames.PROCESS_CAMPAIGN_BATCH, campaign_id)
logger.info(f"Campaign {campaign_id} resumed")
async def get_campaign_status(self, campaign_id: int) -> Dict[str, Any]:
"""Returns detailed campaign status"""
campaign = await db_client.get_campaign_by_id(campaign_id)
if not campaign:
raise ValueError(f"Campaign {campaign_id} not found")
# Count failed calls from workflow runs
failed_calls = await self._count_failed_campaign_calls(campaign_id)
return {
"campaign_id": campaign_id,
"state": campaign.state,
"total_rows": campaign.total_rows or 0,
"processed_rows": campaign.processed_rows,
"failed_calls": failed_calls,
"progress_percentage": (
(campaign.processed_rows / campaign.total_rows * 100)
if campaign.total_rows and campaign.total_rows > 0
else 0
),
"source_sync": {
"status": campaign.source_sync_status,
"last_synced_at": campaign.source_last_synced_at,
"error": campaign.source_sync_error,
},
"rate_limit": campaign.rate_limit_per_second,
"started_at": campaign.started_at,
"completed_at": campaign.completed_at,
}
async def _count_failed_campaign_calls(self, campaign_id: int) -> int:
"""Count failed calls by examining workflow_run Twilio callbacks"""
# Get all workflow runs for this campaign
workflow_runs = await db_client.get_workflow_runs_by_campaign(campaign_id)
failed_count = 0
for run in workflow_runs:
callbacks = run.logs.get("twilio_status_callbacks", [])
if callbacks:
# Check final status
final_status = callbacks[-1].get("status", "").lower()
if final_status in ["failed", "busy", "no-answer"]:
failed_count += 1
return failed_count
# Global instance
campaign_runner_service = CampaignRunnerService()

View file

@ -0,0 +1,49 @@
from abc import ABC, abstractmethod
from typing import Any, Dict
from loguru import logger
class CampaignSourceSyncService(ABC):
"""Base class for campaign data source synchronization"""
@abstractmethod
async def sync_source_data(self, campaign_id: int) -> int:
"""
Fetches data from source and creates queued_runs
Each record gets a unique source_uuid based on source type
Returns: number of records synced
"""
pass
@abstractmethod
async def validate_source_schema(self, source_config: Dict[str, Any]) -> bool:
"""Validates required fields exist in source"""
pass
async def get_source_credentials(
self, organization_id: int, source_type: str
) -> Dict[str, Any]:
"""Gets OAuth tokens or API credentials via Nango"""
# This would be implemented to work with Nango service
# For now, returning placeholder
logger.info(
f"Getting credentials for org {organization_id}, source {source_type}"
)
return {}
def get_sync_service(source_type: str) -> CampaignSourceSyncService:
"""Returns appropriate sync service based on source type"""
from .sources.google_sheets import GoogleSheetsSyncService
services = {
"google-sheet": GoogleSheetsSyncService,
# Add more as needed: "hubspot": HubSpotSyncService,
}
service_class = services.get(source_type)
if not service_class:
raise ValueError(f"Unknown source type: {source_type}")
return service_class()

View file

@ -0,0 +1,5 @@
"""Campaign source sync services"""
from .google_sheets import GoogleSheetsSyncService
__all__ = ["GoogleSheetsSyncService"]

View file

@ -0,0 +1,180 @@
import re
from typing import Any, Dict, List
import httpx
from loguru import logger
from api.db import db_client
from api.services.campaign.source_sync import CampaignSourceSyncService
from api.services.integrations.nango import NangoService
class GoogleSheetsSyncService(CampaignSourceSyncService):
"""Implementation for Google Sheets synchronization"""
def __init__(self):
self.nango_service = NangoService()
self.sheets_api_base = "https://sheets.googleapis.com/v4/spreadsheets"
async def sync_source_data(self, campaign_id: int) -> int:
"""
Fetches data from Google Sheets and creates queued_runs
"""
# Get campaign
campaign = await db_client.get_campaign_by_id(campaign_id)
if not campaign:
raise ValueError(f"Campaign {campaign_id} not found")
# 1. Get Google Sheets integration for the organization
integrations = await db_client.get_integrations_by_organization_id(
campaign.organization_id
)
integration = None
for intg in integrations:
if intg.provider == "google-sheet" and intg.is_active:
integration = intg
break
if not integration:
raise ValueError("Google Sheets integration not found or inactive")
# 2. Get OAuth token via Nango using the integration_id (which is the Nango connection ID)
token_data = await self.nango_service.get_access_token(
connection_id=integration.integration_id, provider_config_key="google-sheet"
)
access_token = token_data["credentials"]["access_token"]
# 3. Extract sheet ID from URL
sheet_id = self._extract_sheet_id(campaign.source_id)
# 4. Get sheet metadata (to find data range)
metadata = await self._get_sheet_metadata(sheet_id, access_token)
if not metadata.get("sheets"):
raise ValueError("No sheets found in the spreadsheet")
sheet_name = metadata["sheets"][0]["properties"]["title"]
# 5. Fetch all data from sheet
sheet_data = await self._fetch_sheet_data(
sheet_id,
f"{sheet_name}!A:Z", # Get all columns A-Z
access_token,
)
# 6. Convert to queued_runs
if not sheet_data or len(sheet_data) < 2:
logger.warning(f"No data found in sheet for campaign {campaign_id}")
return 0
headers = sheet_data[0] # First row is headers
rows = sheet_data[1:] # Rest is data
queued_runs = []
for idx, row_values in enumerate(rows, 1):
# Pad row to match headers length
padded_row = row_values + [""] * (len(headers) - len(row_values))
# Create context variables dict
context_vars = dict(zip(headers, padded_row))
# Skip if no phone number
if not context_vars.get("phone_number"):
logger.debug(f"Skipping row {idx}: no phone_number")
continue
# Generate unique source UUID
source_uuid = f"sheet_{sheet_id}_row_{idx}"
queued_runs.append(
{
"campaign_id": campaign_id,
"source_uuid": source_uuid,
"context_variables": context_vars,
"state": "queued",
}
)
# 7. Bulk insert
if queued_runs:
await db_client.bulk_create_queued_runs(queued_runs)
logger.info(
f"Created {len(queued_runs)} queued runs for campaign {campaign_id}"
)
# 8. Update campaign total_rows
await db_client.update_campaign(
campaign_id=campaign_id,
total_rows=len(queued_runs),
source_sync_status="completed",
)
return len(queued_runs)
async def _fetch_sheet_data(
self, sheet_id: str, range: str, access_token: str
) -> List[List[str]]:
"""Fetch data from Google Sheets API"""
url = f"{self.sheets_api_base}/{sheet_id}/values/{range}"
headers = {"Authorization": f"Bearer {access_token}"}
async with httpx.AsyncClient() as client:
response = await client.get(url, headers=headers)
response.raise_for_status()
data = response.json()
return data.get("values", [])
async def _get_sheet_metadata(
self, sheet_id: str, access_token: str
) -> Dict[str, Any]:
"""Get sheet metadata including sheet names"""
url = f"{self.sheets_api_base}/{sheet_id}"
headers = {"Authorization": f"Bearer {access_token}"}
logger.debug(f"Fetching sheet metadata from URL: {url}")
logger.debug(f"Using sheet_id: {sheet_id}")
async with httpx.AsyncClient() as client:
try:
response = await client.get(url, headers=headers)
response.raise_for_status()
return response.json()
except httpx.HTTPStatusError as e:
logger.error(f"HTTP error {e.response.status_code} for URL: {url}")
logger.error(f"Response body: {e.response.text}")
raise
except Exception as e:
logger.error(f"Error fetching sheet metadata: {e}")
raise
def _extract_sheet_id(self, sheet_url: str) -> str:
"""
Extract sheet ID from various Google Sheets URL formats:
- https://docs.google.com/spreadsheets/d/{id}/edit
- https://docs.google.com/spreadsheets/d/{id}/edit#gid=0
"""
pattern = r"/spreadsheets/d/([a-zA-Z0-9-_]+)"
match = re.search(pattern, sheet_url)
if match:
return match.group(1)
raise ValueError(f"Invalid Google Sheets URL: {sheet_url}")
async def validate_source_schema(self, source_config: Dict[str, Any]) -> bool:
"""Validate that required columns exist"""
required_columns = ["phone_number", "first_name", "last_name"]
# Fetch just the header row
sheet_id = self._extract_sheet_id(source_config["source_id"])
access_token = source_config["access_token"]
headers = await self._fetch_sheet_data(
sheet_id,
"A1:Z1", # Just first row
access_token,
)
if not headers:
return False
header_row = headers[0]
return all(col in header_row for col in required_columns)

View file

View file

@ -0,0 +1,153 @@
from typing import Dict, Optional, TypedDict
import openai
from deepgram import (
DeepgramClient,
LiveOptions,
)
from groq import Groq
# try:
# from pyneuphonic import Neuphonic
# except ImportError:
# Neuphonic = None
from api.schemas.user_configuration import (
UserConfiguration,
)
from api.services.configuration.registry import ServiceConfig, ServiceProviders
class APIKeyStatus(TypedDict):
model: str
message: str
class APIKeyStatusResponse(TypedDict):
status: list[APIKeyStatus]
class UserConfigurationValidator:
def __init__(self):
self._provider_api_key_validity_status: Dict[str, bool] = {}
self._validator_map = {
ServiceProviders.OPENAI.value: self._check_openai_api_key,
ServiceProviders.DEEPGRAM.value: self._check_deepgram_api_key,
ServiceProviders.GROQ.value: self._check_groq_api_key,
ServiceProviders.ELEVENLABS.value: self._validate_elevenlabs_api_key,
ServiceProviders.GOOGLE.value: self._check_google_api_key,
ServiceProviders.AZURE.value: self._check_azure_api_key,
ServiceProviders.CARTESIA.value: self._check_cartesia_api_key,
ServiceProviders.DOGRAH.value: self._check_dograh_api_key,
}
async def validate(self, configuration: UserConfiguration) -> APIKeyStatusResponse:
status_list = []
status_list.extend(self._validate_service(configuration.llm, "llm"))
status_list.extend(self._validate_service(configuration.stt, "stt"))
status_list.extend(self._validate_service(configuration.tts, "tts"))
if status_list:
raise ValueError(status_list)
return {"status": [{"model": "all", "message": "ok"}]}
def _validate_service(
self, service_config: Optional[ServiceConfig], service_name: str
) -> list[APIKeyStatus]:
"""Validate a service configuration and return any error statuses."""
if not service_config:
return [{"model": service_name, "message": "API key is missing"}]
provider = service_config.provider
api_key = service_config.api_key
if not self._check_api_key(provider, api_key):
return [{"model": service_name, "message": f"Invalid {provider} API key"}]
return []
def _check_api_key(self, provider: str, api_key: str) -> bool:
"""Check if an API key for a provider is valid."""
validator = self._validator_map.get(provider)
if not validator:
return False
return validator(provider, api_key)
def _check_openai_api_key(self, model: str, api_key: str) -> bool:
if model in self._provider_api_key_validity_status:
return self._provider_api_key_validity_status[model]
client = openai.OpenAI(api_key=api_key)
try:
client.models.list()
self._provider_api_key_validity_status[model] = True
except openai.AuthenticationError:
self._provider_api_key_validity_status[model] = False
return self._provider_api_key_validity_status[model]
def _check_deepgram_api_key(self, model: str, api_key: str) -> bool:
if model in self._provider_api_key_validity_status:
return self._provider_api_key_validity_status[model]
deepgram = DeepgramClient(api_key)
dg_connection = deepgram.listen.websocket.v("1")
try:
options = LiveOptions(
model="nova-2",
language="en-US",
smart_format=True,
)
connected = dg_connection.start(options)
self._provider_api_key_validity_status[model] = connected
finally:
dg_connection.finish()
return self._provider_api_key_validity_status[model]
def _check_groq_api_key(self, model: str, api_key: str) -> bool:
if model in self._provider_api_key_validity_status:
return self._provider_api_key_validity_status[model]
client = Groq(api_key=api_key)
try:
client.models.list()
self._provider_api_key_validity_status[model] = True
except Exception:
self._provider_api_key_validity_status[model] = False
return self._provider_api_key_validity_status[model]
def _validate_elevenlabs_api_key(self, model: str, api_key: str) -> bool:
return True
def _check_google_api_key(self, model: str, api_key: str) -> bool:
return True
def _check_azure_api_key(self, model: str, api_key: str) -> bool:
return True
def _check_cartesia_api_key(self, model: str, api_key: str) -> bool:
return True
def _check_dograh_api_key(self, model: str, api_key: str) -> bool:
return True
# def _check_neuphonic_api_key(self, model: str, api_key: str) -> bool:
# if not Neuphonic:
# self._provider_api_key_validity_status[model] = False
# return self._provider_api_key_validity_status[model]
# if model in self._provider_api_key_validity_status:
# return self._provider_api_key_validity_status[model]
# client = Neuphonic(api_key=api_key)
# try:
# response = client.voices.list() # get's all available voices
# voices = response.data["voices"]
# self._provider_api_key_validity_status[model] = True
# except Exception:
# self._provider_api_key_validity_status[model] = False
# return self._provider_api_key_validity_status[model]

View file

@ -0,0 +1,34 @@
from __future__ import annotations
"""Utilities for building default service configurations for a new user.
The defaults follow the same provider choices exposed by `/user/configurations/defaults`.
Values for `api_key` are pulled from environment variables named *{PROVIDER}_API_KEY*.
If an environment variable is missing, that particular provider configuration is
left as ``None``.
"""
from api.services.configuration.registry import (
DeepgramSTTConfiguration,
ElevenlabsTTSConfiguration,
OpenAILLMService,
ServiceProviders,
)
# Mapping of service to (provider enum, configuration class)
_DEFAULTS = {
"llm": (ServiceProviders.OPENAI, OpenAILLMService),
"tts": (ServiceProviders.ELEVENLABS, ElevenlabsTTSConfiguration),
"stt": (ServiceProviders.DEEPGRAM, DeepgramSTTConfiguration),
}
# Public mapping of service name -> default provider
DEFAULT_SERVICE_PROVIDERS = {
field: provider for field, (provider, _) in _DEFAULTS.items()
}
__all__ = [
"DEFAULT_SERVICE_PROVIDERS",
]

View file

@ -0,0 +1,69 @@
from __future__ import annotations
"""Utilities for masking API keys before they are sent to the client.
The rules are simple:
1. Only expose the last *visible* characters (default 4) of a key.
2. Incoming masked keys are considered a placeholder if they equal the mask of
the already-stored key, we treat them as *unchanged* and keep the real value
in storage.
"""
from typing import Any, Dict, Optional
from api.schemas.user_configuration import UserConfiguration
from api.services.configuration.registry import ServiceConfig
VISIBLE_CHARS = 4 # number of trailing characters to reveal
MASK_CHAR = "*"
def mask_key(real_key: str, visible: int = VISIBLE_CHARS) -> str:
"""Return a masked representation of *real_key*.
Example:
>>> mask_key("sk-1234567890abcdef")
'****************cdef'
"""
if real_key is None:
return ""
if visible <= 0 or visible >= len(real_key):
# mask entire key or nothing to mask edge-cases
return MASK_CHAR * len(real_key)
masked_part = MASK_CHAR * (len(real_key) - visible)
return f"{masked_part}{real_key[-visible:]}"
def is_mask_of(masked: str, real_key: str) -> bool:
"""Return *True* if *masked* equals the mask of *real_key* under the current rules."""
return mask_key(real_key) == masked
# ---------------------------------------------------------------------------
# High-level helpers for UserConfiguration objects
# ---------------------------------------------------------------------------
def _mask_service(service_cfg: Optional[ServiceConfig]) -> Optional[Dict[str, Any]]:
if service_cfg is None:
return None
# Work on a dict copy so we don't mutate original models
data = service_cfg.model_dump()
if "api_key" in data and data["api_key"]:
data["api_key"] = mask_key(data["api_key"])
return data
def mask_user_config(config: UserConfiguration) -> Dict[str, Any]:
"""Return a JSON-serialisable dict of *config* with every api_key masked."""
return {
"llm": _mask_service(config.llm),
"tts": _mask_service(config.tts),
"stt": _mask_service(config.stt),
"test_phone_number": config.test_phone_number,
"timezone": config.timezone,
}

View file

@ -0,0 +1,75 @@
from __future__ import annotations
"""Helpers for merging incoming user-configuration updates with what is already
stored, while honouring masked API keys.
"""
from typing import Dict
from api.schemas.user_configuration import UserConfiguration
from api.services.configuration.masking import is_mask_of
SERVICE_FIELDS = ("llm", "tts", "stt")
def merge_user_configurations(
existing: UserConfiguration, incoming_partial: Dict[str, dict]
) -> UserConfiguration:
"""Merge *incoming_partial* onto *existing* and return a new UserConfiguration.
*incoming_partial* is the body of the PUT request (already `model_dump()`ed or
extracted via Pydantic `model_dump`).
Rules:
1. If a service block is absent in the request, keep the existing one.
2. If provider unchanged and the api_key field is either missing or equal to
the masked placeholder, preserve the existing real key.
3. If provider changes, the incoming api_key is used verbatim (validation
will fail later if it is missing).
4. Non-service top-level fields (e.g. `test_phone_number`) are overwritten
when supplied.
"""
merged = existing.model_dump(exclude_none=True)
def _merge_service_block(service_name: str):
incoming_cfg = incoming_partial.get(service_name)
if incoming_cfg is None:
return # nothing to do
old_cfg = merged.get(service_name, {})
provider_changed = (
old_cfg.get("provider") is not None
and incoming_cfg.get("provider") is not None
and incoming_cfg.get("provider") != old_cfg.get("provider")
)
incoming_api_key = incoming_cfg.get("api_key")
if not provider_changed:
# conditional preservation of api_key
if incoming_api_key is not None:
if (
old_cfg
and "api_key" in old_cfg
and is_mask_of(incoming_api_key, old_cfg["api_key"])
):
incoming_cfg["api_key"] = old_cfg["api_key"]
else:
if "api_key" in old_cfg:
incoming_cfg["api_key"] = old_cfg["api_key"]
merged[service_name] = incoming_cfg
for service in SERVICE_FIELDS:
_merge_service_block(service)
# other simple fields
if "test_phone_number" in incoming_partial:
merged["test_phone_number"] = incoming_partial["test_phone_number"]
if "timezone" in incoming_partial:
merged["timezone"] = incoming_partial["timezone"]
return UserConfiguration.model_validate(merged)

View file

@ -0,0 +1,356 @@
from enum import Enum, auto
from typing import Annotated, Dict, Literal, Type, TypeVar, Union
from pydantic import BaseModel, Field, computed_field
class ServiceType(Enum):
LLM = auto()
TTS = auto()
STT = auto()
class ServiceProviders(str, Enum):
OPENAI = "openai"
DEEPGRAM = "deepgram"
GROQ = "groq"
CARTESIA = "cartesia"
# NEUPHONIC = "neuphonic"
ELEVENLABS = "elevenlabs"
GOOGLE = "google"
AZURE = "azure"
DOGRAH = "dograh"
class BaseServiceConfiguration(BaseModel):
provider: Literal[
ServiceProviders.OPENAI,
ServiceProviders.DEEPGRAM,
ServiceProviders.GROQ,
ServiceProviders.ELEVENLABS,
ServiceProviders.GOOGLE,
ServiceProviders.AZURE,
ServiceProviders.DOGRAH,
]
api_key: str
class BaseLLMConfiguration(BaseServiceConfiguration):
model: str
class BaseTTSConfiguration(BaseServiceConfiguration):
model: str
class BaseSTTConfiguration(BaseServiceConfiguration):
model: str
# Unified registry for all service types
REGISTRY: Dict[ServiceType, Dict[str, Type[BaseServiceConfiguration]]] = {
ServiceType.LLM: {},
ServiceType.TTS: {},
ServiceType.STT: {},
}
T = TypeVar("T", bound=BaseServiceConfiguration)
def register_service(service_type: ServiceType):
"""Generic decorator for registering service configurations"""
def decorator(cls: Type[T]) -> Type[T]:
# Get provider from class attributes or field defaults
provider = getattr(cls, "provider", None)
if provider is None:
# Try to get from model fields
provider = cls.model_fields.get("provider", None)
if provider is not None:
provider = provider.default
if provider is None:
raise ValueError(f"Provider not specified for {cls.__name__}")
REGISTRY[service_type][provider] = cls
return cls
return decorator
# Convenience decorators
def register_llm(cls: Type[BaseLLMConfiguration]):
return register_service(ServiceType.LLM)(cls)
def register_tts(cls: Type[BaseTTSConfiguration]):
return register_service(ServiceType.TTS)(cls)
def register_stt(cls: Type[BaseSTTConfiguration]):
return register_service(ServiceType.STT)(cls)
###################################################### LLM ########################################################################
class OpenAIModel(str, Enum):
GPT3_5_TURBO = "gpt-3.5-turbo"
GPT4_1 = "gpt-4.1"
GPT4_1_MINI = "gpt-4.1-mini"
GPT4_1_NANO = "gpt-4.1-nano"
GPT5 = "gpt-5"
GPT5_MINI = "gpt-5-mini"
GPT5_NANO = "gpt-5-nano"
@register_llm
class OpenAILLMService(BaseLLMConfiguration):
provider: Literal[ServiceProviders.OPENAI] = ServiceProviders.OPENAI
model: OpenAIModel = OpenAIModel.GPT4_1
api_key: str
class GoogleModel(str, Enum):
GEMINI_2_0_FLASH = "gemini-2.0-flash"
GEMINI_2_0_FLASH_LITE = "gemini-2.0-flash-lite"
GEMINI_2_5_FLASH = "gemini-2.5-flash"
GEMINI_2_5_FLASH_LITE = "gemini-2.5-flash-lite"
@register_llm
class GoogleLLMService(BaseLLMConfiguration):
provider: Literal[ServiceProviders.GOOGLE] = ServiceProviders.GOOGLE
model: GoogleModel = GoogleModel.GEMINI_2_0_FLASH
api_key: str
class GroqModel(str, Enum):
LLAMA_3_3_70B = "llama-3.3-70b-versatile"
DEEPSEEK_R1_DISTILL_LLAMA_70B = "deepseek-r1-distill-llama-70b"
QUEN_QWQ_32B = "qwen-qwq-32b"
LLAMA_4_SCOUT_17B_16E_INSTRUCT = "meta-llama/llama-4-scout-17b-16e-instruct"
LLAMA_4_MAVERICK_17B_128E_INSTRUCT = "meta-llama/llama-4-maverick-17b-128e-instruct"
GEMMA2_9B_IT = "gemma2-9b-it"
LLAMA_3_1_8B_INSTANT = "llama-3.1-8b-instant"
OPENAI_GPT_OSS_120B = "openai/gpt-oss-120b"
@register_llm
class GroqLLMService(BaseLLMConfiguration):
provider: Literal[ServiceProviders.GROQ] = ServiceProviders.GROQ
model: GroqModel = GroqModel.LLAMA_3_3_70B
api_key: str
class AzureModel(str, Enum):
GPT4_1_MINI = "gpt-4.1-mini"
@register_llm
class AzureLLMService(BaseLLMConfiguration):
provider: Literal[ServiceProviders.AZURE] = ServiceProviders.AZURE
model: AzureModel = AzureModel.GPT4_1_MINI
api_key: str
endpoint: str
# Dograh LLM Service
class DograhLLMModel(str, Enum):
DEFAULT = "default"
@register_llm
class DograhLLMService(BaseLLMConfiguration):
provider: Literal[ServiceProviders.DOGRAH] = ServiceProviders.DOGRAH
model: DograhLLMModel = DograhLLMModel.DEFAULT
api_key: str
LLMConfig = Annotated[
Union[
OpenAILLMService,
GroqLLMService,
GoogleLLMService,
AzureLLMService,
DograhLLMService,
],
Field(discriminator="provider"),
]
###################################################### TTS ########################################################################
class DeepgramVoice(str, Enum):
HELENA = "aura-2-helena-en"
THALIA = "aura-2-thalia-en"
@register_tts
class DeepgramTTSConfiguration(BaseServiceConfiguration):
provider: Literal[ServiceProviders.DEEPGRAM] = ServiceProviders.DEEPGRAM
voice: DeepgramVoice = DeepgramVoice.HELENA
api_key: str
@computed_field
@property
def model(self) -> str:
# Deepgram model's name is inferred using the voice name.
# It can either contain aura-2 or aura-1
if "aura-2" in self.voice:
return "aura-2"
elif "aura-1" in self.voice:
return "aura-1"
else:
# Default fallback
return "aura-2"
class ElevenlabsVoice(str, Enum):
ALEXANDRA = "Alexandra - 3dzJXoCYueSQiptQ6euE"
AMY = "Amy - oGn4Ha2pe2vSJkmIJgLQ"
ANGELA = "Angela - FUfBrNit0NNZAwb58KWH"
ARIA = "Aria - 9BWtsMINqrJLrRacOk9x"
CHELSEA = "Chelsea - NHRgOEwqx5WZNClv5sat"
CHRISTINA = "Christina - X03mvPuTfprif8QBAVeJ"
CLARA = "Clara - ZIlrSGI4jZqobxRKprJz"
CLYDE = "Clyde - 2EiwWnXFnvU5JabPnv8n"
DAVE = "Dave - CYw3kZ02Hs0563khs1Fj"
DOMI = "Domi - AZnzlk1XvdvUeBnXmlld"
DREW = "Drew - 29vD33N1CtxCmqQRPOHJ"
EVE = "Eve - BZgkqPqms7Kj9ulSkVzn"
FIN = "Fin - D38z5RcWu1voky8WS1ja"
HOPE_BESTIE = "Hope_Bestie - uYXf8XasLslADfZ2MB4u"
HOPE_NATURAL = "Hope_Natural - OYTbf65OHHFELVut7v2H"
JARNATHAN = "Jarnathan - c6SfcYrb2t09NHXiT80T"
JENNA = "Jenna - C2BkQxlGNzBn7WD2bqfR"
JESSICA = "Jessica - cgSgspJ2msm6clMCkdW9"
JUNIPER = "Juniper - aMSt68OGf4xUZAnLpTU8"
LAUREN = "Lauren - 3liN8q8YoeB9Hk6AboKe"
LINA = "Lina - oWjuL7HSoaEJRMDMP3HD"
OLIVIA = "Olivia - 1rviaVF7GGGkTU36HNpz"
PAUL = "Paul - 5Q0t7uMcjvnagumLfvZi"
RACHEL = "Rachel - 21m00Tcm4TlvDq8ikWAM"
ROGER = "Roger - CwhRBWXzGAHq8TQ4Fs17"
SAMI_REAL = "Sami_Real - O4cGUVdAocn0z4EpQ9yF"
SARAH = "Sarah - EXAVITQu4vr4xnSDxMaL"
class ElevenlabsModel(str, Enum):
FLASH_2 = "eleven_flash_v2_5"
@register_tts
class ElevenlabsTTSConfiguration(BaseServiceConfiguration):
provider: Literal[ServiceProviders.ELEVENLABS] = ServiceProviders.ELEVENLABS
voice: ElevenlabsVoice = ElevenlabsVoice.RACHEL
speed: float = Field(default=1.0, ge=0.1, le=2.0, description="Speed of the voice")
model: ElevenlabsModel = ElevenlabsModel.FLASH_2
api_key: str
class OpenAIVoice(str, Enum):
ALLY = "alloy"
class OpenAITTSModel(str, Enum):
GPT_4o_MINI = "gpt-4o-mini-tts"
@register_tts
class OpenAITTSService(BaseTTSConfiguration):
provider: Literal[ServiceProviders.OPENAI] = ServiceProviders.OPENAI
model: OpenAITTSModel = OpenAITTSModel.GPT_4o_MINI
voice: OpenAIVoice = OpenAIVoice.ALLY
api_key: str
# class NeuphonicVoice(str, Enum):
# EMILY = "Emily - fc854436-2dac-4d21-aa69-ae17b54e98eb"
# @register_tts
# class NeuphonicTTSService(BaseTTSConfiguration):
# provider: Literal[ServiceProviders.NEUPHONIC] = ServiceProviders.NEUPHONIC
# voice: NeuphonicVoice = NeuphonicVoice.EMILY
# model: str = "NA"
# api_key: str
# Dograh TTS Service
class DograhVoice(str, Enum):
DEFAULT = "default"
class DograhTTSModel(str, Enum):
DEFAULT = "default"
@register_tts
class DograhTTSService(BaseTTSConfiguration):
provider: Literal[ServiceProviders.DOGRAH] = ServiceProviders.DOGRAH
model: DograhTTSModel = DograhTTSModel.DEFAULT
voice: DograhVoice = DograhVoice.DEFAULT
api_key: str
TTSConfig = Annotated[
Union[
DeepgramTTSConfiguration,
OpenAITTSService,
ElevenlabsTTSConfiguration,
DograhTTSService,
],
Field(discriminator="provider"),
]
###################################################### STT ########################################################################
class DeepgramSTTModel(str, Enum):
NOVA_3_GENERAL = "nova-3-general"
@register_stt
class DeepgramSTTConfiguration(BaseSTTConfiguration):
provider: Literal[ServiceProviders.DEEPGRAM] = ServiceProviders.DEEPGRAM
model: DeepgramSTTModel = DeepgramSTTModel.NOVA_3_GENERAL
api_key: str
@register_stt
class CartesiaSTTConfiguration(BaseSTTConfiguration):
provider: Literal[ServiceProviders.CARTESIA] = ServiceProviders.CARTESIA
api_key: str
class OpenAISTTModel(str, Enum):
GPT_4o_TRANSCRIBE = "gpt-4o-transcribe"
@register_stt
class OpenAISTTConfiguration(BaseSTTConfiguration):
provider: Literal[ServiceProviders.OPENAI] = ServiceProviders.OPENAI
model: OpenAISTTModel = OpenAISTTModel.GPT_4o_TRANSCRIBE
api_key: str
# Dograh STT Service
class DograhSTTModel(str, Enum):
DEFAULT = "default"
@register_stt
class DograhSTTService(BaseSTTConfiguration):
provider: Literal[ServiceProviders.DOGRAH] = ServiceProviders.DOGRAH
model: DograhSTTModel = DograhSTTModel.DEFAULT
api_key: str
STTConfig = Annotated[
Union[DeepgramSTTConfiguration, OpenAISTTConfiguration, DograhSTTService],
Field(discriminator="provider"),
]
ServiceConfig = Annotated[
Union[LLMConfig, TTSConfig, STTConfig], Field(discriminator="provider")
]

View file

@ -0,0 +1,9 @@
from .base import BaseFileSystem
from .minio import MinioFileSystem
from .s3 import S3FileSystem
__all__ = [
"BaseFileSystem",
"S3FileSystem",
"MinioFileSystem",
]

View file

@ -0,0 +1,60 @@
from abc import ABC, abstractmethod
from typing import Any, BinaryIO, Dict, Optional
class BaseFileSystem(ABC):
"""Abstract base class for filesystem operations."""
@abstractmethod
async def acreate_file(self, file_path: str, content: BinaryIO) -> bool:
"""Create a new file with the given content.
Args:
file_path: Path where the file should be created
content: File content as a binary stream
Returns:
bool: True if file was created successfully, False otherwise
"""
pass
@abstractmethod
async def aupload_file(self, local_path: str, destination_path: str) -> bool:
"""Upload a file from local path to destination.
Args:
local_path: Path to the local file
destination_path: Path where the file should be uploaded
Returns:
bool: True if file was uploaded successfully, False otherwise
"""
pass
@abstractmethod
async def aget_signed_url(
self, file_path: str, expiration: int = 3600
) -> Optional[str]:
"""Generate a signed URL for temporary access to a file.
Args:
file_path: Path to the file
expiration: URL expiration time in seconds (default: 1 hour)
Returns:
Optional[str]: Signed URL if successful, None otherwise
"""
pass
@abstractmethod
async def aget_file_metadata(self, file_path: str) -> Optional[Dict[str, Any]]:
"""Get metadata for a file.
Args:
file_path: Path to the file
Returns:
Optional[Dict[str, Any]]: File metadata if successful, None otherwise
Contains: size, created_at, modified_at, etag, etc.
"""
pass

View file

@ -0,0 +1,95 @@
import asyncio
import os
from datetime import datetime
from typing import BinaryIO, Optional
import aiofiles
from .base import BaseFileSystem
class LocalFileSystem(BaseFileSystem):
"""Local filesystem implementation."""
def __init__(self, base_path: str):
"""Initialize local filesystem.
Args:
base_path: Base directory path for file operations
"""
self.base_path = base_path
os.makedirs(base_path, exist_ok=True)
def _get_full_path(self, file_path: str) -> str:
"""Get the full path by joining with base path."""
return os.path.join(self.base_path, file_path)
async def acreate_file(self, file_path: str, content: BinaryIO) -> bool:
try:
full_path = self._get_full_path(file_path)
os.makedirs(os.path.dirname(full_path), exist_ok=True)
async with aiofiles.open(full_path, "wb") as f:
await f.write(await content.read())
return True
except Exception:
return False
async def create_temp_file(self, file_path: str) -> bool:
try:
full_path = self._get_full_path(file_path)
os.makedirs(os.path.dirname(full_path), exist_ok=True)
return True
except Exception:
return False
async def aupload_file(self, local_path: str, destination_path: str) -> bool:
try:
full_dest_path = self._get_full_path(destination_path)
os.makedirs(os.path.dirname(full_dest_path), exist_ok=True)
async with (
aiofiles.open(local_path, "rb") as src,
aiofiles.open(full_dest_path, "wb") as dst,
):
await dst.write(await src.read())
return True
except Exception:
return False
async def aget_signed_url(
self, file_path: str, expiration: int = 3600
) -> Optional[str]:
# For local filesystem, we'll create a temporary symlink with expiration
try:
full_path = self._get_full_path(file_path)
if not os.path.exists(full_path):
return None
# Create a temporary directory for symlinks
temp_dir = os.path.join(self.base_path, ".temp_links")
os.makedirs(temp_dir, exist_ok=True)
# Generate a unique temporary filename
temp_filename = (
f"{datetime.now().timestamp()}_{os.path.basename(file_path)}"
)
temp_path = os.path.join(temp_dir, temp_filename)
# Create symlink
os.symlink(full_path, temp_path)
# Schedule deletion after expiration
async def delete_after_expiration():
await asyncio.sleep(expiration)
try:
os.remove(temp_path)
except Exception:
pass
asyncio.create_task(delete_after_expiration())
return f"/files/{temp_filename}"
except Exception:
return None

View file

@ -0,0 +1,137 @@
import asyncio
from datetime import timedelta
from typing import Any, BinaryIO, Dict, Optional
from minio import Minio
from minio.error import S3Error
from .base import BaseFileSystem
class MinioFileSystem(BaseFileSystem):
"""MinIO implementation of the filesystem interface for OSS users.
Handles both internal (container-to-container) and external (browser) access:
- endpoint: Used for API operations (uploads, downloads from code)
- public_endpoint: Used for generating browser-accessible presigned URLs
Auto-detection logic:
1. If MINIO_PUBLIC_ENDPOINT env var is set, use it (for production/custom domains)
2. If endpoint is "minio:9000" (Docker internal), auto-use "localhost:9000" for browser
3. Otherwise, endpoint works for both (e.g., "localhost:9000" in local non-Docker setup)
"""
def __init__(
self,
endpoint: str = "localhost:9000",
access_key: str = "minioadmin",
secret_key: str = "minioadmin",
bucket_name: str = "voice-audio",
secure: bool = False,
public_endpoint: Optional[str] = None,
):
self.bucket_name = bucket_name
self.endpoint = endpoint
self.public_endpoint = public_endpoint or endpoint
self.secure = secure
self.access_key = access_key
self.secret_key = secret_key
# Client for internal operations (uploads, etc.)
self.client = Minio(
endpoint, access_key=access_key, secret_key=secret_key, secure=secure
)
# Ensure bucket exists (using internal client)
try:
if not self.client.bucket_exists(self.bucket_name):
self.client.make_bucket(self.bucket_name)
except Exception as e:
# Bucket might already exist or we might be in a restricted environment
pass
async def acreate_file(self, file_path: str, content: BinaryIO) -> bool:
try:
data = await content.read()
def _put():
self.client.put_object(
self.bucket_name,
file_path,
data=bytes(data),
length=len(data),
)
await asyncio.to_thread(_put)
return True
except S3Error:
return False
async def aupload_file(self, local_path: str, destination_path: str) -> bool:
try:
def _fput():
self.client.fput_object(self.bucket_name, destination_path, local_path)
await asyncio.to_thread(_fput)
return True
except S3Error:
return False
async def aget_signed_url(
self, file_path: str, expiration: int = 3600, force_inline: bool = False
) -> Optional[str]:
try:
def _presign():
response_headers = None
if force_inline and file_path.endswith(".txt"):
response_headers = {
"response-content-type": "text/plain",
"response-content-disposition": "inline",
}
# Generate URL with the main client
url = self.client.presigned_get_object(
self.bucket_name,
file_path,
expires=timedelta(seconds=expiration),
response_headers=response_headers,
)
# If we have different public endpoint, replace it in the URL
if self.endpoint != self.public_endpoint:
# Simple string replacement since presigned URLs are just strings
# Replace the endpoint in the URL
url = url.replace(
f"://{self.endpoint}/", f"://{self.public_endpoint}/"
)
url = url.replace(
f"Host={self.endpoint}", f"Host={self.public_endpoint}"
)
return url
url = await asyncio.to_thread(_presign)
return url
except S3Error:
return None
async def aget_file_metadata(self, file_path: str) -> Optional[Dict[str, Any]]:
"""Get MinIO object metadata."""
try:
def _stat():
return self.client.stat_object(self.bucket_name, file_path)
stat = await asyncio.to_thread(_stat)
return {
"size": stat.size,
"created_at": stat.last_modified,
"modified_at": stat.last_modified,
"etag": stat.etag.strip('"') if stat.etag else None,
"content_type": stat.content_type,
"storage_class": None, # MinIO doesn't have storage classes like S3
}
except S3Error:
return None

View file

@ -0,0 +1,99 @@
from typing import Any, BinaryIO, Dict, Optional
import aioboto3
from botocore.exceptions import ClientError
from .base import BaseFileSystem
class S3FileSystem(BaseFileSystem):
"""S3 implementation of the filesystem interface."""
def __init__(self, bucket_name: str, region_name: str = "us-east-1"):
"""Initialize S3 filesystem.
Args:
bucket_name: Name of the S3 bucket
region_name: AWS region name
"""
self.bucket_name = bucket_name
self.region_name = region_name
self.session = aioboto3.Session()
async def acreate_file(self, file_path: str, content: BinaryIO) -> bool:
try:
async with self.session.client(
"s3", region_name=self.region_name
) as s3_client:
await s3_client.put_object(
Bucket=self.bucket_name, Key=file_path, Body=await content.read()
)
return True
except ClientError:
return False
async def aupload_file(self, local_path: str, destination_path: str) -> bool:
try:
async with self.session.client(
"s3", region_name=self.region_name
) as s3_client:
await s3_client.upload_file(
local_path, self.bucket_name, destination_path
)
return True
except ClientError:
return False
async def aget_signed_url(
self, file_path: str, expiration: int = 3600, force_inline: bool = False
) -> Optional[str]:
"""Generate a presigned GET url for the given object.
For transcript text files we force the response headers so that the
browser renders the content **inline** instead of triggering a file
download. We do this by asking S3 to override the content type &
disposition on the response.
"""
try:
async with self.session.client(
"s3", region_name=self.region_name
) as s3_client:
params = {"Bucket": self.bucket_name, "Key": file_path}
# Make transcripts viewable inline in the browser when requested
if force_inline and file_path.endswith(".txt"):
params.update(
{
"ResponseContentType": "text/plain",
"ResponseContentDisposition": "inline",
}
)
url = await s3_client.generate_presigned_url(
"get_object",
Params=params,
ExpiresIn=expiration,
)
return url
except ClientError:
return None
async def aget_file_metadata(self, file_path: str) -> Optional[Dict[str, Any]]:
"""Get S3 object metadata."""
try:
async with self.session.client(
"s3", region_name=self.region_name
) as s3_client:
response = await s3_client.head_object(
Bucket=self.bucket_name, Key=file_path
)
return {
"size": response.get("ContentLength"),
"created_at": response.get("LastModified"),
"modified_at": response.get("LastModified"),
"etag": response.get("ETag", "").strip('"'),
"content_type": response.get("ContentType"),
"storage_class": response.get("StorageClass"),
}
except ClientError:
return None

View file

@ -0,0 +1,219 @@
# Gender Prediction Service
An internal service for predicting gender from first names using SSA (Social Security Administration) baby names data with GenderAPI fallback for uncertain predictions.
## Overview
This service provides gender prediction with:
- **Local model** built from 145 years of SSA data (1880-2024)
- **104,819 unique names** with confidence scores
- **Compressed storage** (2.21 MB model file)
- **GenderAPI fallback** for unknown or low-confidence names
## Data Source
The SSA baby names data is already downloaded in the `names/` directory from:
https://catalog.data.gov/dataset/baby-names-from-social-security-card-applications-national-data
## Building the Model
### Prerequisites
- Python 3.11+
- SSA data files in `names/` directory (already included)
### Build Steps
```bash
# Navigate to the gender service directory
cd dograh/api/services/gender/
# Run the model builder
python build_model.py
```
This will:
1. Process all 145 year files (yob1880.txt to yob2024.txt)
2. Aggregate name counts across all years
3. Calculate confidence scores based on gender ratios
4. Generate a compressed `model.txt` file (~2.21 MB)
### Model Output
The builder generates `model.txt` with:
- **Version**: Model version number
- **Metadata**: Build date, statistics, thresholds
- **Names**: Compressed array format `[male_count, female_count, confidence]`
Example output:
```
Model saved to: .../services/gender/model.txt
File size: 2.21 MB
Model statistics:
Total names: 104,819
High confidence names (≥0.85): 1,711
Confidence percentage: 1.6%
```
## Using the Service
### Basic Usage
```python
from services.gender.gender_service import GenderService
# Initialize the service
service = GenderService()
# Predict gender for a single name
result = await service.predict("John")
print(f"Gender: {result.gender}") # "male"
print(f"Confidence: {result.confidence}") # 0.996
print(f"Source: {result.source}") # "model"
# Get salutation for a name
greeting = await service.get_salutation("John")
print(f"Salutation: {greeting}") # "Mr."
greeting = await service.get_salutation("Mary")
print(f"Salutation: {greeting}") # "Ms."
greeting = await service.get_salutation("Unknown")
print(f"Salutation: {greeting}") # "Dear"
# Clean up
await service.close()
```
### Configuration Options
```python
# Custom configuration
service = GenderService(
model_path="custom/path/to/model.txt", # Default: ./model.txt
confidence_threshold=0.85, # Default: 0.85
gender_api_key="your-api-key", # Default: from GENDER_API_KEY env
gender_api_url="https://..." # Default: GenderAPI v2 endpoint
)
```
### Salutation Generation
```python
# Get appropriate salutation based on gender
salutation = await service.get_salutation("John") # "Mr."
salutation = await service.get_salutation("Mary") # "Ms."
salutation = await service.get_salutation("Unknown") # "Dear"
# Custom confidence threshold for salutation
salutation = await service.get_salutation(
"Taylor", # Ambiguous name
confidence_threshold=0.9 # Higher threshold
) # Returns "Dear" due to low confidence
# Salutation logic:
# - "Mr." for male with confidence >= threshold
# - "Ms." for female with confidence >= threshold
# - "Dear" for unknown gender or low confidence
```
### Batch Predictions
```python
# Predict multiple names at once
names = ["Alice", "Bob", "Charlie", "Diana"]
results = await service.batch_predict(names)
for name, result in zip(names, results):
print(f"{name}: {result.gender} ({result.confidence:.2f})")
```
### Response Format
```python
class GenderPrediction:
gender: "male" | "female" | "unknown" # Predicted gender
confidence: float # 0.0 to 1.0
source: "model" | "genderapi" # Prediction source
```
### Service Statistics
```python
# Get service statistics
stats = await service.get_stats()
print(f"Total names: {stats['model']['total_names']:,}")
print(f"High confidence: {stats['model']['high_confidence_names']:,}")
print(f"Cached names in Redis: {stats['cache']['cached_names']}")
print(f"Cache TTL: {stats['cache']['ttl_seconds']} seconds")
print(f"API enabled: {stats['api']['enabled']}")
```
### Cache Management
```python
# Clear Redis cache
await service.clear_cache()
```
## Environment Variables
```bash
# Required: Redis connection URL
export REDIS_URL=redis://localhost:6379
# Optional: Set GenderAPI key for fallback
export GENDERAPI_API_KEY=your-api-key-here
# Optional: Override confidence threshold (default: 0.85)
export CONFIDENCE_THRESHOLD=0.85
```
## How It Works
1. **Name normalization**: Converts to lowercase, strips whitespace
2. **Local model check**: Looks up name in pre-built model
3. **Confidence evaluation**: If confidence ≥ 0.85, returns local prediction
4. **Redis cache check**: Checks Redis for previously fetched API results
5. **API fallback**: For unknown/low-confidence names, calls GenderAPI
6. **Redis caching**: Stores API responses in Redis with 30-day TTL
## Testing
Run the test suite to verify the service:
```bash
python test_service.py
```
This tests:
- High-confidence predictions
- Ambiguous names
- Edge cases (empty strings, special characters)
- International names (with API key)
- Batch predictions
## Model Updates
The model should be rebuilt annually when new SSA data is released:
1. Download new year file (e.g., yob2025.txt) to `names/` directory
2. Run `python build_model.py` to rebuild
3. Test with `python test_service.py`
4. Commit the updated `model.txt`
## Performance
- **Model size**: 2.21 MB (compressed JSON)
- **Load time**: < 100ms
- **Prediction time**: < 1ms (local), < 5ms (Redis cache), < 500ms (API)
- **Memory usage**: ~10 MB for model in memory
- **Cache**: Redis-based with 30-day TTL
- **Scalability**: Shared cache across all service instances
## Limitations
- Based on US SSA data (may not work well for non-US names)
- Historical bias in older data
- Unisex names have lower confidence
- Requires GenderAPI key for comprehensive coverage

View file

View file

@ -0,0 +1,164 @@
"""
Build gender prediction model from SSA baby names data.
Generates a compressed JSON model file.
"""
import json
from collections import defaultdict
from datetime import datetime
from math import log10
from pathlib import Path
from api.services.gender.constants import CONFIDENCE_THRESHOLD
def calculate_confidence(male_count: int, female_count: int) -> float:
"""Calculate confidence score for gender prediction."""
total = male_count + female_count
# Minimum sample size requirement
if total < 100:
return 0.0
# Calculate gender ratio
ratio = max(male_count, female_count) / total
# Apply logarithmic scaling for sample size
# Max confidence at 100,000 occurrences
sample_weight = min(1.0, log10(total) / 5)
# Final confidence
return round(ratio * sample_weight, 4)
def build_model():
"""Build gender prediction model from SSA data."""
# Initialize counters
name_stats = defaultdict(lambda: {"M": 0, "F": 0})
# Get the path to names directory
names_dir = Path(__file__).parent / "names"
if not names_dir.exists():
raise FileNotFoundError(f"Names directory not found: {names_dir}")
file_count = 0
# Process all year files
for year_file in sorted(names_dir.glob("yob*.txt")):
file_count += 1
with open(year_file, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line:
continue
parts = line.split(",")
if len(parts) != 3:
continue
name, gender, count = parts
name = name.lower()
count = int(count)
# Simple aggregation - no year weighting
name_stats[name][gender] += count
print(f"Processed {file_count} year files")
# Build compressed model format
names_data = {}
high_confidence_count = 0
for name, stats in name_stats.items():
male_count = stats["M"]
female_count = stats["F"]
if male_count == 0 and female_count == 0:
continue
# Calculate confidence
confidence = calculate_confidence(male_count, female_count)
# Store as compact array: [male_count, female_count, confidence]
names_data[name] = [male_count, female_count, confidence]
if confidence >= CONFIDENCE_THRESHOLD:
high_confidence_count += 1
# Create final model structure with metadata
model = {
"version": "1.0",
"metadata": {
"confidence_threshold": CONFIDENCE_THRESHOLD,
"total_names": len(names_data),
"high_confidence_names": high_confidence_count,
"build_date": datetime.now().isoformat(),
"source_files": file_count,
},
"names": names_data,
}
return model
def save_model(model, output_path="model.txt"):
"""Save model to compressed JSON file."""
output_file = Path(__file__).parent / output_path
# Write compressed JSON (no indentation for smaller size)
with open(output_file, "w", encoding="utf-8") as f:
json.dump(model, f, separators=(",", ":"))
# Calculate file size
file_size_mb = output_file.stat().st_size / (1024 * 1024)
print(f"\nModel saved to: {output_file}")
print(f"File size: {file_size_mb:.2f} MB")
print(f"\nModel statistics:")
print(f" Total names: {model['metadata']['total_names']:,}")
print(
f" High confidence names (≥{CONFIDENCE_THRESHOLD}): {model['metadata']['high_confidence_names']:,}"
)
print(
f" Confidence percentage: {model['metadata']['high_confidence_names'] / model['metadata']['total_names'] * 100:.1f}%"
)
def test_model(model):
"""Test model with sample names."""
print("\nSample predictions:")
test_names = [
"john",
"mary",
"alex",
"taylor",
"michael",
"sarah",
"jordan",
"casey",
]
for name in test_names:
if name in model["names"]:
male_count, female_count, confidence = model["names"][name]
gender = "male" if male_count > female_count else "female"
print(
f" {name.capitalize():10} -> {gender:6} (confidence: {confidence:.3f}, M:{male_count:,} F:{female_count:,})"
)
else:
print(f" {name.capitalize():10} -> not found in dataset")
if __name__ == "__main__":
print("Building gender prediction model from SSA data...")
print("=" * 50)
try:
model = build_model()
save_model(model)
test_model(model)
print("\n✓ Model build complete!")
except Exception as e:
print(f"\n✗ Error building model: {e}")
raise

View file

@ -0,0 +1,10 @@
"""
Hyperparameters and configuration for gender prediction service.
"""
# Confidence threshold for using local model predictions
CONFIDENCE_THRESHOLD = 0.85
# Redis cache configuration
REDIS_CACHE_TTL = 86400 * 30 # 30 days in seconds
REDIS_KEY_PREFIX = "genderservice:"

View file

@ -0,0 +1,391 @@
"""
Gender prediction service with local model and GenderAPI fallback.
Internal service for use within Dograh platform.
"""
import json
import os
import time
from pathlib import Path
from typing import Literal, Optional
import httpx
import redis.asyncio as aioredis
from loguru import logger
from pydantic import BaseModel, Field
from api.constants import REDIS_URL
from api.services.gender.constants import (
CONFIDENCE_THRESHOLD,
REDIS_CACHE_TTL,
REDIS_KEY_PREFIX,
)
class GenderPrediction(BaseModel):
"""Gender prediction result."""
gender: Literal["male", "female", "unknown"] = Field(
..., description="Predicted gender"
)
confidence: float = Field(..., ge=0, le=1, description="Confidence score (0-1)")
source: Literal["model", "genderapi"] = Field(
..., description="Source of prediction"
)
class GenderService:
"""
Internal service for predicting gender from names.
Uses local SSA-based model with GenderAPI fallback.
"""
def __init__(
self,
model_path: Optional[str] = None,
confidence_threshold: float = CONFIDENCE_THRESHOLD,
gender_api_key: Optional[str] = None,
gender_api_url: str = "https://gender-api.com/v2/gender",
):
"""
Initialize the gender service.
Args:
model_path: Path to the model file (default: ./model.txt)
confidence_threshold: Minimum confidence to use local model
gender_api_key: API key for GenderAPI (falls back to env var)
gender_api_url: GenderAPI endpoint URL
"""
self.confidence_threshold = confidence_threshold
self.gender_api_key = gender_api_key or os.getenv("GENDERAPI_API_KEY")
self.gender_api_url = gender_api_url
# Load model
if model_path is None:
model_path = Path(__file__).parent / "model.txt"
else:
model_path = Path(model_path)
self.model = self._load_model(model_path)
self._http_client = None
self._redis_client: Optional[aioredis.Redis] = None
def _load_model(self, model_path: Path) -> dict:
"""Load the compressed gender prediction model."""
if not model_path.exists():
logger.warning(f"Warning: Model file not found at {model_path}")
return {"metadata": {}, "names": {}}
try:
with open(model_path, "r", encoding="utf-8") as f:
model = json.load(f)
# Validate model structure
if "names" not in model or "metadata" not in model:
raise ValueError("Invalid model format")
logger.debug(
f"Loaded gender prediction model with {model['metadata'].get('total_names', 0):,} names"
)
return model
except Exception as e:
logger.error(f"Error loading gender prediction model: {e}")
return {"metadata": {}, "names": {}}
@property
def http_client(self) -> httpx.AsyncClient:
"""Get or create HTTP client for API calls."""
if self._http_client is None:
self._http_client = httpx.AsyncClient(
timeout=httpx.Timeout(30.0),
limits=httpx.Limits(max_keepalive_connections=5),
)
return self._http_client
async def _get_redis(self) -> aioredis.Redis:
"""Get or create Redis connection."""
if self._redis_client is None:
self._redis_client = await aioredis.from_url(
REDIS_URL, decode_responses=True
)
return self._redis_client
async def predict(
self, first_name: str, last_name: Optional[str] = None
) -> GenderPrediction:
"""
Predict gender for a given name.
Args:
first_name: First name to predict gender for
last_name: Last name (optional, not used in v1.0)
Returns:
GenderPrediction with gender, confidence, and source
"""
if not first_name:
return GenderPrediction(gender="unknown", confidence=0.0, source="model")
# Normalize name for lookup
normalized_name = first_name.lower().strip()
# Step 1: Check local model
if normalized_name in self.model["names"]:
male_count, female_count, confidence = self.model["names"][normalized_name]
# Use local model if confidence meets threshold
if confidence >= self.confidence_threshold:
gender = "male" if male_count > female_count else "female"
logger.debug(
f"GenderService: Local Prediction {first_name} - {gender} with confidence: {confidence}"
)
return GenderPrediction(
gender=gender, confidence=confidence, source="model"
)
else:
logger.debug(
f"GenderService: Low Confidence Local Prediction {first_name} - with confidence: {confidence}"
)
# Step 2: Check Redis cache for previous API responses
try:
redis_client = await self._get_redis()
cache_key = f"{REDIS_KEY_PREFIX}{normalized_name}"
cached_data = await redis_client.get(cache_key)
if cached_data:
cached_result = json.loads(cached_data)
logger.debug(
f"GenderService: Redis Cache Hit {first_name} - {cached_result['gender']} with confidence: {cached_result['confidence']}"
)
return GenderPrediction(**cached_result, source="genderapi")
except Exception as e:
logger.warning(f"Redis cache check failed: {e}")
# Step 3: Fallback to GenderAPI
if self.gender_api_key:
try:
result = await self._call_gender_api(first_name)
# Cache the result in Redis
try:
redis_client = await self._get_redis()
cache_key = f"{REDIS_KEY_PREFIX}{normalized_name}"
cache_data = json.dumps(
{"gender": result.gender, "confidence": result.confidence}
)
await redis_client.setex(cache_key, REDIS_CACHE_TTL, cache_data)
except Exception as e:
logger.warning(f"Failed to cache result in Redis: {e}")
# No need for additional debug log here as _call_gender_api logs with timing
return result
except Exception as e:
# Error already logged in _call_gender_api with timing
pass
# Step 4: Return best guess from model or unknown
if normalized_name in self.model["names"]:
male_count, female_count, confidence = self.model["names"][normalized_name]
gender = "male" if male_count > female_count else "female"
return GenderPrediction(
gender=gender, confidence=confidence, source="model"
)
# Final fallback: unknown
return GenderPrediction(gender="unknown", confidence=0.0, source="model")
async def _call_gender_api(self, first_name: str) -> GenderPrediction:
"""
Call GenderAPI for gender prediction.
Args:
first_name: First name to predict
Returns:
GenderPrediction from API response
"""
headers = {
"Authorization": f"Bearer {self.gender_api_key}",
"Content-Type": "application/json",
}
payload = {"first_name": first_name}
try:
# Track API call timing
start_time = time.perf_counter()
response = await self.http_client.post(
self.gender_api_url, headers=headers, json=payload
)
response.raise_for_status()
# Calculate elapsed time
elapsed_time = (
time.perf_counter() - start_time
) * 1000 # Convert to milliseconds
data = response.json()
# Map GenderAPI response format
gender = data.get("gender", "unknown").lower()
if gender not in ["male", "female"]:
gender = "unknown"
# GenderAPI returns accuracy as probability
confidence = data.get("probability", 0)
# Log the API call with timing
logger.info(
f"GenderAPI call for '{first_name}': {gender} with confidence {confidence:.2f} "
f"(took {elapsed_time:.2f}ms)"
)
return GenderPrediction(
gender=gender, confidence=confidence, source="genderapi"
)
except httpx.HTTPStatusError as e:
# Log error with timing if we got a response
elapsed_time = (
(time.perf_counter() - start_time) * 1000
if "start_time" in locals()
else 0
)
logger.error(
f"GenderAPI HTTP error for '{first_name}': {e.response.status_code} "
f"(took {elapsed_time:.2f}ms)"
)
if e.response.status_code == 401:
raise ValueError("Invalid GenderAPI key")
elif e.response.status_code == 429:
raise ValueError("GenderAPI rate limit exceeded")
else:
raise ValueError(f"GenderAPI HTTP error: {e.response.status_code}")
except httpx.TimeoutException as e:
elapsed_time = (
(time.perf_counter() - start_time) * 1000
if "start_time" in locals()
else 0
)
logger.error(
f"GenderAPI timeout for '{first_name}' after {elapsed_time:.2f}ms"
)
raise ValueError(f"GenderAPI request timed out")
except Exception as e:
elapsed_time = (
(time.perf_counter() - start_time) * 1000
if "start_time" in locals()
else 0
)
logger.error(
f"GenderAPI unexpected error for '{first_name}': {str(e)} "
f"(took {elapsed_time:.2f}ms)"
)
raise
async def get_salutation(
self,
first_name: str,
last_name: Optional[str] = None,
confidence_threshold: Optional[float] = None,
) -> str:
"""
Get appropriate salutation based on gender prediction.
Args:
first_name: First name to predict gender for
last_name: Last name (optional, not used in v1.0)
confidence_threshold: Optional override for confidence threshold
Returns:
"Mr." for male, "Ms." for female, "Dear" for unknown/low confidence
"""
if not first_name:
return "Dear"
# Get gender prediction
prediction = await self.predict(first_name, last_name)
# Return salutation based on gender and confidence
if prediction.gender == "unknown":
return "Dear"
elif prediction.gender == "male":
return "Mr."
else: # female
return "Ms."
async def batch_predict(self, names: list[str]) -> list[GenderPrediction]:
"""
Predict gender for multiple names.
Args:
names: List of first names
Returns:
List of GenderPrediction results
"""
results = []
for name in names:
result = await self.predict(name)
results.append(result)
return results
async def close(self):
"""Close HTTP and Redis clients and cleanup resources."""
if self._http_client:
await self._http_client.aclose()
self._http_client = None
if self._redis_client:
await self._redis_client.close()
self._redis_client = None
async def get_stats(self) -> dict:
"""Get statistics about the service and model."""
metadata = self.model.get("metadata", {})
# Get Redis cache stats
cache_stats = {}
try:
redis_client = await self._get_redis()
# Count keys matching our prefix pattern
keys = await redis_client.keys(f"{REDIS_KEY_PREFIX}*")
cache_stats = {
"cached_names": len(keys),
"cache_type": "redis",
"ttl_seconds": REDIS_CACHE_TTL,
}
except Exception as e:
logger.warning(f"Failed to get Redis stats: {e}")
cache_stats = {"cached_names": 0, "cache_type": "redis", "error": str(e)}
return {
"model": {
"version": self.model.get("version", "unknown"),
"total_names": metadata.get("total_names", 0),
"high_confidence_names": metadata.get("high_confidence_names", 0),
"confidence_threshold": self.confidence_threshold,
"build_date": metadata.get("build_date", "unknown"),
},
"cache": cache_stats,
"api": {"enabled": bool(self.gender_api_key), "url": self.gender_api_url},
}
async def clear_cache(self):
"""Clear the Redis cache for gender predictions."""
try:
redis_client = await self._get_redis()
keys = await redis_client.keys(f"{REDIS_KEY_PREFIX}*")
if keys:
await redis_client.delete(*keys)
logger.info(f"Cleared {len(keys)} entries from Redis cache")
else:
logger.debug("No cache entries to clear")
except Exception as e:
logger.error(f"Failed to clear Redis cache: {e}")

File diff suppressed because one or more lines are too long

View file

@ -0,0 +1,248 @@
"""
Test script for the gender prediction service.
"""
import asyncio
from api.services.gender.gender_service import GenderService
async def test_local_model():
"""Test predictions using local model only."""
print("\n" + "=" * 60)
print("Testing Local Model Predictions")
print("=" * 60)
# Initialize service without API key (local model only)
service = GenderService()
# Test high-confidence names
high_confidence_names = [
"John",
"Mary",
"Michael",
"Sarah",
"Robert",
"Lisa",
"William",
"Jennifer",
"David",
"Patricia",
]
print("\nHigh-confidence predictions (should use local model):")
print("-" * 50)
for name in high_confidence_names:
result = await service.predict(name)
print(
f" {name:12} -> {result.gender:6} (conf: {result.confidence:.3f}, source: {result.source})"
)
# Test ambiguous names
ambiguous_names = [
"Taylor",
"Jordan",
"Casey",
"Alex",
"Morgan",
"Blake",
"Avery",
"Riley",
"Quinn",
"Sage",
]
print("\nAmbiguous names (lower confidence):")
print("-" * 50)
for name in ambiguous_names:
result = await service.predict(name)
status = "" if result.source == "model" and result.confidence >= 0.85 else ""
print(
f" {status} {name:12} -> {result.gender:6} (conf: {result.confidence:.3f}, source: {result.source})"
)
# Test unknown names
unknown_names = ["Xyzabc", "Qwerty", "Abcdef"]
print("\nUnknown names (not in dataset):")
print("-" * 50)
for name in unknown_names:
result = await service.predict(name)
print(
f" {name:12} -> {result.gender:7} (conf: {result.confidence:.3f}, source: {result.source})"
)
# Get service statistics
stats = await service.get_stats()
print("\nService Statistics:")
print("-" * 50)
print(f" Model version: {stats['model']['version']}")
print(f" Total names: {stats['model']['total_names']:,}")
print(f" High confidence: {stats['model']['high_confidence_names']:,}")
print(f" Threshold: {stats['model']['confidence_threshold']}")
print(f" Cache type: {stats['cache'].get('cache_type', 'unknown')}")
print(f" Cached names: {stats['cache'].get('cached_names', 0)}")
print(f" API enabled: {stats['api']['enabled']}")
await service.close()
async def test_with_api():
"""Test with GenderAPI integration (requires API key)."""
print("\n" + "=" * 60)
print("Testing with GenderAPI Integration")
print("=" * 60)
# Check if API key is available
import os
api_key = os.getenv("GENDER_API_KEY")
if not api_key:
print("\n⚠️ No GENDER_API_KEY found in environment")
print(" Skipping API integration tests")
print(" To test API fallback, set: export GENDER_API_KEY=your_key")
return
service = GenderService(gender_api_key=api_key)
# Test names that might need API fallback
test_names = [
"Priya", # Indian name, might not be in SSA data
"Hiroshi", # Japanese name
"Giovanni", # Italian name
"Olga", # Russian name
"Chen", # Chinese name
]
print("\nInternational names (may use API fallback):")
print("-" * 50)
for name in test_names:
result = await service.predict(name)
print(
f" {name:12} -> {result.gender:6} (conf: {result.confidence:.3f}, source: {result.source})"
)
# Test batch prediction
print("\nBatch prediction test:")
print("-" * 50)
batch_names = ["Alice", "Bob", "Charlie", "Diana", "Eve"]
results = await service.batch_predict(batch_names)
for name, result in zip(batch_names, results):
print(f" {name:12} -> {result.gender:6} (conf: {result.confidence:.3f})")
await service.close()
async def test_salutation():
"""Test salutation generation."""
print("\n" + "=" * 60)
print("Testing Salutation Generation")
print("=" * 60)
service = GenderService()
# Test high-confidence names
test_cases = [
("John", "Mr."),
("Mary", "Ms."),
("Michael", "Mr."),
("Sarah", "Ms."),
("Robert", "Mr."),
("Jennifer", "Ms."),
]
print("\nHigh-confidence salutations:")
print("-" * 50)
for name, expected in test_cases:
salutation = await service.get_salutation(name)
status = "" if salutation == expected else ""
print(f" {status} {name:12} -> {salutation:4} (expected: {expected})")
# Test ambiguous/unknown names
ambiguous_cases = [
"Xyzabc", # Unknown name
"Qwerty", # Unknown name
"", # Empty string
" ", # Whitespace
"123", # Numbers
]
print("\nUnknown/ambiguous names (should return 'Dear'):")
print("-" * 50)
for name in ambiguous_cases:
salutation = await service.get_salutation(name)
display_name = f"'{name}'" if name else "(empty)"
status = "" if salutation == "Dear" else ""
print(f" {status} {display_name:12} -> {salutation}")
# Test with custom confidence threshold
print("\nCustom confidence threshold test:")
print("-" * 50)
# Taylor has confidence ~0.744, should be "Dear" with high threshold
salutation_default = await service.get_salutation("Taylor")
salutation_high = await service.get_salutation("Taylor", confidence_threshold=0.9)
print(f" Taylor (default threshold): {salutation_default}")
print(f" Taylor (0.9 threshold): {salutation_high}")
await service.close()
async def test_edge_cases():
"""Test edge cases and error handling."""
print("\n" + "=" * 60)
print("Testing Edge Cases")
print("=" * 60)
service = GenderService()
# Test empty/invalid inputs
edge_cases = [
"", # Empty string
" ", # Whitespace
"123", # Numbers
"John-Paul", # Hyphenated
"Mary Ann", # Space in name
"O'Brien", # Apostrophe
"José", # Accented
]
print("\nEdge case inputs:")
print("-" * 50)
for name in edge_cases:
result = await service.predict(name)
display_name = f"'{name}'" if name else "(empty)"
print(
f" {display_name:12} -> {result.gender:7} (conf: {result.confidence:.3f})"
)
# Test case insensitivity
print("\nCase insensitivity test:")
print("-" * 50)
case_variants = ["john", "JOHN", "John", "JoHn"]
for name in case_variants:
result = await service.predict(name)
print(f" {name:12} -> {result.gender:6} (conf: {result.confidence:.3f})")
await service.close()
async def main():
"""Run all tests."""
print("\n" + "=" * 60)
print("Gender Prediction Service Test Suite")
print("=" * 60)
# Run tests
await test_local_model()
await test_salutation()
await test_edge_cases()
await test_with_api()
print("\n" + "=" * 60)
print("✓ All tests completed!")
print("=" * 60)
if __name__ == "__main__":
asyncio.run(main())

View file

View file

@ -0,0 +1,253 @@
import hashlib
import json
import os
from typing import Any, Dict
import httpx
from fastapi import HTTPException
from loguru import logger
from pydantic import BaseModel
from api.db import db_client
NANGO_ALLOWED_INTEGRATIONS = [
i.strip() for i in os.environ.get("NANGO_ALLOWED_INTEGRATIONS", "slack").split(",")
]
class NangoWebhookRequest(BaseModel):
type: str
connectionId: str
providerConfigKey: str
authMode: str
provider: str
environment: str
operation: str
endUser: dict # Contains endUserId and organizationId
success: bool
class NangoService:
def __init__(self):
self.base_url = "https://api.nango.dev"
self.secret_key = os.getenv("NANGO_API_KEY")
def _verify_webhook_signature(
self, request_body: str, signature: str = None
) -> bool:
"""
Verify the webhook signature using SHA256 hash.
Args:
request_body: The raw request body as string
signature: The signature from request headers (optional for now)
Returns:
True if signature is valid
"""
expected_signature = self.secret_key + request_body
expected_hash = hashlib.sha256(expected_signature.encode("utf-8")).hexdigest()
return expected_hash == signature
async def create_session(
self, user_id: str, organization_id: int
) -> Dict[str, Any]:
"""
Create a Nango session for the given user and organization.
Args:
user_id: The end user ID
organization_id: The organization ID
Returns:
Response from Nango API
"""
if not self.secret_key:
raise ValueError("NANGO_SECRET_KEY environment variable is not set")
headers = {
"Authorization": f"Bearer {self.secret_key}",
"Content-Type": "application/json",
}
payload = {
"end_user": {"id": user_id},
"organization": {"id": str(organization_id)},
"allowed_integrations": NANGO_ALLOWED_INTEGRATIONS,
}
async with httpx.AsyncClient() as client:
response = await client.post(
f"{self.base_url}/connect/sessions", headers=headers, json=payload
)
if response.status_code != 201:
raise httpx.HTTPStatusError(
f"Nango API error: {response.status_code}",
request=response.request,
response=response,
)
return response.json()
async def process_webhook(
self, raw_body: bytes, signature: str = None
) -> Dict[str, str]:
"""
Process incoming Nango webhook request.
Args:
raw_body: The raw request body as bytes
signature: Optional signature from request headers
Returns:
Dict with status and message
"""
# Decode and parse the request body
try:
body_text = raw_body.decode("utf-8")
webhook_json = json.loads(body_text) if body_text else {}
logger.debug(f"received webhook from nango: {webhook_json}")
except json.JSONDecodeError as e:
logger.error(f"JSON decode error: {e} body_text: {body_text}")
raise HTTPException(status_code=400, detail=f"Invalid JSON: {str(e)}")
# Verify webhook signature
if not self._verify_webhook_signature(body_text, signature):
raise HTTPException(status_code=401, detail="Invalid webhook signature")
# Parse webhook data
try:
webhook_data = NangoWebhookRequest(**webhook_json)
except Exception as e:
logger.error(f"Failed to parse webhook data: {e}")
raise HTTPException(
status_code=400, detail=f"Invalid webhook format: {str(e)}"
)
# Extract user and organization IDs from the webhook payload
end_user = webhook_data.endUser
if (
not end_user
or "endUserId" not in end_user
or "organizationId" not in end_user
):
raise HTTPException(
status_code=400, detail="Missing endUser information in webhook payload"
)
user_id = int(end_user["endUserId"])
organization_id = int(end_user["organizationId"])
# Use the connectionId as the integration_id since it's unique per integration
integration_id = webhook_data.connectionId
# Initialize connection_details
connection_details = {}
# Fetch connection details if type is auth and provider is slack
if webhook_data.type == "auth":
connection_details = await self._fetch_connection_details(
integration_id, webhook_data.provider
)
# Create the integration in the database
integration = await db_client.create_integration(
integration_id=integration_id,
organisation_id=organization_id,
provider=webhook_data.provider,
created_by=user_id,
is_active=True,
connection_details=connection_details,
)
return {
"status": "success",
"message": f"Integration created successfully with ID: {integration.id}",
}
async def _fetch_connection_details(
self, connection_id: str, provider_key: str
) -> Dict[str, Any]:
"""
Fetch connection details from Nango API for a given connection ID.
Args:
connection_id: The connection ID from the webhook
Returns:
Connection details as a dictionary
"""
headers = {
"Authorization": f"Bearer {self.secret_key}",
"Content-Type": "application/json",
}
url = f"{self.base_url}/connection/{connection_id}/?provider_config_key={provider_key}"
async with httpx.AsyncClient() as client:
try:
response = await client.get(url, headers=headers)
if response.status_code != 200:
logger.error(
f"Failed to fetch connection details: {response.status_code} - {response.text}"
)
raise httpx.HTTPStatusError(
f"Nango API error while fetching connection: {response.status_code}",
request=response.request,
response=response,
)
connection_details = response.json()
return connection_details
except httpx.HTTPError as e:
logger.error(f"HTTP error while fetching connection details: {e}")
# Return empty dict if API call fails, but log the error
return {}
async def get_access_token(
self, connection_id: str, provider_config_key: str
) -> Dict[str, Any]:
"""
Get the latest access token for a connection from Nango.
Args:
connection_id: The connection ID
provider_config_key: The provider config key (e.g., 'google-sheet')
Returns:
Dict containing access token and other connection details
"""
headers = {
"Authorization": f"Bearer {self.secret_key}",
"Content-Type": "application/json",
}
url = f"{self.base_url}/connection/{connection_id}?provider_config_key={provider_config_key}"
async with httpx.AsyncClient() as client:
try:
response = await client.get(url, headers=headers)
if response.status_code != 200:
logger.error(
f"Failed to get access token: {response.status_code} - {response.text}"
)
raise httpx.HTTPStatusError(
f"Nango API error: {response.status_code}",
request=response.request,
response=response,
)
return response.json()
except httpx.HTTPError as e:
logger.error(f"HTTP error while getting access token: {e}")
raise
# Create a singleton instance
nango_service = NangoService()

View file

@ -0,0 +1,3 @@
from .orchestrator import LoopTalkTestOrchestrator
__all__ = ["LoopTalkTestOrchestrator"]

View file

@ -0,0 +1,220 @@
"""
Audio streaming processor for LoopTalk real-time audio monitoring.
This processor captures audio from both actor and adversary agents and streams
it to connected WebRTC clients for real-time monitoring.
"""
import asyncio
from typing import Dict, Set
from loguru import logger
from pipecat.audio.utils import mix_audio
from pipecat.frames.frames import (
Frame,
InputAudioRawFrame,
OutputAudioRawFrame,
)
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
class LoopTalkAudioStreamer(FrameProcessor):
"""
Processes audio frames from LoopTalk conversations and streams to WebRTC clients.
This processor sits in the pipeline and captures all audio frames, then
forwards them to connected WebRTC clients for real-time monitoring.
"""
def __init__(
self,
test_session_id: str,
role: str, # "actor" or "adversary"
**kwargs,
):
super().__init__(**kwargs)
self._test_session_id = test_session_id
self._role = role
self._listeners: Set[asyncio.Queue] = set()
self._sample_rate = 16000 # Default sample rate
self._num_channels = 1
def add_listener(self, queue: asyncio.Queue):
"""Add a listener queue for streaming audio."""
self._listeners.add(queue)
def remove_listener(self, queue: asyncio.Queue):
"""Remove a listener queue."""
self._listeners.discard(queue)
async def process_frame(self, frame: Frame, direction: FrameDirection):
"""Process audio frames and stream to listeners."""
await super().process_frame(frame, direction)
# Capture both input and output audio
if isinstance(frame, (InputAudioRawFrame, OutputAudioRawFrame)):
# Extract audio data
audio_data = frame.audio
sample_rate = frame.sample_rate
num_channels = frame.num_channels
# Store sample rate for reference
if sample_rate:
self._sample_rate = sample_rate
if num_channels:
self._num_channels = num_channels
# Stream to all listeners
if self._listeners and audio_data:
# Create a packet with metadata
packet = {
"test_session_id": self._test_session_id,
"role": self._role,
"audio": audio_data,
"sample_rate": sample_rate,
"num_channels": num_channels,
"is_input": isinstance(frame, InputAudioRawFrame),
}
# Send to all listeners without blocking
for queue in list(self._listeners):
try:
queue.put_nowait(packet)
except asyncio.QueueFull:
logger.warning(
f"Audio queue full for session {self._test_session_id}"
)
except Exception as e:
logger.error(f"Error streaming audio: {e}")
self._listeners.discard(queue)
elif self._listeners and not audio_data:
logger.warning(
f"Audio streamer {self._role} received frame with no audio data"
)
elif audio_data and not self._listeners:
# This is expected early in the session before WebSocket connects
pass
# Always forward the frame
await self.push_frame(frame, direction)
class LoopTalkAudioMixer:
"""
Mixes audio from actor and adversary streams for combined playback.
This class manages the mixing of two audio streams (actor and adversary)
to create a combined audio stream for monitoring.
"""
def __init__(self, test_session_id: str):
self._test_session_id = test_session_id
self._actor_buffer = bytearray()
self._adversary_buffer = bytearray()
self._listeners: Set[asyncio.Queue] = set()
self._sample_rate = 16000
self._num_channels = 1
self._buffer_size = 8000 # 0.5 seconds at 16kHz
def add_listener(self, queue: asyncio.Queue):
"""Add a listener for mixed audio."""
self._listeners.add(queue)
def remove_listener(self, queue: asyncio.Queue):
"""Remove a listener."""
self._listeners.discard(queue)
async def add_audio(
self, role: str, audio_data: bytes, sample_rate: int, num_channels: int
):
"""Add audio data from actor or adversary."""
if role == "actor":
self._actor_buffer.extend(audio_data)
elif role == "adversary":
self._adversary_buffer.extend(audio_data)
# Update audio parameters
self._sample_rate = sample_rate
self._num_channels = num_channels
# Check if we have enough data to mix
await self._check_and_mix()
async def _check_and_mix(self):
"""Check buffers and mix audio when enough data is available."""
# Mix when we have at least buffer_size in both buffers
while (
len(self._actor_buffer) >= self._buffer_size
and len(self._adversary_buffer) >= self._buffer_size
):
# Extract chunks
actor_chunk = bytes(self._actor_buffer[: self._buffer_size])
adversary_chunk = bytes(self._adversary_buffer[: self._buffer_size])
# Remove from buffers
del self._actor_buffer[: self._buffer_size]
del self._adversary_buffer[: self._buffer_size]
# Mix audio
mixed_audio = mix_audio(actor_chunk, adversary_chunk)
# Stream to listeners
if self._listeners and mixed_audio:
packet = {
"test_session_id": self._test_session_id,
"role": "mixed",
"audio": mixed_audio,
"sample_rate": self._sample_rate,
"num_channels": self._num_channels,
"is_input": False,
}
for queue in list(self._listeners):
try:
queue.put_nowait(packet)
except asyncio.QueueFull:
logger.warning(
f"Mixed audio queue full for session {self._test_session_id}"
)
except Exception as e:
logger.error(f"Error streaming mixed audio: {e}")
self._listeners.discard(queue)
# Global registry for audio streamers and mixers
_audio_streamers: Dict[str, Dict[str, LoopTalkAudioStreamer]] = {}
_audio_mixers: Dict[str, LoopTalkAudioMixer] = {}
def get_or_create_audio_streamer(
test_session_id: str, role: str
) -> LoopTalkAudioStreamer:
"""Get or create an audio streamer for a test session and role."""
if test_session_id not in _audio_streamers:
_audio_streamers[test_session_id] = {}
if role not in _audio_streamers[test_session_id]:
_audio_streamers[test_session_id][role] = LoopTalkAudioStreamer(
test_session_id=test_session_id, role=role
)
return _audio_streamers[test_session_id][role]
def get_or_create_audio_mixer(test_session_id: str) -> LoopTalkAudioMixer:
"""Get or create an audio mixer for a test session."""
if test_session_id not in _audio_mixers:
_audio_mixers[test_session_id] = LoopTalkAudioMixer(test_session_id)
return _audio_mixers[test_session_id]
def cleanup_audio_streamers(test_session_id: str):
"""Clean up audio streamers and mixers for a test session."""
if test_session_id in _audio_streamers:
del _audio_streamers[test_session_id]
if test_session_id in _audio_mixers:
del _audio_mixers[test_session_id]
logger.info(f"Cleaned up audio streamers for test session {test_session_id}")

View file

@ -0,0 +1 @@
"""Core modules for LoopTalk orchestration."""

View file

@ -0,0 +1,167 @@
"""Pipeline building logic for LoopTalk agents."""
from typing import Any, Dict
from loguru import logger
from pipecat.pipeline.pipeline import Pipeline
from pipecat.processors.filters.stt_mute_filter import (
STTMuteConfig,
STTMuteFilter,
STTMuteStrategy,
)
from pipecat.transports import InternalTransport
from api.db.db_client import DBClient
from api.services.looptalk.audio_streamer import get_or_create_audio_streamer
from api.services.pipecat.audio_config import AudioConfig
from api.services.pipecat.pipeline_builder import (
create_pipeline_components,
create_pipeline_task,
)
from api.services.pipecat.pipeline_engine_callbacks_processor import (
PipelineEngineCallbacksProcessor,
)
from api.services.pipecat.service_factory import (
create_llm_service,
create_stt_service,
create_tts_service,
)
from api.services.workflow.dto import ReactFlowDTO
from api.services.workflow.pipecat_engine import PipecatEngine
from api.services.workflow.workflow import WorkflowGraph
class LoopTalkPipelineBuilder:
"""Builds pipelines for LoopTalk agents."""
def __init__(self, db_client: DBClient):
"""Initialize the pipeline builder.
Args:
db_client: Database client for fetching user configurations
"""
self.db_client = db_client
async def create_agent_pipeline(
self,
transport: InternalTransport,
workflow: Any,
test_session_id: int,
agent_id: str,
role: str,
) -> Dict[str, Any]:
"""Create a pipeline for an agent (actor or adversary).
Args:
transport: Internal transport for the agent
workflow: Workflow model from database
test_session_id: ID of the test session
agent_id: Unique identifier for the agent
role: Either "actor" or "adversary"
Returns:
Dictionary containing pipeline task, engine, and components
"""
# Get user configuration from database
user_config = await self.db_client.get_user_configurations(workflow.user_id)
# Create pipeline components
audio_config = AudioConfig(
transport_in_sample_rate=16000,
transport_out_sample_rate=16000,
vad_sample_rate=16000,
pipeline_sample_rate=16000,
)
# Create services
stt = create_stt_service(user_config)
llm = create_llm_service(user_config)
tts = create_tts_service(user_config, audio_config)
logger.debug(f"Created services for {role}: STT={stt}, LLM={llm}, TTS={tts}")
audio_buffer, audio_synchronizer, transcript, context = (
create_pipeline_components(audio_config)
)
context_aggregator = llm.create_context_aggregator(context)
# Get workflow graph
workflow_graph = WorkflowGraph(
ReactFlowDTO.model_validate(workflow.workflow_definition_with_fallback)
)
# Create engine
engine = PipecatEngine(
task=None, # Will be set after creating the task
llm=llm,
context=context,
tts=tts,
workflow=workflow_graph,
call_context_vars={},
audio_buffer=audio_buffer,
workflow_run_id=None, # LoopTalk doesn't have workflow runs
)
# Create STT mute filter
stt_mute_filter = STTMuteFilter(
config=STTMuteConfig(
strategies={STTMuteStrategy.FIRST_SPEECH},
)
)
# Create pipeline engine callback processor
pipeline_engine_callback_processor = PipelineEngineCallbacksProcessor(
max_call_duration_seconds=300,
max_duration_end_task_callback=engine.create_max_duration_callback(),
llm_generated_text_callback=engine.create_llm_generated_text_callback(),
generation_started_callback=engine.create_generation_started_callback(),
)
# Get aggregators
user_context_aggregator = context_aggregator.user()
assistant_context_aggregator = context_aggregator.assistant()
# Register processors with synchronizer for merged audio
audio_synchronizer.register_processors(
audio_buffer.input(), audio_buffer.output()
)
# Get audio streamer for real-time streaming
audio_streamer = get_or_create_audio_streamer(str(test_session_id), role)
# Create pipeline
pipeline = Pipeline(
[
transport.input(),
audio_buffer.input(), # Record input audio
audio_streamer, # Stream audio to connected clients
stt_mute_filter,
stt,
transcript.user(),
user_context_aggregator,
llm,
pipeline_engine_callback_processor,
tts,
transport.output(),
audio_buffer.output(), # Record output audio
transcript.assistant(),
assistant_context_aggregator,
]
)
# Create pipeline task with unique conversation ID for tracing
conversation_id = f"{test_session_id}-{role}-{agent_id}"
task = create_pipeline_task(pipeline, conversation_id, audio_config)
# Set the task on the engine
engine.task = task
return {
"task": task,
"engine": engine,
"audio_buffer": audio_buffer,
"audio_synchronizer": audio_synchronizer,
"transcript": transcript,
"assistant_context_aggregator": assistant_context_aggregator,
"audio_streamer": audio_streamer,
}

View file

@ -0,0 +1,216 @@
"""Recording management for LoopTalk sessions."""
import wave
from pathlib import Path
from typing import Dict, Optional, Tuple
from loguru import logger
from api.enums import StorageBackend
from api.services.storage import storage_fs
class RecordingManager:
"""Manages audio recording and transcript files for LoopTalk sessions."""
def __init__(self, base_dir: Path):
"""Initialize the recording manager.
Args:
base_dir: Base directory for temporary recordings
"""
self.base_dir = base_dir
self.base_dir.mkdir(parents=True, exist_ok=True)
def get_recording_paths(self, test_session_id: int, role: str) -> Dict[str, Path]:
"""Get file paths for recordings.
Args:
test_session_id: ID of the test session
role: Either "actor" or "adversary"
Returns:
Dictionary with paths for audio, transcript, and temp audio files
"""
session_dir = self.base_dir / f"session_{test_session_id}"
session_dir.mkdir(parents=True, exist_ok=True)
return {
"audio": session_dir / f"{role}_audio.wav",
"transcript": session_dir / f"{role}_transcript.txt",
"temp_audio": session_dir / f"{role}_audio_temp.pcm",
}
def convert_pcm_to_wav(
self,
test_session_id: int,
role: str,
sample_rate: int = 16000,
num_channels: int = 1,
) -> Optional[Path]:
"""Convert PCM audio file to WAV format.
Args:
test_session_id: ID of the test session
role: Either "actor" or "adversary"
sample_rate: Sample rate of the audio
num_channels: Number of audio channels
Returns:
Path to the WAV file if successful, None otherwise
"""
paths = self.get_recording_paths(test_session_id, role)
# Check if PCM file exists
if not paths["temp_audio"].exists():
logger.warning(f"No audio recorded for {role} in session {test_session_id}")
return None
try:
# Read PCM data
with open(paths["temp_audio"], "rb") as f:
pcm_data = f.read()
# Write WAV file
with wave.open(str(paths["audio"]), "wb") as wav_file:
wav_file.setnchannels(num_channels)
wav_file.setsampwidth(2) # 16-bit audio
wav_file.setframerate(sample_rate)
wav_file.writeframes(pcm_data)
# Remove temporary PCM file
paths["temp_audio"].unlink()
logger.info(
f"Converted audio to WAV for {role} in session {test_session_id}: {paths['audio']}"
)
return paths["audio"]
except Exception as e:
logger.error(
f"Failed to convert audio to WAV for {role} in session {test_session_id}: {e}"
)
return None
async def upload_recording_to_s3(
self, test_session_id: int, role: str
) -> Tuple[Optional[str], Optional[str]]:
"""Upload recording and transcript to S3.
Args:
test_session_id: ID of the test session
role: Either "actor" or "adversary"
Returns:
Tuple of (audio_url, transcript_url) or (None, None) if failed
"""
paths = self.get_recording_paths(test_session_id, role)
audio_url = None
transcript_url = None
# Import here to avoid circular imports
current_backend = StorageBackend.get_current_backend()
logger.info(
f"LOOPTALK UPLOAD: Using {current_backend.label} (code: {current_backend.code}) for session {test_session_id}, role: {role}"
)
# Upload audio if exists
if paths["audio"].exists():
audio_key = f"looptalk/recordings/{test_session_id}/{role}_audio.wav"
try:
success = await storage_fs.aupload_file(str(paths["audio"]), audio_key)
if success:
audio_url = audio_key
logger.info(
f"Uploaded {role} audio to {current_backend.label}: {audio_key}"
)
else:
logger.error(
f"Failed to upload {role} audio to {current_backend.label}"
)
except Exception as e:
logger.error(
f"Error uploading {role} audio to {current_backend.label}: {e}"
)
# Upload transcript if exists
if paths["transcript"].exists():
transcript_key = (
f"looptalk/transcripts/{test_session_id}/{role}_transcript.txt"
)
try:
success = await storage_fs.aupload_file(
str(paths["transcript"]), transcript_key
)
if success:
transcript_url = transcript_key
logger.info(
f"Uploaded {role} transcript to {current_backend.label}: {transcript_key}"
)
else:
logger.error(
f"Failed to upload {role} transcript to {current_backend.label}"
)
except Exception as e:
logger.error(
f"Error uploading {role} transcript to {current_backend.label}: {e}"
)
return audio_url, transcript_url
def cleanup_session_files(self, test_session_id: int):
"""Clean up local files for a session.
Args:
test_session_id: ID of the test session
"""
session_dir = self.base_dir / f"session_{test_session_id}"
if session_dir.exists():
try:
import shutil
shutil.rmtree(session_dir)
logger.debug(f"Cleaned up local files for session {test_session_id}")
except Exception as e:
logger.error(f"Failed to clean up session files: {e}")
def get_recording_info(self, test_session_id: int) -> Dict[str, any]:
"""Get information about recordings for a test session.
Args:
test_session_id: ID of the test session
Returns:
Dictionary with recording information
"""
session_dir = self.base_dir / f"session_{test_session_id}"
info = {
"test_session_id": test_session_id,
"recording_dir": str(session_dir),
"files": {},
}
for role in ["actor", "adversary"]:
paths = self.get_recording_paths(test_session_id, role)
role_info = {}
# Check audio file
if paths["audio"].exists():
role_info["audio"] = {
"path": str(paths["audio"]),
"size_bytes": paths["audio"].stat().st_size,
}
# Check transcript file
if paths["transcript"].exists():
role_info["transcript"] = {
"path": str(paths["transcript"]),
"size_bytes": paths["transcript"].stat().st_size,
}
if role_info:
info["files"][role] = role_info
return info

View file

@ -0,0 +1,184 @@
"""Session management for LoopTalk test sessions."""
import asyncio
from datetime import UTC, datetime
from typing import Any, Dict, Optional
from loguru import logger
class SessionManager:
"""Manages running LoopTalk test sessions."""
def __init__(self):
"""Initialize the session manager."""
self._running_sessions: Dict[int, Dict[str, Any]] = {}
self._disconnect_handlers: Dict[int, asyncio.Task] = {}
def add_session(self, test_session_id: int, session_info: Dict[str, Any]):
"""Add a new session to the manager.
Args:
test_session_id: ID of the test session
session_info: Dictionary containing session information
"""
self._running_sessions[test_session_id] = session_info
def get_session(self, test_session_id: int) -> Optional[Dict[str, Any]]:
"""Get session information.
Args:
test_session_id: ID of the test session
Returns:
Session information dictionary or None if not found
"""
return self._running_sessions.get(test_session_id)
def remove_session(self, test_session_id: int):
"""Remove a session from the manager.
Args:
test_session_id: ID of the test session
"""
if test_session_id in self._running_sessions:
del self._running_sessions[test_session_id]
# Cancel any disconnect handler for this session
if test_session_id in self._disconnect_handlers:
handler = self._disconnect_handlers.pop(test_session_id)
if not handler.done():
handler.cancel()
def get_active_count(self) -> int:
"""Get the number of currently active sessions."""
return len(self._running_sessions)
def get_active_info(self) -> Dict[str, Any]:
"""Get information about all active sessions."""
return {
"count": len(self._running_sessions),
"sessions": [
{
"test_session_id": session_id,
"conversation_id": info["conversation"].id,
"start_time": info["start_time"],
"duration_seconds": int(
(datetime.now(UTC) - info["start_time"]).total_seconds()
),
}
for session_id, info in self._running_sessions.items()
],
}
async def handle_agent_disconnect(
self, test_session_id: int, disconnected_role: str, stop_callback: callable
):
"""Handle when one agent disconnects.
This will cancel the other agent as well to ensure clean shutdown.
Args:
test_session_id: ID of the test session
disconnected_role: Role that disconnected ("actor" or "adversary")
stop_callback: Callback to stop the session
"""
logger.info(
f"Handling {disconnected_role} disconnect for session {test_session_id}"
)
# Check if we already have a disconnect handler running
if test_session_id in self._disconnect_handlers:
logger.debug(
f"Disconnect handler already running for session {test_session_id}"
)
return
# Create a task to handle the disconnect
async def _handle_disconnect():
try:
# Wait a short time to avoid race conditions
await asyncio.sleep(0.5)
# Check if session still exists
session_info = self.get_session(test_session_id)
if not session_info:
logger.debug(f"Session {test_session_id} already stopped")
return
# Stop the session (which will cancel both agents)
logger.info(
f"Stopping session {test_session_id} due to {disconnected_role} disconnect"
)
await stop_callback(test_session_id)
except asyncio.CancelledError:
logger.debug(
f"Disconnect handler cancelled for session {test_session_id}"
)
raise
except Exception as e:
logger.error(
f"Error handling disconnect for session {test_session_id}: {e}"
)
# Store the task so we can cancel it if needed
self._disconnect_handlers[test_session_id] = asyncio.create_task(
_handle_disconnect()
)
def update_audio_metadata(
self,
test_session_id: int,
role: str,
sample_rate: Optional[int] = None,
num_channels: Optional[int] = None,
):
"""Update audio metadata for a role in a session.
Args:
test_session_id: ID of the test session
role: Either "actor" or "adversary"
sample_rate: Sample rate of the audio
num_channels: Number of audio channels
"""
if test_session_id not in self._running_sessions:
return
if "audio_metadata" not in self._running_sessions[test_session_id]:
self._running_sessions[test_session_id]["audio_metadata"] = {}
if role not in self._running_sessions[test_session_id]["audio_metadata"]:
self._running_sessions[test_session_id]["audio_metadata"][role] = {}
metadata = self._running_sessions[test_session_id]["audio_metadata"][role]
if sample_rate is not None:
metadata["sample_rate"] = sample_rate
if num_channels is not None:
metadata["num_channels"] = num_channels
def get_audio_metadata(self, test_session_id: int, role: str) -> Dict[str, Any]:
"""Get audio metadata for a role in a session.
Args:
test_session_id: ID of the test session
role: Either "actor" or "adversary"
Returns:
Dictionary with sample_rate and num_channels
"""
default = {"sample_rate": 16000, "num_channels": 1}
if test_session_id not in self._running_sessions:
return default
metadata = (
self._running_sessions.get(test_session_id, {})
.get("audio_metadata", {})
.get(role, {})
)
return {
"sample_rate": metadata.get("sample_rate", 16000),
"num_channels": metadata.get("num_channels", 1),
}

View file

@ -0,0 +1,553 @@
import asyncio
import os
import uuid
from datetime import UTC, datetime
from pathlib import Path
from typing import Any, Dict, Optional
from loguru import logger
from pipecat.pipeline.task import PipelineTask
from pipecat.transports import (
InternalTransport,
InternalTransportManager,
)
from pipecat.utils.context import set_current_run_id
from api.db.db_client import DBClient
from api.services.pipecat.transport_setup import create_internal_transport
from .core.pipeline_builder import LoopTalkPipelineBuilder
from .core.recording_manager import RecordingManager
from .core.session_manager import SessionManager
class LoopTalkTestOrchestrator:
"""Orchestrates LoopTalk testing sessions with agent-to-agent conversations."""
def __init__(
self, db_client: DBClient, network_latency_seconds: Optional[float] = None
):
self.db_client = db_client
self.transport_manager = InternalTransportManager()
self.session_manager = SessionManager()
self.pipeline_builder = LoopTalkPipelineBuilder(db_client)
self.recording_manager = RecordingManager(Path("/tmp/looptalk_recordings"))
# Default network latency (can be overridden per session)
# Priority: constructor param > env var > default (100ms)
if network_latency_seconds is not None:
self._default_network_latency = network_latency_seconds
else:
env_latency = os.environ.get("LOOPTALK_NETWORK_LATENCY_MS")
if env_latency:
try:
self._default_network_latency = (
float(env_latency) / 1000.0
) # Convert ms to seconds
except ValueError:
logger.warning(
f"Invalid LOOPTALK_NETWORK_LATENCY_MS value: {env_latency}, using default 100ms"
)
self._default_network_latency = 0.1
else:
self._default_network_latency = 0.1 # 100ms default
async def start_test_session(
self,
test_session_id: int,
organization_id: int,
network_latency_seconds: Optional[float] = None,
) -> Dict[str, Any]:
"""Start a LoopTalk test session."""
# Get test session details
test_session = await self.db_client.get_test_session(
test_session_id=test_session_id, organization_id=organization_id
)
if not test_session:
raise ValueError(f"Test session {test_session_id} not found")
if test_session.status != "pending":
raise ValueError(f"Test session {test_session_id} is not in pending state")
try:
# Update status to running
await self.db_client.update_test_session_status(
test_session_id=test_session_id, status="running"
)
# Create conversation record
conversation = await self.db_client.create_conversation(
test_session_id=test_session_id
)
# Create audio configuration for LoopTalk
from api.services.pipecat.audio_config import AudioConfig
audio_config = AudioConfig(
transport_in_sample_rate=16000,
transport_out_sample_rate=16000,
pipeline_sample_rate=16000,
)
# Use provided latency or fall back to default
latency = (
network_latency_seconds
if network_latency_seconds is not None
else self._default_network_latency
)
logger.info(
f"Using network latency of {latency}s for test session {test_session_id}"
)
# Generate unique workflow run IDs for each agent
actor_workflow_run_id = int(str(test_session_id) + "1")
adversary_workflow_run_id = int(str(test_session_id) + "2")
# Create transports using the new method with turn analyzer
actor_transport = create_internal_transport(
workflow_run_id=actor_workflow_run_id,
audio_config=audio_config,
latency_seconds=latency,
)
adversary_transport = create_internal_transport(
workflow_run_id=adversary_workflow_run_id,
audio_config=audio_config,
latency_seconds=latency,
)
# Connect the transports
actor_transport.connect_partner(adversary_transport)
# Store the transport pair in the manager
self.transport_manager._transport_pairs[str(test_session_id)] = (
actor_transport,
adversary_transport,
)
# Generate unique identifiers for actor and adversary
actor_id = f"actor_{test_session_id}_{str(uuid.uuid4())[:8]}"
adversary_id = f"adversary_{test_session_id}_{str(uuid.uuid4())[:8]}"
# Create pipelines for both agents
actor_pipeline_info = await self.pipeline_builder.create_agent_pipeline(
transport=actor_transport,
workflow=test_session.actor_workflow,
test_session_id=test_session_id,
agent_id=actor_id,
role="actor",
)
actor_pipeline_task = actor_pipeline_info["task"]
adversary_pipeline_info = await self.pipeline_builder.create_agent_pipeline(
transport=adversary_transport,
workflow=test_session.adversary_workflow,
test_session_id=test_session_id,
agent_id=adversary_id,
role="adversary",
)
adversary_pipeline_task = adversary_pipeline_info["task"]
# Register event handlers for both pipelines
await self._register_transport_handlers(
actor_transport, actor_pipeline_info, test_session_id, "actor"
)
await self._register_transport_handlers(
adversary_transport,
adversary_pipeline_info,
test_session_id,
"adversary",
)
# Store session info
session_info = {
"test_session": test_session,
"conversation": conversation,
"actor_task": actor_pipeline_task,
"adversary_task": adversary_pipeline_task,
"actor_transport": actor_transport,
"adversary_transport": adversary_transport,
"start_time": datetime.now(UTC),
}
self.session_manager.add_session(test_session_id, session_info)
# Start both pipelines in background tasks
from pipecat.pipeline.base_task import PipelineTaskParams
params = PipelineTaskParams(loop=asyncio.get_event_loop())
# Start the pipelines - this will trigger initialization through the normal pipeline start process
# The workflow engines will be initialized when the pipeline starts
# Create conversation IDs for tracing
actor_conversation_id = f"{test_session_id}-actor-{actor_id}"
adversary_conversation_id = f"{test_session_id}-adversary-{adversary_id}"
# Create tasks but don't await them - they'll run in the background
logger.debug(f"Running actor task with ID: {actor_id}")
actor_task_future = asyncio.create_task(
self._run_pipeline_with_context(
actor_pipeline_task,
params,
actor_id,
actor_conversation_id,
"actor",
)
)
logger.debug(f"Running adversary task with ID: {adversary_id}")
adversary_task_future = asyncio.create_task(
self._run_pipeline_with_context(
adversary_pipeline_task,
params,
adversary_id,
adversary_conversation_id,
"adversary",
)
)
# Store the futures so we can monitor them
session_info["actor_task_future"] = actor_task_future
session_info["adversary_task_future"] = adversary_task_future
logger.info(f"Started LoopTalk test session {test_session_id}")
return {
"test_session_id": test_session_id,
"conversation_id": conversation.id,
"status": "running",
}
except Exception as e:
logger.error(f"Failed to start test session {test_session_id}: {e}")
await self.db_client.update_test_session_status(
test_session_id=test_session_id, status="failed", error=str(e)
)
raise
async def _register_transport_handlers(
self,
transport: InternalTransport,
pipeline_info: Dict[str, Any],
test_session_id: int,
role: str,
):
"""Register transport event handlers for a pipeline.
Args:
transport: The transport to register handlers on
pipeline_info: Dictionary containing pipeline components
test_session_id: ID of the test session
role: Either "actor" or "adversary"
"""
engine = pipeline_info["engine"]
task = pipeline_info["task"]
audio_buffer = pipeline_info["audio_buffer"]
audio_synchronizer = pipeline_info["audio_synchronizer"]
transcript = pipeline_info["transcript"]
assistant_context_aggregator = pipeline_info["assistant_context_aggregator"]
# Register transport event handlers
@transport.event_handler("on_client_connected")
async def on_client_connected(transport, participant):
logger.debug(f"LoopTalk {role} client connected - initializing workflow")
# Start audio recording
await audio_buffer.start_recording()
await audio_synchronizer.start_recording()
await engine.initialize()
@transport.event_handler("on_client_disconnected")
async def on_client_disconnected(transport, participant):
logger.debug(f"LoopTalk {role} client disconnected")
# Stop audio recording
await audio_buffer.stop_recording()
await audio_synchronizer.stop_recording()
# Handle disconnect propagation - stop the other agent too
await self.session_manager.handle_agent_disconnect(
test_session_id, role, self.stop_test_session
)
await task.cancel()
# Connect the context aggregator events to engine
@assistant_context_aggregator.event_handler("on_push_aggregation")
async def on_assistant_aggregator_push_context(_aggregator):
logger.debug(
"Assistant aggregator push context flushing pending transitions"
)
await engine.flush_pending_transitions()
# Register custom audio and transcript handlers for LoopTalk
await self._register_looptalk_handlers(
audio_synchronizer, transcript, test_session_id, role
)
async def _register_looptalk_handlers(
self, audio_synchronizer, transcript, test_session_id: int, role: str
):
"""Register LoopTalk-specific handlers for audio and transcript recording"""
paths = self.recording_manager.get_recording_paths(test_session_id, role)
# Store audio metadata for later WAV conversion
audio_metadata = {"sample_rate": None, "num_channels": None}
# Audio handler - writes directly to PCM file
@audio_synchronizer.event_handler("on_merged_audio")
async def on_merged_audio(_, pcm, sample_rate, num_channels):
if not pcm:
return
# Store metadata on first write
if audio_metadata["sample_rate"] is None:
audio_metadata["sample_rate"] = sample_rate
audio_metadata["num_channels"] = num_channels
# Append PCM data to temporary file
try:
with open(paths["temp_audio"], "ab") as f:
f.write(pcm)
except Exception as e:
logger.error(
f"Failed to write audio for {role} in session {test_session_id}: {e}"
)
# Transcript handler - writes directly to text file
@transcript.event_handler("on_transcript_update")
async def on_transcript_update(processor, frame):
transcript_text = ""
for msg in frame.messages:
timestamp = f"[{msg.timestamp}] " if msg.timestamp else ""
line = f"{timestamp}{msg.role}: {msg.content}\n"
transcript_text += line
# Append transcript to file
try:
with open(paths["transcript"], "a") as f:
f.write(transcript_text)
except Exception as e:
logger.error(
f"Failed to write transcript for {role} in session {test_session_id}: {e}"
)
# Store metadata in session info for later WAV conversion
# Set default values if not yet captured
if audio_metadata["sample_rate"] is None:
audio_metadata["sample_rate"] = 16000 # Default sample rate
audio_metadata["num_channels"] = 1 # Default channels
self.session_manager.update_audio_metadata(
test_session_id,
role,
sample_rate=audio_metadata["sample_rate"],
num_channels=audio_metadata["num_channels"],
)
async def _run_pipeline_with_context(
self,
pipeline_task: PipelineTask,
params,
agent_id: str,
conversation_id: str,
role: str,
):
"""Run a pipeline task with the agent_id set in context"""
set_current_run_id(agent_id)
return await pipeline_task.run(params)
async def stop_test_session(self, test_session_id: int) -> Dict[str, Any]:
"""Stop a running test session."""
session_info = self.session_manager.get_session(test_session_id)
if not session_info:
raise ValueError(f"Test session {test_session_id} is not running")
try:
# Cancel both pipeline tasks
await session_info["actor_task"].cancel()
await session_info["adversary_task"].cancel()
# Also cancel the task futures if they exist
if "actor_task_future" in session_info:
session_info["actor_task_future"].cancel()
if "adversary_task_future" in session_info:
session_info["adversary_task_future"].cancel()
# Calculate duration
duration_seconds = int(
(datetime.now(UTC) - session_info["start_time"]).total_seconds()
)
# Update conversation
await self.db_client.update_conversation(
conversation_id=session_info["conversation"].id,
duration_seconds=duration_seconds,
ended_at=datetime.now(UTC),
)
# Update test session status
await self.db_client.update_test_session_status(
test_session_id=test_session_id,
status="completed",
results={
"duration_seconds": duration_seconds,
"conversation_id": session_info["conversation"].id,
},
)
# Finalize recordings for both actor and adversary
# Convert PCM files to WAV
actor_metadata = self.session_manager.get_audio_metadata(
test_session_id, "actor"
)
adversary_metadata = self.session_manager.get_audio_metadata(
test_session_id, "adversary"
)
self.recording_manager.convert_pcm_to_wav(
test_session_id,
"actor",
sample_rate=actor_metadata["sample_rate"],
num_channels=actor_metadata["num_channels"],
)
self.recording_manager.convert_pcm_to_wav(
test_session_id,
"adversary",
sample_rate=adversary_metadata["sample_rate"],
num_channels=adversary_metadata["num_channels"],
)
# Upload recordings to S3 (synchronously for load testing)
(
actor_audio_url,
actor_transcript_url,
) = await self.recording_manager.upload_recording_to_s3(
test_session_id, "actor"
)
(
adversary_audio_url,
adversary_transcript_url,
) = await self.recording_manager.upload_recording_to_s3(
test_session_id, "adversary"
)
# Update conversation with recording URLs
await self.db_client.update_conversation(
conversation_id=session_info["conversation"].id,
actor_recording_url=actor_audio_url,
adversary_recording_url=adversary_audio_url,
transcript={
"actor_transcript_url": actor_transcript_url,
"adversary_transcript_url": adversary_transcript_url,
},
)
# Log recording locations
logger.info(f"LoopTalk recordings uploaded to S3:")
if actor_audio_url:
logger.info(f" - Actor audio: {actor_audio_url}")
if actor_transcript_url:
logger.info(f" - Actor transcript: {actor_transcript_url}")
if adversary_audio_url:
logger.info(f" - Adversary audio: {adversary_audio_url}")
if adversary_transcript_url:
logger.info(f" - Adversary transcript: {adversary_transcript_url}")
# Clean up local files after successful upload
self.recording_manager.cleanup_session_files(test_session_id)
# Clean up
self.transport_manager.remove_transport_pair(str(test_session_id))
self.session_manager.remove_session(test_session_id)
# Clean up audio streamers
from api.services.looptalk.audio_streamer import cleanup_audio_streamers
cleanup_audio_streamers(str(test_session_id))
logger.info(f"Stopped LoopTalk test session {test_session_id}")
return {
"test_session_id": test_session_id,
"status": "completed",
"duration_seconds": duration_seconds,
}
except Exception as e:
logger.error(f"Failed to stop test session {test_session_id}: {e}")
await self.db_client.update_test_session_status(
test_session_id=test_session_id, status="failed", error=str(e)
)
raise
async def start_load_test(
self,
organization_id: int,
name_prefix: str,
actor_workflow_id: int,
adversary_workflow_id: int,
config: Dict[str, Any],
test_count: int,
) -> Dict[str, Any]:
"""Start a load test with multiple concurrent test sessions."""
# Validate test count
if test_count < 1 or test_count > 10:
raise ValueError("Test count must be between 1 and 10")
# Create test sessions
test_sessions = await self.db_client.create_load_test_group(
organization_id=organization_id,
name_prefix=name_prefix,
actor_workflow_id=actor_workflow_id,
adversary_workflow_id=adversary_workflow_id,
config=config,
test_count=test_count,
)
# Start all test sessions concurrently
tasks = []
for test_session in test_sessions:
task = asyncio.create_task(
self.start_test_session(
test_session_id=test_session.id, organization_id=organization_id
)
)
tasks.append(task)
# Wait for all to start
results = await asyncio.gather(*tasks, return_exceptions=True)
# Count successes and failures
started = sum(1 for r in results if not isinstance(r, Exception))
failed = sum(1 for r in results if isinstance(r, Exception))
load_test_group_id = test_sessions[0].load_test_group_id
logger.info(
f"Started load test {load_test_group_id}: "
f"{started} started, {failed} failed out of {test_count}"
)
return {
"load_test_group_id": load_test_group_id,
"total": test_count,
"started": started,
"failed": failed,
"test_session_ids": [ts.id for ts in test_sessions],
}
def get_active_test_count(self) -> int:
"""Get the number of currently active test sessions."""
return self.session_manager.get_active_count()
def get_active_test_info(self) -> Dict[str, Any]:
"""Get information about all active test sessions."""
return self.session_manager.get_active_info()
def get_recording_info(self, test_session_id: int) -> Dict[str, Any]:
"""Get information about recordings for a test session"""
return self.recording_manager.get_recording_info(test_session_id)

View file

@ -0,0 +1,283 @@
"""
MPS Service Key HTTP Client
This client communicates with the Model Proxy Service (MPS) for service key management.
Service keys are stored and managed entirely in MPS, not in the local database.
"""
from typing import List, Optional
import httpx
from loguru import logger
from api.constants import DEPLOYMENT_MODE, DOGRAH_MPS_SECRET_KEY, MPS_API_URL
class MPSServiceKeyClient:
"""HTTP client for managing service keys via MPS API."""
def __init__(self):
self.base_url = MPS_API_URL
self.timeout = httpx.Timeout(10.0)
def _get_headers(self) -> dict:
"""Get headers for MPS API requests."""
headers = {"Content-Type": "application/json"}
# Add authentication for non-OSS mode
if DEPLOYMENT_MODE != "oss" and DOGRAH_MPS_SECRET_KEY:
headers["X-Secret-Key"] = DOGRAH_MPS_SECRET_KEY
return headers
async def create_service_key(
self,
name: str,
organization_id: Optional[int] = None,
created_by: str = None,
expires_in_days: int = 90,
description: Optional[str] = None,
) -> dict:
"""
Create a new service key via MPS API.
For OSS mode: organization_id should be None
For authenticated mode: organization_id should be provided
"""
async with httpx.AsyncClient(timeout=self.timeout) as client:
request_body = {
"name": name,
"description": description or f"Service key: {name}",
"expires_in_days": expires_in_days,
"created_by": created_by,
}
# Only add organization_id for non-OSS mode
if DEPLOYMENT_MODE != "oss" and organization_id:
request_body["organization_id"] = organization_id
response = await client.post(
f"{self.base_url}/api/v1/service-keys/",
json=request_body,
headers=self._get_headers(),
)
if response.status_code == 200:
data = response.json()
# Transform the response to match our expected format
return {
"id": data.get("id"),
"name": data.get("name"),
"service_key": data.get("service_key"), # Only returned on creation
"key_prefix": data.get("service_key", "")[:8]
if data.get("service_key")
else "",
"expires_at": data.get("expires_at"),
"created_at": data.get("created_at"),
"is_active": data.get("is_active", True),
"created_by": data.get("created_by"),
}
else:
raise httpx.HTTPStatusError(
f"Failed to create service key: {response.text}",
request=response.request,
response=response,
)
async def get_service_keys(
self,
organization_id: Optional[int] = None,
created_by: Optional[str] = None,
include_archived: bool = False,
) -> List[dict]:
"""
Get service keys from MPS.
For OSS mode: Use created_by to filter keys
For authenticated mode: Use organization_id to filter keys
"""
async with httpx.AsyncClient(timeout=self.timeout) as client:
params = {}
if DEPLOYMENT_MODE == "oss":
# In OSS mode, filter by created_by
if created_by:
params["created_by"] = created_by
else:
# In authenticated mode, filter by organization_id
if organization_id:
params["organization_id"] = organization_id
if include_archived:
params["include_archived"] = "true"
response = await client.get(
f"{self.base_url}/api/v1/service-keys/",
params=params,
headers=self._get_headers(),
)
if response.status_code == 200:
keys = response.json()
# Transform the response to match our expected format
return [
{
"id": key.get("id"),
"name": key.get("name"),
"key_prefix": key.get("key_prefix", ""),
"is_active": key.get("is_active", True),
"created_at": key.get("created_at"),
"last_used_at": key.get("last_used_at"),
"expires_at": key.get("expires_at"),
"archived_at": key.get("archived_at"),
"created_by": key.get("created_by"),
}
for key in keys
]
else:
logger.error(
f"Failed to get service keys: {response.status_code} - {response.text}"
)
return []
async def get_service_key_by_id(
self,
key_id: int,
organization_id: Optional[int] = None,
created_by: Optional[str] = None,
) -> Optional[dict]:
"""Get a specific service key by ID."""
async with httpx.AsyncClient(timeout=self.timeout) as client:
response = await client.get(
f"{self.base_url}/api/v1/service-keys/{key_id}",
headers=self._get_headers(),
)
if response.status_code == 200:
key = response.json()
# Validate ownership for OSS mode
if DEPLOYMENT_MODE == "oss" and created_by:
if key.get("created_by") != created_by:
logger.warning(
f"Access denied: User {created_by} tried to access key created by {key.get('created_by')}"
)
return None
# Validate organization for authenticated mode
if DEPLOYMENT_MODE != "oss" and organization_id:
if key.get("organization_id") != organization_id:
logger.warning(
f"Access denied: Org {organization_id} tried to access key for org {key.get('organization_id')}"
)
return None
return {
"id": key.get("id"),
"name": key.get("name"),
"key_prefix": key.get("key_prefix", ""),
"is_active": key.get("is_active", True),
"created_at": key.get("created_at"),
"last_used_at": key.get("last_used_at"),
"expires_at": key.get("expires_at"),
"archived_at": key.get("archived_at"),
"created_by": key.get("created_by"),
}
else:
return None
async def archive_service_key(
self,
key_id: int,
organization_id: Optional[int] = None,
created_by: Optional[str] = None,
) -> bool:
"""
Archive (soft delete) a service key.
For OSS mode: Validates that created_by matches the key creator
For authenticated mode: Validates organization_id matches
"""
# First, verify ownership
key = await self.get_service_key_by_id(key_id, organization_id, created_by)
if not key:
logger.error(f"Service key {key_id} not found or access denied")
return False
async with httpx.AsyncClient(timeout=self.timeout) as client:
response = await client.delete(
f"{self.base_url}/api/v1/service-keys/{key_id}",
headers=self._get_headers(),
)
if response.status_code in [200, 204]:
return True
else:
logger.error(
f"Failed to archive service key: {response.status_code} - {response.text}"
)
return False
async def call_workflow_api(
self,
call_type: str,
use_case: str,
activity_description: str,
organization_id: Optional[int] = None,
created_by: Optional[str] = None,
) -> dict:
"""
Call the MPS workflow creation API using secret key authentication.
For OSS mode: Pass created_by in headers
For authenticated mode: Pass organization_id in headers
Args:
call_type: INBOUND or OUTBOUND
use_case: Description of the use case
activity_description: Description of what the agent should do
organization_id: Organization ID (for authenticated mode)
created_by: User provider ID (for OSS mode)
Returns:
Workflow data from MPS API
Raises:
HTTPException: If the API call fails
"""
headers = {"Content-Type": "application/json"}
# Add secret key authentication
if DEPLOYMENT_MODE != "oss" and DOGRAH_MPS_SECRET_KEY:
headers["X-Secret-Key"] = DOGRAH_MPS_SECRET_KEY
if organization_id:
headers["X-Organization-Id"] = str(organization_id)
elif DEPLOYMENT_MODE == "oss":
if created_by:
headers["X-Created-By"] = created_by
async with httpx.AsyncClient(timeout=httpx.Timeout(30.0)) as client:
response = await client.post(
f"{self.base_url}/api/v1/workflow/create-workflow",
json={
"call_type": call_type,
"use_case": use_case,
"activity_description": activity_description,
},
headers=headers,
)
if response.status_code == 200:
return response.json()
else:
logger.error(
f"Failed to create workflow: {response.status_code} - {response.text}"
)
raise httpx.HTTPStatusError(
f"Failed to create workflow: {response.text}",
request=response.request,
response=response,
)
# Create a singleton instance
mps_service_key_client = MPSServiceKeyClient()

View file

View file

@ -0,0 +1,120 @@
"""
Audio configuration for pipeline components.
This module provides centralized audio configuration to ensure consistent
sample rates across all pipeline components and proper coordination between
transport serializers, VAD, and audio buffers.
"""
from dataclasses import dataclass
from typing import Optional
from loguru import logger
from api.enums import WorkflowRunMode
@dataclass
class AudioConfig:
"""Centralized audio configuration for the pipeline.
Note: Pipeline is limited to 16kHz maximum to support VAD.
Transports handle resampling from/to higher rates (24kHz, 48kHz).
Attributes:
transport_in_sample_rate: Sample rate of incoming audio from transport (after resampling)
transport_out_sample_rate: Sample rate of outgoing audio to transport (before resampling)
vad_sample_rate: Sample rate for VAD processing (8000 or 16000)
pipeline_sample_rate: Internal pipeline processing sample rate (max 16000)
buffer_size_seconds: Audio buffer size in seconds
"""
transport_in_sample_rate: int
transport_out_sample_rate: int
vad_sample_rate: int = 16000 # VAD typically resamples internally
pipeline_sample_rate: Optional[int] = None # If None, uses transport rates
buffer_size_seconds: float = 1.0 # This is how frequenly we will call merge_auido
def __post_init__(self):
# Validate VAD sample rate
if self.vad_sample_rate not in [8000, 16000]:
raise ValueError(
f"VAD sample rate must be 8000 or 16000, got {self.vad_sample_rate}"
)
# Set pipeline sample rate to transport out rate if not specified
if self.pipeline_sample_rate is None:
self.pipeline_sample_rate = min(self.transport_out_sample_rate, 16000)
# Ensure pipeline sample rate doesn't exceed 16kHz (VAD limitation)
if self.pipeline_sample_rate > 16000:
logger.warning(
f"Pipeline sample rate {self.pipeline_sample_rate} exceeds 16kHz limit, "
f"capping at 16kHz. Transport will handle resampling."
)
self.pipeline_sample_rate = 16000
# Log configuration for auditing
logger.info(
f"AudioConfig initialized: "
f"transport_in={self.transport_in_sample_rate}Hz, "
f"transport_out={self.transport_out_sample_rate}Hz, "
f"vad={self.vad_sample_rate}Hz, "
f"pipeline={self.pipeline_sample_rate}Hz, "
f"buffer={self.buffer_size_seconds}s"
)
@property
def buffer_size_bytes(self) -> int:
"""Calculate buffer size in bytes based on pipeline sample rate."""
# 2 bytes per sample (16-bit PCM)
return int(self.pipeline_sample_rate * 2 * self.buffer_size_seconds)
@property
def buffer_size_samples(self) -> int:
"""Calculate buffer size in samples based on pipeline sample rate."""
return int(self.pipeline_sample_rate * self.buffer_size_seconds)
def create_audio_config(transport_type: str) -> AudioConfig:
"""Create audio configuration based on transport type.
Args:
transport_type: Type of transport ("webrtc", "twilio", "stasis")
Returns:
AudioConfig instance with appropriate settings
"""
if transport_type in (WorkflowRunMode.STASIS.value, WorkflowRunMode.TWILIO.value):
return AudioConfig(
transport_in_sample_rate=8000,
transport_out_sample_rate=8000,
vad_sample_rate=8000, # Use matching VAD rate
pipeline_sample_rate=8000, # Keep at 8kHz to avoid resampling
buffer_size_seconds=1.0,
)
elif transport_type in [
WorkflowRunMode.WEBRTC.value,
WorkflowRunMode.SMALLWEBRTC.value,
]:
# WebRTC typically uses 24kHz or 48kHz, but we limit pipeline to 16kHz
# The transport will handle resampling between 24kHz and 16kHz
return AudioConfig(
transport_in_sample_rate=16000, # Transport will resample from 24kHz
transport_out_sample_rate=16000, # Transport will resample to 24kHz
vad_sample_rate=16000, # VAD native rate
pipeline_sample_rate=16000, # Keep pipeline at 16kHz
buffer_size_seconds=1.0,
)
else:
# Default configuration
logger.warning(
f"Unknown transport type: {transport_type}, using default config"
)
return AudioConfig(
transport_in_sample_rate=16000,
transport_out_sample_rate=16000,
vad_sample_rate=16000,
pipeline_sample_rate=16000,
buffer_size_seconds=1.0,
)

View file

@ -0,0 +1,122 @@
import asyncio
import re
import tempfile
import wave
from typing import List
from loguru import logger
class InMemoryAudioBuffer:
"""Buffer audio data in memory during a call, then write to temp file on disconnect."""
def __init__(self, workflow_run_id: int, sample_rate: int, num_channels: int = 1):
self._workflow_run_id = workflow_run_id
self._sample_rate = sample_rate
self._num_channels = num_channels
self._chunks: List[bytes] = []
self._lock = asyncio.Lock()
self._total_size = 0
self._max_size = 100 * 1024 * 1024 # 100MB limit
async def append(self, pcm_data: bytes):
"""Append PCM audio data to the buffer."""
async with self._lock:
if self._total_size + len(pcm_data) > self._max_size:
logger.error(
f"Audio buffer size limit exceeded for workflow {self._workflow_run_id}. "
f"Current: {self._total_size}, Attempted to add: {len(pcm_data)}"
)
raise MemoryError("Audio buffer size limit exceeded")
self._chunks.append(pcm_data)
self._total_size += len(pcm_data)
logger.trace(
f"Appended {len(pcm_data)} bytes to audio buffer. Total size: {self._total_size}"
)
async def write_to_temp_file(self) -> str:
"""Write audio data to a temporary WAV file and return the path."""
async with self._lock:
temp_file = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
logger.debug(
f"Writing audio buffer to temp file {temp_file.name} for workflow {self._workflow_run_id}"
)
# Write WAV header and PCM data
with wave.open(temp_file.name, "wb") as wf:
wf.setnchannels(self._num_channels)
wf.setsampwidth(2) # 16-bit audio
wf.setframerate(self._sample_rate)
# Concatenate all chunks
for chunk in self._chunks:
wf.writeframes(chunk)
logger.info(
f"Successfully wrote {self._total_size} bytes of audio to {temp_file.name}"
)
return temp_file.name
@property
def is_empty(self) -> bool:
"""Check if the buffer is empty."""
return len(self._chunks) == 0
@property
def size(self) -> int:
"""Get the total size of buffered data."""
return self._total_size
class InMemoryTranscriptBuffer:
"""Buffer transcript data in memory during a call, then write to temp file on disconnect."""
# Compiled regex to identify user speech lines, e.g.
# [2025-06-29T12:34:56.789+00:00] user: hello
_USER_SPEECH_RE: re.Pattern[str] = re.compile(
r"^\[\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}\.\d{3}\+\d{2}:\d{2}\] user: .+"
)
def __init__(self, workflow_run_id: int):
self._workflow_run_id = workflow_run_id
self._lines: List[str] = []
self._lock = asyncio.Lock()
async def append(self, transcript: str):
"""Append transcript text to the buffer."""
async with self._lock:
self._lines.append(transcript)
logger.trace(
f"Appended transcript line to buffer for workflow {self._workflow_run_id}"
)
async def write_to_temp_file(self) -> str:
"""Write transcript to a temporary text file and return the path."""
async with self._lock:
temp_file = tempfile.NamedTemporaryFile(
mode="w", suffix=".txt", delete=False
)
logger.debug(
f"Writing transcript buffer to temp file {temp_file.name} for workflow {self._workflow_run_id}"
)
content = "".join(self._lines)
temp_file.write(content)
temp_file.close()
logger.info(
f"Successfully wrote {len(content)} chars of transcript to {temp_file.name}"
)
return temp_file.name
@property
def is_empty(self) -> bool:
"""Check if the buffer is empty."""
return len(self._lines) == 0
def contains_user_speech(self) -> bool:
"""Return True if any buffered transcript line matches the user speech pattern."""
for line in self._lines:
if self._USER_SPEECH_RE.match(line):
return True
return False

View file

@ -0,0 +1,69 @@
"""Engine Pre-Aggregator Processor
This processor sits before the user context aggregator in the pipeline and handles
engine-specific callbacks for frames that need to be processed before aggregation.
This ensures the engine can update context before the aggregator generates LLM frames.
"""
from typing import Awaitable, Callable, Optional
from loguru import logger
from api.services.pipecat.exceptions import VoicemailDetectedException
from pipecat.frames.frames import (
Frame,
UserStartedSpeakingFrame,
UserStoppedSpeakingFrame,
)
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
class EnginePreAggregatorProcessor(FrameProcessor):
"""
Processor that handles engine callbacks before user context aggregation.
This processor is positioned before the user context aggregator to ensure
the engine can update LLM context before aggregation occurs.
"""
def __init__(
self,
user_started_speaking_callback: Optional[Callable[[], Awaitable[None]]] = None,
user_stopped_speaking_callback: Optional[Callable[[], Awaitable[None]]] = None,
**kwargs,
):
super().__init__(**kwargs)
self._user_started_speaking_callback = user_started_speaking_callback
self._user_stopped_speaking_callback = user_stopped_speaking_callback
async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)
# Handle frames that need engine processing before aggregation
if isinstance(frame, UserStartedSpeakingFrame):
await self._handle_user_started_speaking()
elif isinstance(frame, UserStoppedSpeakingFrame):
try:
await self._handle_user_stopped_speaking()
except VoicemailDetectedException:
# We have detected voicemail, lets not
# forward the UserStoppedSpeakingFrame, so that
# we don't issue an llm call from user context
# aggregator
logger.debug("Voicemail detected, not pushing UserStoppedSpeakingFrame")
return
# Always push the frame downstream
await self.push_frame(frame, direction)
async def _handle_user_started_speaking(self):
"""Handle UserStartedSpeakingFrame before aggregation."""
if self._user_started_speaking_callback:
# logger.debug("Engine pre-aggregator: User started speaking")
await self._user_started_speaking_callback()
async def _handle_user_stopped_speaking(self):
"""Handle UserStoppedSpeakingFrame before aggregation."""
if self._user_stopped_speaking_callback:
# logger.debug("Engine pre-aggregator: User stopped speaking")
await self._user_stopped_speaking_callback()

View file

@ -0,0 +1,249 @@
from typing import Optional
from loguru import logger
from api.db import db_client
from api.services.campaign.call_dispatcher import campaign_call_dispatcher
from api.services.pipecat.audio_transcript_buffers import (
InMemoryAudioBuffer,
InMemoryTranscriptBuffer,
)
from api.services.pipecat.pipeline_metrics_aggregator import PipelineMetricsAggregator
from api.services.workflow.disposition_mapper import (
apply_disposition_mapping,
get_organization_id_from_workflow_run,
)
from api.services.workflow.pipecat_engine import PipecatEngine
from api.tasks.arq import enqueue_job
from api.tasks.function_names import FunctionNames
from pipecat.pipeline.task import PipelineTask
from pipecat.transports.base_transport import BaseTransport
from pipecat.utils.enums import EndTaskReason
def register_transport_event_handlers(
transport,
workflow_run_id,
audio_buffer,
task: PipelineTask,
engine: PipecatEngine,
usage_metrics_aggregator: PipelineMetricsAggregator,
audio_synchronizer=None,
audio_config=None,
):
"""Register event handlers for transport events"""
# Initialize in-memory buffers with proper audio configuration
sample_rate = audio_config.pipeline_sample_rate if audio_config else 16000
num_channels = 1 # Pipeline audio is always mono
logger.debug(
f"Initializing audio buffer for workflow {workflow_run_id} "
f"with sample_rate={sample_rate}Hz, channels={num_channels}"
)
in_memory_audio_buffer = InMemoryAudioBuffer(
workflow_run_id=workflow_run_id,
sample_rate=sample_rate,
num_channels=num_channels,
)
in_memory_transcript_buffer = InMemoryTranscriptBuffer(workflow_run_id)
@transport.event_handler("on_client_connected")
async def on_client_connected(transport, participant):
logger.debug("In on_client_connected callback handler - initializing workflow")
await audio_buffer.start_recording()
if audio_synchronizer:
await audio_synchronizer.start_recording()
await engine.initialize()
@transport.event_handler("on_client_disconnected")
async def on_client_disconnected(
transport: BaseTransport,
participant,
transport_disconnect_reason: Optional[str] = None,
):
logger.debug(
f"In on_client_disconnected callback handler, disconnect_reason: {transport_disconnect_reason}"
)
workflow_run = await db_client.get_workflow_run_by_id(workflow_run_id)
# First priority: Check if engine has a disconnect reason (local disconnect)
engine_call_disposition = engine.get_call_disposition()
gathered_context = engine.get_gathered_context()
# also consider existing gathered context in workflow_run
gathered_context = {**gathered_context, **workflow_run.gathered_context}
if engine_call_disposition:
# Engine has set a disconnect reason - this takes priority
call_disposition = engine_call_disposition
logger.debug(f"Engine disposition detected, code: {call_disposition}")
elif transport_disconnect_reason:
# TODO: Make this more generic using some DSL or equivalent. This is currently
# configured to work for Kapil's bot
call_duration = usage_metrics_aggregator.get_call_duration()
if transport_disconnect_reason == EndTaskReason.USER_HANGUP.value:
if call_duration < 10:
call_disposition = "HU"
else:
call_disposition = "NIBP"
else:
# Transport provided a disconnect reason (remote hangup)
call_disposition = transport_disconnect_reason
logger.debug(
f"Remote disconnect detected, reason: {call_disposition} duration: {call_duration}"
)
else:
# No reason provided - assume user hangup
call_disposition = EndTaskReason.UNKNOWN.value
logger.debug("No disposition found from either engine or transport")
# Cancel task only when no engine disconnect reason (remote disconnect)
if not engine_call_disposition:
await task.cancel()
organization_id = await get_organization_id_from_workflow_run(workflow_run_id)
mapped_call_disposition = await apply_disposition_mapping(
call_disposition, organization_id
)
gathered_context.update({"mapped_call_disposition": mapped_call_disposition})
if in_memory_transcript_buffer:
call_tags = gathered_context.get("call_tags", [])
try:
has_user_speech = in_memory_transcript_buffer.contains_user_speech()
except Exception:
has_user_speech = False
if has_user_speech and "user_speech" not in call_tags:
call_tags.append("user_speech")
# Append any keys from gathered_context that start with 'tag_' to call_tags
for key in gathered_context:
if key.startswith("tag_") and key not in call_tags:
call_tags.append(gathered_context[key])
gathered_context["call_tags"] = call_tags
# Clean up engine resources (including voicemail detector)
await engine.cleanup()
await audio_buffer.stop_recording()
if audio_synchronizer:
await audio_synchronizer.stop_recording()
# ------------------------------------------------------------------
# Close Smart-Turn WebSocket if the transport's analyzer supports it
# ------------------------------------------------------------------
try:
turn_analyzer = None
# Most transports store their params (with turn_analyzer) directly.
if hasattr(transport, "_params") and transport._params:
turn_analyzer = getattr(transport._params, "turn_analyzer", None)
# Fallback: some transports expose params through input() instance.
if turn_analyzer is None and hasattr(transport, "input"):
try:
input_transport = transport.input()
if input_transport and hasattr(input_transport, "_params"):
turn_analyzer = getattr(
input_transport._params, "turn_analyzer", None
)
except Exception:
pass
if turn_analyzer and hasattr(turn_analyzer, "close"):
await turn_analyzer.close()
logger.debug("Closed turn analyzer websocket")
except Exception as exc:
logger.warning(f"Failed to close Smart-Turn analyzer gracefully: {exc}")
usage_info = usage_metrics_aggregator.get_all_usage_metrics_serialized()
logger.debug(f"Usage metrics: {usage_info}")
await db_client.update_workflow_run(
run_id=workflow_run_id,
usage_info=usage_info,
gathered_context=gathered_context,
is_completed=True,
)
# Release concurrent slot for campaign calls
if workflow_run and workflow_run.campaign_id:
await campaign_call_dispatcher.release_call_slot(workflow_run_id)
# Write buffers to temp files and enqueue S3 upload
try:
# Only upload if buffers have content
if not in_memory_audio_buffer.is_empty:
audio_temp_path = await in_memory_audio_buffer.write_to_temp_file()
await enqueue_job(
FunctionNames.UPLOAD_AUDIO_TO_S3, workflow_run_id, audio_temp_path
)
else:
logger.debug("Audio buffer is empty, skipping upload")
if not in_memory_transcript_buffer.is_empty:
transcript_temp_path = (
await in_memory_transcript_buffer.write_to_temp_file()
)
await enqueue_job(
FunctionNames.UPLOAD_TRANSCRIPT_TO_S3,
workflow_run_id,
transcript_temp_path,
)
else:
logger.debug("Transcript buffer is empty, skipping upload")
except Exception as e:
logger.error(f"Error preparing buffers for S3 upload: {e}", exc_info=True)
await enqueue_job(FunctionNames.CALCULATE_WORKFLOW_RUN_COST, workflow_run_id)
await enqueue_job(
FunctionNames.RUN_INTEGRATIONS_POST_WORKFLOW_RUN, workflow_run_id
)
# Return the buffers so they can be passed to other handlers
return in_memory_audio_buffer, in_memory_transcript_buffer
def register_audio_data_handler(
audio_synchronizer, workflow_run_id, in_memory_buffer: InMemoryAudioBuffer
):
"""Register event handler for audio data"""
logger.info(f"Registering audio data handler for workflow run {workflow_run_id}")
@audio_synchronizer.event_handler("on_merged_audio")
async def on_merged_audio(_, pcm, sample_rate, num_channels):
if not pcm:
return
# Use in-memory buffer
try:
await in_memory_buffer.append(pcm)
except MemoryError as e:
logger.error(f"Memory buffer full: {e}")
# Could implement overflow to disk here if needed
def register_transcript_handler(
transcript, workflow_run_id, in_memory_buffer: InMemoryTranscriptBuffer
):
"""Register event handler for transcript updates"""
@transcript.event_handler("on_transcript_update")
async def on_transcript_update(processor, frame):
transcript_text = ""
for msg in frame.messages:
timestamp = f"[{msg.timestamp}] " if msg.timestamp else ""
line = f"{timestamp}{msg.role}: {msg.content}\n"
transcript_text += line
# Use in-memory buffer
await in_memory_buffer.append(transcript_text)

View file

@ -0,0 +1,6 @@
class VoicemailDetectedException(Exception):
"""
Exception raised when voicemail is detected.
"""
pass

View file

@ -0,0 +1,147 @@
import os
from typing import TYPE_CHECKING
from loguru import logger
from api.constants import (
ENABLE_TRACING,
)
from api.services.pipecat.audio_config import AudioConfig
from pipecat.pipeline.pipeline import Pipeline
from pipecat.pipeline.task import PipelineParams, PipelineTask
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
from pipecat.processors.audio.audio_buffer_processor import AudioBuffer
from pipecat.processors.audio.audio_synchronizer import AudioSynchronizer
from pipecat.processors.transcript_processor import TranscriptProcessor
from pipecat.utils.context import turn_var
if TYPE_CHECKING:
from api.services.workflow.pipecat_engine import PipecatEngine
def create_pipeline_components(audio_config: AudioConfig, engine: "PipecatEngine"):
"""Create and return the main pipeline components with proper audio configuration"""
logger.info(f"Creating pipeline components with audio config: {audio_config}")
# Use new split audio buffer for better performance
audio_buffer = AudioBuffer(
sample_rate=audio_config.pipeline_sample_rate,
buffer_size=audio_config.buffer_size_bytes,
)
# Create synchronizer for merged audio (outside pipeline)
audio_synchronizer = AudioSynchronizer(
sample_rate=audio_config.pipeline_sample_rate,
buffer_size=audio_config.buffer_size_bytes,
)
transcript = TranscriptProcessor(
assistant_correct_aggregation_callback=engine.create_aggregation_correction_callback()
)
context = OpenAILLMContext()
return audio_buffer, audio_synchronizer, transcript, context
def build_pipeline(
transport,
stt,
transcript,
audio_buffer,
audio_synchronizer,
llm,
tts,
user_context_aggregator,
assistant_context_aggregator,
pipeline_engine_callback_processor,
stt_mute_filter,
pipeline_metrics_aggregator,
user_idle_disconnect,
engine_pre_aggregator_processor=None,
):
"""Build the main pipeline with all components"""
# Register processors with synchronizer for merged audio
logger.info("Registering audio buffer processors with synchronizer")
audio_synchronizer.register_processors(audio_buffer.input(), audio_buffer.output())
# Build processors list with optional context controller
processors = [
transport.input(), # Transport user input
audio_buffer.input(), # Record input audio (only processes InputAudioRawFrame)
stt_mute_filter,
stt, # STT can now have audio_passthrough=False
user_idle_disconnect,
transcript.user(),
]
# Insert engine pre-aggregator processor if provided (before user aggregator)
if engine_pre_aggregator_processor:
processors.append(engine_pre_aggregator_processor)
processors.extend(
[
user_context_aggregator,
llm, # LLM
pipeline_engine_callback_processor,
tts, # TTS
transport.output(), # Transport bot output
audio_buffer.output(), # Record output audio (only processes OutputAudioRawFrame)
transcript.assistant(),
assistant_context_aggregator, # Assistant spoken responses
pipeline_metrics_aggregator,
]
)
return Pipeline(processors)
def create_pipeline_task(pipeline, workflow_run_id, audio_config: AudioConfig = None):
"""Create a pipeline task with appropriate parameters"""
# Set up pipeline params with audio configuration if provided
pipeline_params = PipelineParams(
allow_interruptions=True,
enable_metrics=True,
enable_usage_metrics=True,
send_initial_empty_metrics=False,
enable_heartbeats=True,
)
# If audio_config is provided, set the audio sample rates
if audio_config:
pipeline_params.audio_in_sample_rate = audio_config.transport_in_sample_rate
pipeline_params.audio_out_sample_rate = audio_config.transport_out_sample_rate
logger.debug(
f"Setting pipeline audio params - in: {audio_config.transport_in_sample_rate}Hz, "
f"out: {audio_config.transport_out_sample_rate}Hz"
)
task = PipelineTask(
pipeline,
params=pipeline_params,
enable_tracing=ENABLE_TRACING,
conversation_id=f"{workflow_run_id}",
)
# Check if turn logging is enabled
enable_turn_logging = os.getenv("ENABLE_TURN_LOGGING", "false").lower() == "true"
if enable_turn_logging:
# Attach event handlers to propagate turn information into the logging context
turn_observer = task.turn_tracking_observer
if turn_observer is not None:
# Import turn context manager only if needed
from api.services.pipecat.turn_context import get_turn_context_manager
async def _on_turn_started(observer, turn_number: int):
"""Set the current turn number into the context variable."""
# Set in both contextvar and turn context manager
turn_var.set(turn_number)
turn_manager = get_turn_context_manager()
turn_manager.set_turn(turn_number)
# Register the handlers with the observer
turn_observer.add_event_handler("on_turn_started", _on_turn_started)
return task

View file

@ -0,0 +1,84 @@
import time
from typing import Awaitable, Callable, Optional
from loguru import logger
from pipecat.frames.frames import (
Frame,
HeartbeatFrame,
LLMFullResponseStartFrame,
LLMGeneratedTextFrame,
LLMTextFrame,
StartFrame,
TTSSpeakFrame,
)
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
class PipelineEngineCallbacksProcessor(FrameProcessor):
"""
Custom PipelineEngineCallbacksProcessor that accepts callbacks for various
use cases, like ending tasks when max call duration is exceeded, or informing
the engine that the bot is done speaking.
"""
def __init__(
self,
max_call_duration_seconds: int = 300,
max_duration_end_task_callback: Optional[Callable[[], Awaitable[None]]] = None,
llm_generated_text_callback: Optional[Callable[[], Awaitable[None]]] = None,
generation_started_callback: Optional[Callable[[], Awaitable[None]]] = None,
llm_text_frame_callback: Optional[Callable[[str], Awaitable[None]]] = None,
):
super().__init__()
self._start_time = None
self._max_call_duration_seconds = max_call_duration_seconds
self._max_duration_end_task_callback = max_duration_end_task_callback
self._llm_generated_text_callback = llm_generated_text_callback
self._generation_started_callback = generation_started_callback
self._llm_text_frame_callback = llm_text_frame_callback
self._end_task_frame_pushed = False
async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)
if isinstance(frame, StartFrame):
await self._start(frame)
elif isinstance(frame, HeartbeatFrame):
await self._check_call_duration()
elif isinstance(frame, LLMGeneratedTextFrame):
await self._generated_text_frame(frame)
elif isinstance(frame, LLMFullResponseStartFrame):
await self._generation_started()
elif (
isinstance(frame, (LLMTextFrame, TTSSpeakFrame))
and self._llm_text_frame_callback
):
# Include TTSSpeakFrame here since for static nodes, we send TTSSpeakFrame
# which can act as reference while fixing the aggregated trascript
await self._llm_text_frame_callback(frame.text)
await self.push_frame(frame, direction)
async def _start(self, _: StartFrame):
self._start_time = time.time()
async def _check_call_duration(self):
if self._start_time is not None:
if time.time() - self._start_time > self._max_call_duration_seconds:
if not self._end_task_frame_pushed:
if self._max_duration_end_task_callback:
await self._max_duration_end_task_callback()
self._end_task_frame_pushed = True
else:
logger.debug(
"Max call duration exceeded. Skipping EndTaskFrame since already sent"
)
async def _generated_text_frame(self, _: LLMGeneratedTextFrame):
"""Handle LLMGeneratedTextFrame."""
if self._llm_generated_text_callback is not None:
await self._llm_generated_text_callback()
async def _generation_started(self):
if self._generation_started_callback:
await self._generation_started_callback()

View file

@ -0,0 +1,162 @@
import time
from collections import defaultdict
from typing import Dict, Optional
from loguru import logger
from pipecat.frames.frames import (
CancelFrame,
EndFrame,
Frame,
MetricsFrame,
StartFrame,
)
from pipecat.metrics.metrics import (
LLMTokenUsage,
LLMUsageMetricsData,
STTUsageMetricsData,
TTSUsageMetricsData,
)
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
class PipelineMetricsAggregator(FrameProcessor):
def __init__(self):
super().__init__()
# Structure: {f"{processor}|||{model}": aggregated_metrics}
# For LLM: aggregated_metrics is LLMTokenUsage
# For TTS: aggregated_metrics is int (total characters)
# For STT: aggregated_metrics is float (total seconds)
self._start_time: Optional[float] = None
self._stop_time: Optional[float] = None
self._llm_usage_metrics: Dict[str, LLMTokenUsage] = {}
self._tts_usage_metrics: Dict[str, int] = defaultdict(int)
self._stt_usage_metrics: Dict[str, float] = defaultdict(float)
async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)
if isinstance(frame, StartFrame):
await self._start(frame)
elif isinstance(frame, EndFrame):
await self._stop(frame)
elif isinstance(frame, CancelFrame):
await self._cancel(frame)
elif isinstance(frame, MetricsFrame):
for data in frame.data:
if isinstance(data, LLMUsageMetricsData):
await self._handle_llm_usage_metrics(data)
elif isinstance(data, TTSUsageMetricsData):
await self._handle_tts_usage_metrics(data)
elif isinstance(data, STTUsageMetricsData):
await self._handle_stt_usage_metrics(data)
await self.push_frame(frame, direction)
async def _start(self, _: StartFrame):
"""Start tracking call duration."""
self._start_time = time.time()
self._stop_time = None
async def _stop(self, _: EndFrame):
"""Stop tracking call duration."""
if self._start_time is not None and self._stop_time is None:
self._stop_time = time.time()
async def _cancel(self, _: CancelFrame):
"""Handle call cancellation - also stop tracking duration."""
if self._start_time is not None and self._stop_time is None:
self._stop_time = time.time()
async def _handle_llm_usage_metrics(self, data: LLMUsageMetricsData):
key = f"{data.processor}|||{data.model}"
new_usage = data.value
if key in self._llm_usage_metrics:
# Aggregate with existing metrics
existing = self._llm_usage_metrics[key]
aggregated = LLMTokenUsage(
prompt_tokens=existing.prompt_tokens + new_usage.prompt_tokens,
completion_tokens=existing.completion_tokens
+ new_usage.completion_tokens,
total_tokens=existing.total_tokens + new_usage.total_tokens,
cache_read_input_tokens=(existing.cache_read_input_tokens or 0)
+ (new_usage.cache_read_input_tokens or 0),
cache_creation_input_tokens=(existing.cache_creation_input_tokens or 0)
+ (new_usage.cache_creation_input_tokens or 0),
)
self._llm_usage_metrics[key] = aggregated
else:
# First occurrence for this processor+model combination
self._llm_usage_metrics[key] = LLMTokenUsage(
prompt_tokens=new_usage.prompt_tokens,
completion_tokens=new_usage.completion_tokens,
total_tokens=new_usage.total_tokens,
cache_read_input_tokens=new_usage.cache_read_input_tokens,
cache_creation_input_tokens=new_usage.cache_creation_input_tokens,
)
logger.debug(f"LLM usage metrics: {self._llm_usage_metrics}")
async def _handle_tts_usage_metrics(self, data: TTSUsageMetricsData):
key = f"{data.processor}|||{data.model}"
self._tts_usage_metrics[key] += data.value
# logger.debug(f"TTS usage metrics: {self._tts_usage_metrics}")
async def _handle_stt_usage_metrics(self, data: STTUsageMetricsData):
key = f"{data.processor}|||{data.model}"
self._stt_usage_metrics[key] += data.value
logger.debug(f"STT usage metrics: {self._stt_usage_metrics}")
def get_llm_usage_metrics(self) -> Dict[str, LLMTokenUsage]:
"""Get the aggregated LLM usage metrics grouped by processor|||model."""
return self._llm_usage_metrics
def get_tts_usage_metrics(self) -> Dict[str, int]:
"""Get the aggregated TTS usage metrics grouped by processor|||model."""
return self._tts_usage_metrics
def get_stt_usage_metrics(self) -> Dict[str, float]:
"""Get the aggregated STT usage metrics grouped by processor|||model."""
return self._stt_usage_metrics
def get_call_duration(self) -> float:
"""Get call duration"""
if self._start_time is None:
return 0.0
if self._stop_time is None:
call_duration = time.time() - self._start_time
else:
call_duration = self._stop_time - self._start_time
# Lets return a rounded integer
return int(round(call_duration))
def get_all_usage_metrics_serialized(self) -> Dict[str, Dict[str, any]]:
"""Get all aggregated usage metrics in JSON-serializable format."""
serialized_llm = {}
for key, usage in self._llm_usage_metrics.items():
serialized_llm[key] = {
"prompt_tokens": usage.prompt_tokens,
"completion_tokens": usage.completion_tokens,
"total_tokens": usage.total_tokens,
"cache_read_input_tokens": usage.cache_read_input_tokens,
"cache_creation_input_tokens": usage.cache_creation_input_tokens,
}
return {
"llm": serialized_llm,
"tts": dict(self._tts_usage_metrics),
"stt": dict(self._stt_usage_metrics),
"call_duration_seconds": self.get_call_duration(),
}
def reset_metrics(self):
"""Reset all aggregated metrics."""
self._llm_usage_metrics.clear()
self._tts_usage_metrics.clear()
self._stt_usage_metrics.clear()
self._start_time = None
self._stop_time = None

View file

@ -0,0 +1,388 @@
from typing import Optional
from fastapi import HTTPException, WebSocket
from loguru import logger
from api.db import db_client
from api.enums import WorkflowRunMode
from api.services.pipecat.audio_config import AudioConfig, create_audio_config
from api.services.pipecat.engine_pre_aggregator_processor import (
EnginePreAggregatorProcessor,
)
from api.services.pipecat.event_handlers import (
register_audio_data_handler,
register_transcript_handler,
register_transport_event_handlers,
)
from api.services.pipecat.pipeline_builder import (
build_pipeline,
create_pipeline_components,
create_pipeline_task,
)
from api.services.pipecat.pipeline_engine_callbacks_processor import (
PipelineEngineCallbacksProcessor,
)
from api.services.pipecat.pipeline_metrics_aggregator import PipelineMetricsAggregator
from api.services.pipecat.service_factory import (
create_llm_service,
create_stt_service,
create_tts_service,
)
from api.services.pipecat.tracing_config import setup_pipeline_tracing
from api.services.pipecat.transport_setup import (
create_stasis_transport,
create_twilio_transport,
create_webrtc_transport,
)
from api.services.telephony.stasis_rtp_connection import StasisRTPConnection
from api.services.workflow.dto import ReactFlowDTO
from api.services.workflow.pipecat_engine import PipecatEngine
from api.services.workflow.workflow import WorkflowGraph
from pipecat.pipeline.runner import PipelineRunner
from pipecat.processors.aggregators.llm_response import LLMAssistantAggregatorParams
from pipecat.processors.filters.stt_mute_filter import (
STTMuteConfig,
STTMuteFilter,
STTMuteStrategy,
)
from pipecat.processors.user_idle_processor import UserIdleProcessor
from pipecat.transports.network.webrtc_connection import SmallWebRTCConnection
from pipecat.utils.context import set_current_run_id
from pipecat.utils.tracing.context_registry import ContextProviderRegistry
# Setup tracing if enabled
setup_pipeline_tracing()
async def run_pipeline_twilio(
websocket_client: WebSocket,
stream_sid: str,
call_sid: str,
workflow_id: int,
workflow_run_id: int,
user_id: int,
) -> None:
"""Run pipeline for Twilio connections"""
logger.debug(
f"Running pipeline for Twilio connection with workflow_id: {workflow_id} and workflow_run_id: {workflow_run_id}"
)
set_current_run_id(workflow_run_id)
# Store Twilio call SID in cost_info for later cost calculation
cost_info = {"twilio_call_sid": call_sid}
await db_client.update_workflow_run(workflow_run_id, cost_info=cost_info)
# Get workflow to extract all pipeline configurations
workflow = await db_client.get_workflow(workflow_id, user_id)
vad_config = None
ambient_noise_config = None
if workflow and workflow.workflow_configurations:
if "vad_configuration" in workflow.workflow_configurations:
vad_config = workflow.workflow_configurations["vad_configuration"]
if "ambient_noise_configuration" in workflow.workflow_configurations:
ambient_noise_config = workflow.workflow_configurations[
"ambient_noise_configuration"
]
# Create audio configuration for Twilio
audio_config = create_audio_config(WorkflowRunMode.TWILIO.value)
transport = create_twilio_transport(
websocket_client,
stream_sid,
call_sid,
workflow_run_id,
audio_config,
vad_config,
ambient_noise_config,
)
await _run_pipeline(
transport,
workflow_id,
workflow_run_id,
user_id,
audio_config=audio_config,
)
async def run_pipeline_smallwebrtc(
webrtc_connection: SmallWebRTCConnection,
workflow_id: int,
workflow_run_id: int,
user_id: int,
call_context_vars: dict = {},
) -> None:
"""Run pipeline for WebRTC connections"""
logger.debug(
f"Running pipeline for WebRTC connection with workflow_id: {workflow_id} and workflow_run_id: {workflow_run_id}"
)
set_current_run_id(workflow_run_id)
# Get workflow to extract all pipeline configurations
workflow = await db_client.get_workflow(workflow_id, user_id)
vad_config = None
ambient_noise_config = None
if workflow and workflow.workflow_configurations:
if "vad_configuration" in workflow.workflow_configurations:
vad_config = workflow.workflow_configurations["vad_configuration"]
if "ambient_noise_configuration" in workflow.workflow_configurations:
ambient_noise_config = workflow.workflow_configurations[
"ambient_noise_configuration"
]
# Create audio configuration for WebRTC
audio_config = create_audio_config(WorkflowRunMode.SMALLWEBRTC.value)
transport = create_webrtc_transport(
webrtc_connection,
workflow_run_id,
audio_config,
vad_config,
ambient_noise_config,
)
await _run_pipeline(
transport,
workflow_id,
workflow_run_id,
user_id,
call_context_vars=call_context_vars,
audio_config=audio_config,
)
async def run_pipeline_ari_stasis(
stasis_connection: StasisRTPConnection,
workflow_id: int,
workflow_run_id: int,
user_id: int,
call_context_vars: dict,
) -> None:
"""Run pipeline for ARI connections"""
logger.debug(
f"Running pipeline for ARI connection with workflow_id: {workflow_id} and workflow_run_id: {workflow_run_id}"
)
set_current_run_id(workflow_run_id)
# Get workflow to extract all pipeline configurations
workflow = await db_client.get_workflow(workflow_id, user_id)
vad_config = None
ambient_noise_config = None
if workflow and workflow.workflow_configurations:
if "vad_configuration" in workflow.workflow_configurations:
vad_config = workflow.workflow_configurations["vad_configuration"]
if "ambient_noise_configuration" in workflow.workflow_configurations:
ambient_noise_config = workflow.workflow_configurations[
"ambient_noise_configuration"
]
# Create audio configuration for Stasis
audio_config = create_audio_config(WorkflowRunMode.STASIS.value)
transport = create_stasis_transport(
stasis_connection,
workflow_run_id,
audio_config,
vad_config,
ambient_noise_config,
)
await _run_pipeline(
transport,
workflow_id,
workflow_run_id,
user_id,
call_context_vars=call_context_vars,
audio_config=audio_config,
stasis_connection=stasis_connection, # Pass connection for immediate transfers
)
async def _run_pipeline(
transport,
workflow_id: int,
workflow_run_id: int,
user_id: int,
call_context_vars: dict = {},
audio_config: AudioConfig = None,
stasis_connection: Optional[StasisRTPConnection] = None,
) -> None:
"""
Run the pipeline with the given transport and configuration
Args:
transport: The transport to use for the pipeline
workflow_id: The ID of the workflow
workflow_run_id: The ID of the workflow run
user_id: The ID of the user
mode: The mode of the pipeline (twilio or smallwebrtc)
"""
workflow_run = await db_client.get_workflow_run(workflow_run_id, user_id)
# If the workflow run is already completed, we don't need to run it again
if workflow_run.is_completed:
raise HTTPException(status_code=400, detail="Workflow run already completed")
merged_call_context_vars = workflow_run.initial_context
# If there is some extra call_context_vars, update them
if call_context_vars:
merged_call_context_vars = {**merged_call_context_vars, **call_context_vars}
await db_client.update_workflow_run(
workflow_run_id, initial_context=merged_call_context_vars
)
# Get user configuration
user_config = await db_client.get_user_configurations(user_id)
# Create services based on user configuration
stt = create_stt_service(user_config)
tts = create_tts_service(user_config, audio_config)
llm = create_llm_service(user_config)
# Get workflow first so we can create engine before pipeline components
workflow = await db_client.get_workflow(workflow_id, user_id)
if not workflow:
raise HTTPException(status_code=404, detail="Workflow not found")
# Extract configurations from workflow configurations
max_call_duration_seconds = 300 # Default 5 minutes
max_user_idle_timeout = 10.0 # Default 10 seconds
if workflow.workflow_configurations:
# Use workflow-specific max call duration if provided
if "max_call_duration" in workflow.workflow_configurations:
max_call_duration_seconds = workflow.workflow_configurations[
"max_call_duration"
]
# Use workflow-specific max user idle timeout if provided
if "max_user_idle_timeout" in workflow.workflow_configurations:
max_user_idle_timeout = workflow.workflow_configurations[
"max_user_idle_timeout"
]
workflow_graph = WorkflowGraph(
ReactFlowDTO.model_validate(workflow.workflow_definition_with_fallback)
)
engine = PipecatEngine(
llm=llm,
tts=tts,
workflow=workflow_graph,
call_context_vars=merged_call_context_vars,
workflow_run_id=workflow_run_id,
)
# Create pipeline components with audio configuration and engine
audio_buffer, audio_synchronizer, transcript, context = create_pipeline_components(
audio_config, engine
)
# Set the context and audio_buffer after creation
engine.set_context(context)
engine.set_audio_buffer(audio_buffer)
# Set Stasis connection for immediate transfers (if available)
if stasis_connection:
engine.set_stasis_connection(stasis_connection)
assistant_params = LLMAssistantAggregatorParams(
expect_stripped_words=True,
correct_aggregation_callback=engine.create_aggregation_correction_callback(),
)
context_aggregator = llm.create_context_aggregator(
context, assistant_params=assistant_params
)
# Create engine pre-aggregator processor for speaking events
engine_pre_aggregator_processor = EnginePreAggregatorProcessor(
user_started_speaking_callback=engine.create_user_started_speaking_callback(),
user_stopped_speaking_callback=engine.create_user_stopped_speaking_callback(),
)
# Create usage metrics aggregator with engine's callback
pipeline_engine_callback_processor = PipelineEngineCallbacksProcessor(
max_call_duration_seconds=max_call_duration_seconds,
max_duration_end_task_callback=engine.create_max_duration_callback(),
llm_generated_text_callback=engine.create_llm_generated_text_callback(),
generation_started_callback=engine.create_generation_started_callback(),
llm_text_frame_callback=engine.handle_llm_text_frame,
# Note: speaking event callbacks are now handled by pre-aggregator processor
)
pipeline_metrics_aggregator = PipelineMetricsAggregator()
# Create STT mute filter using the selected strategies and the engine's callback
stt_mute_filter = STTMuteFilter(
config=STTMuteConfig(
strategies={
STTMuteStrategy.MUTE_UNTIL_FIRST_BOT_COMPLETE,
STTMuteStrategy.CUSTOM,
},
should_mute_callback=engine.create_should_mute_callback(),
)
)
# Use engine's user idle callback with configured timeout
user_idle_disconnect = UserIdleProcessor(
callback=engine.create_user_idle_callback(), timeout=max_user_idle_timeout
)
user_context_aggregator = context_aggregator.user()
assistant_context_aggregator = context_aggregator.assistant()
@assistant_context_aggregator.event_handler("on_push_aggregation")
async def on_assistant_aggregator_push_context(_aggregator):
logger.debug("Assistant aggregator push context flushing pending transitions")
await engine.flush_pending_transitions(source="context_push")
# Build the pipeline with the STT mute filter and context controller
pipeline = build_pipeline(
transport,
stt,
transcript,
audio_buffer,
audio_synchronizer,
llm,
tts,
user_context_aggregator,
assistant_context_aggregator,
pipeline_engine_callback_processor,
stt_mute_filter,
pipeline_metrics_aggregator,
user_idle_disconnect,
engine_pre_aggregator_processor=engine_pre_aggregator_processor,
)
# Create pipeline task with audio configuration
task = create_pipeline_task(pipeline, workflow_run_id, audio_config)
# Now set the task on the engine
engine.set_task(task)
# Register event handlers
in_memory_audio_buffer, in_memory_transcript_buffer = (
register_transport_event_handlers(
transport,
workflow_run_id,
audio_buffer,
task,
engine=engine,
usage_metrics_aggregator=pipeline_metrics_aggregator,
audio_synchronizer=audio_synchronizer,
audio_config=audio_config,
)
)
register_audio_data_handler(
audio_synchronizer, workflow_run_id, in_memory_audio_buffer
)
register_transcript_handler(
transcript, workflow_run_id, in_memory_transcript_buffer
)
try:
# Run the pipeline
runner = PipelineRunner()
await runner.run(task)
logger.info(f"Pipeline runner completed for run {workflow_run_id}")
finally:
ContextProviderRegistry.remove_providers(str(workflow_run_id))
logger.debug(f"Cleaned up context providers for workflow run {workflow_run_id}")

View file

@ -0,0 +1,150 @@
from typing import TYPE_CHECKING
from fastapi import HTTPException
from api.constants import MPS_API_URL
from api.services.configuration.registry import ServiceProviders
from pipecat.services.azure.llm import AzureLLMService
from pipecat.services.cartesia.stt import CartesiaSTTService
from pipecat.services.deepgram.stt import DeepgramSTTService
from pipecat.services.deepgram.tts import DeepgramTTSService
from pipecat.services.dograh.llm import DograhLLMService
from pipecat.services.dograh.stt import DograhSTTService
from pipecat.services.dograh.tts import DograhTTSService
from pipecat.services.elevenlabs.tts import ElevenLabsTTSService
from pipecat.services.google.llm import GoogleLLMService
from pipecat.services.groq.llm import GroqLLMService
from pipecat.services.openai.llm import OpenAILLMService
from pipecat.services.openai.stt import OpenAISTTService
from pipecat.services.openai.tts import OpenAITTSService
if TYPE_CHECKING:
from api.services.pipecat.audio_config import AudioConfig
def create_stt_service(user_config):
"""Create and return appropriate STT service based on user configuration"""
if user_config.stt.provider == ServiceProviders.DEEPGRAM.value:
return DeepgramSTTService(
api_key=user_config.stt.api_key,
audio_passthrough=False, # Disable passthrough since audio is buffered separately
)
elif user_config.stt.provider == ServiceProviders.OPENAI.value:
return OpenAISTTService(
api_key=user_config.stt.api_key,
model=user_config.stt.model.value,
audio_passthrough=False, # Disable passthrough since audio is buffered separately
)
elif user_config.stt.provider == ServiceProviders.CARTESIA.value:
return CartesiaSTTService(
api_key=user_config.stt.api_key,
audio_passthrough=False, # Disable passthrough since audio is buffered separately
)
elif user_config.stt.provider == ServiceProviders.DOGRAH.value:
base_url = MPS_API_URL.replace("http://", "ws://").replace("https://", "wss://")
return DograhSTTService(
base_url=base_url,
api_key=user_config.stt.api_key,
model=user_config.stt.model.value,
audio_passthrough=False, # Disable passthrough since audio is buffered separately
)
else:
raise HTTPException(
status_code=400, detail=f"Invalid STT provider {user_config.stt.provider}"
)
def create_tts_service(user_config, audio_config: "AudioConfig"):
"""Create and return appropriate TTS service based on user configuration
Args:
user_config: User configuration containing TTS settings
transport_type: Type of transport (e.g., 'stasis', 'twilio', 'webrtc')
"""
if user_config.tts.provider == ServiceProviders.DEEPGRAM.value:
return DeepgramTTSService(
api_key=user_config.tts.api_key,
voice=user_config.tts.voice.value,
sample_rate=24000,
)
elif user_config.tts.provider == ServiceProviders.OPENAI.value:
return OpenAITTSService(
api_key=user_config.tts.api_key, model=user_config.tts.model.value
)
elif user_config.tts.provider == ServiceProviders.ELEVENLABS.value:
voice_id = user_config.tts.voice.split(" - ")[1]
return ElevenLabsTTSService(
reconnect_on_error=False,
api_key=user_config.tts.api_key,
voice_id=voice_id,
model=user_config.tts.model.value,
params=ElevenLabsTTSService.InputParams(
stability=0.8, speed=user_config.tts.speed, similarity_boost=0.75
),
)
elif user_config.tts.provider == ServiceProviders.DOGRAH.value:
# Convert HTTP URL to WebSocket URL for TTS
base_url = MPS_API_URL.replace("http://", "ws://").replace("https://", "wss://")
# Handle both enum and string values for model and voice
return DograhTTSService(
base_url=base_url,
api_key=user_config.tts.api_key,
model=user_config.tts.model.value,
voice=user_config.tts.voice.value,
sample_rate=24000,
)
else:
raise HTTPException(
status_code=400, detail=f"Invalid TTS provider {user_config.tts.provider}"
)
def create_llm_service(user_config):
"""Create and return appropriate LLM service based on user configuration"""
if user_config.llm.provider == ServiceProviders.OPENAI.value:
if "gpt-5" in user_config.llm.model.value:
return OpenAILLMService(
api_key=user_config.llm.api_key,
model=user_config.llm.model.value,
params=OpenAILLMService.InputParams(
reasoning_effort="minimal", verbosity="low"
),
)
else:
return OpenAILLMService(
api_key=user_config.llm.api_key,
model=user_config.llm.model.value,
params=OpenAILLMService.InputParams(temperature=0.1),
)
elif user_config.llm.provider == ServiceProviders.GROQ.value:
print(
f"Creating Groq LLM service with API key: {user_config.llm.api_key} and model: {user_config.llm.model.value}"
)
return GroqLLMService(
api_key=user_config.llm.api_key,
model=user_config.llm.model.value,
params=OpenAILLMService.InputParams(temperature=0.1),
)
elif user_config.llm.provider == ServiceProviders.GOOGLE.value:
# Use the correct InputParams class for Google to avoid propagating OpenAI-specific
# NOT_GIVEN sentinels that break Pydantic validation in GoogleLLMService.
return GoogleLLMService(
api_key=user_config.llm.api_key,
model=user_config.llm.model.value,
params=GoogleLLMService.InputParams(temperature=0.1),
)
elif user_config.llm.provider == ServiceProviders.AZURE.value:
return AzureLLMService(
api_key=user_config.llm.api_key,
endpoint=user_config.llm.endpoint,
model=user_config.llm.model.value, # Azure uses deployment name as model
params=AzureLLMService.InputParams(temperature=0.1),
)
elif user_config.llm.provider == ServiceProviders.DOGRAH.value:
return DograhLLMService(
base_url=f"{MPS_API_URL}/api/v1/llm",
api_key=user_config.llm.api_key,
model=user_config.llm.model.value,
)
else:
raise HTTPException(status_code=400, detail="Invalid LLM provider")

View file

@ -0,0 +1,44 @@
import base64
import os
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
from api.constants import ENABLE_TRACING
from pipecat.utils.tracing.setup import setup_tracing
def is_tracing_enabled():
"""Check if tracing should be enabled based on ENABLE_TRACING flag."""
# Tracing is only enabled when ENABLE_TRACING is explicitly set to true
# This makes the system OSS-friendly by default (no external dependencies required)
return ENABLE_TRACING
def setup_pipeline_tracing():
"""Setup tracing for the pipeline if enabled"""
if is_tracing_enabled():
# Only set up Langfuse if credentials are provided
langfuse_host = os.environ.get("LANGFUSE_HOST")
langfuse_public_key = os.environ.get("LANGFUSE_PUBLIC_KEY")
langfuse_secret_key = os.environ.get("LANGFUSE_SECRET_KEY")
if not all([langfuse_host, langfuse_public_key, langfuse_secret_key]):
print(
"Warning: ENABLE_TRACING is true but Langfuse credentials are not configured. Tracing disabled."
)
return
LANGFUSE_AUTH = base64.b64encode(
f"{langfuse_public_key}:{langfuse_secret_key}".encode()
).decode()
os.environ["OTEL_EXPORTER_OTLP_ENDPOINT"] = f"{langfuse_host}/api/public/otel"
os.environ["OTEL_EXPORTER_OTLP_HEADERS"] = (
f"Authorization=Basic {LANGFUSE_AUTH}"
)
otlp_exporter = OTLPSpanExporter()
setup_tracing(service_name="dograh-pipeline", exporter=otlp_exporter)
print("Langfuse tracing enabled")
else:
print("Tracing disabled (ENABLE_TRACING=false)")

View file

@ -0,0 +1,299 @@
import os
from fastapi import WebSocket
from api.constants import APP_ROOT_DIR, ENABLE_RNNOISE, ENABLE_SMART_TURN
from api.services.pipecat.audio_config import AudioConfig
from api.services.smart_turn.websocket_smart_turn import (
WebSocketSmartTurnAnalyzer,
)
from api.services.telephony.stasis_rtp_connection import StasisRTPConnection
from api.services.telephony.stasis_rtp_serializer import StasisRTPFrameSerializer
from api.services.telephony.stasis_rtp_transport import (
StasisRTPTransport,
StasisRTPTransportParams,
)
from pipecat.audio.filters.rnnoise_filter import RNNoiseFilter
from pipecat.audio.mixers.silence_audio_mixer import SilenceAudioMixer
from pipecat.audio.mixers.soundfile_mixer import SoundfileMixer
from pipecat.audio.turn.smart_turn.base_smart_turn import SmartTurnParams
from pipecat.audio.vad.silero import SileroVADAnalyzer, VADParams
from pipecat.serializers.twilio import TwilioFrameSerializer
from pipecat.transports import InternalTransport
from pipecat.transports.base_transport import TransportParams
from pipecat.transports.network.fastapi_websocket import (
FastAPIWebsocketParams,
FastAPIWebsocketTransport,
)
from pipecat.transports.network.small_webrtc import SmallWebRTCTransport
from pipecat.transports.network.webrtc_connection import SmallWebRTCConnection
librnnoise_path = os.path.normpath(
str(APP_ROOT_DIR / "native" / "rnnoise" / "librnnoise.so")
)
def create_turn_analyzer(workflow_run_id: int, audio_config: AudioConfig):
"""Create a turn analyzer backed by the local Smart Turn HTTP service.
Args:
workflow_run_id: ID of the workflow run for turn analyzer context
audio_config: Audio configuration containing pipeline sample rate
"""
if ENABLE_SMART_TURN:
service_url = os.getenv(
"SMART_TURN_WS_SERVICE_ENDPOINT", "ws://localhost:8010/ws"
)
# Prepare optional authentication headers for Smart Turn service
secret_key = os.getenv("SMART_TURN_HTTP_SERVICE_KEY")
headers = {"X-API-Key": secret_key} if secret_key else None
return WebSocketSmartTurnAnalyzer(
url=service_url,
headers=headers,
sample_rate=audio_config.pipeline_sample_rate,
params=SmartTurnParams(
stop_secs=1.5, # send turn complete if silent for stop_secs seconds
pre_speech_ms=0, # send speech segments before speech was detected by VAD
max_duration_secs=5, # max duration of speech to be sent to the end of turn analyzer
# we don't want to _clear except when we have end of turn prediction as 1 from last run
# else if we have speaking -> queit -> trigger end of turn -> clear() and then
# we have speak -> queit, we may end up sending a very small segment of speech
# to end of turn model, which is not good
use_only_last_vad_segment=False,
),
service_context=workflow_run_id,
)
return None
def create_twilio_transport(
websocket_client: WebSocket,
stream_sid: str,
call_sid: str,
workflow_run_id: int,
audio_config: AudioConfig,
vad_config: dict | None = None,
ambient_noise_config: dict | None = None,
):
"""Create a transport for Twilio connections"""
turn_analyzer = create_turn_analyzer(workflow_run_id, audio_config)
serializer = TwilioFrameSerializer(
stream_sid=stream_sid,
call_sid=call_sid,
account_sid=os.environ["TWILIO_ACCOUNT_SID"],
auth_token=os.environ["TWILIO_AUTH_TOKEN"],
)
return FastAPIWebsocketTransport(
websocket=websocket_client,
params=FastAPIWebsocketParams(
audio_in_enabled=True,
audio_out_enabled=True,
audio_in_sample_rate=audio_config.transport_in_sample_rate,
audio_out_sample_rate=audio_config.transport_out_sample_rate,
vad_analyzer=(
SileroVADAnalyzer(
params=VADParams(
confidence=vad_config.get("confidence", 0.7),
start_secs=vad_config.get("start_seconds", 0.4),
stop_secs=vad_config.get("stop_seconds", 0.8),
min_volume=vad_config.get("minimum_volume", 0.6),
)
)
if vad_config
else SileroVADAnalyzer()
), # Sample rate will be set by transport
audio_out_mixer=(
SoundfileMixer(
sound_files={
"office": APP_ROOT_DIR
/ "assets"
/ f"office-ambience-{audio_config.transport_out_sample_rate}-mono.wav"
},
default_sound="office",
volume=ambient_noise_config.get("volume", 0.3),
)
if ambient_noise_config and ambient_noise_config.get("enabled", False)
else SilenceAudioMixer()
),
turn_analyzer=turn_analyzer,
serializer=serializer,
audio_in_filter=RNNoiseFilter(library_path=librnnoise_path)
if ENABLE_RNNOISE
else None,
),
)
def create_webrtc_transport(
webrtc_connection: SmallWebRTCConnection,
workflow_run_id: int,
audio_config: AudioConfig,
vad_config: dict | None = None,
ambient_noise_config: dict | None = None,
):
"""Create a transport for WebRTC connections"""
turn_analyzer = create_turn_analyzer(workflow_run_id, audio_config)
return SmallWebRTCTransport(
webrtc_connection=webrtc_connection,
params=TransportParams(
audio_in_enabled=True,
audio_out_enabled=True,
audio_in_sample_rate=audio_config.transport_in_sample_rate,
audio_out_sample_rate=audio_config.transport_out_sample_rate,
vad_analyzer=(
SileroVADAnalyzer(
params=VADParams(
confidence=vad_config.get("confidence", 0.7),
start_secs=vad_config.get("start_seconds", 0.4),
stop_secs=vad_config.get("stop_seconds", 0.8),
min_volume=vad_config.get("minimum_volume", 0.6),
)
)
if vad_config
else SileroVADAnalyzer()
), # Sample rate will be set by transport
audio_out_mixer=(
SoundfileMixer(
sound_files={
"office": APP_ROOT_DIR
/ "assets"
/ f"office-ambience-{audio_config.transport_out_sample_rate}-mono.wav"
},
default_sound="office",
volume=ambient_noise_config.get("volume", 0.3),
)
if ambient_noise_config and ambient_noise_config.get("enabled", False)
else SilenceAudioMixer()
),
turn_analyzer=turn_analyzer,
audio_in_filter=RNNoiseFilter(library_path=librnnoise_path)
if ENABLE_RNNOISE
else None,
),
)
def create_stasis_transport(
stasis_connection: StasisRTPConnection,
workflow_run_id: int,
audio_config: AudioConfig,
vad_config: dict | None = None,
ambient_noise_config: dict | None = None,
):
"""Create a transport for ARI connections"""
turn_analyzer = create_turn_analyzer(workflow_run_id, audio_config)
serializer = StasisRTPFrameSerializer(
StasisRTPFrameSerializer.InputParams(
sample_rate=audio_config.transport_in_sample_rate
)
)
return StasisRTPTransport(
stasis_connection,
params=StasisRTPTransportParams(
audio_in_enabled=True,
audio_out_enabled=True,
audio_out_sample_rate=audio_config.transport_out_sample_rate,
audio_in_sample_rate=audio_config.transport_in_sample_rate,
audio_out_10ms_chunks=2, # Send 20ms packets
vad_analyzer=(
SileroVADAnalyzer(
params=VADParams(
confidence=vad_config.get("confidence", 0.7),
start_secs=vad_config.get("start_seconds", 0.4),
stop_secs=vad_config.get("stop_seconds", 0.8),
min_volume=vad_config.get("minimum_volume", 0.6),
)
)
if vad_config
else SileroVADAnalyzer()
), # Sample rate will be set by transport
audio_out_mixer=(
SoundfileMixer(
sound_files={
"office": APP_ROOT_DIR
/ "assets"
/ f"office-ambience-{audio_config.transport_out_sample_rate}-mono.wav"
},
default_sound="office",
volume=ambient_noise_config.get("volume", 0.3),
)
if ambient_noise_config and ambient_noise_config.get("enabled", False)
else SilenceAudioMixer()
),
turn_analyzer=turn_analyzer,
serializer=serializer,
audio_in_filter=RNNoiseFilter(library_path=librnnoise_path)
if ENABLE_RNNOISE
else None,
),
)
def create_internal_transport(
workflow_run_id: int,
audio_config: AudioConfig,
latency_seconds: float = 0.0,
vad_config: dict | None = None,
ambient_noise_config: dict | None = None,
):
"""Create an internal transport for agent-to-agent connections (LoopTalk).
Args:
workflow_run_id: ID of the workflow run for turn analyzer context
audio_config: Audio configuration for the transport
latency_seconds: Network latency to simulate
Returns:
InternalTransport instance configured with turn analyzer
"""
turn_analyzer = create_turn_analyzer(workflow_run_id, audio_config)
# Create and return the internal transport with latency
return InternalTransport(
params=TransportParams(
audio_out_enabled=True,
audio_out_sample_rate=audio_config.transport_out_sample_rate,
audio_out_channels=1,
audio_in_enabled=True,
audio_in_sample_rate=audio_config.transport_in_sample_rate,
audio_in_channels=1,
vad_analyzer=(
SileroVADAnalyzer(
params=VADParams(
confidence=vad_config.get("confidence", 0.7),
start_secs=vad_config.get("start_seconds", 0.4),
stop_secs=vad_config.get("stop_seconds", 0.8),
min_volume=vad_config.get("minimum_volume", 0.6),
)
)
if vad_config
else SileroVADAnalyzer()
),
audio_out_mixer=(
SoundfileMixer(
sound_files={
"office": APP_ROOT_DIR
/ "assets"
/ f"office-ambience-{audio_config.transport_out_sample_rate}-mono.wav"
},
default_sound="office",
volume=ambient_noise_config.get("volume", 0.3),
)
if ambient_noise_config and ambient_noise_config.get("enabled", False)
else SilenceAudioMixer()
),
turn_analyzer=turn_analyzer,
audio_in_filter=RNNoiseFilter(library_path=librnnoise_path)
if ENABLE_RNNOISE
else None,
),
latency_seconds=latency_seconds,
)

View file

@ -0,0 +1,76 @@
"""Turn context management for logging across async boundaries.
This module provides a mechanism to track turn numbers across different
async contexts, working around the limitation that contextvars don't
propagate through asyncio.create_task() calls.
"""
import asyncio
from typing import Dict, Optional
from pipecat.utils.context import turn_var
class TurnContextManager:
"""Manages turn context across async task boundaries.
This class provides a workaround for the issue where contextvars
don't propagate through asyncio.create_task() calls in the pipecat
library's event system.
"""
def __init__(self):
# Map from task to turn number
self._task_turns: Dict[asyncio.Task, int] = {}
# Store the pipeline task reference
self._pipeline_task: Optional[asyncio.Task] = None
self._current_turn: int = 0
def set_pipeline_task(self, task: asyncio.Task):
"""Set the main pipeline task reference."""
self._pipeline_task = task
def set_turn(self, turn_number: int):
"""Set the turn number for the current context."""
self._current_turn = turn_number
# Set in contextvar for direct access
turn_var.set(turn_number)
# Also store for the current task
try:
current_task = asyncio.current_task()
if current_task:
self._task_turns[current_task] = turn_number
except RuntimeError:
pass
def get_turn(self) -> int:
"""Get the turn number, trying multiple sources."""
# First try contextvar
turn = turn_var.get()
if turn > 0:
return turn
# Try current task mapping
try:
current_task = asyncio.current_task()
if current_task and current_task in self._task_turns:
return self._task_turns[current_task]
except RuntimeError:
pass
# Fall back to stored current turn
return self._current_turn
def cleanup_task(self, task: asyncio.Task):
"""Clean up turn mapping for completed tasks."""
self._task_turns.pop(task, None)
# Global instance
_turn_context_manager = TurnContextManager()
def get_turn_context_manager() -> TurnContextManager:
"""Get the global turn context manager instance."""
return _turn_context_manager

View file

@ -0,0 +1,76 @@
# Pricing Module
This module contains pricing models and registries for different AI services used in workflow cost calculations.
## Structure
```
pricing/
├── __init__.py # Main module exports
├── models.py # Base pricing model classes
├── llm.py # LLM pricing configurations
├── tts.py # TTS pricing configurations
├── stt.py # STT pricing configurations
├── registry.py # Combined pricing registry
└── README.md # This file
```
## Pricing Models
### TokenPricingModel
Used for LLM services that charge based on tokens:
- `prompt_token_price`: Cost per prompt token
- `completion_token_price`: Cost per completion token
- `cache_read_discount`: Discount for cache read tokens (default 50%)
- `cache_creation_multiplier`: Premium for cache creation tokens (default 25%)
### CharacterPricingModel
Used for TTS services that charge based on character count:
- `character_price`: Cost per character
### TimePricingModel
Used for STT services that charge based on time:
- `second_price`: Cost per second
## Adding New Pricing
### Adding a New LLM Model
Edit `llm.py` and add the model to the appropriate provider:
```python
ServiceProviders.OPENAI: {
"new-model": TokenPricingModel(
prompt_token_price=Decimal("2.00") / 1000000,
completion_token_price=Decimal("8.00") / 1000000,
),
# ... existing models
}
```
### Adding a New Provider
1. Add pricing configurations to the appropriate service file (llm.py, tts.py, stt.py)
2. The registry will automatically include them
### Adding a New Service Type
1. Create a new pricing file (e.g., `image.py`)
2. Define the pricing models
3. Import and add to `registry.py`
## Usage
The pricing registry is automatically imported and used by the cost calculator:
```python
from api.services.pricing import PRICING_REGISTRY
from api.services.workflow.cost_calculator import cost_calculator
# The cost calculator uses the pricing registry automatically
result = cost_calculator.calculate_total_cost(usage_info)
```
## Maintenance
- Update pricing when providers change their rates
- All prices should use `Decimal` for precision
- Include comments with current pricing from provider documentation
- Test changes with existing test suite

View file

@ -0,0 +1,9 @@
"""
Pricing module for workflow cost calculation.
This module contains pricing models and registries for different AI services.
"""
from .registry import PRICING_REGISTRY
__all__ = ["PRICING_REGISTRY"]

View file

@ -0,0 +1,228 @@
"""
Cost Calculator for Workflow Runs
This module provides a comprehensive cost calculation system for workflow runs based on usage metrics
from different AI service providers (OpenAI, Groq, Deepgram, etc.).
Features:
- Token-based pricing for LLM services with cache optimization support
- Character-based pricing for TTS services
- Time-based pricing for STT services
- Configurable pricing models that can be updated
- Support for multiple providers and models
- Automatic provider inference from model names
- JSON serialization support for database storage
Usage:
from api.tasks.cost_calculator import cost_calculator
usage_info = {
"llm": {
"processor_name|||gpt-4o": {
"prompt_tokens": 1000,
"completion_tokens": 500,
"total_tokens": 1500,
"cache_read_input_tokens": 0,
"cache_creation_input_tokens": 0
}
},
"tts": {
"processor_name|||aura-2-helena-en": 2000 # character count
}
}
cost_breakdown = cost_calculator.calculate_total_cost(usage_info)
print(f"Total cost: ${cost_breakdown['total']:.6f}")
"""
from decimal import Decimal
from typing import Any, Dict, Optional, Tuple
from api.services.configuration.registry import ServiceProviders
from api.services.pricing import PRICING_REGISTRY
from api.services.pricing.models import (
PricingModel,
)
class CostCalculator:
"""Main cost calculator class"""
def __init__(self, pricing_registry: Dict = None):
self.pricing_registry = pricing_registry or PRICING_REGISTRY
def get_pricing_model(
self, service_type: str, provider: str, model: str
) -> Optional[PricingModel]:
"""Get pricing model for a specific service, provider, and model"""
try:
service_pricing = self.pricing_registry.get(service_type, {})
# Try to get pricing for the specific provider
provider_pricing = service_pricing.get(provider, {})
pricing_model = provider_pricing.get(model) or provider_pricing.get(
"default"
)
if pricing_model:
return pricing_model
# If not found, try the "default" provider for this service type
default_provider_pricing = service_pricing.get("default", {})
return default_provider_pricing.get(model) or default_provider_pricing.get(
"default"
)
except (KeyError, AttributeError):
return None
def calculate_llm_cost(
self, provider: str, model: str, usage: Dict[str, int]
) -> Decimal:
"""Calculate cost for LLM usage"""
pricing_model = self.get_pricing_model("llm", provider, model)
if not pricing_model:
return Decimal("0")
return pricing_model.calculate_cost(usage)
def calculate_tts_cost(
self, provider: str, model: str, character_count: int
) -> Decimal:
"""Calculate cost for TTS usage"""
pricing_model = self.get_pricing_model("tts", provider, model)
if not pricing_model:
return Decimal("0")
return pricing_model.calculate_cost(character_count)
def calculate_stt_cost(self, provider: str, model: str, seconds: float) -> Decimal:
"""Calculate cost for STT usage"""
pricing_model = self.get_pricing_model("stt", provider, model)
if not pricing_model:
return Decimal("0")
return pricing_model.calculate_cost(seconds)
def calculate_total_cost(self, usage_info: Dict) -> Dict[str, Any]:
llm_cost_total = Decimal("0")
tts_cost_total = Decimal("0")
stt_cost_total = Decimal("0")
# Calculate LLM costs
llm_usage = usage_info.get("llm", {})
for key, usage in llm_usage.items():
processor, model = self._parse_key(key)
# Try to determine provider from processor name or model
provider = self._infer_provider_from_model(model, "llm")
cost = self.calculate_llm_cost(provider, model, usage)
llm_cost_total += cost
# Calculate TTS costs
tts_usage = usage_info.get("tts", {})
for key, character_count in tts_usage.items():
processor, model = self._parse_key(key)
# Handle the case where model is "None" - infer from processor
if model.lower() in ["none", "null", ""]:
provider = self._infer_provider_from_processor(processor, "tts")
model = "default" # Use default model for the provider
else:
provider = self._infer_provider_from_model(model, "tts")
cost = self.calculate_tts_cost(provider, model, character_count)
tts_cost_total += cost
# Calculate STT costs from explicit stt usage
stt_usage = usage_info.get("stt", {})
for key, seconds in stt_usage.items():
processor, model = self._parse_key(key)
provider = self._infer_provider_from_model(model, "stt")
cost = self.calculate_stt_cost(provider, model, seconds)
stt_cost_total += cost
total_cost = llm_cost_total + tts_cost_total + stt_cost_total
return {
"llm_cost": float(llm_cost_total),
"tts_cost": float(tts_cost_total),
"stt_cost": float(stt_cost_total),
"total": float(total_cost),
}
def _parse_key(self, key) -> Tuple[str, str]:
"""Parse key which is in format 'processor|||model'"""
if isinstance(key, str) and "|||" in key:
parts = key.split("|||", 1)
return parts[0], parts[1]
else:
# Fallback for backwards compatibility or malformed keys
return str(key), "unknown"
def _infer_provider_from_model(self, model: str, service_type: str) -> str:
"""Infer provider from model name"""
if not model:
return "unknown"
model_lower = model.lower()
# OpenAI models
if any(keyword in model_lower for keyword in ["gpt", "whisper", "openai"]):
return ServiceProviders.OPENAI
# Groq models
if any(keyword in model_lower for keyword in ["groq"]):
return ServiceProviders.GROQ
# Elevenlabs models
if any(keyword in model_lower for keyword in ["eleven"]):
return ServiceProviders.ELEVENLABS
# Deepgram models
if any(
keyword in model_lower
for keyword in ["deepgram", "nova", "phonecall", "general"]
):
return ServiceProviders.DEEPGRAM
# Default to first available provider for the service type
service_providers = self.pricing_registry.get(service_type, {})
if service_providers:
return list(service_providers.keys())[0]
return "unknown"
def _infer_provider_from_processor(self, processor: str, service_type: str) -> str:
"""Infer provider from processor name"""
if not processor:
return "unknown"
processor_lower = processor.lower()
# OpenAI processors
if any(keyword in processor_lower for keyword in ["openai", "gpt"]):
return ServiceProviders.OPENAI
# Groq processors
if any(keyword in processor_lower for keyword in ["groq"]):
return ServiceProviders.GROQ
# Deepgram processors
if any(keyword in processor_lower for keyword in ["deepgram"]):
return ServiceProviders.DEEPGRAM
# Default to first available provider for the service type
service_providers = self.pricing_registry.get(service_type, {})
if service_providers:
return list(service_providers.keys())[0]
return "unknown"
def update_pricing(
self, service_type: str, provider: str, model: str, pricing_model: PricingModel
):
"""Update pricing for a specific service/provider/model combination"""
if service_type not in self.pricing_registry:
self.pricing_registry[service_type] = {}
if provider not in self.pricing_registry[service_type]:
self.pricing_registry[service_type][provider] = {}
self.pricing_registry[service_type][provider][model] = pricing_model
# Global cost calculator instance
cost_calculator = CostCalculator()

143
api/services/pricing/llm.py Normal file
View file

@ -0,0 +1,143 @@
"""
LLM pricing models for different providers.
Prices are per 1000 tokens for most models, with some newer models priced per million tokens.
"""
from decimal import Decimal
from typing import Dict
from api.services.configuration.registry import ServiceProviders
from .models import TokenPricingModel
# LLM pricing registry
LLM_PRICING: Dict[str, Dict[str, TokenPricingModel]] = {
ServiceProviders.OPENAI: {
"gpt-3.5-turbo": TokenPricingModel(
prompt_token_price=Decimal("0.0015") / 1000, # $0.0015 per 1K tokens
completion_token_price=Decimal("0.002") / 1000, # $0.002 per 1K tokens
),
"gpt-4": TokenPricingModel(
prompt_token_price=Decimal("0.03") / 1000, # $0.03 per 1K tokens
completion_token_price=Decimal("0.06") / 1000, # $0.06 per 1K tokens
),
"gpt-4.1": TokenPricingModel(
prompt_token_price=Decimal("2.00") / 1000000, # $2.00 per 1M tokens
completion_token_price=Decimal("8.00") / 1000000, # $8.00 per 1M tokens
),
"gpt-4.1-mini": TokenPricingModel(
prompt_token_price=Decimal("0.40") / 1000000, # $0.40 per 1M tokens
completion_token_price=Decimal("1.60") / 1000000, # $1.60 per 1M tokens
),
"gpt-4.1-nano": TokenPricingModel(
prompt_token_price=Decimal("0.10") / 1000000, # $0.10 per 1M tokens
completion_token_price=Decimal("0.40") / 1000000, # $0.40 per 1M tokens
),
"gpt-4.5-preview": TokenPricingModel(
prompt_token_price=Decimal("75.00") / 1000000, # $75.00 per 1M tokens
completion_token_price=Decimal("150.00") / 1000000, # $150.00 per 1M tokens
),
"gpt-4o": TokenPricingModel(
prompt_token_price=Decimal("2.50") / 1000000, # $2.50 per 1M tokens - FIXED
completion_token_price=Decimal("10.00")
/ 1000000, # $10.00 per 1M tokens - FIXED
),
"gpt-4o-audio-preview": TokenPricingModel(
prompt_token_price=Decimal("2.50") / 1000000, # $2.50 per 1M tokens
completion_token_price=Decimal("10.00") / 1000000, # $10.00 per 1M tokens
),
"gpt-4o-realtime-preview": TokenPricingModel(
prompt_token_price=Decimal("5.00") / 1000000, # $5.00 per 1M tokens
completion_token_price=Decimal("20.00") / 1000000, # $20.00 per 1M tokens
),
"gpt-4o-mini": TokenPricingModel(
prompt_token_price=Decimal("0.15") / 1000000, # $0.15 per 1M tokens
completion_token_price=Decimal("0.60") / 1000000, # $0.60 per 1M tokens
),
"gpt-4o-mini-audio-preview": TokenPricingModel(
prompt_token_price=Decimal("0.15") / 1000000, # $0.15 per 1M tokens
completion_token_price=Decimal("0.60") / 1000000, # $0.60 per 1M tokens
),
"gpt-4o-mini-realtime-preview": TokenPricingModel(
prompt_token_price=Decimal("0.60") / 1000000, # $0.60 per 1M tokens
completion_token_price=Decimal("2.40") / 1000000, # $2.40 per 1M tokens
),
"gpt-4o-search-preview": TokenPricingModel(
prompt_token_price=Decimal("2.50") / 1000000, # $2.50 per 1M tokens
completion_token_price=Decimal("10.00") / 1000000, # $10.00 per 1M tokens
),
"gpt-4o-mini-search-preview": TokenPricingModel(
prompt_token_price=Decimal("0.15") / 1000000, # $0.15 per 1M tokens
completion_token_price=Decimal("0.60") / 1000000, # $0.60 per 1M tokens
),
"o1": TokenPricingModel(
prompt_token_price=Decimal("15.00") / 1000000, # $15.00 per 1M tokens
completion_token_price=Decimal("60.00") / 1000000, # $60.00 per 1M tokens
),
"o1-pro": TokenPricingModel(
prompt_token_price=Decimal("150.00") / 1000000, # $150.00 per 1M tokens
completion_token_price=Decimal("600.00") / 1000000, # $600.00 per 1M tokens
),
"o1-mini": TokenPricingModel(
prompt_token_price=Decimal("1.10") / 1000000, # $1.10 per 1M tokens
completion_token_price=Decimal("4.40") / 1000000, # $4.40 per 1M tokens
),
"o3": TokenPricingModel(
prompt_token_price=Decimal("10.00") / 1000000, # $10.00 per 1M tokens
completion_token_price=Decimal("40.00") / 1000000, # $40.00 per 1M tokens
),
"o3-mini": TokenPricingModel(
prompt_token_price=Decimal("1.10") / 1000000, # $1.10 per 1M tokens
completion_token_price=Decimal("4.40") / 1000000, # $4.40 per 1M tokens
),
"o4-mini": TokenPricingModel(
prompt_token_price=Decimal("1.10") / 1000000, # $1.10 per 1M tokens
completion_token_price=Decimal("4.40") / 1000000, # $4.40 per 1M tokens
),
"computer-use-preview": TokenPricingModel(
prompt_token_price=Decimal("3.00") / 1000000, # $3.00 per 1M tokens
completion_token_price=Decimal("12.00") / 1000000, # $12.00 per 1M tokens
),
"gpt-image-1": TokenPricingModel(
prompt_token_price=Decimal("5.00") / 1000000, # $5.00 per 1M tokens
completion_token_price=Decimal("0") / 1000000, # No output pricing shown
),
"codex-mini-latest": TokenPricingModel(
prompt_token_price=Decimal("1.50") / 1000000, # $1.50 per 1M tokens
completion_token_price=Decimal("6.00") / 1000000, # $6.00 per 1M tokens
),
# Transcription models
"gpt-4o-transcribe": TokenPricingModel(
prompt_token_price=Decimal("2.50") / 1000000, # $2.50 per 1M tokens
completion_token_price=Decimal("10.00") / 1000000, # $10.00 per 1M tokens
),
"gpt-4o-mini-transcribe": TokenPricingModel(
prompt_token_price=Decimal("1.25") / 1000000, # $1.25 per 1M tokens
completion_token_price=Decimal("5.00") / 1000000, # $5.00 per 1M tokens
),
# TTS models with token-based pricing
"gpt-4o-mini-tts": TokenPricingModel(
prompt_token_price=Decimal("0.60") / 1000000, # $0.60 per 1M tokens
completion_token_price=Decimal("0")
/ 1000000, # No completion tokens for TTS
),
},
ServiceProviders.GROQ: {
"llama-3.3-70b-versatile": TokenPricingModel(
prompt_token_price=Decimal("0.00059") / 1000, # $0.00059 per 1K tokens
completion_token_price=Decimal("0.00079") / 1000, # $0.00079 per 1K tokens
),
"deepseek-r1-distill-llama-70b": TokenPricingModel(
prompt_token_price=Decimal("0.00059") / 1000, # Assuming similar pricing
completion_token_price=Decimal("0.00079") / 1000,
),
},
ServiceProviders.AZURE: {
"gpt-4.1-mini": TokenPricingModel(
prompt_token_price=Decimal("0.44") / 1000000, # $0.40 per 1M tokens
completion_token_price=Decimal("8.80")
/ 1000000, # $1.60 per 1M tokens if using data zone
)
},
}

View file

@ -0,0 +1,89 @@
"""
Base pricing models for different service types.
"""
from decimal import Decimal
from enum import Enum
from typing import Any, Dict
class CostType(Enum):
LLM_TOKENS = "llm_tokens"
TTS_CHARACTERS = "tts_characters"
STT_SECONDS = "stt_seconds"
class PricingModel:
"""Base class for pricing models"""
def calculate_cost(self, usage: Any) -> Decimal:
"""Calculate cost based on usage"""
raise NotImplementedError
class TokenPricingModel(PricingModel):
"""Pricing model for token-based services (LLM)"""
def __init__(
self,
prompt_token_price: Decimal,
completion_token_price: Decimal,
cache_read_discount: Decimal = Decimal("0.5"), # 50% discount for cache reads
cache_creation_multiplier: Decimal = Decimal(
"1.25"
), # 25% premium for cache creation
):
self.prompt_token_price = prompt_token_price
self.completion_token_price = completion_token_price
self.cache_read_discount = cache_read_discount
self.cache_creation_multiplier = cache_creation_multiplier
def calculate_cost(self, usage: Dict[str, int]) -> Decimal:
"""Calculate cost for LLM token usage"""
prompt_tokens = usage.get("prompt_tokens", 0)
completion_tokens = usage.get("completion_tokens", 0)
cache_read_tokens = usage.get("cache_read_input_tokens") or 0
cache_creation_tokens = usage.get("cache_creation_input_tokens") or 0
# Base cost
prompt_cost = Decimal(prompt_tokens) * self.prompt_token_price
completion_cost = Decimal(completion_tokens) * self.completion_token_price
# Cache adjustments
cache_read_savings = (
Decimal(cache_read_tokens)
* self.prompt_token_price
* self.cache_read_discount
)
cache_creation_premium = (
Decimal(cache_creation_tokens)
* self.prompt_token_price
* (self.cache_creation_multiplier - 1)
)
total_cost = (
prompt_cost + completion_cost - cache_read_savings + cache_creation_premium
)
return max(total_cost, Decimal("0")) # Ensure non-negative
class CharacterPricingModel(PricingModel):
"""Pricing model for character-based services (TTS)"""
def __init__(self, character_price: Decimal):
self.character_price = character_price
def calculate_cost(self, character_count: int) -> Decimal:
"""Calculate cost for TTS character usage"""
return Decimal(character_count) * self.character_price
class TimePricingModel(PricingModel):
"""Pricing model for time-based services (STT)"""
def __init__(self, second_price: Decimal):
self.second_price = second_price
def calculate_cost(self, seconds: float) -> Decimal:
"""Calculate cost for STT time usage"""
return Decimal(str(seconds)) * self.second_price

View file

@ -0,0 +1,16 @@
"""
Main pricing registry that combines all service type pricing models.
"""
from typing import Dict
from .llm import LLM_PRICING
from .stt import STT_PRICING
from .tts import TTS_PRICING
# Combined pricing registry for all service types
PRICING_REGISTRY: Dict = {
"llm": LLM_PRICING,
"tts": TTS_PRICING,
"stt": STT_PRICING,
}

View file

@ -0,0 +1,26 @@
"""
STT (Speech-to-Text) pricing models for different providers.
Prices are per second for STT services.
"""
from decimal import Decimal
from typing import Dict
from api.services.configuration.registry import ServiceProviders
from .models import TimePricingModel
# STT pricing registry
STT_PRICING: Dict[str, Dict[str, TimePricingModel]] = {
ServiceProviders.DEEPGRAM: {
"nova-3-general": TimePricingModel(Decimal("0.0077") / 60),
"nova-2": TimePricingModel(Decimal("0.0058") / 60),
"default": TimePricingModel(Decimal("0.0077") / 60),
},
ServiceProviders.OPENAI: {
"gpt-4o-transcribe": TimePricingModel(Decimal("0.015") / 60),
"default": TimePricingModel(Decimal("0.015") / 60),
},
"default": {"default": TimePricingModel(Decimal("0.0077") / 60)},
}

View file

@ -0,0 +1,30 @@
"""
TTS (Text-to-Speech) pricing models for different providers.
Prices are per character for TTS services.
"""
from decimal import Decimal
from typing import Dict
from api.services.configuration.registry import ServiceProviders
from .models import CharacterPricingModel
# TTS pricing registry
TTS_PRICING: Dict[str, Dict[str, CharacterPricingModel]] = {
ServiceProviders.OPENAI: {
"gpt-4o-mini-tts": CharacterPricingModel(Decimal("0.6") / 1_00_00_000),
"default": CharacterPricingModel(Decimal("0.6") / 1_00_00_000),
},
ServiceProviders.DEEPGRAM: {
"aura-2": CharacterPricingModel(Decimal("0.030") / 1_000),
"aura-1": CharacterPricingModel(Decimal("0.015") / 1_000),
"default": CharacterPricingModel(Decimal("0.030") / 1_000),
},
ServiceProviders.ELEVENLABS: {
# 6400 usd per 250*1e6 characters
"default": CharacterPricingModel(Decimal("0.0256") / 1_000)
},
"default": {"default": CharacterPricingModel(Decimal("0.030") / 1_000)},
}

View file

@ -0,0 +1,3 @@
from .daily_report import DailyReportService
__all__ = ["DailyReportService"]

View file

@ -0,0 +1,237 @@
from datetime import datetime, time
from typing import Any, Dict, List, Optional
from zoneinfo import ZoneInfo
from api.db import db_client
class DailyReportService:
async def get_daily_report(
self,
organization_id: int,
date: str,
timezone: str,
workflow_id: Optional[int] = None,
) -> Dict[str, Any]:
"""
Get daily report for a specific date and timezone.
Args:
organization_id: The organization ID to filter by
date: Date in YYYY-MM-DD format
timezone: IANA timezone string (e.g., "America/New_York")
workflow_id: Optional workflow ID to filter by (None means all workflows)
"""
# Parse date and timezone
tz = ZoneInfo(timezone)
date_obj = datetime.strptime(date, "%Y-%m-%d")
# Create start and end datetime in the specified timezone
start_dt = datetime.combine(date_obj, time.min, tzinfo=tz)
end_dt = datetime.combine(date_obj, time.max, tzinfo=tz)
# Convert to UTC for database queries
start_utc = start_dt.astimezone(ZoneInfo("UTC"))
end_utc = end_dt.astimezone(ZoneInfo("UTC"))
# Get workflow runs from database (optimized - only required fields)
runs = await db_client.get_workflow_runs_for_daily_report(
organization_id=organization_id,
start_utc=start_utc,
end_utc=end_utc,
workflow_id=workflow_id,
)
# Calculate metrics
total_runs = len(runs)
xfer_count = sum(
1
for run in runs
if run["gathered_context"]
and run["gathered_context"].get("mapped_call_disposition") == "XFER"
)
# Calculate disposition distribution
disposition_counts = {}
for run in runs:
if run["gathered_context"]:
disposition = run["gathered_context"].get(
"mapped_call_disposition", "UNKNOWN"
)
disposition_counts[disposition] = (
disposition_counts.get(disposition, 0) + 1
)
# Sort dispositions by count and get top 5
sorted_dispositions = sorted(
disposition_counts.items(), key=lambda x: x[1], reverse=True
)
disposition_distribution = []
other_count = 0
for i, (disposition, count) in enumerate(sorted_dispositions):
if i < 5:
disposition_distribution.append(
{
"disposition": disposition,
"count": count,
"percentage": round(
(count / total_runs * 100) if total_runs > 0 else 0, 2
),
}
)
else:
other_count += count
# Add "Other" category if there are more than 5 dispositions
if other_count > 0:
disposition_distribution.append(
{
"disposition": "Other",
"count": other_count,
"percentage": round(
(other_count / total_runs * 100) if total_runs > 0 else 0, 2
),
}
)
# Calculate call duration distribution
duration_buckets = {
"0-10": {"range_start": 0, "range_end": 10, "count": 0},
"10-30": {"range_start": 10, "range_end": 30, "count": 0},
"30-60": {"range_start": 30, "range_end": 60, "count": 0},
"60-120": {"range_start": 60, "range_end": 120, "count": 0},
"120-180": {"range_start": 120, "range_end": 180, "count": 0},
">180": {"range_start": 180, "range_end": None, "count": 0},
}
for run in runs:
if run["usage_info"]:
duration_str = run["usage_info"].get("call_duration_seconds")
if duration_str:
try:
duration = float(duration_str)
if duration < 10:
duration_buckets["0-10"]["count"] += 1
elif duration < 30:
duration_buckets["10-30"]["count"] += 1
elif duration < 60:
duration_buckets["30-60"]["count"] += 1
elif duration < 120:
duration_buckets["60-120"]["count"] += 1
elif duration < 180:
duration_buckets["120-180"]["count"] += 1
else:
duration_buckets[">180"]["count"] += 1
except (ValueError, TypeError):
pass
# Format duration distribution
call_duration_distribution = []
total_calls_with_duration = sum(b["count"] for b in duration_buckets.values())
for bucket_name, bucket_data in duration_buckets.items():
call_duration_distribution.append(
{
"bucket": bucket_name,
"range_start": bucket_data["range_start"],
"range_end": bucket_data["range_end"],
"count": bucket_data["count"],
"percentage": round(
(bucket_data["count"] / total_calls_with_duration * 100)
if total_calls_with_duration > 0
else 0,
2,
),
}
)
return {
"date": date,
"timezone": timezone,
"workflow_id": workflow_id,
"metrics": {"total_runs": total_runs, "xfer_count": xfer_count},
"disposition_distribution": disposition_distribution,
"call_duration_distribution": call_duration_distribution,
}
async def get_workflows_for_organization(
self, organization_id: int
) -> List[Dict[str, Any]]:
"""
Get all workflows for an organization.
"""
workflows = await db_client.get_workflows_for_organization(organization_id)
return [{"id": workflow.id, "name": workflow.name} for workflow in workflows]
async def get_daily_runs_detail(
self,
organization_id: int,
date: str,
timezone: str,
workflow_id: Optional[int] = None,
) -> List[Dict[str, Any]]:
"""
Get detailed workflow runs for CSV export.
Args:
organization_id: The organization ID to filter by
date: Date in YYYY-MM-DD format
timezone: IANA timezone string (e.g., "America/New_York")
workflow_id: Optional workflow ID to filter by
"""
# Parse date and timezone
tz = ZoneInfo(timezone)
date_obj = datetime.strptime(date, "%Y-%m-%d")
# Create start and end datetime in the specified timezone
start_dt = datetime.combine(date_obj, time.min, tzinfo=tz)
end_dt = datetime.combine(date_obj, time.max, tzinfo=tz)
# Convert to UTC for database queries
start_utc = start_dt.astimezone(ZoneInfo("UTC"))
end_utc = end_dt.astimezone(ZoneInfo("UTC"))
# Get workflow runs from database (optimized - only required fields)
runs = await db_client.get_workflow_runs_for_daily_report(
organization_id=organization_id,
start_utc=start_utc,
end_utc=end_utc,
workflow_id=workflow_id,
)
# Format runs for CSV export
detailed_runs = []
for run in runs:
# Phone number is already extracted at the database level
# Try customer_phone_number first, then fall back to initial_context
phone_number = run["gathered_context"].get(
"customer_phone_number", ""
) or run["initial_context"].get("phone_number", "")
# Disposition is already extracted at the database level
disposition = run["gathered_context"].get("mapped_call_disposition", "")
# Duration is already extracted at the database level
duration_seconds = 0
duration_str = run["usage_info"].get("call_duration_seconds", "0")
try:
duration_seconds = float(duration_str)
except (ValueError, TypeError):
duration_seconds = 0
detailed_runs.append(
{
"phone_number": phone_number,
"disposition": disposition,
"duration_seconds": duration_seconds,
"workflow_id": run["workflow_id"],
"run_id": run["id"],
"workflow_name": run["workflow_name"],
"created_at": run["created_at"].isoformat(),
}
)
return detailed_runs

View file

@ -0,0 +1,3 @@
from .websocket_smart_turn import WebSocketSmartTurnAnalyzer
__all__ = ["WebSocketSmartTurnAnalyzer"]

View file

@ -0,0 +1,478 @@
import asyncio
import io
import json
import logging
import os
import sys
import time
from contextlib import asynccontextmanager
from datetime import datetime
from pathlib import Path
import numpy as np
from fastapi import (
BackgroundTasks,
FastAPI,
HTTPException,
Request,
WebSocket,
WebSocketDisconnect,
WebSocketException,
status,
)
from fastapi.websockets import WebSocketState
from pipecat.audio.turn.smart_turn.local_smart_turn_v2 import LocalSmartTurnAnalyzerV2
from scipy.io import wavfile
LOG_LEVEL = (
logging.DEBUG
if os.environ.get("LOG_LEVEL", "DEBUG").lower() == "debug"
else logging.INFO
)
logger = logging.getLogger("smart_turn")
logger.setLevel(LOG_LEVEL)
handler = logging.StreamHandler(sys.stdout)
handler.setFormatter(
logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
)
logger.addHandler(handler)
# ----------------------------------------------------------------------------
# Configuration
# ----------------------------------------------------------------------------
MODEL_PATH = os.getenv("LOCAL_SMART_TURN_MODEL_PATH", "pipecat-ai/smart-turn-v2")
# ----------------------------------------------------------------------------
# Analyzer Pool
# ----------------------------------------------------------------------------
class _AnalyzerWrapper:
"""Wraps a LocalSmartTurnAnalyzer with a lock so only one request can use it at a time."""
def __init__(self, analyzer: LocalSmartTurnAnalyzerV2):
self.analyzer = analyzer
self.lock = asyncio.Lock()
_analyzer_wrapper: _AnalyzerWrapper | None = None # Will be initialised in the lifespan
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Manage the application lifespan - startup and shutdown logic."""
# Startup logic
global _analyzer_wrapper
if _analyzer_wrapper is None:
logger.debug("Initializing LocalSmartTurnAnalyzer")
analyzer = LocalSmartTurnAnalyzerV2(smart_turn_model_path=MODEL_PATH)
_analyzer_wrapper = _AnalyzerWrapper(analyzer)
logger.debug("LocalSmartTurnAnalyzer initialized")
yield # Application runs here
# Shutdown logic (if needed in the future)
# Any cleanup code would go here
app = FastAPI(
title="Smart Turn API",
description="A FastAPI application exposing LocalSmartTurnAnalyzer via HTTP",
lifespan=lifespan,
)
# ----------------------------------------------------------------------------
# API Endpoints
# ----------------------------------------------------------------------------
async def save_wav_file(
audio_array: np.ndarray,
prediction: int,
probability: float,
service_id: str | None = None,
sample_rate: int = 16000,
) -> None:
"""Save audio data as a WAV file in the background.
Runs the blocking ``wavfile.write`` call in a thread so that the event loop
is not blocked. This function is now ``async`` so it can be scheduled with
``asyncio.create_task`` from the WebSocket endpoint, while still being
compatible with ``BackgroundTasks`` (which will ``await`` coroutine
functions).
Args:
audio_array: The audio data as a numpy array
prediction: The prediction result (0 or 1)
probability: The probability of the prediction
service_id: Optional service identifier
sample_rate: The sample rate of the audio (default: 16000 Hz)
"""
def _blocking_save() -> None:
try:
# Generate filename with current timestamp and prediction
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")[:-3] # Include ms
# Include service_id in filename if available
service_prefix = f"{service_id}_" if service_id else ""
root_dir = (
Path(__file__).resolve().parents[3]
) # dograh/api/services/smart_turn/app.py
filename = (
root_dir
/ f"smart_turn_pipeline/{service_prefix}{timestamp}_{prediction}_{probability}.wav"
)
# Convert float32 [-1, 1] back to int16 PCM for WAV file
audio_int16 = np.clip(audio_array * 32767, -32768, 32767).astype(np.int16)
# Use provided sample rate
wavfile.write(filename, sample_rate, audio_int16)
length_seconds = len(audio_array) / sample_rate
log_message = f"Saved audio to {filename} (length: {length_seconds:.2f}s, prediction: {prediction}"
if service_id:
log_message += f", service_id: {service_id}"
log_message += ")"
logger.info(log_message)
except Exception as exc: # pragma: no cover best-effort logging only
log_message = f"Failed to save WAV file: {exc}"
if service_id:
log_message += f" (service_id: {service_id})"
logger.error(log_message)
# Offload the blocking I/O to a thread to avoid blocking the event loop
await asyncio.to_thread(_blocking_save)
@app.post("/raw", status_code=status.HTTP_200_OK)
async def handle_raw(request: Request, background_tasks: BackgroundTasks):
"""
Accept a NumPy-serialized float32 array (written via ``np.save``) in the body and
return a JSON prediction compatible with ``HttpSmartTurnAnalyzer``.
"""
# ------------------------------------------------------------------
# Secret key validation
# ------------------------------------------------------------------
expected_secret = os.getenv("SMART_TURN_HTTP_SERVICE_KEY")
if expected_secret: # If a secret is configured, enforce validation
provided_secret = request.headers.get("X-API-Key")
if provided_secret != expected_secret:
logger.warning(
"Unauthorized access attempt to /raw endpoint with invalid or missing secret key"
)
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Unauthorized",
)
# ------------------------------------------------------------------
# Start total-time measurement as early as possible
# ------------------------------------------------------------------
request_start_time = time.perf_counter()
# ------------------------------------------------------------------
# Log that we received a request (before doing any heavy work)
# ------------------------------------------------------------------
logger.debug("Received /raw request")
body = await request.body()
if not body:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail="Empty request body"
)
# Extract service context and sample rate from headers
service_id = request.headers.get("X-Service-Context")
sample_rate_str = request.headers.get("X-Sample-Rate")
sample_rate = int(sample_rate_str) if sample_rate_str else 16000
# Deserialize NumPy array
try:
audio_array = np.load(io.BytesIO(body))
except Exception as exc:
error_msg = f"Invalid NumPy payload: {exc}"
if service_id:
error_msg += f" (service_id: {service_id})"
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=error_msg,
)
wrapper = _analyzer_wrapper
if wrapper is None:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="Analyzer not initialized",
)
# Run inference guarded by the wrapper lock so the model isn't used concurrently
log_msg = "Going to acquire lock for model inference"
if service_id:
log_msg += f" (service_id: {service_id})"
logger.debug(log_msg)
async with wrapper.lock:
log_msg = "Acquired lock for model inference"
if service_id:
log_msg += f" (service_id: {service_id})"
logger.debug(log_msg)
# Measure inference-only latency
inference_start_time = time.perf_counter()
result = await wrapper.analyzer._predict_endpoint(audio_array)
inference_time = time.perf_counter() - inference_start_time
# Calculate total processing time (from request receipt to response preparation)
total_time = time.perf_counter() - request_start_time
log_msg = (
f"Inference done result: {result['prediction']} "
f"probability: {result['probability']} time taken: {inference_time:.2f}s total: {total_time:.2f}s"
)
if service_id:
log_msg += f" (service_id: {service_id})"
logger.debug(log_msg)
# Ensure metrics section exists so client code can parse it consistently
metrics = result.get("metrics", {})
# Overwrite / set the timing metrics explicitly
metrics["inference_time"] = inference_time
metrics["total_time"] = total_time
result["metrics"] = metrics
logger.debug(f"Result for service_id: {service_id} is: {result}")
# Add service_id to result for potential client use
if service_id:
result["service_id"] = service_id
# Persist audio in background so it doesn't block the response.
background_tasks.add_task(
save_wav_file,
audio_array,
result.get("prediction", 0),
result.get("probability", 0),
service_id,
sample_rate,
)
return result
@app.get("/")
async def root():
"""Health-check endpoint."""
return {"message": "Smart Turn API is running"}
# ----------------------------------------------------------------------------
# WebSocket endpoint
# ----------------------------------------------------------------------------
@app.websocket("/ws")
async def websocket_endpoint(ws: WebSocket):
"""Handle streaming Smart Turn requests over WebSocket.
Each incoming binary message must be a NumPy-serialized float32 array (as
produced by ``np.save``). A JSON-formatted prediction (identical to the
``/raw`` HTTP endpoint) is sent back as a text message.
"""
# Extract optional secret key from headers (during handshake)
expected_secret = os.getenv("SMART_TURN_HTTP_SERVICE_KEY")
if expected_secret:
provided_secret = ws.headers.get("X-API-Key")
if provided_secret != expected_secret:
await ws.close(code=4401, reason="Unauthorized")
return
# Accept the websocket connection and log it
await ws.accept()
service_id = ws.headers.get("X-Service-Context")
sample_rate_str = ws.headers.get("X-Sample-Rate")
sample_rate = int(sample_rate_str) if sample_rate_str else 16000
logger.debug(
f"WebSocket connection accepted from service_id: {service_id}, sample_rate: {sample_rate}"
)
# ------------------------------------------------------------------
# Tunables consider moving to env vars for ops control
# ------------------------------------------------------------------
connection_timeout = 120.0 # Seconds of inactivity before timing out
MAX_BINARY_SIZE = int(
os.getenv("SMART_TURN_MAX_PAYLOAD", 10 * 1024 * 1024) # 10MB max message size
)
# Track background tasks so we can cancel them on disconnect
background_tasks = set() # Track background tasks for cleanup
try:
logger.debug("Entering WebSocket message loop")
while True:
data = None # Initialize data for each iteration
try:
logger.debug("Waiting for WebSocket message…")
# Create receive task to handle timeout properly
receive_task = asyncio.create_task(ws.receive())
try:
msg = await asyncio.wait_for(
receive_task, timeout=connection_timeout
)
except asyncio.TimeoutError:
# Cancel the receive task to prevent it from running in background
receive_task.cancel()
try:
await receive_task
except asyncio.CancelledError:
pass
logger.warning(
f"WebSocket connection timeout for service_id: {service_id}"
)
try:
await ws.close(code=1001, reason="Connection timeout")
except Exception as e:
logger.debug(f"Error closing WebSocket after timeout: {e}")
break
except WebSocketDisconnect as e:
logger.debug(f"WebSocket client disconnected: {e}")
break
# Validate message structure
if not isinstance(msg, dict):
logger.error(f"Unexpected message type: {type(msg)}")
break
# Handle disconnect message explicitly
if msg.get("type") == "websocket.disconnect":
logger.debug("Client sent disconnect frame")
break
data = None
# Binary frame
if "bytes" in msg and msg["bytes"] is not None:
data = msg["bytes"]
logger.debug(
"Received WebSocket audio payload (%d bytes)", len(data)
)
except WebSocketDisconnect as e:
logger.debug(f"WebSocket client disconnected: {e}")
break
except Exception as e:
logger.error(f"Error in WebSocket loop: {e}")
break
if data is None:
continue
request_start_time = time.perf_counter()
# --------------------------------------------------------------
# Basic validation & secure deserialisation
# --------------------------------------------------------------
if len(data) > MAX_BINARY_SIZE:
logger.warning("Received payload exceeding maximum allowed size")
await ws.send_text('{"error": "Payload too large"}')
continue
# Deserialize NumPy array (pickle disabled for security)
try:
audio_array = np.load(io.BytesIO(data), allow_pickle=False)
except Exception as exc:
error_msg = f"Invalid NumPy payload: {exc}"
if service_id:
error_msg += f" (service_id: {service_id})"
# Send error response with proper error handling
if ws.application_state == WebSocketState.CONNECTED:
try:
await ws.send_text(f'{{"error": "{error_msg}"}}')
except Exception as e:
logger.error(f"Failed to send error message: {e}")
continue
wrapper = _analyzer_wrapper
if wrapper is None:
logger.error("Analyzer not initialized; closing connection")
if ws.application_state == WebSocketState.CONNECTED:
await ws.close(code=1011, reason="Analyzer not ready")
break
async with wrapper.lock:
inference_start_time = time.perf_counter()
result = await wrapper.analyzer._predict_endpoint(audio_array)
inference_time = time.perf_counter() - inference_start_time
# Timing metrics
total_time = time.perf_counter() - request_start_time
metrics = result.get("metrics", {})
metrics["inference_time"] = inference_time
metrics["total_time"] = total_time
result["metrics"] = metrics
logger.debug(f"Result for service_id: {service_id} is: {result}")
if service_id:
result["service_id"] = service_id
# Send result with proper error handling
try:
if ws.application_state == WebSocketState.CONNECTED:
await ws.send_text(json.dumps(result))
else:
logger.warning(
f"Cannot send result - WebSocket not connected for service_id: {service_id}"
)
break
except WebSocketDisconnect:
logger.debug(
f"Client disconnected while sending result for service_id: {service_id}"
)
break
except Exception as e:
logger.error(f"Failed to send result: {e}")
break
# Save audio in the background so that it doesn't block streaming
task = asyncio.create_task(
save_wav_file(
audio_array,
result.get("prediction", 0),
result.get("probability", 0),
service_id,
sample_rate,
)
)
# Track task and remove when done
background_tasks.add(task)
task.add_done_callback(background_tasks.discard)
except WebSocketException as exc:
logger.error(f"WebSocket error: {exc}")
finally:
# Cancel any remaining background tasks
for task in background_tasks:
if not task.done():
task.cancel()
# Wait for all background tasks to complete or be cancelled
if background_tasks:
await asyncio.gather(*background_tasks, return_exceptions=True)
# Attempt a graceful close if it's not already closed
if ws.application_state == WebSocketState.CONNECTED:
try:
await ws.close()
except Exception as exc:
# Socket is probably already closed; log and ignore
logger.debug(f"WebSocket already closed: {exc}")

View file

@ -0,0 +1,314 @@
"""Smart-Turn analyzer that talks to a FastAPI WebSocket endpoint.
This analyzer keeps a persistent WebSocket connection alive so that the TCP/TLS
handshake and HTTP upgrade happen only once per call session. Each speech
segment is sent as a single binary message containing the NumPy-serialized
float32 array, and a JSON reply is expected in return.
Rewritten to use the websockets library for simplified connection management.
"""
from __future__ import annotations
import asyncio
import io
import json
import random
import time
from typing import Any, Dict, Optional
import numpy as np
import websockets
from loguru import logger
from pipecat.audio.turn.smart_turn.base_smart_turn import (
BaseSmartTurn,
SmartTurnTimeoutException,
)
class WebSocketSmartTurnAnalyzer(BaseSmartTurn):
"""End-of-turn analyzer that sends audio via a persistent WebSocket."""
def __init__(
self,
*,
url: str,
headers: Optional[Dict[str, str]] = None,
service_context: Optional[Any] = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
self._url = url.rstrip("/")
self._headers = headers or {}
self._service_context = service_context
# WebSocket connection
self._ws: Optional[websockets.WebSocketClientProtocol] = None
self._ws_lock = asyncio.Lock()
# Connection management
self._connection_task: Optional[asyncio.Task] = None
self._reconnect_delay = 1.0
self._max_reconnect_delay = 30.0
self._closing = False
self._connection_closed_event = asyncio.Event()
# Connection health monitoring
self._last_successful_request = 0.0
self._connection_attempts = 0
# Start connection manager in background
try:
loop = asyncio.get_event_loop()
if loop.is_running():
self._connection_task = loop.create_task(self._connection_manager())
except RuntimeError:
logger.debug(
"No running loop at object creation time. Connection will be opened lazily on first use."
)
def _serialize_array(self, audio_array: np.ndarray) -> bytes:
"""Serialize numpy array to bytes."""
buffer = io.BytesIO()
np.save(buffer, audio_array)
return buffer.getvalue()
async def _connection_manager(self) -> None:
"""Manages WebSocket connection lifecycle with automatic reconnection."""
while not self._closing:
try:
# Establish connection
await self._establish_connection()
# Reset reconnect delay on successful connection
self._reconnect_delay = 1.0
self._connection_attempts = 0
# Wait for connection close event
self._connection_closed_event.clear()
await self._connection_closed_event.wait()
logger.debug("WebSocket connection closed")
except Exception as e:
logger.error(f"Connection manager error: {e}")
finally:
# Clean up connection
if self._ws:
try:
await self._ws.close()
except:
pass
self._ws = None
if not self._closing:
# Exponential backoff for reconnection
self._connection_attempts += 1
delay = min(
self._reconnect_delay
* (2 ** min(self._connection_attempts - 1, 5)),
self._max_reconnect_delay,
)
# Add jitter to avoid thundering herd
delay += random.uniform(0, 0.5)
logger.info(
f"Reconnecting in {delay:.1f} seconds (attempt {self._connection_attempts})"
)
await asyncio.sleep(delay)
async def _establish_connection(self) -> None:
"""Establish a new WebSocket connection with retry logic."""
logger.debug("Establishing new WebSocket connection to Smart-Turn service...")
# Prepare headers
extra_headers = dict(self._headers)
if self._service_context is not None:
extra_headers["X-Service-Context"] = str(self._service_context)
# _init_sample_rate is being set in the constructor, which we should
# use in case self._sample_rate is not set yet. The actual _sample_rate
# is being set in the set_sample_rate() method
# but in case of WebSocketSmartTurnAnalyzer, we establish the websocket connection
# during __init__() and won't see the set_sample_rate until later. So, lets
# user the _init_sample_rate instead
_sample_rate = self._sample_rate or self._init_sample_rate
if _sample_rate > 0:
extra_headers["X-Sample-Rate"] = str(_sample_rate)
max_attempts = 3
for attempt in range(max_attempts):
try:
# Add jitter to prevent thundering herd
if attempt > 0:
jitter = 0.1 * attempt
await asyncio.sleep(jitter)
# Connect with websockets library
self._ws = await websockets.connect(
self._url,
extra_headers=extra_headers,
ping_interval=5.0, # let websockets send pings every 5s
ping_timeout=3.0, # fail fast if no pong in 3s
close_timeout=10,
max_size=10 * 1024 * 1024, # 10MB max message size
)
logger.info("WebSocket connection established successfully")
return
except asyncio.CancelledError:
raise
except Exception as exc:
logger.warning(
f"Failed to establish WebSocket (attempt {attempt + 1}/{max_attempts}): {exc}"
)
if attempt == max_attempts - 1:
raise
await asyncio.sleep(0.5 * (attempt + 1))
async def _ensure_ws(self) -> websockets.WebSocketClientProtocol:
"""Return a connected WebSocket, waiting for connection if necessary."""
async with self._ws_lock:
# If connection manager isn't running, start it
if not self._connection_task or self._connection_task.done():
self._connection_task = asyncio.create_task(self._connection_manager())
# Wait for connection with timeout
start_time = time.time()
max_wait_time = 10.0
while not self._closing:
if self._ws:
return self._ws
elapsed = time.time() - start_time
if elapsed > max_wait_time:
raise Exception(
f"Timeout waiting for WebSocket connection after {max_wait_time}s"
)
await asyncio.sleep(0.1)
if self._closing:
raise Exception("Analyzer is closing")
raise Exception("Failed to establish WebSocket connection")
async def _predict_endpoint(self, audio_array: np.ndarray) -> Dict[str, Any]:
"""Send audio and await JSON response via WebSocket."""
data_bytes = self._serialize_array(audio_array)
try:
# Ensure we have a connection
ws = await self._ensure_ws()
# Send data
try:
await ws.send(data_bytes)
except Exception as e:
logger.error(f"Failed to send data: {e}")
self._connection_closed_event.set()
return {
"prediction": 0,
"probability": 0.0,
"metrics": {"inference_time": 0.0, "total_time": 0.0},
}
# Wait for response
start_time = time.time()
while True:
remaining_timeout = self._params.stop_secs - (time.time() - start_time)
if remaining_timeout <= 0:
raise SmartTurnTimeoutException(
f"Request exceeded {self._params.stop_secs} seconds."
)
try:
# Receive message with timeout
message = await asyncio.wait_for(
ws.recv(), timeout=min(remaining_timeout, 0.5)
)
# Handle text messages (JSON responses)
if isinstance(message, str):
try:
result = json.loads(message)
# Skip ping/pong messages
if result.get("type") in ["ping", "pong"]:
continue
# Validate prediction response
if "prediction" not in result:
if "type" in result:
continue
else:
logger.error(
"Invalid response format from Smart-Turn service"
)
return {
"prediction": 0,
"probability": 0.0,
"metrics": {
"inference_time": 0.0,
"total_time": 0.0,
},
}
self._last_successful_request = time.time()
return result
except json.JSONDecodeError as exc:
logger.error(
f"Smart turn service returned invalid JSON: {exc}"
)
raise
else:
logger.error(f"Unexpected message type: {type(message)}")
except asyncio.TimeoutError:
continue
except websockets.exceptions.ConnectionClosed:
logger.warning("WebSocket connection closed during prediction")
self._connection_closed_event.set()
return {
"prediction": 0,
"probability": 0.0,
"metrics": {"inference_time": 0.0, "total_time": 0.0},
}
except SmartTurnTimeoutException:
raise
except Exception as exc:
logger.error(f"Smart turn prediction failed over WebSocket: {exc}")
self._connection_closed_event.set()
return {
"prediction": 0,
"probability": 0.0,
"metrics": {"inference_time": 0.0, "total_time": 0.0},
}
async def close(self):
"""Asynchronously close the WebSocket."""
self._closing = True
self._connection_closed_event.set()
async with self._ws_lock:
# Cancel tasks
if self._connection_task and not self._connection_task.done():
self._connection_task.cancel()
try:
await self._connection_task
except asyncio.CancelledError:
pass
# Close WebSocket
if self._ws:
try:
await self._ws.close()
except:
pass
finally:
self._ws = None

112
api/services/storage.py Normal file
View file

@ -0,0 +1,112 @@
from loguru import logger
from api.constants import (
ENABLE_AWS_S3,
MINIO_ACCESS_KEY,
MINIO_BUCKET,
MINIO_ENDPOINT,
MINIO_PUBLIC_ENDPOINT,
MINIO_SECRET_KEY,
MINIO_SECURE,
S3_BUCKET,
S3_REGION,
)
from api.enums import StorageBackend
from .filesystem import BaseFileSystem, MinioFileSystem, S3FileSystem
def get_storage_for_backend(backend: str) -> BaseFileSystem:
"""Get storage instance for a specific backend enum.
Maps StorageBackend enum codes to actual storage implementations:
- Code 1 (S3): AWS S3 via S3FileSystem
- Code 2 (MINIO): MinIO via MinioFileSystem
"""
# Code 2: MinIO implementation (local/OSS deployments)
if backend == StorageBackend.MINIO.value:
endpoint = MINIO_ENDPOINT
# Auto-detect public endpoint:
# - If MINIO_PUBLIC_ENDPOINT is set, use it (for custom domains/IPs)
# - If running in Docker and endpoint is "minio:9000", use "localhost:9000" for local dev
# - Otherwise, use the endpoint as-is (both containers and browser can reach it)
public_endpoint = MINIO_PUBLIC_ENDPOINT
if not public_endpoint:
# Auto-detect based on endpoint
if endpoint.startswith("minio:"):
# Docker internal endpoint detected, assume local development
public_endpoint = endpoint.replace("minio:", "localhost:")
logger.info(
f"Auto-detected local development: using {public_endpoint} for public access"
)
elif endpoint.startswith("host.docker.internal:"):
# Docker Desktop special DNS detected, use localhost for browser access
public_endpoint = endpoint.replace(
"host.docker.internal:", "localhost:"
)
logger.info(
f"Auto-detected Docker Desktop: using {public_endpoint} for public access"
)
else:
# Already using a public endpoint (localhost:9000 or domain:9000)
public_endpoint = endpoint
access_key = MINIO_ACCESS_KEY
secret_key = MINIO_SECRET_KEY
bucket = MINIO_BUCKET
secure = MINIO_SECURE
logger.info(
f"Initializing {backend} storage at {endpoint} (public: {public_endpoint}) with bucket '{bucket}'"
)
return MinioFileSystem(
endpoint=endpoint,
access_key=access_key,
secret_key=secret_key,
bucket_name=bucket,
secure=secure,
public_endpoint=public_endpoint,
)
# Code 1: AWS S3 implementation (cloud deployments)
elif backend == StorageBackend.S3.value:
if not S3_BUCKET:
raise ValueError(
"S3_BUCKET environment variable is required when using S3 storage"
)
bucket = S3_BUCKET
region = S3_REGION
logger.info(
f"Initializing {backend} storage with bucket '{bucket}' in region '{region}'"
)
return S3FileSystem(bucket, region)
# Future backend implementations can be added here:
# elif backend == StorageBackend.GCS: # Code 3
# return GoogleCloudFileSystem(...)
# elif backend == StorageBackend.AZURE: # Code 4
# return AzureBlobFileSystem(...)
else:
raise ValueError(f"Unknown storage backend: {backend}")
def get_current_storage_backend() -> StorageBackend:
"""Get the current storage backend enum."""
return StorageBackend.get_current_backend()
# Create a single storage instance at module load time
_backend = StorageBackend.get_current_backend()
logger.info(
f"Initializing storage backend: {_backend.name} (value: {_backend.value}, ENABLE_AWS_S3={ENABLE_AWS_S3})"
)
storage_fs = get_storage_for_backend(_backend.value)
# For backward compatibility, keep get_storage() function
def get_storage() -> BaseFileSystem:
"""Get the module-level storage instance.
Deprecated: Use 'from api.services.storage import storage_fs' instead.
"""
return storage_fs

View file

View file

@ -0,0 +1,765 @@
"""
Dynamic ARI client that generates methods from Swagger/OpenAPI specification.
Pure asyncio implementation without anyio dependencies.
"""
import asyncio
import json
from dataclasses import dataclass, field
from typing import Callable, Dict, List, Optional
from urllib.parse import urljoin
import aiohttp
from loguru import logger
class SwaggerMethod:
"""Represents a Swagger API method."""
def __init__(
self, client: "AsyncARIClient", path: str, method: str, operation: dict
):
self.client = client
self.path = path
self.http_method = method.upper()
self.operation = operation
self.operation_id = operation.get("operationId", "")
self.parameters = operation.get("parameters", [])
self.description = operation.get("description", "")
def _build_path(self, **kwargs) -> str:
"""Build the actual path by substituting path parameters."""
path = self.path
# Replace path parameters like {channelId} with actual values
for param in self.parameters:
# Swagger spec uses 'paramType' not 'in'
if param.get("paramType", param.get("in")) == "path":
param_name = param["name"]
if param_name in kwargs:
path = path.replace(f"{{{param_name}}}", str(kwargs[param_name]))
return path
def _build_params(self, **kwargs) -> dict:
"""Extract query parameters from kwargs."""
params = {}
for param in self.parameters:
# Swagger spec uses 'paramType' not 'in'
if param.get("paramType", param.get("in")) == "query":
param_name = param["name"]
if param_name in kwargs:
params[param_name] = kwargs[param_name]
return params
def _build_body(self, **kwargs) -> dict:
"""Extract body parameters from kwargs."""
body = {}
for param in self.parameters:
# Swagger 1.2 uses 'paramType' = 'body' for body parameters
if param.get("paramType", param.get("in")) == "body":
param_name = param["name"]
if param_name in kwargs:
# In Swagger 1.2, body param is usually the whole body
return (
kwargs[param_name]
if isinstance(kwargs[param_name], dict)
else {param_name: kwargs[param_name]}
)
return body
async def __call__(self, **kwargs):
"""Execute the API method."""
path = self._build_path(**kwargs)
params = self._build_params(**kwargs)
# Check if there's a body parameter defined in the spec
body_data = self._build_body(**kwargs)
# If no body param in spec, use remaining kwargs for body (backward compat)
if not body_data:
# Remove path and query parameters from kwargs (leaving body params)
# Swagger spec uses 'paramType' not 'in'
path_param_names = {
p["name"]
for p in self.parameters
if p.get("paramType", p.get("in")) == "path"
}
query_param_names = {
p["name"]
for p in self.parameters
if p.get("paramType", p.get("in")) == "query"
}
body_param_names = {
p["name"]
for p in self.parameters
if p.get("paramType", p.get("in")) == "body"
}
body_data = {
k: v
for k, v in kwargs.items()
if k not in path_param_names
and k not in query_param_names
and k not in body_param_names
}
# Debug logging for externalMedia
if "externalMedia" in path:
logger.debug(
f"externalMedia call - method: {self.http_method}, path: {path}, params: {params}"
)
if self.http_method == "GET":
return await self.client.api_get(path, **params)
elif self.http_method == "POST":
return await self.client.api_post(
path, json_data=body_data if body_data else None, **params
)
elif self.http_method == "PUT":
return await self.client.api_put(
path, json_data=body_data if body_data else None, **params
)
elif self.http_method == "DELETE":
return await self.client.api_delete(path, **params)
else:
raise ValueError(f"Unsupported HTTP method: {self.http_method}")
class ResourceAPI:
"""Represents a resource API (like channels, bridges, etc.)."""
def __init__(self, client: "AsyncARIClient", resource_name: str):
self.client = client
self.resource_name = resource_name
self._methods = {}
def add_method(self, method_name: str, swagger_method: SwaggerMethod):
"""Add a method to this resource."""
self._methods[method_name] = swagger_method
def __getattr__(self, name):
"""Dynamically return methods."""
if name in self._methods:
return self._methods[name]
raise AttributeError(f"'{self.resource_name}' has no method '{name}'")
@dataclass
class Channel:
"""Channel model with dynamic method support."""
id: str
name: str = ""
state: str = ""
caller: Dict[str, str] = field(default_factory=dict)
connected: Dict[str, str] = field(default_factory=dict)
accountcode: str = ""
dialplan: Dict[str, str] = field(default_factory=dict)
creationtime: str = ""
language: str = "en"
# Store reference to client for method calls
_client: Optional["AsyncARIClient"] = field(default=None, repr=False)
@classmethod
def from_dict(cls, data: dict, client=None) -> "Channel":
"""Create Channel from API response."""
channel = cls(
id=data.get("id", ""),
name=data.get("name", ""),
state=data.get("state", ""),
caller=data.get("caller", {}),
connected=data.get("connected", {}),
accountcode=data.get("accountcode", ""),
dialplan=data.get("dialplan", {}),
creationtime=data.get("creationtime", ""),
language=data.get("language", "en"),
_client=client,
)
return channel
async def continueInDialplan(
self,
context: str = None,
extension: str = None,
priority: int = None,
label: str = None,
):
"""Continue channel in dialplan."""
if not self._client:
raise RuntimeError("Channel not associated with a client")
params = {"channelId": self.id}
if context:
params["context"] = context
if extension:
params["extension"] = extension
if priority is not None:
params["priority"] = priority
if label:
params["label"] = label
# The ARI API method is named 'continueInDialplan'
channels_api = self._client.channels
if hasattr(channels_api, "continueInDialplan"):
await channels_api.continueInDialplan(**params)
else:
# Fallback to direct API call
await self._client.api_post(f"/channels/{self.id}/continue", **params)
async def hangup(self, reason: str = "normal"):
"""Hangup the channel."""
if not self._client:
raise RuntimeError("Channel not associated with a client")
await self._client.channels.hangup(channelId=self.id, reason=reason)
async def answer(self):
"""Answer the channel."""
if not self._client:
raise RuntimeError("Channel not associated with a client")
await self._client.channels.answer(channelId=self.id)
async def getChannelVar(self, variable: str):
"""Get a channel variable."""
if not self._client:
raise RuntimeError("Channel not associated with a client")
return await self._client.channels.getChannelVar(
channelId=self.id, variable=variable
)
@dataclass
class Bridge:
"""Bridge model with dynamic method support."""
id: str
technology: str = ""
bridge_type: str = ""
bridge_class: str = ""
creator: str = ""
name: str = ""
channels: List[str] = field(default_factory=list)
_client: Optional["AsyncARIClient"] = field(default=None, repr=False)
@classmethod
def from_dict(cls, data: dict, client=None) -> "Bridge":
"""Create Bridge from API response."""
return cls(
id=data.get("id", ""),
technology=data.get("technology", ""),
bridge_type=data.get("bridge_type", ""),
bridge_class=data.get("bridge_class", ""),
creator=data.get("creator", ""),
name=data.get("name", ""),
channels=data.get("channels", []),
_client=client,
)
async def addChannel(self, channel: str):
"""Add channel to bridge."""
if not self._client:
raise RuntimeError("Bridge not associated with a client")
await self._client.bridges.addChannel(bridgeId=self.id, channel=channel)
async def removeChannel(self, channel: str):
"""Remove channel from bridge."""
if not self._client:
raise RuntimeError("Bridge not associated with a client")
await self._client.bridges.removeChannel(bridgeId=self.id, channel=channel)
async def destroy(self):
"""Destroy the bridge."""
if not self._client:
raise RuntimeError("Bridge not associated with a client")
await self._client.bridges.destroy(bridgeId=self.id)
class AsyncARIClient:
"""ARI client that dynamically generates methods from Swagger spec."""
def __init__(self, base_url: str, username: str, password: str, app: str):
self.base_url = base_url.rstrip("/")
self.username = username
self.password = password
self.app = app
# REST API URL
self.api_url = self.base_url.replace("ws://", "http://").replace(
"wss://", "https://"
)
# WebSocket URL
self.ws_url = (
f"{self.base_url}/ari/events?app={app}&api_key={username}:{password}"
)
# Session and WebSocket
self._session: Optional[aiohttp.ClientSession] = None
self._websocket: Optional[aiohttp.ClientWebSocketResponse] = None
self._running = False
# Event handling
self._event_handlers: Dict[str, List[Callable]] = {}
self._event_queue: asyncio.Queue = asyncio.Queue(maxsize=1000)
# Resource APIs (will be populated from Swagger)
self.channels: Optional[ResourceAPI] = None
self.bridges: Optional[ResourceAPI] = None
self.endpoints: Optional[ResourceAPI] = None
self.recordings: Optional[ResourceAPI] = None
self.sounds: Optional[ResourceAPI] = None
self.playbacks: Optional[ResourceAPI] = None
self.asterisk: Optional[ResourceAPI] = None
self.applications: Optional[ResourceAPI] = None
self.deviceStates: Optional[ResourceAPI] = None
self.mailboxes: Optional[ResourceAPI] = None
# Swagger spec cache
self._swagger_spec: Optional[dict] = None
async def connect(self):
"""Connect to ARI and load Swagger spec."""
# Create HTTP session
auth = aiohttp.BasicAuth(self.username, self.password)
self._session = aiohttp.ClientSession(auth=auth)
try:
# Load Swagger spec and generate methods
await self._load_swagger_spec()
# Connect WebSocket
self._websocket = await self._session.ws_connect(
self.ws_url, heartbeat=30, autoping=True
)
self._running = True
logger.info(f"Connected to ARI at {self.ws_url}")
except Exception as e:
await self._session.close()
raise Exception(f"Failed to connect to ARI: {e}")
async def _load_swagger_spec(self):
"""Load Swagger spec and generate dynamic methods."""
spec_loaded = False
try:
# Get Swagger spec from ARI
url = f"{self.api_url}/ari/api-docs/resources.json"
async with self._session.get(url) as resp:
resp.raise_for_status()
resources = await resp.json()
# Store the spec
self._swagger_spec = resources
# Create resource APIs
for api_info in resources.get("apis", []):
resource_path = api_info["path"]
# Fix the path - remove .{format} and add proper prefix
resource_path = resource_path.replace(".{format}", ".json")
# Load detailed spec for this resource
# The resource_path already contains /api-docs/, so we just need the base URL
url = f"{self.api_url}/ari{resource_path}"
try:
async with self._session.get(url) as resp:
resp.raise_for_status()
spec = await resp.json()
self._process_swagger_spec(spec)
spec_loaded = True
except Exception as e:
logger.warning(f"Failed to load spec for {resource_path}: {e}")
if spec_loaded:
logger.info("Loaded Swagger spec and generated dynamic methods")
else:
raise Exception("No individual specs could be loaded")
except Exception as e:
logger.warning(f"Failed to load Swagger spec, using fallback methods: {e}")
self._create_fallback_methods()
def _process_swagger_spec(self, spec: dict):
"""Process a Swagger spec and create dynamic methods."""
# basePath is available in spec but not currently used
for api in spec.get("apis", []):
path = api["path"]
for operation in api.get("operations", []):
self._create_method_from_operation(path, operation)
def _create_method_from_operation(self, path: str, operation: dict):
"""Create a method from a Swagger operation."""
# Swagger spec uses 'httpMethod' not 'method'
method = operation.get("httpMethod", operation.get("method", "GET"))
operation_id = operation.get("nickname", "")
if not operation_id:
return
# Determine resource from path (e.g., /channels/{channelId} -> channels)
path_parts = path.strip("/").split("/")
if path_parts:
resource_name = path_parts[0]
# Create resource API if it doesn't exist
if not hasattr(self, resource_name) or getattr(self, resource_name) is None:
setattr(self, resource_name, ResourceAPI(self, resource_name))
resource_api = getattr(self, resource_name)
# Extract method name from operation ID
# e.g., "channels_continue" -> "continue_"
# or "channels_get" -> "get"
method_name = operation_id
if method_name.startswith(resource_name + "_"):
method_name = method_name[len(resource_name) + 1 :]
# Handle special cases
if method_name == "continue":
method_name = "continue_" # Avoid Python keyword
# Create and add the method
swagger_method = SwaggerMethod(self, path, method, operation)
resource_api.add_method(method_name, swagger_method)
def _create_fallback_methods(self):
"""Create fallback methods if Swagger spec is not available."""
# Create basic resource APIs
self.channels = ResourceAPI(self, "channels")
self.bridges = ResourceAPI(self, "bridges")
# Add essential channel methods
self.channels.add_method(
"get",
SwaggerMethod(
self,
"/channels/{channelId}",
"GET",
{
"operationId": "get",
"parameters": [{"name": "channelId", "in": "path"}],
},
),
)
self.channels.add_method(
"hangup",
SwaggerMethod(
self,
"/channels/{channelId}",
"DELETE",
{
"operationId": "hangup",
"parameters": [
{"name": "channelId", "in": "path"},
{"name": "reason", "in": "query"},
],
},
),
)
self.channels.add_method(
"answer",
SwaggerMethod(
self,
"/channels/{channelId}/answer",
"POST",
{
"operationId": "answer",
"parameters": [{"name": "channelId", "in": "path"}],
},
),
)
self.channels.add_method(
"continueInDialplan",
SwaggerMethod(
self,
"/channels/{channelId}/continue",
"POST",
{
"operationId": "continueInDialplan",
"parameters": [
{"name": "channelId", "in": "path"},
{"name": "context", "in": "query"},
{"name": "extension", "in": "query"},
{"name": "priority", "in": "query"},
{"name": "label", "in": "query"},
],
},
),
)
self.channels.add_method(
"externalMedia",
SwaggerMethod(
self,
"/channels/externalMedia",
"POST",
{
"operationId": "externalMedia",
"parameters": [
{"name": "channelId", "in": "query"}, # Add channelId parameter
{"name": "app", "in": "query"},
{"name": "external_host", "in": "query"},
{"name": "format", "in": "query"},
{"name": "encapsulation", "in": "query"},
{"name": "transport", "in": "query"},
{"name": "connection_type", "in": "query"},
{"name": "direction", "in": "query"},
],
},
),
)
self.channels.add_method(
"getChannelVar",
SwaggerMethod(
self,
"/channels/{channelId}/variable",
"GET",
{
"operationId": "getChannelVar",
"parameters": [
{"name": "channelId", "in": "path"},
{"name": "variable", "in": "query"},
],
},
),
)
# Add essential bridge methods
self.bridges.add_method(
"get",
SwaggerMethod(
self,
"/bridges/{bridgeId}",
"GET",
{
"operationId": "get",
"parameters": [{"name": "bridgeId", "in": "path"}],
},
),
)
self.bridges.add_method(
"create",
SwaggerMethod(
self,
"/bridges",
"POST",
{
"operationId": "create",
"parameters": [
{"name": "type", "in": "query"},
{"name": "name", "in": "query"},
],
},
),
)
self.bridges.add_method(
"addChannel",
SwaggerMethod(
self,
"/bridges/{bridgeId}/addChannel",
"POST",
{
"operationId": "addChannel",
"parameters": [
{"name": "bridgeId", "in": "path"},
{"name": "channel", "in": "query"},
],
},
),
)
self.bridges.add_method(
"removeChannel",
SwaggerMethod(
self,
"/bridges/{bridgeId}/removeChannel",
"POST",
{
"operationId": "removeChannel",
"parameters": [
{"name": "bridgeId", "in": "path"},
{"name": "channel", "in": "query"},
],
},
),
)
self.bridges.add_method(
"destroy",
SwaggerMethod(
self,
"/bridges/{bridgeId}",
"DELETE",
{
"operationId": "destroy",
"parameters": [{"name": "bridgeId", "in": "path"}],
},
),
)
async def disconnect(self):
"""Disconnect from ARI."""
self._running = False
if self._websocket:
await self._websocket.close()
if self._session:
await self._session.close()
async def run(self):
"""Main event loop."""
if not self._websocket:
raise RuntimeError("Not connected")
processor_task = asyncio.create_task(self._process_events())
try:
async for msg in self._websocket:
if msg.type == aiohttp.WSMsgType.TEXT:
try:
event = json.loads(msg.data)
# Wrap channel/bridge objects
if "channel" in event and isinstance(event["channel"], dict):
event["channel"] = Channel.from_dict(event["channel"], self)
if "bridge" in event and isinstance(event["bridge"], dict):
event["bridge"] = Bridge.from_dict(event["bridge"], self)
await self._event_queue.put(event)
except json.JSONDecodeError:
logger.error(f"Invalid JSON: {msg.data}")
elif msg.type == aiohttp.WSMsgType.ERROR:
logger.error(f"WebSocket error: {self._websocket.exception()}")
break
elif msg.type == aiohttp.WSMsgType.CLOSED:
logger.info("WebSocket closed")
break
finally:
self._running = False
processor_task.cancel()
await asyncio.gather(processor_task, return_exceptions=True)
async def _process_events(self):
"""Process events from queue."""
while self._running:
try:
event = await asyncio.wait_for(self._event_queue.get(), timeout=1.0)
event_type = event.get("type")
if event_type:
await self._dispatch_event(event_type, event)
except asyncio.TimeoutError:
continue
except asyncio.CancelledError:
break
except Exception as e:
logger.error(f"Error processing event: {e}")
async def _dispatch_event(self, event_type: str, event: dict):
"""Dispatch event to handlers."""
handlers = self._event_handlers.get(event_type, [])
if handlers:
logger.debug(
f"AsyncARIClient: Dispatching {event_type} to {len(handlers)} handlers"
)
for i, handler in enumerate(handlers):
try:
logger.debug(
f" AsyncARIClient: Calling {event_type} handler {i + 1}/{len(handlers)}"
)
await handler(event)
except Exception as e:
logger.error(f"Handler {i + 1} error for {event_type}: {e}")
def on_event(self, event_type: str, handler: Callable):
"""Register event handler."""
if event_type not in self._event_handlers:
self._event_handlers[event_type] = []
logger.debug(
f"AsyncARIClient: Registering handler for {event_type}. Current count: {len(self._event_handlers.get(event_type, []))}"
)
self._event_handlers[event_type].append(handler)
logger.debug(
f"AsyncARIClient: After registration, {event_type} handler count: {len(self._event_handlers[event_type])}"
)
# REST API methods
async def api_get(self, path: str, **params) -> dict:
"""GET request."""
# Ensure path starts with /ari if not already
if not path.startswith("/ari"):
path = f"/ari{path}" if path.startswith("/") else f"/ari/{path}"
url = urljoin(self.api_url, path.lstrip("/"))
async with self._session.get(url, params=params) as resp:
resp.raise_for_status()
data = await resp.json()
# Wrap known objects
if isinstance(data, list):
# Handle lists of channels/bridges
if "/channels" in path:
return [
Channel.from_dict(item, self)
if isinstance(item, dict)
else item
for item in data
]
elif "/bridges" in path:
return [
Bridge.from_dict(item, self) if isinstance(item, dict) else item
for item in data
]
return data
elif isinstance(data, dict):
if "/channels/" in path and "id" in data:
return Channel.from_dict(data, self)
elif "/bridges/" in path and "id" in data:
return Bridge.from_dict(data, self)
return data
async def api_post(self, path: str, json_data: dict = None, **params) -> dict:
"""POST request."""
# Ensure path starts with /ari if not already
if not path.startswith("/ari"):
path = f"/ari{path}" if path.startswith("/") else f"/ari/{path}"
url = urljoin(self.api_url, path.lstrip("/"))
async with self._session.post(url, json=json_data, params=params) as resp:
resp.raise_for_status()
if resp.content_length and resp.content_length > 0:
data = await resp.json()
# Wrap known objects
if "id" in data and "state" in data:
return Channel.from_dict(data, self)
elif "id" in data and "bridge_type" in data:
return Bridge.from_dict(data, self)
return data
return {}
async def api_put(self, path: str, json_data: dict = None, **params) -> dict:
"""PUT request."""
# Ensure path starts with /ari if not already
if not path.startswith("/ari"):
path = f"/ari{path}" if path.startswith("/") else f"/ari/{path}"
url = urljoin(self.api_url, path.lstrip("/"))
async with self._session.put(url, json=json_data, params=params) as resp:
resp.raise_for_status()
if resp.content_length and resp.content_length > 0:
return await resp.json()
return {}
async def api_delete(self, path: str, **params) -> dict:
"""DELETE request."""
# Ensure path starts with /ari if not already
if not path.startswith("/ari"):
path = f"/ari{path}" if path.startswith("/") else f"/ari/{path}"
url = urljoin(self.api_url, path.lstrip("/"))
async with self._session.delete(url, params=params) as resp:
resp.raise_for_status()
if resp.content_length and resp.content_length > 0:
return await resp.json()
return {}

View file

@ -0,0 +1,446 @@
"""
ARI Client Manager using the new Async ARI Client.
Drop-in replacement for the existing ari_client_manager.py.
"""
import asyncio
import json
import os
import random
import time
from typing import Awaitable, Callable, Optional
import httpx
from loguru import logger
from api.services.telephony.ari_client import AsyncARIClient, Channel
from api.services.telephony.ari_client_singleton import ari_client_singleton
class ARIClientManager:
"""Manages ARI client connection and event handling.
This is a compatibility wrapper around AsyncARIClient.
"""
def __init__(
self,
ari_client: AsyncARIClient,
app_endpoint: str,
_conn_ctx=None, # Not used with AsyncARIClient
):
"""Initialize the ARI client manager.
Parameters
----------
ari_client: AsyncARIClient
The connected ARI client.
app_endpoint: str
The app endpoint for external media.
_conn_ctx:
Not used, kept for compatibility.
"""
self._ari_client = ari_client
self._app_endpoint = app_endpoint
self._conn_ctx = _conn_ctx # Not used but kept for compatibility
self._start_handlers = []
self._end_handlers = []
self._running = False
self._handlers_registered = False # Track if handlers are registered
def register_start_handler(
self, handler: Callable[[Channel, dict], Awaitable[None]]
):
"""Register a handler for StasisStart events."""
logger.debug(
f"Registering start handler. Current count: {len(self._start_handlers)}"
)
self._start_handlers.append(handler)
logger.debug(f"After registration, handler count: {len(self._start_handlers)}")
def register_end_handler(self, handler: Callable[[str], Awaitable[None]]):
"""Register a handler for StasisEnd events."""
self._end_handlers.append(handler)
async def update_client(self, new_client: AsyncARIClient, new_conn_ctx=None):
"""Update to a new client (for reconnection)."""
logger.info("Updating ARI client for reconnection")
self._ari_client = new_client
self._conn_ctx = new_conn_ctx
# Clear old event handlers from the client before re-registering
# to prevent duplicate handler registrations
if hasattr(new_client, "_event_handlers"):
new_client._event_handlers.clear()
# Re-register event handlers
self._register_handlers()
def _register_handlers(self):
"""Register event handlers with the client."""
logger.debug(
f"_register_handlers called. Start handlers count: {len(self._start_handlers)}, End handlers count: {len(self._end_handlers)}"
)
async def on_stasis_start(event):
"""Handle StasisStart events."""
channel = event.get("channel")
# Only handle PJSIP and SIP channels
if channel and hasattr(channel, "name"):
if not (
channel.name.startswith("PJSIP") or channel.name.startswith("SIP")
):
logger.debug(
f"Ignoring StasisStart for non-SIP channel: {channel.name}"
)
return
# Log the event
logger.info(
f"StasisStart event for channel: {channel.id if channel else 'unknown'}"
)
# Extract call context variables
call_context_vars = {}
try:
# Get channel variables
var_result = await channel.getChannelVar(
variable="LOCAL_ARI_CALL_VARIABLES"
)
call_context_vars = json.loads(var_result.get("value", "{}"))
# Try to get phone number and fetch additional data
phone_number = call_context_vars.get("phone")
ari_data_uri = os.getenv("ARI_DATA_FETCHING_URI")
if phone_number and ari_data_uri:
try:
start_time = time.time()
fetch_url = f"{ari_data_uri}{phone_number}"
async with httpx.AsyncClient() as client:
response = await client.get(fetch_url, timeout=10.0)
response.raise_for_status()
# Parse the response - get the latest line if multiple lines
response_text = response.text.strip()
if response_text:
lines = response_text.split("\n")
latest_line = lines[-1].strip()
if latest_line:
# Parse the pipe-delimited data
fields = latest_line.split("|")
field_names = [
"status",
"user",
"vendor_lead_code",
"source_id",
"list_id",
"gmt_offset_now",
"phone_code",
"phone_number",
"title",
"first_name",
"middle_initial",
"last_name",
"address1",
"address2",
"address3",
"city",
"state",
"province",
"postal_code",
"country_code",
"gender",
"date_of_birth",
"alt_phone",
"email",
"security_phrase",
"comments",
"called_count",
"last_local_call_time",
"rank",
"owner",
"entry_list_id",
"lead_id",
]
# Map fields to call_context_vars
for i, field_name in enumerate(field_names):
try:
call_context_vars[field_name] = fields[i]
except IndexError:
logger.error(
f"channelID: {channel.id} IndexError while accessing fields {i}"
)
elapsed_time = time.time() - start_time
logger.info(
f"channelID: {channel.id} Successfully fetched user details for phone: {phone_number} in {elapsed_time:.3f} seconds"
)
except Exception as e:
elapsed_time = time.time() - start_time
logger.error(
f"channelID: {channel.id} Failed to fetch user details from ARI_DATA_FETCHING_URI after {elapsed_time:.3f} seconds: {e}"
)
logger.debug(
f"channelID: {channel.id} call context variables: {call_context_vars}"
)
except (
KeyError,
AttributeError,
httpx.HTTPStatusError,
json.JSONDecodeError,
) as e:
logger.debug(f"could not find variable LOCAL_ARI_CALL_VARIABLES: {e}")
# Call all registered handlers with call_context_vars
logger.debug(
f"Calling {len(self._start_handlers)} start handlers for channel {channel.id}"
)
for i, handler in enumerate(self._start_handlers):
try:
logger.debug(
f" Calling start handler {i + 1}/{len(self._start_handlers)}"
)
await handler(channel, call_context_vars)
except Exception as e:
logger.error(f"Error in StasisStart handler {i + 1}: {e}")
async def on_stasis_end(event):
"""Handle StasisEnd events."""
channel = event.get("channel", {})
channel_id = channel.id if hasattr(channel, "id") else channel.get("id", "")
# # Only handle PJSIP and SIP channels
# if channel:
# channel_name = channel.name if hasattr(channel, 'name') else channel.get("name", "")
# if channel_name and not (channel_name.startswith("PJSIP") or channel_name.startswith("SIP")):
# logger.debug(f"Ignoring StasisEnd for non-SIP channel: {channel_name}")
# return
logger.info(f"StasisEnd event for channel: {channel_id}")
# Call all registered handlers
for handler in self._end_handlers:
try:
await handler(channel_id)
except Exception as e:
logger.error(f"Error in StasisEnd handler: {e}")
# Register with the AsyncARIClient
logger.debug(f"Registering StasisStart and StasisEnd with AsyncARIClient")
self._ari_client.on_event("StasisStart", on_stasis_start)
self._ari_client.on_event("StasisEnd", on_stasis_end)
logger.debug(f"Event handlers registered with client")
async def run(self):
"""Run the event loop.
The actual WebSocket handling is done by AsyncARIClient.
This just registers handlers and waits.
"""
logger.debug("Running ARIClientManager")
self._running = True
# Register handlers only once, on first run
if not self._handlers_registered:
self._register_handlers()
self._handlers_registered = True
try:
# The AsyncARIClient.run() method handles WebSocket
# We don't call it here as it's called by the supervisor
while self._running:
await asyncio.sleep(1)
except asyncio.CancelledError:
logger.debug(f"ARIClientManager run cancelled")
self._running = False
raise
finally:
self._running = False
class _ARIClientManagerSupervisor:
"""Supervisor that maintains ARI connection with automatic reconnection.
This replaces the asyncari-based supervisor with AsyncARIClient.
"""
# Reconnection parameters
_INITIAL_BACKOFF = 1 # Start with 1 second
_MAX_BACKOFF = 60 # Max 60 seconds between retries
def __init__(
self,
on_channel_start: Callable[[Channel, dict], Awaitable[None]],
on_channel_end: Optional[Callable[[str], Awaitable[None]]] = None,
):
self._on_channel_start = on_channel_start
self._on_channel_end = on_channel_end
self._shutting_down = False
async def start(self):
"""Start the supervisor and maintain connection."""
await self._runner()
async def stop(self):
"""Stop the supervisor."""
logger.info("Stopping ARI Client Manager Supervisor")
self._shutting_down = True
async def __aenter__(self):
"""Async context manager entry."""
asyncio.create_task(self.start())
return self
async def __aexit__(self, *args):
"""Async context manager exit."""
await self.stop()
async def _runner(self):
"""Main reconnection loop using AsyncARIClient."""
backoff = self._INITIAL_BACKOFF
ari_client_manager: Optional[ARIClientManager] = None
while not self._shutting_down:
client = None
try:
logger.debug("Going to connect with ARI")
# Get configuration from environment
base_url = os.getenv("ARI_STASIS_ENDPOINT")
username = os.getenv("ARI_STASIS_USER")
password = os.getenv("ARI_STASIS_USER_PASSWORD")
app = os.getenv("ARI_STASIS_APP_NAME")
# Convert HTTP to WebSocket URL
ws_url = base_url.replace("http://", "ws://").replace(
"https://", "wss://"
)
# Create and connect the AsyncARIClient
client = AsyncARIClient(ws_url, username, password, app)
await client.connect()
# Update the singleton with the new client
ari_client_singleton.set_client(client)
if ari_client_manager is None:
# First connection - create new manager
logger.debug("Creating new ARIClientManager (first connection)")
ari_client_manager = ARIClientManager(
client,
os.getenv("ARI_STASIS_APP_ENDPOINT"),
_conn_ctx=None, # Not needed with AsyncARIClient
)
logger.debug(f"Registering handlers with new manager")
ari_client_manager.register_start_handler(self._on_channel_start)
if self._on_channel_end:
ari_client_manager.register_end_handler(self._on_channel_end)
else:
# Reconnection - update existing manager
logger.debug("Updating existing ARIClientManager (reconnection)")
# Don't re-register start and end handlers as they're already registered
await ari_client_manager.update_client(client, None)
logger.info("Connected to ARI — supervisor entering event loop")
# Reset backoff after successful connection
backoff = self._INITIAL_BACKOFF
# Create tasks for both the client and manager
client_task = asyncio.create_task(client.run())
manager_task = asyncio.create_task(ari_client_manager.run())
# Wait for either to complete (likely due to disconnection)
done, pending = await asyncio.wait(
{client_task, manager_task}, return_when=asyncio.FIRST_COMPLETED
)
# Cancel the other task
for task in pending:
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
except asyncio.CancelledError:
# Check if we're shutting down
if self._shutting_down or asyncio.current_task().cancelled():
logger.debug("ARI supervisor task cancelled — shutting down")
break
# Otherwise it's a transient connection error
logger.warning("ARI connection lost due to CancelledError — will retry")
# Force a context switch to reset event loop state
await asyncio.sleep(0)
except Exception as exc:
# Check if we're shutting down
if self._shutting_down or asyncio.current_task().cancelled():
logger.warning("Exiting due to shutdown during exception handling")
break
# Log and retry
logger.warning(f"ARI connection failed or lost: {exc!r} - will retry")
finally:
# Disconnect client if connected
if client:
try:
await client.disconnect()
except Exception as e:
logger.warning(f"Error disconnecting client: {e}")
# Clear the singleton when disconnecting
ari_client_singleton.clear()
# Check if we're shutting down before sleeping
if self._shutting_down:
logger.debug("Exiting reconnection loop due to shutdown")
break
# Exponential back-off with jitter before the next attempt
jitter = random.uniform(0.1, backoff)
logger.debug(f"Waiting {jitter:.1f} seconds before reconnecting...")
# Sleep with proper event loop handling
await asyncio.sleep(0) # Yield control first
await asyncio.sleep(jitter)
logger.debug(f"Finished sleeping for {jitter} seconds")
backoff = min(backoff * 2, self._MAX_BACKOFF)
logger.debug(f"New backoff value: {backoff}, continuing loop...")
async def setup_ari_client_supervisor(
on_channel_start: Callable[[Channel, dict], Awaitable[None]],
on_channel_end: Callable[[str], Awaitable[None]] | None = None,
) -> "_ARIClientManagerSupervisor | None":
"""Start a background supervisor that keeps the ARI connection alive.
This is a drop-in replacement for the asyncari-based function.
Uses AsyncARIClient instead of asyncari.
If the *ENABLE_ARI_STASIS* environment variable is not set to ``"true"``
(case-insensitive) the function returns ``None`` and no supervisor is
launched.
"""
if os.getenv("ENABLE_ARI_STASIS", "false").lower() != "true":
logger.info("ARI Stasis integration disabled via environment variable")
return None
logger.info("Starting ARI Client Supervisor with AsyncARIClient")
supervisor = _ARIClientManagerSupervisor(on_channel_start, on_channel_end)
# Start the supervisor in the background
asyncio.create_task(supervisor.start())
return supervisor

View file

@ -0,0 +1,50 @@
"""Singleton holder for the current ARI client instance.
This module provides a thread-safe singleton that holds the current
ARI client instance, which can be updated during reconnections.
"""
from typing import Optional
from loguru import logger
from api.services.telephony.ari_client import AsyncARIClient
class ARIClientSingleton:
"""Singleton holder for the current ARI client instance."""
_instance: Optional["ARIClientSingleton"] = None
_client: Optional[AsyncARIClient] = None
def __new__(cls):
"""Ensure only one instance exists."""
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def set_client(self, client: AsyncARIClient) -> None:
"""Update the ARI client instance.
Args:
client: The new ARI client instance.
"""
self._client = client
logger.info("ARI client singleton updated with new client instance")
def get_client(self) -> Optional[AsyncARIClient]:
"""Get the current ARI client instance.
Returns:
The current ARI client, or None if not set.
"""
return self._client
def clear(self) -> None:
"""Clear the current client instance."""
self._client = None
logger.info("ARI client singleton cleared")
# Global singleton instance
ari_client_singleton = ARIClientSingleton()

View file

@ -0,0 +1,748 @@
"""Standalone ARI Manager Service for distributed architecture.
This service maintains the single WebSocket connection to Asterisk ARI
and distributes events to multiple FastAPI workers via Redis pub/sub.
ARIManager creates an instance of ARIClientSupervisor and registers the callbacks
on_channel_start and on_channel_end. It is responsible to take in caller_channel
and setup ARIManagerConnection, i.e create bridge for externalMedia.
"""
import asyncio
import json
import os
import signal
import time
from typing import Dict, Optional
from api.constants import REDIS_URL
# --- Add logging setup before importing loguru ---
from api.logging_config import setup_logging
from api.services.telephony.stasis_event_protocol import (
BaseWorkerToARIManagerCommand,
DisconnectCommand,
RedisChannels,
RedisKeys,
SocketClosedCommand,
TransferCommand,
parse_command,
)
logging_queue_listener = setup_logging()
import redis.asyncio as aioredis
import redis.exceptions
from loguru import logger
from pipecat.utils.enums import EndTaskReason
from api.services.telephony.ari_client import Channel
from api.services.telephony.ari_client_manager import (
ARIClientManager,
setup_ari_client_supervisor,
)
from api.services.telephony.ari_manager_connection import ARIManagerConnection
class ARIManager:
"""Manages ARI connection and distributes events to workers via Redis."""
def __init__(self, redis_client: aioredis.Redis):
self.redis = redis_client
self.stasis_manager: Optional[ARIClientManager] = None
self._running = False
self._ari_client_supervisor = None
self._tasks: Dict[str, asyncio.Task] = {}
self._pubsubs: Dict[
str, aioredis.client.PubSub
] = {} # Track pubsub connections
self._active_channels: set[str] = (
set()
) # Track channels managed by this instance
self._port_range = range(4000, 5000, 2) # Even ports only
self._channel_connections: Dict[
str, ARIManagerConnection
] = {} # Track connections by channel ID
self._channel_disposed: Dict[str, bool] = {} # Track channel disposed state
self._socket_closed: Dict[str, bool] = {} # Track socket closed state
self._active_workers: list[str] = [] # Cached list of active workers
self._worker_discovery_task: Optional[asyncio.Task] = None
self._channel_to_worker: Dict[str, str] = {} # Map channel to worker
async def on_channel_start(self, caller_channel: Channel, call_context_vars: dict):
"""Handle new channel from ARIClientManager with atomically allocated port."""
try:
# Atomically allocate port for this channel (prevents race conditions)
port = await self._get_and_allocate_port_atomic(caller_channel.id)
# Create connection with allocated port
connection = ARIManagerConnection(
caller_channel=caller_channel,
host=os.getenv("ARI_STASIS_APP_ENDPOINT"),
port=port,
)
# Track the connection
self._channel_connections[caller_channel.id] = connection
# Initialize channel state flags
self._channel_disposed[caller_channel.id] = False
self._socket_closed[caller_channel.id] = False
# Handle the connection
await self._on_stasis_call(connection, call_context_vars)
except Exception as e:
logger.exception(f"Error handling new channel {caller_channel.id}: {e}")
# Release port if allocation failed
await self._release_port_for_channel(caller_channel.id)
async def on_channel_end(self, channel_id: str):
"""Handle channel end notification from ARIClientManager."""
logger.info(f"channelID: {channel_id} Received channel end notification")
# Find the connection for this channel
connection = None
caller_channel_id = None
# Check if it's a caller channel
if channel_id in self._channel_connections:
connection = self._channel_connections[channel_id]
caller_channel_id = channel_id
else:
# TODO: We are currently not handling StasisEnd on ExternalMedia
for conn_channel_id, conn in self._channel_connections.items():
if conn.em_channel_id and conn.em_channel_id == channel_id:
logger.debug(
f"channelID: {channel_id} ExternalMedia StasisEnd - Ignoring"
)
# connection = conn
# caller_channel_id = conn_channel_id
break
# Publish StasisEnd event to worker immediately
if connection and caller_channel_id:
worker_id = self._get_worker_for_channel(caller_channel_id)
event = {
"type": "stasis_end",
"channel_id": caller_channel_id,
"reason": EndTaskReason.USER_HANGUP.value,
}
await self.redis.publish(
RedisChannels.worker_events(worker_id), json.dumps(event)
)
logger.info(f"channelID: {channel_id} Published StasisEnd event")
# Notify the connection about channel end
await connection.notify_channel_end()
# Mark channel as disposed
if caller_channel_id in self._channel_disposed:
self._channel_disposed[caller_channel_id] = True
# Check if both flags are set to cleanup
await self._check_and_cleanup_channel(caller_channel_id)
async def _on_stasis_call(
self, connection: ARIManagerConnection, call_context_vars: dict
):
"""Handle new Stasis call by setting up the connection and publishing to Redis."""
try:
# Setup the connection (create bridge and external media)
await connection.setup_call()
if not connection.is_connected():
logger.warning("Connection is not connected, skipping")
return
# Extract all necessary information after bridge is created
channel_id = connection.caller_channel_id
em_channel_id = connection.em_channel_id
bridge_id = connection.bridge_id
# Track this channel as active
self._active_channels.add(channel_id)
# Create event with all connection details
event = {
"type": "stasis_start",
"channel_id": channel_id,
"caller_channel_id": channel_id,
"em_channel_id": em_channel_id,
"bridge_id": bridge_id,
"local_addr": list(connection.local_addr),
"remote_addr": list(connection.remote_addr),
"call_context_vars": call_context_vars,
}
# Select worker using round-robin
worker_id = await self._select_worker()
if worker_id is None:
logger.error(f"channelID: {channel_id} No active workers available")
await connection.disconnect()
return
# Track channel to worker mapping
self._channel_to_worker[channel_id] = worker_id
channel = RedisChannels.worker_events(worker_id)
# Publish event to specific worker
await self.redis.publish(channel, json.dumps(event))
logger.info(
f"channelID: {channel_id} Published stasis_start event to worker {worker_id}"
)
# Start monitoring for commands from workers
self._tasks[channel_id] = asyncio.create_task(
self._monitor_channel_commands(channel_id, connection)
)
except Exception as e:
logger.exception(f"Error handling stasis call: {e}")
async def _get_and_allocate_port_atomic(self, channel_id: str) -> int:
"""Atomically find and allocate an available port using Redis Lua script.
This method prevents race conditions by using a Lua script that executes
atomically in Redis, ensuring that two concurrent calls cannot allocate
the same port.
"""
# Lua script for atomic port allocation
lua_script = """
local port_range_start = tonumber(ARGV[1])
local port_range_end = tonumber(ARGV[2])
local port_range_step = tonumber(ARGV[3])
local channel_id = KEYS[1]
local timestamp = ARGV[4]
-- Check if channel already has a port allocated
local existing_port = redis.call('HGET', 'channel_ports', channel_id)
if existing_port then
return tonumber(existing_port)
end
-- Find first available port
for port = port_range_start, port_range_end, port_range_step do
local port_str = tostring(port)
local exists = redis.call('HEXISTS', 'port_channels', port_str)
if exists == 0 then
-- Atomically allocate the port
redis.call('HSET', 'channel_ports', channel_id, port)
redis.call('HSET', 'port_channels', port_str, channel_id)
redis.call('HSET', 'channel_allocation_time', channel_id, timestamp)
return port
end
end
return -1 -- No ports available
"""
# Execute the Lua script with port range parameters
port_start = min(self._port_range)
port_end = max(self._port_range)
port_step = self._port_range.step
timestamp = int(time.time())
port = await self.redis.eval(
lua_script,
1, # Number of keys
channel_id, # KEYS[1]
port_start, # ARGV[1]
port_end, # ARGV[2]
port_step, # ARGV[3]
timestamp, # ARGV[4]
)
if port == -1:
# If all ports exhausted, clean up orphaned ports and retry
await self._cleanup_orphaned_ports()
# Retry after cleanup
port = await self.redis.eval(
lua_script, 1, channel_id, port_start, port_end, port_step, timestamp
)
if port == -1:
raise RuntimeError(
"No available ports in configured range after cleanup"
)
logger.debug(f"Atomically allocated port {port} for channel {channel_id}")
return port
async def _release_port_for_channel(self, channel_id: str):
"""Atomically release port when channel ends.
Uses a Lua script to ensure all cleanup operations happen atomically,
preventing partial cleanup or race conditions during release.
"""
lua_script = """
local channel_id = KEYS[1]
-- Get the port allocated to this channel
local port = redis.call('HGET', 'channel_ports', channel_id)
if port then
-- Atomically clean up all related entries
redis.call('HDEL', 'channel_ports', channel_id)
redis.call('HDEL', 'port_channels', port)
redis.call('HDEL', 'channel_allocation_time', channel_id)
return port
end
return nil
"""
port = await self.redis.eval(lua_script, 1, channel_id)
if port:
logger.debug(f"Atomically released port {port} for channel {channel_id}")
else:
logger.debug(f"No port was allocated for channel {channel_id}")
async def _discover_workers(self):
"""Periodically discover active workers from Redis."""
try:
while self._running:
try:
# Get all worker IDs from the set
worker_ids = await self.redis.smembers(RedisKeys.workers_set())
# Filter to only active workers
active_workers = []
for worker_id in worker_ids:
worker_id = (
worker_id.decode()
if isinstance(worker_id, bytes)
else worker_id
)
worker_key = RedisKeys.worker_active(worker_id)
worker_data = await self.redis.get(worker_key)
if worker_data:
try:
data = json.loads(worker_data)
# Only include workers that are ready (not draining)
if data.get("status") == "ready":
active_workers.append(worker_id)
except json.JSONDecodeError:
logger.warning(f"Invalid worker data for {worker_id}")
# Update the cached list atomically
self._active_workers = active_workers
logger.info(f"Discovered {len(active_workers)} active workers")
except Exception as e:
logger.error(f"Error discovering workers: {e}")
# Check every 5 seconds
await asyncio.sleep(5)
except asyncio.CancelledError:
logger.debug("Worker discovery task cancelled")
async def _select_worker(self) -> Optional[str]:
"""Select a worker using round-robin."""
if not self._active_workers:
return None
# Use Redis to maintain round-robin index across restarts
try:
index = await self.redis.incr(RedisKeys.round_robin_index())
worker_index = (index - 1) % len(self._active_workers)
return self._active_workers[worker_index]
except Exception as e:
logger.error(f"Error selecting worker: {e}")
# Fallback to first worker if Redis operation fails
return self._active_workers[0] if self._active_workers else None
def _get_worker_for_channel(self, channel_id: str) -> str:
"""Get the assigned worker for a channel (for sending commands)."""
# Return the worker ID that was assigned to this channel
return self._channel_to_worker.get(channel_id, "")
async def _monitor_channel_commands(
self, channel_id: str, connection: ARIManagerConnection
):
"""Listen for commands from workers for this channel."""
# TODO: Not sure if its a good idea to monitor command for every channel
# using pubsub. What happens if there are more number of calls than number
# of tcp connections redis can support? We can do something similar to
# Campaign Orchestrator, where we can subscribe to one channel and have
# commands for every channel there.
command_channel = RedisChannels.channel_commands(channel_id)
pubsub = None
try:
pubsub = self.redis.pubsub()
await pubsub.subscribe(command_channel)
# Store the pubsub connection for cleanup
self._pubsubs[channel_id] = pubsub
logger.debug(f"channelID: {channel_id} Monitoring commands for channel")
async for message in pubsub.listen():
if message["type"] == "message":
try:
command = parse_command(message["data"])
if command:
await self._handle_worker_command(
channel_id, command, connection
)
else:
logger.warning(
f"Failed to parse command for {channel_id}: {message['data']}"
)
except Exception as e:
logger.exception(
f"Error handling command for {channel_id}: {e}"
)
except asyncio.CancelledError:
logger.debug(f"channelID: {channel_id} Command monitor cancelled")
raise # Re-raise to maintain proper cancellation semantics
except (ConnectionError, redis.exceptions.ConnectionError) as e:
# We close the pubsub before cancelling the task. So, the code
# flow will arrive here
pass
except Exception as e:
logger.exception(f"Error in command monitor for {channel_id}: {e}")
async def _handle_worker_command(
self,
channel_id: str,
command: BaseWorkerToARIManagerCommand,
connection: ARIManagerConnection,
):
"""Execute commands from workers."""
if isinstance(command, DisconnectCommand):
logger.info(
f"channelID: {channel_id} Worker requested disconnect: {command.reason}"
)
await connection.disconnect(command.reason)
elif isinstance(command, TransferCommand):
logger.info(f"channelID: {channel_id} Worker requested transfer")
await connection.transfer(command.context)
elif isinstance(command, SocketClosedCommand):
logger.info(f"channelID: {channel_id} Worker notified socket closed")
# Mark socket as closed
if channel_id in self._socket_closed:
self._socket_closed[channel_id] = True
# Release port immediately
await self._release_port_for_channel(channel_id)
# Check if both flags are set to cleanup
await self._check_and_cleanup_channel(channel_id)
else:
logger.warning(
f"channelID: {channel_id} Received unknown command: {command}"
)
async def _check_and_cleanup_channel(self, channel_id: str):
"""Check if both flags are set and cleanup channel if so."""
channel_disposed = self._channel_disposed.get(channel_id, False)
socket_closed = self._socket_closed.get(channel_id, False)
logger.debug(
f"channelID: {channel_id} Check cleanup - disposed: {channel_disposed}, socket_closed: {socket_closed}"
)
if channel_disposed and socket_closed:
# Remove from active channels and connections
self._active_channels.discard(channel_id)
self._channel_connections.pop(channel_id, None)
# Close pubsub connection first (before cancelling task)
if channel_id in self._pubsubs:
pubsub = self._pubsubs[channel_id]
try:
command_channel = RedisChannels.channel_commands(channel_id)
await pubsub.unsubscribe(command_channel)
await pubsub.aclose()
logger.debug(
f"channelID: {channel_id} Closed pubsub connection in cleanup"
)
except Exception as e:
logger.warning(f"Error closing pubsub for {channel_id}: {e}")
finally:
del self._pubsubs[channel_id]
# Cancel command monitor task
if channel_id in self._tasks:
task = self._tasks[channel_id]
if not task.done():
# Task is still running, cancel it
task.cancel()
try:
# Wait for task to complete
await task
logger.debug(
f"channelID: {channel_id} Task completed after cancel"
)
except asyncio.CancelledError:
logger.debug(
f"channelID: {channel_id} Task cancelled successfully"
)
except Exception as e:
logger.warning(
f"channelID: {channel_id} Task raised exception: {e}"
)
else:
# Task already completed
logger.debug(
f"channelID: {channel_id} Monitor task already completed"
)
try:
# Still await to get any exception that might have occurred
await task
except Exception as e:
logger.warning(
f"channelID: {channel_id} Completed task had exception: {e}"
)
del self._tasks[channel_id]
# Clean up the flag tracking
self._channel_disposed.pop(channel_id, None)
self._socket_closed.pop(channel_id, None)
logger.info(f"channelID: {channel_id} Completed cleanup of all resources")
async def _cleanup_orphaned_ports(self):
"""Clean up ports from previous ungraceful shutdowns."""
try:
# Get all channel-port mappings
channel_ports = await self.redis.hgetall("channel_ports")
if not channel_ports:
return
logger.info(
f"Found {len(channel_ports)} existing port allocations, checking for orphans..."
)
cleaned = 0
current_time = int(time.time())
max_age_seconds = 3600 # 1 hour
# On startup, we can safely assume any existing allocations are orphaned
# since this is a fresh instance with no active channels yet
if not self._active_channels:
# Clean up all existing allocations on startup
for channel_id, port in channel_ports.items():
allocation_time = await self.redis.hget(
"channel_allocation_time", channel_id
)
age_str = ""
if allocation_time:
age = current_time - int(allocation_time)
age_str = f" (aged {age}s)"
await self._release_port_for_channel(channel_id)
logger.info(
f"Cleaned up orphaned port {port} for channel {channel_id}{age_str}"
)
cleaned += 1
else:
# During runtime, only clean up channels not being tracked
for channel_id, port in channel_ports.items():
if channel_id not in self._active_channels:
# Check allocation age
allocation_time = await self.redis.hget(
"channel_allocation_time", channel_id
)
if allocation_time:
age = current_time - int(allocation_time)
if age > max_age_seconds:
# Too old, clean up regardless
await self._release_port_for_channel(channel_id)
logger.info(
f"Cleaned up stale port {port} for channel {channel_id} (aged {age}s)"
)
cleaned += 1
continue
# Not tracked by this instance, might be orphaned
# For safety, only clean up if reasonably old (5 minutes)
if (
allocation_time
and (current_time - int(allocation_time)) > 300
):
await self._release_port_for_channel(channel_id)
logger.info(
f"Cleaned up orphaned port {port} for untracked channel {channel_id}"
)
cleaned += 1
if cleaned > 0:
logger.info(f"Cleaned up {cleaned} orphaned port allocations")
except Exception as e:
logger.exception(f"Error during orphaned port cleanup: {e}")
async def _periodic_cleanup(self):
"""Periodically clean up orphaned ports."""
cleanup_interval = 1800 # 30 minutes
while self._running:
try:
await asyncio.sleep(cleanup_interval)
if self._running: # Check again after sleep
logger.info("Running periodic orphaned port cleanup...")
await self._cleanup_orphaned_ports()
except asyncio.CancelledError:
logger.debug("Periodic cleanup task cancelled")
break
except Exception as e:
logger.exception(f"Error in periodic cleanup: {e}")
async def run(self):
"""Main run loop for ARI Manager."""
self._running = True
# Setup ARI connection with supervisor
try:
self._ari_client_supervisor = await setup_ari_client_supervisor(
self.on_channel_start, self.on_channel_end
)
if not self._ari_client_supervisor:
logger.error("Failed to setup ARI connection")
return
# Start worker discovery task
self._worker_discovery_task = asyncio.create_task(self._discover_workers())
# Wait a moment for initial worker discovery
await asyncio.sleep(1)
logger.info(
f"ARI Manager started with {len(self._active_workers)} active workers"
)
# Clean up any orphaned ports from previous runs
await self._cleanup_orphaned_ports()
# Start periodic cleanup task
cleanup_task = asyncio.create_task(self._periodic_cleanup())
# Keep running until shutdown
while self._running:
await asyncio.sleep(1)
logger.debug("ARIManager._running is false. Will cleanup and shutdown")
# Cancel cleanup task
cleanup_task.cancel()
try:
await cleanup_task
except asyncio.CancelledError:
pass
except Exception as e:
logger.exception(f"ARI Manager error: {e}")
finally:
if self._ari_client_supervisor:
await self._ari_client_supervisor.close()
logger.info("ARI Manager stopped")
async def shutdown(self):
"""Graceful shutdown."""
logger.info("Shutting down ARI Manager...")
# Close supervisor first to prevent reconnection attempts
if self._ari_client_supervisor:
await self._ari_client_supervisor.close()
# Cancel worker discovery task
if self._worker_discovery_task:
self._worker_discovery_task.cancel()
try:
await self._worker_discovery_task
except asyncio.CancelledError:
pass
self._worker_discovery_task = None
# Now set running to False
self._running = False
# Clean up all active channel ports before shutting down
if self._active_channels:
logger.info(f"Cleaning up {len(self._active_channels)} active channels...")
for channel_id in list(
self._active_channels
): # Copy to avoid modification during iteration
await self._release_port_for_channel(channel_id)
logger.info(
f"Released port for active channel {channel_id} during shutdown"
)
self._active_channels.clear()
# Clear flag tracking
self._channel_disposed.clear()
self._socket_closed.clear()
# Cancel all monitoring tasks
for task in self._tasks.values():
task.cancel()
# Wait for tasks to complete
if self._tasks:
await asyncio.gather(*self._tasks.values(), return_exceptions=True)
async def main():
"""Main entry point for ARI Manager service."""
# Setup Redis connection
redis = await aioredis.from_url(REDIS_URL, decode_responses=True)
# Create and run manager
manager = ARIManager(redis)
# Create a shutdown event for clean coordination
shutdown_event = asyncio.Event()
# Setup signal handlers
loop = asyncio.get_event_loop()
def signal_handler(signum):
logger.info(f"Received shutdown signal {signum}")
# Set the shutdown event which will trigger shutdown
shutdown_event.set()
for sig in (signal.SIGTERM, signal.SIGINT):
loop.add_signal_handler(sig, lambda s=sig: signal_handler(s))
# Run manager with shutdown monitoring
manager_task = asyncio.create_task(manager.run())
shutdown_task = asyncio.create_task(shutdown_event.wait())
try:
# Wait for either normal completion or shutdown signal
done, pending = await asyncio.wait(
[manager_task, shutdown_task], return_when=asyncio.FIRST_COMPLETED
)
# If shutdown was triggered, perform graceful shutdown
if shutdown_task in done:
await manager.shutdown()
# Cancel the manager task if still running
if manager_task in pending:
manager_task.cancel()
try:
await manager_task
except asyncio.CancelledError:
pass
finally:
await redis.aclose()
# --- Ensure Axiom logging listener is stopped gracefully ---
if logging_queue_listener is not None:
logging_queue_listener.stop()
if __name__ == "__main__":
# Configure logging
logger.add("logs/ari_manager.log", rotation="10 MB")
asyncio.run(main())

View file

@ -0,0 +1,323 @@
"""ARI-specific Stasis connection for use by ARI Manager.
This connection has direct access to the ARI client and manages
the actual Asterisk channels, bridges, and RTP setup.
"""
import json
import os
import uuid
from typing import Optional
import httpx
from loguru import logger
from pipecat.utils.base_object import BaseObject
from api.services.telephony.ari_client import AsyncARIClient, Bridge, Channel
from api.services.telephony.ari_client_singleton import ari_client_singleton
class ARIManagerConnection(BaseObject):
"""ARI Manager's connection that directly controls Asterisk resources.
This class is used only by the ARI Manager process and has full
access to the ARI client for creating bridges, channels, etc.
"""
def __init__(
self,
caller_channel: Channel,
host: str,
port: int,
) -> None:
"""Initialize ARI Stasis connection.
Args:
caller_channel: The caller's channel object.
host: Host address for RTP transport.
port: Port number for RTP transport.
"""
super().__init__()
# External dependencies.
self._host: str = host
self._port: int = port
# Store channel IDs instead of Channel objects to avoid stale references
self.caller_channel_id: str = caller_channel.id
self.em_channel_id: Optional[str] = None # externalMedia channel ID
# Store bridge ID to avoid stale references after reconnection
self.bridge_id: Optional[str] = None
# RTP addressing information
self.local_addr = ("0.0.0.0", port)
self.remote_addr = None
# Internal state.
self._closed: bool = False
self._is_connected: bool = False
def is_connected(self) -> bool:
"""Check if the connection is established."""
return self._is_connected and not self._closed
@property
def _ari(self) -> Optional[AsyncARIClient]:
"""Get the current ARI client from singleton."""
return ari_client_singleton.get_client()
async def _get_channel(self, channel_id: str) -> Optional[Channel]:
"""Safely get a channel object by ID.
Returns None if the channel doesn't exist or can't be fetched.
"""
if not channel_id:
return None
try:
# Get current client from singleton
client = self._ari
if not client:
logger.warning(
f"Cannot get channel {channel_id} - No ARI client available"
)
return None
# Check if the session is still active
if not client._session or client._session.closed:
logger.warning(
f"Cannot get channel {channel_id} - ARI session is closed"
)
return None
return await client.channels.get(channelId=channel_id)
except Exception as e:
logger.warning(f"Could not get channel {channel_id} - {e}")
return None
async def _get_bridge(self, bridge_id: str) -> Optional[Bridge]:
"""Safely get a bridge object by ID.
Returns None if the bridge doesn't exist or can't be fetched.
"""
if not bridge_id:
return None
try:
# Get current client from singleton
client = self._ari
if not client:
logger.warning(
f"Cannot get bridge {bridge_id} - No ARI client available"
)
return None
# Check if the session is still active
if not client._session or client._session.closed:
logger.warning(f"Cannot get bridge {bridge_id} - ARI session is closed")
return None
return await client.bridges.get(bridgeId=bridge_id)
except Exception as e:
logger.warning(f"Could not get bridge {bridge_id}: {e}")
return None
async def _cleanup_resources(self):
"""Clean up external media channel and bridge."""
# Cleanup external media channel
try:
if self.em_channel_id:
em_channel = await self._get_channel(self.em_channel_id)
if em_channel:
await em_channel.hangup()
logger.debug(
f"channelID: {self.em_channel_id} Hung up external media"
)
self.em_channel_id = None
except Exception as exc:
logger.warning(
f"Failed to hang-up externalMedia channel: {self.em_channel_id}"
f"Error: {exc}"
)
# Cleanup bridge
try:
if self.bridge_id:
bridge = await self._get_bridge(self.bridge_id)
if bridge:
await bridge.destroy()
logger.debug(f"bridgeID: {self.bridge_id} Destroyed bridge")
self.bridge_id = None
except Exception as exc:
logger.warning(f"Failed to destroy bridge: {self.bridge_id}Error: {exc}")
async def _sync_call_data(self, call_transfer_context: dict):
"""Sync call data to ARI_DATA_SYNCING_URI."""
if not os.getenv("ARI_DATA_SYNCING_URI"):
return
lead_id = call_transfer_context.get("lead_id")
status = call_transfer_context.get("disposition")
# {'lead_id': '299154', 'disposition': 'VM', 'agent_name': 'Alex', 'decision_maker': 'False', 'employment': 'N/A', 'debts': 'N/A', 'number_of_credit_cards': 'N/A', 'time': '2025-08-07T13:16:02-04:00'}
full_name = call_transfer_context.get("full_name", "")
phone = call_transfer_context.get("phone", "")
debts = call_transfer_context.get("debts", "")
employment = call_transfer_context.get("employment", "")
time = call_transfer_context.get("time", "")
comment = f"Type:Qualified!NName:{full_name}!NPhone:{phone}!NDebts:{debts}!NCC:N/A!NDM:Yes!NEmployment:{employment}!NTime:{time}!NVendor Id:!NStatus:{status}"
try:
if lead_id and status:
ari_data_uri = os.getenv("ARI_DATA_SYNCING_URI")
# Add URL params to the base URL
sync_url = f"{ari_data_uri}&lead_id={lead_id}&status={status}&comments={comment}"
logger.debug(
f"channelID: {self.caller_channel_id} Syncing data to ARI_DATA_SYNCING_URI: {sync_url}"
)
async with httpx.AsyncClient() as client:
response = await client.post(sync_url, timeout=10.0)
response.raise_for_status()
logger.info(
f"channelID: {self.caller_channel_id} Successfully synced data for lead_id: {lead_id} with status: {status}"
)
else:
logger.warning(
f"channelID: {self.caller_channel_id} Missing lead_id or status for syncing"
)
except Exception as e:
logger.error(
f"channelID: {self.caller_channel_id} Failed to sync data to ARI_DATA_SYNCING_URI: {e}"
)
async def disconnect(self, reason: str):
"""Instruct Asterisk to hang-up the call and perform cleanup."""
if self._closed:
return
# Lets mark it as closed so that when we receive StasisEnd, we don't
# try to cleanup resource again
self._closed = True
# Clean up resources first
await self._cleanup_resources()
try:
if self.caller_channel_id:
caller_channel = await self._get_channel(self.caller_channel_id)
if caller_channel:
logger.debug(
f"channelID: {self.caller_channel_id} Hanging up caller channel due to reason: {reason}"
)
await caller_channel.hangup()
except Exception:
logger.exception("Failed to hangup caller channel")
async def transfer(self, call_transfer_context: dict):
"""Transfer the call by continuing in dialplan with extracted variables."""
if self._closed:
return
# Lets mark it as closed so that when we receive StasisEnd, we don't
# try to cleanup resource again
self._closed = True
try:
# Clean up resources before transferring
await self._cleanup_resources()
if self.caller_channel_id:
caller_channel = await self._get_channel(self.caller_channel_id)
if caller_channel:
logger.debug(
f"channelID: {self.caller_channel_id} User qualified, continuing in dialplan "
f"REMOTE_DISPO_CALL_VARIABLES: {json.dumps(call_transfer_context)}"
)
# Sync data to ARI_DATA_SYNCING_URI
await self._sync_call_data(
call_transfer_context=call_transfer_context
)
await caller_channel.continueInDialplan()
except Exception:
logger.exception("Failed to transfer caller channel")
async def setup_call(self):
"""Setup the bridge and external media channel.
This must be called after initialization to establish the connection.
"""
await self._setup_call(self._host, self._port)
async def _setup_call(self, host: str, port: int):
"""Create externalMedia + bridge and notify that the call is connected."""
try:
em_channel_id = str(uuid.uuid4())
logger.debug(
f"channelID: {em_channel_id} Creating externalMedia channel on {host}:{port}"
)
client = self._ari
if not client:
raise RuntimeError("No ARI client available")
em_channel = await client.channels.externalMedia(
app=client.app,
channelId=em_channel_id,
external_host=f"{host}:{port}",
format="ulaw",
direction="both",
)
# Store the channel ID
self.em_channel_id = em_channel.id
# Create a mixing bridge and add both legs.
bridge = await client.bridges.create(type="mixing")
self.bridge_id = bridge.id
# Add channels individually as AsyncARIClient expects single channel per call
await bridge.addChannel(channel=self.caller_channel_id)
await bridge.addChannel(channel=self.em_channel_id)
# TODO: Figure out how can we get the remote public IP. Till then
# just pick it from the environment variable
# Get RTP addressing information
# ip = await em_channel.getChannelVar(
# variable="UNICASTRTP_LOCAL_ADDRESS"
# )
port = await em_channel.getChannelVar(variable="UNICASTRTP_LOCAL_PORT")
self.remote_addr = (
os.environ.get("ASTERISK_REMOTE_IP"),
int(port["value"]),
)
logger.debug(
f"channelID: {self.caller_channel_id} ARIManagerConnection connection resources ready "
f"(bridgeID: {self.bridge_id}), (emChannelID: {self.em_channel_id})"
f"remote address: {self.remote_addr}, local address: {self.local_addr}"
)
self._is_connected = True
except Exception as exc:
logger.exception(f"Error setting up ARIManagerConnection: {exc}")
await self._cleanup_resources()
async def notify_channel_end(self):
"""Notify that a channel has ended. Received after we get StasisEnd on the caller channel"""
if self._closed:
return
self._closed = True
self._is_connected = False
# Cleanup resources using the shared method
await self._cleanup_resources()
def __repr__(self):
"""Return string representation of connection."""
return (
f"<ARIManagerConnection id={self.id} caller={self.caller_channel_id} "
f"em={self.em_channel_id} bridge={self.bridge_id} state={'closed' if self._closed else 'open'}>"
)

View file

@ -0,0 +1,184 @@
"""Redis communication protocol for distributed ARI architecture.
Defines message formats and helpers for ARI Manager <-> Worker communication.
"""
import json
from dataclasses import asdict, dataclass
from enum import Enum
from typing import Any, Dict, List, Optional
class EventType(str, Enum):
"""Types of events sent from ARI Manager to Workers."""
STASIS_START = "stasis_start"
STASIS_END = "stasis_end"
CHANNEL_UPDATE = "channel_update"
ERROR = "error"
class CommandType(str, Enum):
"""Types of commands sent from Workers to ARI Manager."""
DISCONNECT = "disconnect"
TRANSFER = "transfer"
UPDATE_STATE = "update_state"
SOCKET_CLOSED = "socket_closed"
@dataclass
class BaseWorkerToARIManagerCommand:
"""Base class for all commands sent from Workers to ARI Manager."""
type: str
channel_id: str = ""
def to_json(self) -> str:
return json.dumps(asdict(self))
@classmethod
def from_json(cls, data: str):
return cls(**json.loads(data))
@dataclass
class StasisStartEvent:
"""Event sent when a new call is bridged and ready."""
type: str = EventType.STASIS_START
channel_id: str = ""
caller_channel_id: str = ""
em_channel_id: Optional[str] = None
bridge_id: Optional[str] = None
local_addr: List[Any] = None # [host, port]
remote_addr: Optional[List[Any]] = None # [host, port] with UNICASTRTP_LOCAL_PORT
call_context_vars: Dict[str, Any] = None
def __post_init__(self):
if self.local_addr is None:
self.local_addr = []
if self.call_context_vars is None:
self.call_context_vars = {}
def to_json(self) -> str:
return json.dumps(asdict(self))
@classmethod
def from_json(cls, data: str) -> "StasisStartEvent":
return cls(**json.loads(data))
@dataclass
class StasisEndEvent:
"""Event sent when a call ends."""
type: str = EventType.STASIS_END
channel_id: str = ""
reason: Optional[str] = None
def to_json(self) -> str:
return json.dumps(asdict(self))
@classmethod
def from_json(cls, data: str) -> "StasisEndEvent":
return cls(**json.loads(data))
@dataclass
class DisconnectCommand(BaseWorkerToARIManagerCommand):
"""Command to disconnect a call."""
type: str = CommandType.DISCONNECT
reason: str = "worker_requested"
@dataclass
class TransferCommand(BaseWorkerToARIManagerCommand):
"""Command to transfer a call."""
type: str = CommandType.TRANSFER
context: Dict[str, Any] = None
def __post_init__(self):
if self.context is None:
self.context = {}
@dataclass
class SocketClosedCommand(BaseWorkerToARIManagerCommand):
"""Command to notify that RTP sockets have been closed."""
type: str = CommandType.SOCKET_CLOSED
class RedisChannels:
"""Redis channel naming conventions."""
@staticmethod
def worker_events(worker_id: str) -> str:
"""Channel for events sent to a specific worker."""
return f"ari:events:worker:{worker_id}"
@staticmethod
def channel_commands(channel_id: str) -> str:
"""Channel for commands related to a specific call channel."""
return f"ari:commands:{channel_id}"
@staticmethod
def channel_updates(channel_id: str) -> str:
"""Channel for state updates about a specific call."""
return f"ari:updates:{channel_id}"
class RedisKeys:
"""Redis key naming conventions for worker registration and discovery."""
@staticmethod
def worker_active(worker_id: str) -> str:
"""Key for active worker status and metadata."""
return f"workers:active:{worker_id}"
@staticmethod
def workers_set() -> str:
"""Set containing all registered worker IDs."""
return "workers:set"
@staticmethod
def round_robin_index() -> str:
"""Counter for round-robin worker selection."""
return "workers:round_robin:index"
def parse_event(data: str) -> Any:
"""Parse a Redis event message."""
try:
parsed = json.loads(data)
event_type = parsed.get("type")
if event_type == EventType.STASIS_START:
return StasisStartEvent(**parsed)
elif event_type == EventType.STASIS_END:
return StasisEndEvent(**parsed)
else:
return parsed
except Exception:
return None
def parse_command(data: str) -> Any:
"""Parse a Redis command message."""
try:
parsed = json.loads(data)
cmd_type = parsed.get("type")
if cmd_type == CommandType.DISCONNECT:
return DisconnectCommand(**parsed)
elif cmd_type == CommandType.TRANSFER:
return TransferCommand(**parsed)
elif cmd_type == CommandType.SOCKET_CLOSED:
return SocketClosedCommand(**parsed)
else:
return parsed
except Exception:
return None

View file

@ -0,0 +1,361 @@
"""Low-level RTP transport for Asterisk externalMedia sessions.
stasis_rtp_client.py
~~~~~~~~~~~~~~~~~~~~
* Sends and receives **proper RTP/UDP** (PT 0 PCMU/μ-law).
* Uses 20 ms frames (160 bytes payload) by default; automatically
chunks or concatenates data so timestamps stay correct.
* Verifies the RTP header on the receive path (SSRC and PT).
"""
import asyncio
import secrets
import socket
import struct
from typing import TYPE_CHECKING, AsyncIterator, Optional
from loguru import logger
from pipecat.utils.enums import EndTaskReason
if TYPE_CHECKING:
from api.services.telephony.stasis_rtp_connection import StasisRTPConnection
from api.services.telephony.stasis_rtp_transport import StasisRTPCallbacks
# ─────────────────────────────────────────────────────────────────── helpers
_RTP_HDR = struct.Struct("!BBHII") # v/p/x/cc, m/pt, seq, ts, ssrc
_PT_PCMU = 0 # static payload type for μ-law
class _RTPEncoder:
"""Builds PCMU RTP headers for the packets we SEND to Asterisk."""
def __init__(self):
self.ssrc = secrets.randbits(32)
self.seq = secrets.randbits(16)
self.ts = 0 # incremented by #payload bytes
def pack(self, payload: bytes, mark=False) -> bytes:
b0 = 0x80 # V=2
b1 = (0x80 if mark else 0x00) | _PT_PCMU
hdr = _RTP_HDR.pack(b0, b1, self.seq, self.ts, self.ssrc)
self.seq = (self.seq + 1) & 0xFFFF
self.ts += len(payload) # 1 sample/byte @ 8 kHz
return hdr + payload
class _RTPDecoder:
"""Very forgiving RTP decoder.
Latches on the first valid packet and then insists
that SSRC & PT match afterwards. Returns *None* if the packet
should be ignored.
"""
def __init__(self):
self.peer_ssrc: int | None = None # learned from first packet
def unpack(self, packet: bytes) -> bytes | None:
if len(packet) < _RTP_HDR.size:
return None
b0, b1, seq, ts, ssrc = _RTP_HDR.unpack_from(packet)
if (b0 & 0xC0) != 0x80: # RTP v2?
return None
if (b1 & 0x7F) != _PT_PCMU: # payload-type 0?
return None
if self.peer_ssrc is None:
self.peer_ssrc = ssrc # latch on first good packet
elif ssrc != self.peer_ssrc:
return None # stray stream drop
return packet[_RTP_HDR.size :]
# ──────────────────────────────────────────────────────────────── client
class StasisRTPClient:
"""Low-level wrapper around StasisRTPConnection.
Public API
await setup(start_frame) kept for parity (does nothing)
await connect()
async for payload in receive(): # μ-law bytes (20 ms each)
await send(data) # any length; will be chunked
await disconnect()
"""
_FRAME_SIZE = 160 # 20 ms @ 8 kHz PCMU
def __init__(
self,
connection: "StasisRTPConnection",
callbacks: "StasisRTPCallbacks",
):
"""Initialize Stasis RTP client.
Args:
connection: RTP connection parameters.
callbacks: Callback handlers for transport events.
"""
from typing import Any
self._connection = connection
self._callbacks = callbacks
self._encoder = _RTPEncoder()
self._decoder = _RTPDecoder()
self._recv_sock: Optional[socket.socket] = None
self._send_sock: Optional[socket.socket] = None
self._closing = False
self._recv_sock_ready = asyncio.Event() # Signal when recv socket is ready
self._leave_counter = 0 # Track input/output transport usage
self._fallback_disconnect_timer: Optional[asyncio.Task] = (
None # Safety timer for disconnect
)
# ── wire event handlers to the connection ────────────────
@self._connection.event_handler("connected")
async def _on_connected(_: Any):
await self._setup_sockets()
await self._callbacks.on_client_connected(
self._connection.caller_channel_id
)
@self._connection.event_handler("disconnected")
async def _on_disconnected(_: Any, reason: str):
# Cancel the safety timer if it exists. We start the safety timer when
# sending disconnect or transfer from the engine, i.e when the disconnect()
# method of the StasisRTPClient is called during wind down of the pipeline.
# We start the timer so that if we don't get the remote hangup in a given
# duration, we will call client disconnected handler.
if (
self._fallback_disconnect_timer
and not self._fallback_disconnect_timer.done()
):
self._fallback_disconnect_timer.cancel()
self._fallback_disconnect_timer = None
if not self._closing:
# Mark the client as closing, so that when the pipeline is
# cancelled or getting closed, we don't try start the fallback
# disconnect timer and return safely from disconnect
self._closing = True
await self._callbacks.on_client_disconnected(
self._connection.caller_channel_id, reason
)
# ─── public helpers ──────────────────────────────────────────
async def setup(self, _):
"""Setup method for compatibility."""
self._leave_counter += 1
async def connect(self):
"""Connect to the RTP socket."""
if self._connection.is_connected():
return
await self._connection.connect()
async def disconnect(
self,
reason: str = EndTaskReason.UNKNOWN.value,
call_transfer_context: dict = {}, # Keep parameter for backward compatibility
):
"""Disconnect from the RTP socket."""
# Decrement leave counter when disconnect is called
logger.debug(f"StasisRTPClient.disconnect leave_counter: {self._leave_counter}")
self._leave_counter -= 1
if self._leave_counter > 0:
# Early return - InputTransport called first, OutputTransport will call later
return
# Only proceed when counter reaches 0 (OutputTransport's call)
# Close sockets
logger.debug("Going to close sockets")
await self._close_sockets()
if self._closing:
# We might have received the disconnected callback from the StasisRTPConnection
# due to user hangup. We will just return. We have already closed the sockets
# in disconnected callback handler.
return
self._closing = True
# Create a safety timer that will call on_client_disconnected if we don't
# get StasisEnd from the dialer within 5 seconds. StasisEnd is needed to
# trigger on_client_disconnected handler in the event_handlers
async def _fallback_disconnect_timeout():
await asyncio.sleep(5.0)
logger.warning(
"Disconnect event not received within 5 seconds, calling on_client_disconnected as fallback"
)
await self._callbacks.on_client_disconnected(
self._connection.caller_channel_id
)
self._fallback_disconnect_timer = asyncio.create_task(
_fallback_disconnect_timeout()
)
# Only call disconnect if not a transfer (transfer already handled in PipecatEngine)
# NOTE: Transfer now happens immediately in PipecatEngine.send_end_task_frame()
if reason != EndTaskReason.USER_QUALIFIED.value:
try:
await self._connection.disconnect(reason)
except Exception as exc:
logger.error(f"Failed to disconnect RTP connection: {exc}")
else:
logger.debug(
"Skipping disconnect call for USER_QUALIFIED - transfer already initiated by engine"
)
# ─── socket management ──────────────────────────────────────
async def _setup_sockets(self):
if self._recv_sock and self._send_sock:
return
logger.debug(
f"Setting up Sockets - local {self._connection.local_addr}, remote: {self._connection.remote_addr}"
)
# receive socket bind to local address provided by connection
if not self._recv_sock:
rs = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
rs.setblocking(False)
rs.bind(self._connection.local_addr)
self._recv_sock = rs
self._recv_sock_ready.set() # Signal that recv socket is ready
# send socket connect to remote (Asterisk) address
if not self._send_sock:
ss = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
ss.setblocking(False)
ss.connect(self._connection.remote_addr)
self._send_sock = ss
logger.debug(
f"Socket setup complete - recv_fd: {self._recv_sock.fileno()}, send_fd: {self._send_sock.fileno()}"
)
async def _close_sockets(self):
"""Safely close sockets with proper error handling."""
for sock_name, sock in [("recv", self._recv_sock), ("send", self._send_sock)]:
if sock:
try:
# Shutdown the socket first to break any pending operations
sock.shutdown(socket.SHUT_RDWR)
except OSError:
# Socket might already be closed or in a bad state
pass
try:
sock.close()
except Exception as exc:
logger.debug(f"Error closing {sock_name} socket: {exc}")
self._recv_sock = None
self._send_sock = None
self._recv_sock_ready.clear() # Reset the event for potential reconnection
# Notify the connection that sockets are closed so ARI Manager can clean up ports
await self._connection.notify_sockets_closed()
logger.debug("Closed sockets in StasisRTPClient")
# ─── receive path ────────────────────────────────────────────
async def receive(self) -> AsyncIterator[bytes]:
"""Async generator yielding μ-law frames (exactly 160 bytes each).
Silently drops any packet whose RTP header does not match our SSRC/PT.
"""
loop = asyncio.get_running_loop()
# Wait for recv socket to be created
try:
await self._recv_sock_ready.wait()
except asyncio.CancelledError:
return
logger.debug("Going to receive from the socket now")
while not self._closing:
try:
# each loop gets 172 bytes UDP packet, which is 160 bytes of
# audio data (Asterisk sends 20ms audio chunks with 8k sample rate)
# and 12 bytes of RTP header
data = await loop.sock_recv(self._recv_sock, 2048)
except asyncio.CancelledError:
logger.debug("RTP receive task cancelled")
break
except (OSError, socket.error) as exc:
logger.warning(f"RTP receive failed (socket closed): {exc}")
break
except Exception as exc:
logger.debug(f"Unexpected error in receive: {exc}")
break
payload = self._decoder.unpack(data)
if payload is None:
continue # header failed validation
# In practice Asterisk sends 20 ms frames assert just in case.
if len(payload) != self._FRAME_SIZE:
logger.warning(f"Dropping non-20 ms packet len={len(payload)}")
continue
yield payload
# ─── send path ───────────────────────────────────────────────
async def send(self, data: bytes):
"""Send μ-law data of arbitrary length.
Splits/aggregates into 160-byte chunks before RTP-wrapping.
"""
if self._closing or not self._send_sock:
return
loop = asyncio.get_running_loop()
# chunk/concat to 160-byte frames
chunks = self._chunk_ulaw(data, self._FRAME_SIZE)
for i, chunk in enumerate(chunks):
mark = i == 0 # set marker on the first packet of talk-spurt
packet = self._encoder.pack(chunk, mark=mark)
try:
await loop.sock_sendall(self._send_sock, packet)
except (OSError, socket.error) as exc:
logger.warning(f"RTP send failed (socket closed): {exc}")
break
except Exception as exc:
logger.error(f"RTP send failed: {exc}")
break
def _chunk_ulaw(self, buf: bytes, size: int) -> list[bytes]:
"""Split / aggregate μ-law bytes to exact *size* multiples.
If buf length is not a multiple of *size*, pad the last chunk with 0xFF
(silence). That keeps timestamps monotonic.
"""
if not buf:
return []
if len(buf) % size:
pad = size - (len(buf) % size)
buf += b"\xff" * pad
return [buf[i : i + size] for i in range(0, len(buf), size)]
# ─── properties ──────────────────────────────────────────────
@property
def is_connected(self) -> bool:
"""Check if client is connected."""
return self._connection.is_connected() and not self._closing
@property
def is_closing(self) -> bool:
"""Check if client is closing."""
return self._closing

View file

@ -0,0 +1,182 @@
"""Stasis RTP connection for worker processes.
This connection works without direct ARI access and communicates with
the ARI Manager via Redis for all control operations.
"""
from typing import Optional, Tuple
import redis.asyncio as aioredis
from loguru import logger
from pipecat.utils.base_object import BaseObject
from pipecat.utils.enums import EndTaskReason
from api.services.telephony.stasis_event_protocol import (
DisconnectCommand,
RedisChannels,
SocketClosedCommand,
TransferCommand,
)
class StasisRTPConnection(BaseObject):
"""Worker-side connection that communicates with ARI Manager via Redis.
This class provides the same API as the original StasisRTPConnection but
without direct ARI client access. All channel operations are delegated
to the ARI Manager process via Redis.
"""
_SUPPORTED_EVENTS = [
"connecting",
"connected",
"disconnected",
"closed",
"failed",
"new",
]
def __init__(
self,
redis_client: aioredis.Redis,
channel_id: str,
caller_channel_id: str,
em_channel_id: Optional[str],
bridge_id: Optional[str],
local_addr: Optional[Tuple[str, int]],
remote_addr: Optional[Tuple[str, int]],
workflow_run_id: Optional[int] = None,
):
"""Initialize distributed connection with pre-established details.
Args:
redis_client: Redis client for communication
channel_id: Primary channel ID for this connection
caller_channel_id: Caller's channel ID
em_channel_id: External media channel ID
bridge_id: Bridge ID (already created by ARI Manager)
local_addr: Local RTP address (host, port)
remote_addr: Remote RTP address with UNICASTRTP_LOCAL_PORT
workflow_run_id: Workflow run ID for logging context
"""
super().__init__()
self.redis = redis_client
self.channel_id = channel_id
self.caller_channel_id = caller_channel_id
self.em_channel_id = em_channel_id
self.bridge_id = bridge_id
self.workflow_run_id = workflow_run_id
# RTP addressing (same as StasisRTPConnection)
self.local_addr = local_addr
self.remote_addr = remote_addr
# State tracking
# self._closed_by_stasis_end should only be set True after we get
# StasisEnd from the transport
self._closed_by_stasis_end = False
self._connect_invoked = False
# Register event handlers
for evt in self._SUPPORTED_EVENTS:
self._register_event_handler(evt)
logger.debug(
f"channelID: {channel_id} StasisRTPConnection created: "
f"bridgeID: {bridge_id}, local_addr={local_addr}, remote_addr={remote_addr}"
)
async def connect(self):
"""Signal readiness to start the call.
Since the bridge is already established by ARI Manager,
we can immediately trigger the connected event.
"""
self._connect_invoked = True
if self.is_connected():
await self._call_event_handler("connected")
else:
logger.warning(
"StasisRTPConnection is not connected - did not call connected handler"
)
async def disconnect(self, reason: str):
"""Request disconnection via Redis command to ARI Manager. Usually called
when there is a disconnect triggered by workflow"""
# If we have already received user hangup via StasisEnd, lets
# return
if self._closed_by_stasis_end:
return
logger.info(f"channelID: {self.channel_id} Requesting disconnect: {reason}")
# Send disconnect command to ARI Manager
command = DisconnectCommand(channel_id=self.channel_id, reason=reason)
channel = RedisChannels.channel_commands(self.channel_id)
await self.redis.publish(channel, command.to_json())
async def transfer(self, call_transfer_context: dict):
"""Request call transfer via Redis command to ARI Manager."""
# If we have already received user hangup via StasisEnd, lets
# return
if self._closed_by_stasis_end:
return
logger.info(f"channelID: {self.channel_id} Requesting transfer")
# Send transfer command to ARI Manager
command = TransferCommand(
channel_id=self.channel_id, context=call_transfer_context
)
channel = RedisChannels.channel_commands(self.channel_id)
await self.redis.publish(channel, command.to_json())
async def notify_sockets_closed(self):
"""Notify ARI Manager that RTP sockets have been closed."""
logger.info(
f"channelID: {self.channel_id} Notifying ARI Manager that sockets are closed"
)
# Send socket_closed command to ARI Manager
command = SocketClosedCommand(channel_id=self.channel_id)
channel = RedisChannels.channel_commands(self.channel_id)
await self.redis.publish(channel, command.to_json())
def is_connected(self) -> bool:
"""Check if connection is established.
Returns True once connect() has been called and connection is not closed.
"""
return self._connect_invoked and not self._closed_by_stasis_end
async def handle_remote_disconnect(self, reason: str = EndTaskReason.UNKNOWN.value):
"""Handle disconnection initiated by ARI Manager. Is called when the user hangs up."""
if self._closed_by_stasis_end:
return
self._closed_by_stasis_end = True
if self._connect_invoked:
# Unless self._connect_invoked is True, the event handlers won't be registered. We only
# register the event handler of client when the transports are initiated during pipeline
# initialisation. Any caller must check and wait for _connect_invoked before
# calling the method
await self._call_event_handler("disconnected", reason)
else:
logger.warning(
f"ChannelID: {self.channel_id} Got remote disconnect before connection was invoked"
)
logger.info(
f"channelID: {self.channel_id} StasisRTPConnection disconnected: {reason}"
)
def __repr__(self):
"""String representation of connection."""
return (
f"<StasisRTPConnection id={self.id} channel={self.channel_id} "
f"caller={self.caller_channel_id} em={self.em_channel_id} "
f"state={'closed' if self._closed_by_stasis_end else 'open'}>"
)

View file

@ -0,0 +1,120 @@
# Copyright (c) 20242025, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
"""Stasis RTP frame serializer.
This serializer converts between Pipecat frames and the raw μ-law RTP payload
stream expected by an Stasis *External Media* channel.
The serializer:
* Down-samples PCM to 8-kHz μ-law for **outgoing** audio (:class:`AudioRawFrame`).
* Up-samples μ-law to the pipeline's native rate for **incoming** audio.
"""
from typing import Optional
from loguru import logger
from pipecat.audio.utils import create_default_resampler, pcm_to_ulaw, ulaw_to_pcm
from pipecat.frames.frames import (
AudioRawFrame,
Frame,
InputAudioRawFrame,
StartFrame,
)
from pipecat.serializers.base_serializer import FrameSerializer, FrameSerializerType
from pydantic import BaseModel
class StasisRTPFrameSerializer(FrameSerializer):
"""Serializer for Asterisk External Media streams (raw μ-law)."""
class InputParams(BaseModel):
"""Configuration parameters.
Attributes:
----------
stasis_sample_rate : int, default 8000
The sample-rate used by Stasis when sending μ-law (PCMU).
sample_rate : Optional[int]
Override for the pipeline's *input* sample-rate. When omitted the
value from the :class:`StartFrame` is used.
"""
stasis_sample_rate: int = 8000
sample_rate: Optional[int] = None
def __init__(self, params: Optional[InputParams] = None):
"""Initialize Stasis RTP frame serializer.
Args:
params: Optional configuration parameters for the serializer.
"""
self._params = params or self.InputParams()
# Wire / pipeline rates
self._stasis_sample_rate = self._params.stasis_sample_rate
self._sample_rate = 0 # pipeline rate, filled in *setup*
# Resampler shared between encode / decode paths
self._resampler = create_default_resampler()
@property
def type(self) -> FrameSerializerType:
"""Stasis uses raw bytes → BINARY."""
return FrameSerializerType.BINARY
async def setup(self, frame: StartFrame):
"""Remember pipeline configuration."""
self._sample_rate = self._params.sample_rate or frame.audio_in_sample_rate
async def serialize(self, frame: Frame) -> bytes | str | None:
"""Convert a Pipecat frame to a wire payload.
Only :class:`AudioRawFrame` instances are translated all other frame
types are silently ignored, allowing higher-level transports to deal
with them as needed.
"""
if isinstance(frame, AudioRawFrame):
try:
# Pipeline PCM → 8-kHz μ-law
encoded = await pcm_to_ulaw(
frame.audio,
frame.sample_rate,
self._stasis_sample_rate,
self._resampler,
)
return encoded # raw bytes
except Exception as exc: # pragma: no cover robustness
logger.error(
f"StasisRTPFrameSerializer.serialize: encode failed: {exc}"
)
return None
# Non-audio frames are not transmitted on the media path
return None
async def deserialize(self, data: bytes | str) -> Frame | None:
"""Convert wire payloads to Pipecat frames.
The Stasis media socket delivers bare μ-law bytes, therefore *data*
must be *bytes*. Any *str* is ignored.
"""
if not isinstance(data, (bytes, bytearray)):
return None
try:
pcm = await ulaw_to_pcm(
bytes(data),
self._stasis_sample_rate,
self._sample_rate,
self._resampler,
)
return InputAudioRawFrame(
audio=pcm,
sample_rate=self._sample_rate,
num_channels=1,
)
except Exception as exc: # pragma: no cover
logger.error(f"StasisRTPFrameSerializer.deserialize: decode failed: {exc}")
return None

View file

@ -0,0 +1,324 @@
# transports/ari_external_media.py (new file)
"""Stasis RTP transport for Asterisk External Media integration."""
import asyncio
import time
from typing import Awaitable, Callable, Optional
from loguru import logger
from pipecat.frames.frames import (
CancelFrame,
EndFrame,
InputAudioRawFrame,
OutputAudioRawFrame,
StartFrame,
TransportMessageFrame,
TransportMessageUrgentFrame,
)
from pipecat.serializers.base_serializer import FrameSerializer
from pipecat.transports.base_input import BaseInputTransport
from pipecat.transports.base_output import (
BaseOutputTransport,
TransportClientNotConnectedException,
)
from pipecat.transports.base_transport import BaseTransport, TransportParams
from pipecat.utils.enums import EndTaskReason
from pydantic import BaseModel
from api.services.telephony.stasis_rtp_client import StasisRTPClient
from api.services.telephony.stasis_rtp_connection import StasisRTPConnection
class StasisRTPTransportParams(TransportParams):
"""Transport parameters for Stasis RTP transport."""
serializer: FrameSerializer
class StasisRTPCallbacks(BaseModel):
"""Callbacks for Stasis RTP transport events."""
on_client_connected: Callable[[str], Awaitable[None]]
on_client_disconnected: Callable[
[str, Optional[str]], Awaitable[None]
] # Added optional disconnect reason
on_client_closed: Callable[[str], Awaitable[None]]
# ------------------------------------------------ Input Transport -------------------------
"""
Transport calls client receive to receive the audio from the socket. This happens in the self._receive_audio task.
Then the audio frames are pushed to _audio_in_queue using push_audio_frame method. Then the _audio_task_handler processes
the frames from the _audio_in_queue and pushes them to the VAD analyzer, turn analyzer and pushes the audio
further downstream to tts.
The BaseInputTransport pipeline is responsible for:
- Resampling the audio to the correct sample rate
- Applying the audio filter
- Pushing the audio frames to the VAD analyzer
- Pushing the audio frames to the turn analyzer
- Pushing the audio frames to the bot interruption analyzer
- Pushing the audio frames down the pipeline to the tts
stop method is called from process_frame of the BaseInputTransport. super.stop() stops _audio_task_handler. It then
calls _client.disconnect. Transport's callbacks are sent to the client using StasisRTPCallbacks.
"""
class StasisRTPInputTransport(BaseInputTransport):
"""Input transport for receiving audio over Stasis RTP."""
def __init__(
self,
transport: BaseTransport,
client: StasisRTPClient,
params: StasisRTPTransportParams,
**kwargs,
):
"""Initialize Stasis RTP input transport.
Args:
transport: Parent transport instance.
client: Stasis RTP client for socket communication.
params: Transport parameters including serializer.
**kwargs: Additional keyword arguments for BaseInputTransport.
"""
super().__init__(params, **kwargs)
self._transport = transport
self._client = client
self._params = params
self._receive_task: Optional[asyncio.Task] = None
async def start(self, frame: StartFrame):
"""Start the input transport."""
await super().start(frame)
await self._client.setup(frame)
await self._params.serializer.setup(frame)
# Ensure underlying connection is established and socket ready.
await self._client.connect()
if not self._receive_task:
self._receive_task = self.create_task(self._receive_audio())
await self.set_transport_ready(frame)
async def _stop_tasks(self):
if self._receive_task:
await self.cancel_task(self._receive_task)
self._receive_task = None
async def stop(self, frame: EndFrame):
"""Stop the input transport."""
await super().stop(frame)
await self._stop_tasks()
# Call disconnect on the client when EndFrame is encountered
await self._client.disconnect(
frame.metadata.get("reason", EndTaskReason.UNKNOWN.value),
frame.metadata.get("call_transfer_context", {}),
)
logger.debug("Successfully disconnected from StasisRTPClient")
async def cancel(self, frame: CancelFrame):
"""Cancel the input transport."""
await super().cancel(frame)
await self._stop_tasks()
# Call disconnect on the client when CancelFrame is encountered
await self._client.disconnect(
frame.metadata.get("reason", EndTaskReason.SYSTEM_CANCELLED.value),
frame.metadata.get("call_transfer_context", {}),
)
async def _receive_audio(self):
try:
async for payload in self._client.receive():
frame = await self._params.serializer.deserialize(payload)
if not frame:
continue
if isinstance(frame, InputAudioRawFrame):
await self.push_audio_frame(frame)
else:
await self.push_frame(frame)
except Exception as exc:
logger.error(f"StasisRTPInputTransport receive error: {exc}")
# No app-messages in RTP path, but keep compatibility
async def push_app_message(self, message):
"""Push app message (not supported in RTP transport)."""
logger.debug("StasisRTPInputTransport received app message ignored (RTP only)")
# ------------------------------------------------ Output Transport ------------------------
class StasisRTPOutputTransport(BaseOutputTransport):
"""Output transport for sending audio over Stasis RTP."""
def __init__(
self,
transport: BaseTransport,
client: StasisRTPClient,
params: StasisRTPTransportParams,
**kwargs,
):
"""Initialize Stasis RTP output transport.
Args:
transport: Parent transport instance.
client: Stasis RTP client for socket communication.
params: Transport parameters including serializer.
**kwargs: Additional keyword arguments for BaseOutputTransport.
"""
super().__init__(params, **kwargs)
self._transport = transport
self._client = client
self._params = params
# Pace outgoing audio so we don't dump buffers instantly (simulate 10-ms chunks)
self._send_interval: float = 0
self._next_send_time: float = 0
async def start(self, frame: StartFrame):
"""Start the output transport."""
await super().start(frame)
await self._client.setup(frame)
await self._params.serializer.setup(frame)
self._send_interval = self._params.audio_out_10ms_chunks * 10 / 1000 # ms
await self.set_transport_ready(frame)
async def stop(self, frame: EndFrame):
"""Stop the output transport."""
await super().stop(frame)
# Call disconnect on the client when EndFrame is encountered
# The client will check its _leave_counter and decide whether to close sockets
await self._client.disconnect(
frame.metadata.get("reason", EndTaskReason.UNKNOWN.value),
frame.metadata.get("call_transfer_context", {}),
)
async def cancel(self, frame: CancelFrame):
"""Cancel the output transport."""
await super().cancel(frame)
# Call disconnect on the client when CancelFrame is encountered
await self._client.disconnect(
frame.metadata.get("reason", EndTaskReason.SYSTEM_CANCELLED.value),
frame.metadata.get("call_transfer_context", {}),
)
async def send_message(
self, frame: TransportMessageFrame | TransportMessageUrgentFrame
):
"""Send message frame (not supported in RTP transport)."""
# RTP path has no generic message channel; ignore.
pass
async def write_audio_frame(self, frame: OutputAudioRawFrame):
"""Write audio frame to RTP stream."""
if self._client.is_closing:
raise TransportClientNotConnectedException()
if not self._client.is_connected:
# If not connected yet, just simulate playback delay.
await self._write_audio_sleep()
return
payload = await self._params.serializer.serialize(frame)
if payload:
await self._client.send(payload)
await self._write_audio_sleep()
async def _write_audio_sleep(self):
"""Simulates real-time audio playback timing by introducing controlled delays.
This method implements a clock simulation to pace audio transmission at realistic
intervals. Without this pacing, audio frames would be sent as fast as possible,
which could overwhelm receivers or cause buffering issues.
The method:
1. Calculates how long to sleep based on when the next frame should be sent
2. Sleeps for the calculated duration (or 0 if we're already behind schedule)
3. Updates _next_send_time for the next audio chunk
The _send_interval is computed as: (audio_chunk_size / sample_rate) / 2
This creates timing that simulates how an actual audio device would output
audio at the proper rate (e.g., every 10ms for 10ms audio chunks).
"""
current_time = time.monotonic()
sleep_duration = max(0, self._next_send_time - current_time)
await asyncio.sleep(sleep_duration)
if sleep_duration == 0:
self._next_send_time = time.monotonic() + self._send_interval
else:
self._next_send_time += self._send_interval
class StasisRTPTransport(BaseTransport):
"""Main transport class for Stasis RTP communication."""
def __init__(
self,
stasis_connection: StasisRTPConnection,
params: StasisRTPTransportParams,
input_name: Optional[str] = None,
output_name: Optional[str] = None,
):
"""Initialize Stasis RTP transport.
Args:
stasis_connection: Connection parameters for Stasis RTP.
params: Transport parameters including serializer.
input_name: Optional name for input transport.
output_name: Optional name for output transport.
"""
super().__init__(input_name=input_name, output_name=output_name)
self._params = params
client_callbacks = StasisRTPCallbacks(
on_client_connected=self._on_client_connected,
on_client_disconnected=self._on_client_disconnected,
on_client_closed=self._on_client_closed,
)
self._client = StasisRTPClient(stasis_connection, client_callbacks)
self._input = StasisRTPInputTransport(
self, self._client, self._params, name=self._input_name
)
self._output = StasisRTPOutputTransport(
self, self._client, self._params, name=self._output_name
)
# expose handlers
self._register_event_handler("on_client_connected")
self._register_event_handler("on_client_disconnected")
self._register_event_handler("on_client_closed")
def input(self) -> StasisRTPInputTransport:
"""Get the input transport."""
return self._input
def output(self) -> StasisRTPOutputTransport:
"""Get the output transport."""
return self._output
# ------------------------------------------------ event adapters ----------
async def _on_client_connected(self, chan_id: str):
await self._call_event_handler("on_client_connected", chan_id)
async def _on_client_disconnected(self, chan_id: str, reason: Optional[str] = None):
await self._call_event_handler("on_client_disconnected", chan_id, reason)
async def _on_client_closed(self, chan_id: str):
await self._call_event_handler("on_client_closed", chan_id)

View file

@ -0,0 +1,105 @@
#!/usr/bin/env python3
"""Test script to verify asyncari ping functionality."""
import asyncio
import os
import sys
from pathlib import Path
# Add the asyncari src to Python path for testing
asyncari_path = Path(__file__).parent.parent.parent.parent.parent / "asyncari" / "src"
sys.path.insert(0, str(asyncari_path))
import asyncari
from loguru import logger
async def test_ping():
"""Test the ping functionality with asyncari."""
# Configure from environment or use defaults
base_url = os.getenv("ARI_STASIS_ENDPOINT", "http://localhost:8088")
username = os.getenv("ARI_STASIS_USER", "asterisk")
password = os.getenv("ARI_STASIS_USER_PASSWORD", "asterisk")
apps = os.getenv("ARI_STASIS_APP_NAME", "test-app")
logger.info(f"Connecting to ARI at {base_url}")
try:
async with asyncari.connect(
base_url=base_url, apps=apps, username=username, password=password
) as client:
logger.info("Connected to ARI")
# Test REST API ping
logger.info("Testing REST API ping...")
result = await client.asterisk.ping()
logger.info(f"REST API ping successful: {result}")
# Test WebSocket ping (should work with our wrapper)
logger.info("Testing WebSocket ping...")
for ws in client.websockets:
try:
await ws.ping()
logger.info("WebSocket ping() called successfully (no-op)")
except AttributeError:
logger.error("WebSocket doesn't have ping() method")
except Exception as e:
logger.error(f"WebSocket ping failed: {e}")
# Test the keep_alive function
from ari_client_manager import keep_alive
logger.info("Starting keep_alive task...")
keep_alive_task = asyncio.create_task(keep_alive(client, interval=5.0))
# Run for 20 seconds to see several pings
await asyncio.sleep(20)
# Cancel keep_alive
keep_alive_task.cancel()
try:
await keep_alive_task
except asyncio.CancelledError:
logger.info("keep_alive task cancelled")
logger.info("Test completed successfully!")
except Exception as e:
logger.exception(f"Test failed: {e}")
return False
return True
async def test_with_manager():
"""Test using the ARI client manager."""
from ari_client_manager import setup_ari_client_supervisor
async def on_stasis_call(client, channel, context_vars):
logger.info(f"Received call: {channel.id}")
# Enable ARI Stasis for testing
os.environ["ENABLE_ARI_STASIS"] = "true"
supervisor = await setup_ari_client_supervisor(on_stasis_call)
if supervisor:
logger.info("ARI Stasis supervisor started with ping support")
# Run for 30 seconds
await asyncio.sleep(30)
await supervisor.close()
logger.info("Supervisor closed")
else:
logger.error("Failed to start supervisor")
if __name__ == "__main__":
import sys
if len(sys.argv) > 1 and sys.argv[1] == "manager":
asyncio.run(test_with_manager())
else:
asyncio.run(test_ping())

View file

@ -0,0 +1,83 @@
#!/usr/bin/env python3
"""Test script to verify real WebSocket ping frames are being sent."""
import asyncio
import os
import sys
from pathlib import Path
# Add the asyncari src to Python path
asyncari_path = Path(__file__).parent.parent.parent.parent.parent / "asyncari" / "src"
sys.path.insert(0, str(asyncari_path))
import asyncari
from loguru import logger
# Enable debug logging to see ping frames
logger.add(sys.stderr, level="DEBUG")
async def test_real_ping():
"""Test that real WebSocket ping frames are sent."""
# Configure from environment or use defaults
base_url = os.getenv("ARI_STASIS_ENDPOINT", "http://localhost:8088")
username = os.getenv("ARI_STASIS_USER", "asterisk")
password = os.getenv("ARI_STASIS_USER_PASSWORD", "asterisk")
apps = os.getenv("ARI_STASIS_APP_NAME", "test-app")
logger.info(f"Connecting to ARI at {base_url}")
try:
async with asyncari.connect(
base_url=base_url, apps=apps, username=username, password=password
) as client:
logger.info("Connected to ARI")
# Get the WebSocket
for ws in client.websockets:
logger.info(f"WebSocket type: {type(ws)}")
logger.info(
f"WebSocket wrapper active: {'WebSocketWrapper' in str(type(ws))}"
)
# Check internal structure
if hasattr(ws, "_websocket"):
inner_ws = ws._websocket
logger.info(f"Inner WebSocket type: {type(inner_ws)}")
logger.info(f"Has _connection: {hasattr(inner_ws, '_connection')}")
logger.info(f"Has _sock: {hasattr(inner_ws, '_sock')}")
# Send a test ping
logger.info("Sending test ping...")
try:
await ws.ping(b"test-ping-123")
logger.info("Ping sent successfully!")
except Exception as e:
logger.error(f"Ping failed: {e}")
# Test the keep_alive function
logger.info("\nTesting keep_alive function...")
from ari_client_manager import keep_alive
# Run keep_alive for a short time
keep_alive_task = asyncio.create_task(keep_alive(client, interval=3.0))
# Let it run for 10 seconds to see multiple pings
await asyncio.sleep(10)
# Cancel and cleanup
keep_alive_task.cancel()
try:
await keep_alive_task
except asyncio.CancelledError:
pass
logger.info("Test completed!")
except Exception as e:
logger.exception(f"Test failed: {e}")
if __name__ == "__main__":
asyncio.run(test_real_ping())

View file

@ -0,0 +1,207 @@
import random
from typing import Any, Dict, List, Optional
from urllib.parse import urlencode
import aiohttp
from loguru import logger
from pydantic import ValidationError
from twilio.request_validator import RequestValidator
from api.constants import (
BACKEND_API_ENDPOINT,
TWILIO_ACCOUNT_SID,
TWILIO_AUTH_TOKEN,
TWILIO_DEFAULT_FROM_NUMBER,
)
from api.db import db_client
class TwilioService:
"""Service for interacting with Twilio API."""
def __init__(self):
if (
not TWILIO_DEFAULT_FROM_NUMBER
or not TWILIO_ACCOUNT_SID
or not TWILIO_AUTH_TOKEN
):
raise ValidationError(
"Please set TWILIO_DEFAULT_FROM_NUMBER, TWILIO_ACCOUNT_SID, and TWILIO_AUTH_TOKEN environment"
"variables to use TwilioService"
)
self.account_sid = TWILIO_ACCOUNT_SID
self.auth_token = TWILIO_AUTH_TOKEN
self.default_from_number = TWILIO_DEFAULT_FROM_NUMBER
self.base_url = f"https://api.twilio.com/2010-04-01/Accounts/{self.account_sid}"
async def get_organization_phone_numbers(self, organization_id: int) -> List[str]:
"""
Get the list of Twilio phone numbers configured for an organization.
Args:
organization_id: The organization ID
Returns:
List of phone numbers, or default if none configured
"""
try:
from api.enums import OrganizationConfigurationKey
config = await db_client.get_configuration(
organization_id,
OrganizationConfigurationKey.TWILIO_PHONE_NUMBERS.value,
)
if config and config.value:
# Expect the value to be a list of phone numbers
phone_numbers = config.value.get("value", [])
if isinstance(phone_numbers, list) and phone_numbers:
return phone_numbers
except Exception as e:
logger.warning(
f"Error getting phone numbers for org {organization_id}: {e}"
)
# Fall back to default from environment
return [self.default_from_number]
async def initiate_call(
self,
to_number: str,
url_args: Dict[str, Any] = {},
workflow_run_id: Optional[int] = None,
organization_id: Optional[int] = None,
**kwargs: Any,
) -> Dict[str, Any]:
"""
Initiates a Twilio call using the Calls API.
Args:
to_number: The destination phone number
url_args: Dictionary of URL parameters to append to the base URL
workflow_run_id: The workflow run ID for tracking callbacks
organization_id: The organization ID for selecting phone numbers
**kwargs: Additional parameters to pass to the Twilio API
Returns:
Dict containing the Twilio API response
"""
endpoint = f"{self.base_url}/Calls.json"
if not BACKEND_API_ENDPOINT:
raise ValidationError(
"Please set BACKEND_API_ENDPOINT environment variable to a tunnel or persistant URL"
)
# Construct the URL with parameters if any
url: str = f"https://{BACKEND_API_ENDPOINT}/api/v1/twilio/twiml"
if url_args:
query_string = urlencode(url_args)
url = f"{url}?{query_string}"
logger.debug(f"Initiating call with URL: {url}")
# Get phone numbers for organization and select one randomly
if organization_id:
phone_numbers = await self.get_organization_phone_numbers(organization_id)
from_number = random.choice(phone_numbers)
logger.info(
f"Selected phone number {from_number} from {len(phone_numbers)} "
f"available numbers for org {organization_id}"
)
else:
from_number = self.default_from_number
# Prepare call data
data = {"To": to_number, "From": from_number, "Url": url}
if not BACKEND_API_ENDPOINT:
raise ValidationError(
"Please set BACKEND_API_ENDPOINT environment variable to a tunnel or persistant URL"
)
# Add status callback configuration if workflow_run_id is provided
if workflow_run_id:
callback_url = f"https://{BACKEND_API_ENDPOINT}/api/v1/twilio/status-callback/{workflow_run_id}"
data.update(
{
"StatusCallback": callback_url,
"StatusCallbackEvent": [
"initiated",
"ringing",
"answered",
"completed",
],
"StatusCallbackMethod": "POST",
}
)
# Add any additional kwargs
data.update(kwargs)
# Make the API request
async with aiohttp.ClientSession() as session:
auth = aiohttp.BasicAuth(self.account_sid, self.auth_token)
async with session.post(endpoint, data=data, auth=auth) as response:
if response.status != 201:
error_data = await response.json()
raise Exception(f"Failed to initiate call: {error_data}")
return await response.json()
async def get_start_call_twiml(
self, workflow_id: int, user_id: int, workflow_run_id: int
) -> str:
if not BACKEND_API_ENDPOINT:
raise ValidationError(
"Please set BACKEND_API_ENDPOINT environment variable to a tunnel or persistant URL"
)
twiml_content = f"""<?xml version="1.0" encoding="UTF-8"?>
<Response>
<Connect>
<Stream url="wss://{BACKEND_API_ENDPOINT}/api/v1/twilio/ws/{workflow_id}/{user_id}/{workflow_run_id}"></Stream>
</Connect>
<Pause length="40"/>
</Response>"""
return twiml_content
async def get_call(self, call_sid: str) -> Dict[str, Any]:
"""
Retrieves information about a specific call.
Args:
call_sid: The SID of the call to retrieve
Returns:
Dict containing the call information
"""
endpoint = f"{self.base_url}/Calls/{call_sid}.json"
async with aiohttp.ClientSession() as session:
auth = aiohttp.BasicAuth(self.account_sid, self.auth_token)
async with session.get(endpoint, auth=auth) as response:
if response.status != 200:
error_data = await response.json()
raise Exception(f"Failed to get call: {error_data}")
return await response.json()
def verify_signature(
self, url: str, params: Dict[str, Any], signature: str
) -> bool:
"""
Verify Twilio request signature using official Twilio SDK.
Args:
url: The full URL of the webhook
params: The POST parameters (form data) as a dictionary
signature: The X-Twilio-Signature header value
Returns:
bool: True if signature is valid, False otherwise
"""
validator = RequestValidator(self.auth_token)
return validator.validate(url, params, signature)

View file

@ -0,0 +1,371 @@
"""Worker Event Subscriber for distributed ARI architecture.
This component runs in each FastAPI worker process and subscribes to
Redis events from the ARI Manager. It creates pipelines for assigned calls
without any direct ARI connection.
"""
import asyncio
import json
import uuid
from typing import Awaitable, Callable, Optional
import redis.asyncio as aioredis
from loguru import logger
from pipecat.utils.context import set_current_run_id
from api.routes.stasis_rtp import on_stasis_call
from api.services.telephony.stasis_event_protocol import (
DisconnectCommand,
RedisChannels,
RedisKeys,
StasisEndEvent,
StasisStartEvent,
parse_event,
)
from api.services.telephony.stasis_rtp_connection import StasisRTPConnection
class WorkerEventSubscriber:
"""Subscribes to ARI events from Redis and processes them in the worker."""
def __init__(
self,
redis_client: aioredis.Redis,
on_stasis_call: Callable[[StasisRTPConnection, dict], Awaitable[None]],
):
self.redis = redis_client
self.worker_id = str(uuid.uuid4()) # Generate unique worker ID
self.on_stasis_call = on_stasis_call
self._running = False
self._task: Optional[asyncio.Task] = None
self._heartbeat_task: Optional[asyncio.Task] = None
self._active_connections: dict[str, StasisRTPConnection] = {}
self._active_tasks: dict[str, asyncio.Task] = {}
self._cleanup_tasks: dict[str, asyncio.Task] = {}
self._shutting_down = False
self._shutdown_event = asyncio.Event()
async def start(self):
"""Start the event subscriber."""
if self._task is None:
self._running = True
# Register worker in Redis
await self._register_worker()
# Start main event loop
self._task = asyncio.create_task(
self._run(), name=f"worker_subscriber_{self.worker_id}"
)
# Start heartbeat task
self._heartbeat_task = asyncio.create_task(
self._heartbeat_loop(), name=f"worker_heartbeat_{self.worker_id}"
)
logger.info(f"Worker {self.worker_id} event subscriber started")
async def _register_worker(self):
"""Register this worker in Redis."""
worker_key = RedisKeys.worker_active(self.worker_id)
worker_data = json.dumps({"status": "ready", "active_calls": 0})
# Set with TTL of 30 seconds (will be refreshed by heartbeat)
await self.redis.setex(worker_key, 30, worker_data)
# Add to workers set
await self.redis.sadd(RedisKeys.workers_set(), self.worker_id)
logger.info(f"Worker {self.worker_id} registered in Redis")
async def _heartbeat_loop(self):
"""Send periodic heartbeats to Redis."""
try:
while self._running:
# Update worker status with current active call count
worker_key = RedisKeys.worker_active(self.worker_id)
worker_data = json.dumps(
{
"status": "draining" if self._shutting_down else "ready",
"active_calls": len(self._active_tasks),
}
)
# Refresh TTL to 30 seconds
await self.redis.setex(worker_key, 30, worker_data)
# Wait 10 seconds before next heartbeat
await asyncio.sleep(10)
except asyncio.CancelledError:
logger.debug(f"Worker {self.worker_id} heartbeat cancelled")
except Exception as e:
logger.exception(f"Worker {self.worker_id} heartbeat error: {e}")
async def graceful_shutdown(self, max_wait_seconds: int = 300):
"""Gracefully shutdown the worker, waiting for calls to complete.
Args:
max_wait_seconds: Maximum time to wait for calls to complete (default 5 minutes)
"""
logger.info(f"Worker {self.worker_id} starting graceful shutdown")
# Mark as shutting down to prevent new calls
self._shutting_down = True
# Update status in Redis to 'draining'
worker_key = RedisKeys.worker_active(self.worker_id)
worker_data = json.dumps(
{"status": "draining", "active_calls": len(self._active_tasks)}
)
await self.redis.setex(worker_key, 30, worker_data)
# Wait for active tasks to complete (with timeout)
start_time = asyncio.get_event_loop().time()
while (
self._active_tasks
and (asyncio.get_event_loop().time() - start_time) < max_wait_seconds
):
active_count = len(self._active_tasks)
logger.info(
f"Worker {self.worker_id} waiting for {active_count} active calls to complete"
)
# Update Redis with current status
worker_data = json.dumps(
{"status": "draining", "active_calls": active_count}
)
await self.redis.setex(worker_key, 30, worker_data)
# Wait a bit before checking again
await asyncio.sleep(5)
# Force stop if timeout reached
if self._active_tasks:
logger.warning(
f"Worker {self.worker_id} forcefully stopping {len(self._active_tasks)} active calls after timeout channel_ids: {list(self._active_tasks.keys())}"
)
await self.stop()
async def stop(self):
"""Stop the event subscriber and deregister from Redis."""
self._running = False
# Deregister from Redis
await self._deregister_worker()
# Cancel all active call processing tasks
for channel_id, task in list(self._active_tasks.items()):
if not task.done():
logger.info(f"Cancelling active call task for channel {channel_id}")
task.cancel()
# Cancel all cleanup tasks
for channel_id, task in list(self._cleanup_tasks.items()):
if not task.done():
logger.info(f"Cancelling cleanup task for channel {channel_id}")
task.cancel()
# Wait for all tasks to complete
all_tasks = list(self._active_tasks.values()) + list(
self._cleanup_tasks.values()
)
if all_tasks:
await asyncio.gather(*all_tasks, return_exceptions=True)
# Cancel heartbeat task
if self._heartbeat_task:
self._heartbeat_task.cancel()
try:
await self._heartbeat_task
except asyncio.CancelledError:
pass
if self._task:
self._task.cancel()
try:
await self._task
except asyncio.CancelledError:
pass
logger.info(f"Worker {self.worker_id} event subscriber stopped")
async def _deregister_worker(self):
"""Remove this worker from Redis."""
try:
# Remove from active workers
await self.redis.delete(RedisKeys.worker_active(self.worker_id))
# Remove from workers set
await self.redis.srem(RedisKeys.workers_set(), self.worker_id)
logger.info(f"Worker {self.worker_id} deregistered from Redis")
except Exception as e:
logger.error(f"Error deregistering worker {self.worker_id}: {e}")
async def _run(self):
"""Main subscriber loop."""
self._running = True
channel = RedisChannels.worker_events(self.worker_id)
pubsub = self.redis.pubsub()
try:
await pubsub.subscribe(channel)
logger.info(f"Worker {self.worker_id} subscribed to {channel}")
async for message in pubsub.listen():
if not self._running:
break
if message["type"] == "message":
try:
await self._handle_event(message["data"])
except Exception as e:
logger.exception(f"Error handling event: {e}")
except asyncio.CancelledError:
logger.debug(f"Worker {self.worker_id} subscriber cancelled")
except Exception as e:
logger.exception(f"Worker {self.worker_id} subscriber error: {e}")
finally:
await pubsub.unsubscribe(channel)
await pubsub.aclose()
async def _handle_event(self, data: str):
"""Handle an event from the ARI Manager."""
event = parse_event(data)
if not event:
logger.warning(f"Failed to parse event: {data}")
return
if isinstance(event, StasisStartEvent):
await self._handle_stasis_start(event)
elif isinstance(event, StasisEndEvent):
await self._handle_stasis_end(event)
else:
logger.warning(
f"channelID: {event.channel_id} Unhandled event type: {type(event)}"
)
async def _handle_stasis_start(self, event: StasisStartEvent):
"""Handle a new call assignment."""
channel_id = event.channel_id
logger.info(
f"channelID: {channel_id} Worker {self.worker_id} handling StasisStart"
)
try:
# Create StasisRTPConnection without ARI client
connection = StasisRTPConnection(
redis_client=self.redis,
channel_id=channel_id,
caller_channel_id=event.caller_channel_id,
em_channel_id=event.em_channel_id,
bridge_id=event.bridge_id,
local_addr=tuple(event.local_addr) if event.local_addr else None,
remote_addr=tuple(event.remote_addr) if event.remote_addr else None,
)
# Store connection for cleanup
self._active_connections[channel_id] = connection
# Create a background task to handle the call
task = asyncio.create_task(
self._process_call(connection, event.call_context_vars, channel_id),
name=f"call_handler_{channel_id}",
)
self._active_tasks[channel_id] = task
except Exception as e:
logger.exception(f"Error handling StasisStart for {channel_id}: {e}")
# Send disconnect command if setup fails
await self._send_disconnect(channel_id, "setup_failed")
async def _process_call(
self, connection: StasisRTPConnection, call_context_vars: dict, channel_id: str
):
"""Process a call in the background."""
try:
await self.on_stasis_call(connection, call_context_vars)
except Exception as e:
logger.exception(f"Error processing call for {channel_id}: {e}")
# Send disconnect command if call processing fails
await self._send_disconnect(channel_id, "processing_failed")
finally:
# Clean up task reference
if channel_id in self._active_tasks:
del self._active_tasks[channel_id]
async def _process_cleanup(self, channel_id: str, reason: str):
"""Process call cleanup in the background."""
try:
if channel_id in self._active_connections:
connection: StasisRTPConnection = self._active_connections[channel_id]
# We must wait for the connection's invocation
# before sending in remote disconnect. Otherwise,
# the event handlers won't be registered and we won't
# be able to call on_client_disconnected to cancel the
# pipeline
while not connection._connect_invoked:
await asyncio.sleep(0.1)
# Set the run_id context so that we can have it in logs
if connection.workflow_run_id:
set_current_run_id(connection.workflow_run_id)
await connection.handle_remote_disconnect(reason)
del self._active_connections[channel_id]
except Exception as e:
logger.exception(f"Error during cleanup for {channel_id}: {e}")
finally:
# Clean up task reference from cleanup tasks dictionary
if channel_id in self._cleanup_tasks:
del self._cleanup_tasks[channel_id]
async def _handle_stasis_end(self, event: StasisEndEvent):
"""Handle call termination."""
channel_id = event.channel_id
logger.info(
f"channelID: {channel_id} Worker {self.worker_id} handling StasisEnd, Reason: {event.reason}"
)
# Create a background task to handle the cleanup
if channel_id in self._active_connections:
# Check if there's already a cleanup task for this channel
if (
channel_id not in self._cleanup_tasks
or self._cleanup_tasks[channel_id].done()
):
# Lets start a new task, since we need to poll for
# connection to be invoked from the pipeline before
# caling remote disconnect
task = asyncio.create_task(
self._process_cleanup(channel_id, event.reason),
name=f"cleanup_handler_{channel_id}",
)
self._cleanup_tasks[channel_id] = task
else:
logger.warning(
f"channelID: {channel_id} Cleanup skipped - cleanup task still running"
)
async def _send_disconnect(self, channel_id: str, reason: str):
"""Send disconnect command to ARI Manager."""
command = DisconnectCommand(channel_id=channel_id, reason=reason)
channel = RedisChannels.channel_commands(channel_id)
await self.redis.publish(channel, command.to_json())
async def setup_worker_subscriber(
redis_client: aioredis.Redis,
) -> WorkerEventSubscriber:
"""Setup the worker event subscriber with dynamic registration."""
subscriber = WorkerEventSubscriber(redis_client, on_stasis_call)
logger.info(f"Setting up worker event subscriber with ID {subscriber.worker_id}")
await subscriber.start()
return subscriber

View file

View file

@ -0,0 +1,77 @@
"""Utility module for applying disposition code mapping."""
from typing import Optional
from loguru import logger
from api.db import db_client
from api.enums import OrganizationConfigurationKey
async def apply_disposition_mapping(value: str, organization_id: Optional[int]) -> str:
"""Apply disposition code mapping if configured.
Args:
value: The original disposition value to map
organization_id: The organization ID
Returns:
The mapped value if found in configuration, otherwise the original value
"""
if not organization_id or not value:
return value
try:
disposition_mapping = await db_client.get_configuration_value(
organization_id,
OrganizationConfigurationKey.DISPOSITION_CODE_MAPPING.value,
default={},
)
if not disposition_mapping:
return value
# Return mapped value if exists, otherwise original
# DISPOSITION_CODE_MAPPING looks like {"user_idle_max_duration_exceeded": "DAIR"} etc.
mapped_value = disposition_mapping.get(value, value)
if mapped_value != value:
logger.debug(
f"Mapped disposition code from '{value}' to '{mapped_value}' "
f"for organization {organization_id}"
)
return mapped_value
except Exception as e:
logger.error(f"Error applying disposition mapping: {e}")
return value
async def get_organization_id_from_workflow_run(
workflow_run_id: Optional[int],
) -> Optional[int]:
"""Get organization_id from workflow_run_id through the model relationships.
Args:
workflow_run_id: The workflow run ID
Returns:
The organization ID if found, otherwise None
"""
if not workflow_run_id:
return None
try:
workflow_run = await db_client.get_workflow_run_by_id(workflow_run_id)
if not workflow_run or not workflow_run.workflow:
return None
workflow = workflow_run.workflow
if not workflow.user:
return None
return workflow.user.selected_organization_id
except Exception as e:
logger.error(f"Error getting organization_id from workflow_run: {e}")
return None

View file

@ -0,0 +1,96 @@
from enum import Enum
from typing import List, Optional
from pydantic import BaseModel, Field, ValidationError, model_validator
class NodeType(str, Enum):
startNode = "startCall"
endNode = "endCall"
agentNode = "agentNode"
globalNode = "globalNode"
class Position(BaseModel):
x: float
y: float
class VariableType(str, Enum):
string = "string"
number = "number"
boolean = "boolean"
class ExtractionVariableDTO(BaseModel):
name: str = Field(..., min_length=1)
type: VariableType
prompt: Optional[str] = None
class NodeDataDTO(BaseModel):
name: str = Field(..., min_length=1)
prompt: str = Field(..., min_length=1)
is_static: bool = False
is_start: bool = False
is_end: bool = False
allow_interrupt: bool = False
extraction_enabled: bool = False
extraction_prompt: Optional[str] = None
extraction_variables: Optional[list[ExtractionVariableDTO]] = None
add_global_prompt: bool = True
wait_for_user_response: bool = False
wait_for_user_response_timeout: Optional[float] = None
detect_voicemail: bool = True
delayed_start: bool = False
delayed_start_duration: Optional[float] = None
class RFNodeDTO(BaseModel):
id: str
type: NodeType = Field(default=NodeType.agentNode)
position: Position
data: NodeDataDTO
class EdgeDataDTO(BaseModel):
label: str = Field(..., min_length=1)
condition: str = Field(..., min_length=1)
class RFEdgeDTO(BaseModel):
id: str
source: str
target: str
data: EdgeDataDTO
class ReactFlowDTO(BaseModel):
nodes: List[RFNodeDTO]
edges: List[RFEdgeDTO]
@model_validator(mode="after")
def _referential_integrity(self):
node_ids = {n.id for n in self.nodes}
line_errors: list[dict[str, str]] = []
for idx, edge in enumerate(self.edges):
for endpoint in (edge.source, edge.target):
if endpoint not in node_ids:
line_errors.append(
dict(
loc=("edges", idx),
type="missing_node",
msg="Edge references missing node",
input=edge.model_dump(mode="python"),
ctx={"edge_id": edge.id, "endpoint": endpoint},
)
)
if line_errors:
raise ValidationError.from_exception_data(
title="ReactFlowDTO validation failed",
line_errors=line_errors,
)
return self

View file

@ -0,0 +1,16 @@
# api/services/workflow/errors.py
from enum import Enum
from typing import TypedDict
class ItemKind(str, Enum):
node = "node"
edge = "edge"
workflow = "workflow"
class WorkflowError(TypedDict):
kind: ItemKind # "node" | "edge"
id: str | None # nodeId or edgeId
field: str | None # “data.prompt”, “position.x”, … (optional)
message: str # human-readable text

View file

@ -0,0 +1,939 @@
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Optional, Union
from pipecat.frames.frames import (
CancelFrame,
EndFrame,
FunctionCallResultProperties,
LLMFullResponseEndFrame,
LLMFullResponseStartFrame,
TTSSpeakFrame,
)
from pipecat.pipeline.task import PipelineTask
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContextFrame
from pipecat.services.llm_service import FunctionCallParams
from pipecat.services.openai.llm import OpenAILLMContext
from pipecat.transports.base_transport import BaseTransport
from pipecat.utils.enums import EndTaskReason
from api.constants import VOICEMAIL_RECORDING_DURATION
from api.services.gender.gender_service import GenderService
from api.services.workflow.disposition_mapper import (
apply_disposition_mapping,
get_organization_id_from_workflow_run,
)
from api.services.workflow.pipecat_engine_voicemail_detector import (
VoicemailDetector,
)
from api.services.workflow.workflow import Node, WorkflowGraph
if TYPE_CHECKING:
from pipecat.processors.audio.audio_buffer_processor import AudioBuffer
from pipecat.services.anthropic.llm import AnthropicLLMService
from pipecat.services.google.llm import GoogleLLMService
from pipecat.services.openai.llm import OpenAILLMService
from api.services.telephony.stasis_rtp_connection import StasisRTPConnection
LLMService = Union[OpenAILLMService, AnthropicLLMService, GoogleLLMService]
import asyncio
from loguru import logger
from pipecat.processors.filters.stt_mute_filter import STTMuteFilter
from pipecat.utils.tracing.context_registry import get_current_turn_context
from api.services.workflow import pipecat_engine_callbacks as engine_callbacks
from api.services.workflow.pipecat_engine_utils import (
get_function_schema,
render_template,
update_llm_context,
)
from api.services.workflow.pipecat_engine_variable_extractor import (
VariableExtractionManager,
)
from api.services.workflow.tools.calculator import get_calculator_tools, safe_calculator
from api.services.workflow.tools.timezone import (
convert_time,
get_current_time,
get_time_tools,
)
class PipecatEngine:
def __init__(
self,
*,
task: Optional[PipelineTask] = None,
llm: Optional["LLMService"] = None,
context: Optional[OpenAILLMContext] = None,
tts: Optional[Any] = None,
transport: Optional[BaseTransport] = None,
workflow: WorkflowGraph,
call_context_vars: dict,
audio_buffer: Optional["AudioBuffer"] = None,
workflow_run_id: Optional[int] = None,
):
self.task = task
self.llm = llm
self.context = context
self.tts = tts
self.transport = transport
self.workflow = workflow
self._call_context_vars = call_context_vars
self._audio_buffer = audio_buffer
self._workflow_run_id = workflow_run_id
self._initialized = False
self._pending_function_calls = 0
self._current_node: Optional[Node] = None
self._gathered_context: dict = {}
self._user_response_timeout_task: Optional[asyncio.Task] = None
self._call_disposition: Optional[str] = None
# Stasis connection for immediate transfers
self._stasis_connection: Optional["StasisRTPConnection"] = None
# Will be set later in initialize() when we have
# access to _context
self._variable_extraction_manager = None
self._gender_service = GenderService(confidence_threshold=0.5)
# Voicemail detection state
self._detect_voicemail = False
self._voicemail_detector = None
self._voicemail_detection_task: Optional[asyncio.Task] = None
# This transition is generated by the llm as part of tool call. This can
# also be accompanied with some content which can be played using TTS. If the
# bot is interrupted, we would cancel this transition (we do cancel this currently when
# the next generation starts in handle_generation_started callback handler.)
self._pending_generated_transition_after_context_push: Optional[
Callable[[], Awaitable[None]]
] = None
# This is the transtion which is typically programmatic transition, and not goes as
# tool call to LLM. This is not interrupted by the user and is done on context push
self._pending_control_transition_after_context_push: Optional[
Callable[[], Awaitable[None]]
] = None
# Flag to determine if the current llm generation has a text completion
self._defer_context_push: bool = False
# Lazy loaded built-in function schemas
self._builtin_function_schemas: Optional[list[dict]] = None
# Flag to control whether to queue context frame
self._queue_context_frame: bool = True
# Track current LLM reference text for TTS aggregation correction
self._current_llm_reference_text: str = ""
@property
def builtin_function_schemas(self) -> list[dict]:
"""Get built-in function schemas (calculator and timezone tools)."""
if self._builtin_function_schemas is None:
self._builtin_function_schemas = []
# Transform calculator tools to get_function_schema format
for tool in get_calculator_tools():
func = tool["function"]
schema = get_function_schema(
func["name"],
func["description"],
properties=func["parameters"]["properties"],
required=func["parameters"]["required"],
)
self._builtin_function_schemas.append(schema)
# Transform timezone tools to get_function_schema format
for tool in get_time_tools():
func = tool["function"]
schema = get_function_schema(
func["name"],
func["description"],
properties=func["parameters"]["properties"],
required=func["parameters"]["required"],
)
self._builtin_function_schemas.append(schema)
return self._builtin_function_schemas
async def initialize(self):
# TODO: May be set_node in a separate task so that we return from initialize immediately
if self._initialized:
logger.warning(f"{self.__class__.__name__} already initialized")
return
try:
self._initialized = True
# Helper that encapsulates variable extraction logic
self._variable_extraction_manager = VariableExtractionManager(self)
# Add current time in EST (America/New_York) to gathered context
try:
est_time_result = get_current_time("America/New_York")
# The get_current_time utility returns a dict with 'datetime' field
# Store the ISO formatted datetime string under the key 'time'
self._gathered_context["time"] = est_time_result.get("datetime")
except Exception as e:
logger.error(f"Failed to fetch current EST time: {e}")
# Register built-in functions with the LLM
await self._register_builtin_functions()
# Set gender in initial context predicted from first name
if "first_name" in self._call_context_vars:
salutation = await self._gender_service.get_salutation(
self._call_context_vars["first_name"]
)
self._call_context_vars["salutation"] = salutation
await self.set_node(self.workflow.start_node_id)
logger.debug(f"{self.__class__.__name__} initialized")
except Exception as e:
logger.error(f"Error initializing {self.__class__.__name__}: {e}")
raise
def _get_function_schema(self, function_name: str, description: str):
"""Thin wrapper around utils.get_function_schema for backwards compatibility."""
return get_function_schema(function_name, description)
async def _update_llm_context(self, system_message: dict, functions: list[dict]):
"""Delegate context update to the shared workflow.utils implementation."""
update_llm_context(self.context, system_message, functions)
def _format_prompt(self, prompt: str) -> str:
"""Delegate prompt formatting to the shared workflow.utils implementation."""
return render_template(prompt, self._call_context_vars)
async def _create_transition_func(self, name: str, transition_to_node: str):
async def transition_func(function_call_params: FunctionCallParams) -> None:
"""Inner function that handles the actual tool invocation."""
try:
# Track pending function call
self._pending_function_calls += 1
logger.debug(
f"Function call pending: {function_call_params.function_name} (total: {self._pending_function_calls})"
)
# For edge functions, prevent LLM completion until transition (run_llm=False)
# For node functions, allow immediate completion (run_llm=True)
async def on_context_updated() -> None:
"""
Framework will run this function after the function call result has been updated in the context.
This way, when we do set_node from within this function, and go for LLM completion with updated
system prompts, the context is updated with function call result.
"""
self._pending_function_calls -= 1
# Perform variable extraction before transitioning to new node
await self._perform_variable_extraction_if_needed(
self._current_node
)
await self.set_node(transition_to_node)
result = {"status": "done"}
properties = FunctionCallResultProperties(
run_llm=False,
on_context_updated=on_context_updated,
)
async def _invoke_result_callback():
"""
Functions are executed immediately when they come from LLM as part of text completion.
But, if the LLM completion also has some text, we would want to not call the function if the user interrupts the speech.
We would also not want the function to be added to context, so that the LLM can call the function again. Hence, we
defer the function invocation until we receive on_context_updated callback, i.e the bot has finished speaking
the text that was generated.
"""
await function_call_params.result_callback(
result, properties=properties
)
if self._defer_context_push:
"""
We set the flag to _defer_context_push when we receive text in the current generation from LLM.
This is set in the handle_llm_generated_text callback handler.
"""
logger.debug(
"Deferring transition function result until context push"
)
# Only one deferred transition should exist at any time.
# Overwrite if one is somehow already set (unexpected).
self._pending_generated_transition_after_context_push = (
_invoke_result_callback
)
else:
"""
If there was no text in the current generation, and we only had function call,
lets invoke the result callback, so that framework can call on_context_updated and
we can do switch node.
"""
await _invoke_result_callback()
except Exception as e:
logger.error(f"Error in transition function {name}: {str(e)}")
self._pending_function_calls = 0
error_result = {"status": "error", "error": str(e)}
await function_call_params.result_callback(error_result)
return transition_func
async def _register_transition_function_with_llm(
self, name: str, transition_to_node: str
):
logger.debug(
f"Registering function {name} to transition to node {transition_to_node} with LLM"
)
# Create transition function
transition_func = await self._create_transition_func(name, transition_to_node)
# Register function with LLM
self.llm.register_function(
name,
transition_func,
cancel_on_interruption=True,
)
async def _register_builtin_functions(self):
"""Register built-in functions (calculator and timezone) with the LLM."""
logger.debug("Registering built-in functions with LLM")
properties = FunctionCallResultProperties(run_llm=True)
# Register calculator function
async def calculate_func(function_call_params: FunctionCallParams) -> None:
try:
expr = function_call_params.arguments.get("expression", "")
result = safe_calculator(expr)
await function_call_params.result_callback(
{"expression": expr, "result": result}, properties=properties
)
except Exception as e:
await function_call_params.result_callback(
{"error": str(e)}, properties=properties
)
# Register timezone functions
async def get_current_time_func(
function_call_params: FunctionCallParams,
) -> None:
try:
timezone = function_call_params.arguments.get("timezone", "UTC")
result = get_current_time(timezone)
await function_call_params.result_callback(
result, properties=properties
)
except Exception as e:
await function_call_params.result_callback(
{"error": str(e)}, properties=properties
)
async def convert_time_func(function_call_params: FunctionCallParams) -> None:
try:
result = convert_time(
function_call_params.arguments.get("source_timezone"),
function_call_params.arguments.get("time"),
function_call_params.arguments.get("target_timezone"),
)
await function_call_params.result_callback(
result, properties=properties
)
except Exception as e:
await function_call_params.result_callback(
{"error": str(e)}, properties=properties
)
# Register all built-in functions
self.llm.register_function("safe_calculator", calculate_func)
self.llm.register_function("get_current_time", get_current_time_func)
self.llm.register_function("convert_time", convert_time_func)
async def _queue_tts_response(self, text: str) -> None:
"""Queue TTS frames for static text response."""
await self.task.queue_frames(
[
LLMFullResponseStartFrame(),
TTSSpeakFrame(text=text),
LLMFullResponseEndFrame(),
]
)
async def _setup_static_start_node_transition(self, node: Node) -> None:
"""Set up the deferred transition for static start nodes."""
if not node.out_edges:
return
next_node_id = node.out_edges[0].target
if not node.wait_for_user_response:
# Normal static start node - transition immediately after context push
async def _deferred_static_transition():
try:
await self.set_node(next_node_id)
except Exception as exc:
logger.error(
f"Error executing deferred static node transition to {next_node_id}: {exc}"
)
self._pending_control_transition_after_context_push = (
_deferred_static_transition
)
async def _perform_variable_extraction_if_needed(
self, previous_node: Optional[Node]
) -> None:
"""Perform variable extraction if the previous node had extraction enabled."""
if (
previous_node
and previous_node.extraction_enabled
and previous_node.extraction_variables
):
logger.debug(
f"Scheduling background variable extraction for node: {previous_node.name}"
)
# Capture the current turn context before creating the background task
parent_context = get_current_turn_context()
extraction_prompt = self._format_prompt(previous_node.extraction_prompt)
extraction_variables = previous_node.extraction_variables
async def _background_extraction():
try:
extracted_data = (
await self._variable_extraction_manager._perform_extraction(
extraction_variables, parent_context, extraction_prompt
)
)
self._gathered_context.update(extracted_data)
logger.debug(
f"Background variable extraction completed. Extracted: {extracted_data}"
)
except Exception as e:
logger.error(
f"Error during background variable extraction: {str(e)}"
)
# Fire and forget - extraction happens in background without blocking
asyncio.create_task(_background_extraction())
async def _setup_llm_context_and_start_generation(self, node: Node) -> None:
"""Common method to set up LLM context and queue context frame for non-static nodes."""
# Set node name for tracing
try:
self.context.set_node_name(node.name)
except AttributeError:
logger.warning(f"context has no set_node_name method")
# Register transition functions if not an end node
if not node.is_end:
for outgoing_edge in node.out_edges:
await self._register_transition_function_with_llm(
outgoing_edge.get_function_name(), outgoing_edge.target
)
# Set up system message and functions
(
system_message,
functions,
) = await self._compose_system_message_functions_for_node(node)
await self._update_llm_context(system_message, functions)
# Queue context frame if needed
if self._queue_context_frame:
await self.task.queue_frame(OpenAILLMContextFrame(self.context))
else:
logger.debug(
f"Not queueing context frame for node: {node.name} as _queue_context_frame is False"
)
# Reset _queue_context_frame as default behavior
self._queue_context_frame = True
async def set_node(self, node_id: str):
"""
Simplified set_node implementation according to v2 PRD.
"""
node = self.workflow.nodes[node_id]
logger.debug(
f"Executing node: name: {node.name} is_static: {node.is_static} allow_interrupt: {node.allow_interrupt} is_end: {node.is_end}"
)
# Set current node for all nodes (including static ones) so STT mute filter works
self._current_node = node
# Handle start nodes
if node.is_start:
await self._handle_start_node(node)
# Handle end nodes
elif node.is_end:
await self._handle_end_node(node)
# Handle normal agent nodes
else:
await self._handle_agent_node(node)
async def _handle_start_node(self, node: Node) -> None:
"""Handle start node execution."""
# Handle voicemail detection setup (before any returns)
if node.detect_voicemail:
if not self._audio_buffer:
logger.warning(
"Voicemail detection enabled but no audio buffer available - skipping detection"
)
else:
logger.debug(
"Start node has detect_voicemail enabled - setting up audio-based detector"
)
self._detect_voicemail = True
self._voicemail_detector = VoicemailDetector(
detection_duration=VOICEMAIL_RECORDING_DURATION,
workflow_run_id=self._workflow_run_id,
)
# Register audio handler on the audio buffer input processor
audio_input = self._audio_buffer.input()
@audio_input.event_handler("on_input_audio_data")
async def handle_voicemail_audio(
processor, pcm, sample_rate, num_channels
):
if (
self._voicemail_detector
and self._voicemail_detector.is_detecting
):
await self._voicemail_detector.handle_audio_data(
processor, pcm, sample_rate, num_channels
)
# Start detection
await self._voicemail_detector.start_detection(self)
# Check if delayed start is enabled
if node.delayed_start:
# Use configured duration or default to 3 seconds
delay_duration = node.delayed_start_duration or 2.0
logger.debug(
f"Delayed start enabled - waiting {delay_duration} seconds before speaking"
)
await asyncio.sleep(delay_duration)
if node.is_static:
# Queue TTS for static start node
formatted_prompt = self._format_prompt(node.prompt)
await self._queue_tts_response(formatted_prompt)
# Set up deferred transition for static start nodes
await self._setup_static_start_node_transition(node)
else:
# Start generation for non-static start node
await self._setup_llm_context_and_start_generation(node)
async def _handle_end_node(self, node: Node) -> None:
"""Handle end node execution."""
if node.is_static:
# Queue TTS for static end node
formatted_prompt = self._format_prompt(node.prompt)
await self._queue_tts_response(formatted_prompt)
else:
# Start generation for non-static end node
await self._setup_llm_context_and_start_generation(node)
# If this end node has extraction enabled, perform extraction immediately
if node.extraction_enabled and node.extraction_variables:
await self._perform_variable_extraction_if_needed(node)
# TODO: Extract disposition code from extracted variables
# Defer send_end_task_frame using _pending_control_transition_after_context_push
# Decide the end-task reason dynamically depending on call_disposition.
async def _deferred_end_task():
# call_disposition is the disposition which is generated from
# llm call based on the conversation so far.
# TODO: Make this more generic based on configuration or llm prompting
disposition = self._gathered_context.get("call_disposition")
if disposition == "XFER":
reason = EndTaskReason.USER_QUALIFIED.value
else:
reason = EndTaskReason.USER_DISQUALIFIED.value
await self.send_end_task_frame(reason)
self._pending_control_transition_after_context_push = _deferred_end_task
async def _handle_agent_node(self, node: Node) -> None:
"""Handle agent node execution."""
if node.is_static:
# Queue TTS for static agent node
formatted_prompt = self._format_prompt(node.prompt)
await self._queue_tts_response(formatted_prompt)
# Set up deferred transition for static agent nodes
await self._setup_agent_node_transition(node)
else:
# Set context and functions for non-static agent node
await self._setup_llm_context_and_start_generation(node)
async def _setup_agent_node_transition(self, node: Node) -> None:
"""Set up the deferred transition for static agent nodes."""
if not node.out_edges:
return
next_node_id = node.out_edges[0].target
async def _deferred_static_transition():
try:
await self.set_node(next_node_id)
except Exception as exc:
logger.error(
f"Error executing deferred static node transition to {next_node_id}: {exc}"
)
self._pending_control_transition_after_context_push = (
_deferred_static_transition
)
async def send_end_task_frame(
self,
reason: str,
additional_metadata: dict = None,
abort_immediately: bool = False,
):
"""
Centralized method to send EndTaskFrame with metadata including
call_transfer_context and call_context_vars
"""
frame_to_push = CancelFrame() if abort_immediately else EndFrame()
# Customer disposition code using their mapping
mapped_disposition = ""
# Apply disposition mapping - first try call_disposition if it is,
# extracted from the call conversation then fall back to reason
call_disposition = self._gathered_context.get("call_disposition", "")
organization_id = await get_organization_id_from_workflow_run(
self._workflow_run_id
)
if call_disposition:
# If call_disposition exists, map it
mapped_disposition = await apply_disposition_mapping(
call_disposition, organization_id
)
# Store the original and mapped values
self._gathered_context["extracted_call_disposition"] = call_disposition
self._gathered_context["call_disposition"] = mapped_disposition
else:
# Otherwise, map the disconnect reason
mapped_disposition = await apply_disposition_mapping(
reason, organization_id
)
# Store the mapped disconnect reason
self._gathered_context["call_disposition"] = mapped_disposition
# TODO: Generalise this, currently tailored to Kapil's use case
self._gathered_context["address"] = ", ".join(
[
self._call_context_vars.get("address1", ""),
self._call_context_vars.get("address2", ""),
self._call_context_vars.get("address3", ""),
self._call_context_vars.get("city", ""),
self._call_context_vars.get("state", ""),
self._call_context_vars.get("province", ""),
self._call_context_vars.get("postal_code", ""),
]
)
self._gathered_context["full_name"] = " ".join(
[
self._call_context_vars.get("first_name", ""),
self._call_context_vars.get("middle_initial", ""),
self._call_context_vars.get("last_name", ""),
]
)
self._gathered_context["agent_name"] = "Alex"
self._gathered_context["customer_phone_number"] = self._call_context_vars.get(
"phone", ""
)
self._gathered_context["timezone"] = self._call_context_vars.get("province", "")
self._gathered_context["vendor_id"] = self._call_context_vars.get(
"vendor_lead_code", ""
)
decision_maker = self._gathered_context.get("primary_cardholder", False)
employment_status = self._gathered_context.get("employment_status", "N/A")
call_transfer_context = {
"first_name": self._call_context_vars.get("first_name", ""),
"full_name": self._gathered_context.get("full_name", ""),
"phone": self._call_context_vars.get("phone", ""),
"lead_id": self._call_context_vars.get("lead_id"),
"disposition": mapped_disposition,
"agent_name": self._gathered_context.get("agent_name", "Alex"),
"decision_maker": str(decision_maker),
"employment": employment_status.title() if employment_status else "N/A",
"debts": self._gathered_context.get("total_debt", "N/A"),
"number_of_credit_cards": self._gathered_context.get(
"number_of_credit_cards", "N/A"
),
"time": self._gathered_context.get("time"),
}
logger.debug(
f"gathered_context: {self._gathered_context} call_transfer_context: {call_transfer_context}"
)
# Initiate immediate transfer for Stasis connections when user is qualified
if (
reason == EndTaskReason.USER_QUALIFIED.value
and self._stasis_connection is not None
and not abort_immediately
):
try:
logger.info(
f"Initiating immediate Stasis transfer for channel {self._stasis_connection.channel_id}"
)
await self._stasis_connection.transfer(call_transfer_context)
logger.info("Immediate transfer initiated successfully")
except Exception as e:
logger.error(f"Failed to initiate immediate transfer: {e}")
# Continue with normal flow even if immediate transfer fails
if reason == EndTaskReason.CALL_DURATION_EXCEEDED.value:
await self.task.queue_frame(
TTSSpeakFrame(
"Sorry! It seems like our time has exceeded. Someone from our team will reach out to you soon. Thank you!"
)
)
metadata = {
# Keep original reason in metadata, which would be used to decide
# whether to disconnect or to transfer the call in the transport
"reason": reason,
"call_transfer_context": call_transfer_context,
}
# Add any additional metadata
if additional_metadata:
metadata.update(additional_metadata)
frame_to_push.metadata = metadata
# Store the original reason for later retrieval in event handler
self._call_disposition = mapped_disposition
logger.debug(
f"Finishing run with reason: {reason}, disposition: {mapped_disposition} queueing frame {frame_to_push}"
)
await self.task.queue_frame(frame_to_push)
async def _compose_system_message_functions_for_node(
self, node: "Node"
) -> tuple[list[dict], list[dict]]:
"""Generate the system messages and function schemas for the given node.
This performs the same formatting logic used when entering a node but
does **not** register the functions with the LLM; callers are
responsible for that.
"""
global_prompt = ""
if self.workflow.global_node_id and node.add_global_prompt:
global_node = self.workflow.nodes[self.workflow.global_node_id]
global_prompt = self._format_prompt(global_node.prompt)
functions: list[dict] = []
# Add built-in function schemas (calculator and timezone tools)
functions.extend(self.builtin_function_schemas)
# Transition functions (schema only; registration handled elsewhere)
for outgoing_edge in node.out_edges:
function_schema = self._get_function_schema(
outgoing_edge.get_function_name(), outgoing_edge.condition
)
functions.append(function_schema)
formatted_node_prompt = self._format_prompt(node.prompt)
system_message = {
"role": "system",
"content": "\n\n".join(
p for p in (global_prompt, formatted_node_prompt) if p
),
}
return system_message, functions
# ------------------------------------------------------------------
# Pending transition handling
# ------------------------------------------------------------------
async def flush_pending_transitions(self, *, source: str = "context_push"):
"""Execute and clear any pending transitions.
Args:
source: Indicates the trigger that caused this flush:
- "context_push": the assistant context aggregator completed a push.
"""
if source != "context_push":
raise ValueError("Invalid flush source expected 'context_push'")
len_pending_functions = 0
if self._pending_generated_transition_after_context_push is not None:
len_pending_functions += 1
if self._pending_control_transition_after_context_push is not None:
len_pending_functions += 1
# Nothing to do
if len_pending_functions == 0:
return
logger.debug(
f"Flushing {len_pending_functions} pending transition(s) after {source.replace('_', ' ')}"
)
# Generated transition
if self._pending_generated_transition_after_context_push is not None:
pending_cb = self._pending_generated_transition_after_context_push
self._pending_generated_transition_after_context_push = None
try:
await pending_cb()
except Exception as exc: # pragma: no cover
logger.error(f"Error executing deferred transition: {exc}")
# Control transition (context push)
if self._pending_control_transition_after_context_push is not None:
logger.debug("Executing control transition after context push")
static_cb = self._pending_control_transition_after_context_push
self._pending_control_transition_after_context_push = None
try:
await static_cb()
except Exception as exc: # pragma: no cover
logger.error(f"Error executing deferred static node transition: {exc}")
def create_should_mute_callback(self) -> Callable[[STTMuteFilter], Awaitable[bool]]:
"""
This callback is called by STTMuteFilter to determine if the STT should be muted.
"""
return engine_callbacks.create_should_mute_callback(self)
def create_user_idle_callback(self):
"""
This callback is called when the user is idle for a certain duration.
We use this to either play the static text or end the call
"""
return engine_callbacks.create_user_idle_callback(self)
def create_max_duration_callback(self):
"""
This callback is called when the call duration exceeds the max duration.
We use this to send the EndTaskFrame.
"""
return engine_callbacks.create_max_duration_callback(self)
def create_llm_generated_text_callback(self):
"""
This callback is called when some text is generated by the LLM.
We use this to defer the result_callback of the node transition functions if
there is set_node called along with some text generated. This way, we will
have the context sent in the next generation from new node.
"""
return engine_callbacks.create_llm_generated_text_callback(self)
def create_generation_started_callback(self):
"""
This callback is called when a new generation starts.
This is used to reset the flags that control the flow of the engine.
"""
return engine_callbacks.create_generation_started_callback(self)
def create_user_stopped_speaking_callback(self):
"""
This callback is called when the user stops speaking.
We use this to handle transitions when wait_for_user_response is enabled.
"""
return engine_callbacks.create_user_stopped_speaking_callback(self)
def create_user_started_speaking_callback(self):
"""
This callback is called when the user starts speaking.
We use this to handle wait_for_user_greeting functionality.
"""
return engine_callbacks.create_user_started_speaking_callback(self)
def create_aggregation_correction_callback(self) -> Callable[[str], str]:
"""Create a callback that corrects corrupted aggregation using reference text."""
return engine_callbacks.create_aggregation_correction_callback(self)
def get_call_disposition(self) -> Optional[str]:
"""Get the disconnect reason set by the engine."""
return self._call_disposition
def get_gathered_context(self) -> dict:
"""Get the gathered context including extracted variables."""
return self._gathered_context.copy()
def set_context(self, context: OpenAILLMContext) -> None:
"""Set the OpenAI LLM context.
This allows setting the context after the engine has been created,
which is useful when the context needs to be created after the engine.
"""
self.context = context
def set_task(self, task: PipelineTask) -> None:
"""Set the pipeline task.
This allows setting the task after the engine has been created,
which is useful when the task needs to be created after the engine.
"""
self.task = task
def set_audio_buffer(self, audio_buffer: "AudioBuffer") -> None:
"""Set the audio buffer.
This allows setting the audio buffer after the engine has been created,
which is useful when the audio buffer needs to be created after the engine.
"""
self._audio_buffer = audio_buffer
def set_stasis_connection(
self, connection: Optional["StasisRTPConnection"]
) -> None:
"""Set the Stasis RTP connection for immediate transfers.
This allows the engine to initiate transfers immediately when XFER
disposition is detected, without waiting for pipeline shutdown.
Args:
connection: The StasisRTPConnection instance, or None for non-Stasis transports
"""
self._stasis_connection = connection
if connection:
logger.debug(
f"Stasis connection set for immediate transfers: {connection.channel_id}"
)
async def handle_llm_text_frame(self, text: str):
"""Accumulate LLM text frames to build reference text."""
self._current_llm_reference_text += text
async def cleanup(self):
"""Clean up engine resources on disconnect."""
# Cancel any pending timeout tasks
if (
self._user_response_timeout_task
and not self._user_response_timeout_task.done()
):
self._user_response_timeout_task.cancel()
# Stop voicemail detection if active
if self._voicemail_detector and hasattr(
self._voicemail_detector, "stop_detection"
):
await self._voicemail_detector.stop_detection()

View file

@ -0,0 +1,305 @@
from __future__ import annotations
"""Callback factory helpers for :pyclass:`~api.services.workflow.pipecat_engine.PipecatEngine`.
Each helper takes a :class:`PipecatEngine` instance and returns an async
callback function suitable for passing to the various pipeline processors.
Separating these helpers into their own module keeps
``pipecat_engine.py`` focused on high-level engine orchestration logic while
encapsulating the callback implementations here for easier maintenance and
unit-testing.
"""
import re
from typing import TYPE_CHECKING, Awaitable, Callable
from loguru import logger
from pipecat.frames.frames import (
LLMFullResponseEndFrame,
LLMFullResponseStartFrame,
TTSSpeakFrame,
)
from pipecat.processors.filters.stt_mute_filter import STTMuteFilter
from pipecat.utils.enums import EndTaskReason
if TYPE_CHECKING:
from pipecat.processors.user_idle_processor import UserIdleProcessor
from api.services.workflow.pipecat_engine import PipecatEngine
# ---------------------------------------------------------------------------
# STT mute handling
# ---------------------------------------------------------------------------
def create_should_mute_callback(
engine: "PipecatEngine",
) -> Callable[[STTMuteFilter], Awaitable[bool]]:
"""Return a callback indicating whether STT should be muted.
STT is muted when *interruptions are **not*** allowed on the current node.
"""
async def callback(_: STTMuteFilter) -> bool: # noqa: D401
if engine._current_node is None:
# Default to not muting if we have no active node yet.
return False
logger.debug(
f"STT mute callback: allow_interrupt={engine._current_node.allow_interrupt}"
)
return not engine._current_node.allow_interrupt
return callback
# ---------------------------------------------------------------------------
# User-idle handling
# ---------------------------------------------------------------------------
def create_user_idle_callback(engine: "PipecatEngine"):
"""Return a callback that handles user-idle timeouts."""
async def handle_user_idle(
user_idle: "UserIdleProcessor", retry_count: int
) -> bool:
logger.debug(f"Handling user_idle, attempt: {retry_count}")
# Check if we're on a StartNode - if yes, directly disconnect
if engine._current_node and engine._current_node.is_start:
logger.debug("User idle on StartNode - disconnecting immediately")
await engine.send_end_task_frame(
EndTaskReason.USER_IDLE_MAX_DURATION_EXCEEDED.value
)
return False
if retry_count == 1:
# Simulate an LLM generation, so that we can have the LLM context
# updated with the new message
await engine.task.queue_frames(
[
LLMFullResponseStartFrame(),
TTSSpeakFrame("Just checking in to see if you're still there."),
LLMFullResponseEndFrame(),
]
)
return True
# Second attempt: terminate the call due to inactivity.
await user_idle.push_frame(
TTSSpeakFrame("It seems like you're busy right now. Have a nice day!")
)
await engine.send_end_task_frame(
EndTaskReason.USER_IDLE_MAX_DURATION_EXCEEDED.value
)
return False
return handle_user_idle
# ---------------------------------------------------------------------------
# Max-duration handling
# ---------------------------------------------------------------------------
def create_max_duration_callback(engine: "PipecatEngine"):
"""Return a callback that ends the task when the max call duration is exceeded."""
async def handle_max_duration():
logger.debug("Max call duration exceeded. Terminating call")
await engine.send_end_task_frame(EndTaskReason.CALL_DURATION_EXCEEDED.value)
return handle_max_duration
# ---------------------------------------------------------------------------
# LLM-generated-text handling
# ---------------------------------------------------------------------------
def create_llm_generated_text_callback(engine: "PipecatEngine"):
"""Return a callback invoked when the LLM emits text (not only tool calls)."""
async def handle_llm_generated_text(): # noqa: D401
logger.debug(
"Generation has text content in current response - deferring context push from set_node"
)
engine._defer_context_push = True
return handle_llm_generated_text
# ---------------------------------------------------------------------------
# Generation-started handling
# ---------------------------------------------------------------------------
def create_generation_started_callback(engine: "PipecatEngine"):
"""Return a callback that resets flags at the start of each LLM generation."""
async def handle_generation_started(): # noqa: D401
logger.debug("LLM generation started - resetting defer flags and tool counters")
engine._defer_context_push = False
engine._pending_function_calls = 0
engine._pending_generated_transition_after_context_push = None
# Clear reference text from previous generation
engine._current_llm_reference_text = ""
return handle_generation_started
# ---------------------------------------------------------------------------
# User-stopped-speaking handling
# ---------------------------------------------------------------------------
def create_user_stopped_speaking_callback(engine: "PipecatEngine"):
"""Return a callback that handles when the user stops speaking.
According to simplified flow:
- For start nodes with wait_for_user_response=True:
- Cancel timeout task if still active
- Transition to next node with _queue_context_frame=False
"""
async def handle_user_stopped_speaking():
# Only handle if current node is a start node with wait_for_user_response
if (
engine._current_node
and engine._current_node.is_start
and engine._current_node.wait_for_user_response
and engine._current_node.out_edges
):
# Cancel timeout task if it's still active
if (
engine._user_response_timeout_task
and not engine._user_response_timeout_task.done()
):
logger.debug("Cancelling user response timeout - user responded")
engine._user_response_timeout_task.cancel()
engine._user_response_timeout_task = None
# Transition to next node
next_node_id = engine._current_node.out_edges[0].target
logger.debug(
f"User stopped speaking after wait_for_user_response - transitioning to: {next_node_id}"
)
# Set flag to not queue context frame since
# it will be pushed by user context aggregator
# we are just setting the context with next node's
# functions and prompts
engine._queue_context_frame = False
# Transition to next node
await engine.set_node(next_node_id)
return handle_user_stopped_speaking
# ---------------------------------------------------------------------------
# User-started-speaking handling
# ---------------------------------------------------------------------------
def create_user_started_speaking_callback(engine: "PipecatEngine"):
"""Return a callback that handles when the user starts speaking.
According to simplified flow:
- For start nodes with wait_for_user_response=True:
- Cancel the timeout timer if it exists (but don't set to None)
"""
async def handle_user_started_speaking():
# Only handle if current node is a start node with wait_for_user_response
if (
engine._current_node
and engine._current_node.is_start
and engine._current_node.wait_for_user_response
and engine._user_response_timeout_task
and not engine._user_response_timeout_task.done()
):
logger.debug(
"User started speaking during wait_for_user_response - cancelling timeout timer"
)
engine._user_response_timeout_task.cancel()
# Don't set to None here - let user_stopped_speaking handle the transition
return handle_user_started_speaking
def create_aggregation_correction_callback(engine: "PipecatEngine"):
"""Create a callback that uses engine's reference text to correct corrupted aggregation."""
def correct_corrupted_aggregation(ref: str, corrupted: str) -> str:
"""Correct corrupted text by aligning it with reference text.
This is a pure function that doesn't depend on engine instance.
"""
# 1) Safety check: if ref (minus spaces) is shorter than corrupted, bail out
# also if corrupted is less than 10 characters, lets also return that since most likely
# Elevenlabs returned the right alignment
alnum_corr = "".join(ch for ch in corrupted if ch.isalnum())
alnum_ref = "".join(ch for ch in ref if ch.isalnum())
if corrupted in ref or len(alnum_ref) < len(alnum_corr) or len(alnum_corr) < 10:
return corrupted
# 2) Find where in `ref` we should start aligning.
# We take the first N (N=10) characters of `corrupted`
# and look for all their occurrences in `ref`.
# We pick the *last* one
prefix = corrupted[:10]
# find all startindices of that prefix in ref
starts = [m.start() for m in re.finditer(re.escape(prefix), ref)]
start_idx = starts[-1] if starts else 0
# 3) Now run the same twopointer scan from start_idx
i, j = start_idx, 0
out_chars = []
while i < len(ref) and j < len(corrupted):
r_ch, c_ch = ref[i], corrupted[j]
if r_ch == c_ch:
out_chars.append(r_ch)
i += 1
j += 1
elif c_ch == " ":
# extra space in corrupted → skip it
j += 1
elif r_ch == " " or r_ch in ".,;:!?":
# missing structural char in corrupted → emit from ref
out_chars.append(r_ch)
i += 1
else:
# letter mismatch → besteffort copy from ref
out_chars.append(r_ch)
i += 1
j += 1
# 4) A final check - the final created output should be exactly
# as corrupted sentence sans whitespace.
alnum_out = "".join([ch for ch in out_chars if ch.isalnum()])
if alnum_out != alnum_corr:
return corrupted
# 5) Join and return exactly what we built
return "".join(out_chars)
def correct_aggregation(corrupted: str) -> str:
reference = engine._current_llm_reference_text
if not reference:
logger.warning("No reference text available for aggregation correction")
return corrupted
# Apply the correction algorithm
corrected = correct_corrupted_aggregation(reference, corrupted)
return corrected
return correct_aggregation

View file

@ -0,0 +1,90 @@
from __future__ import annotations
from typing import Any, Dict, List
from google.genai.types import (
Content,
Part,
)
from pipecat.adapters.schemas.function_schema import FunctionSchema
from pipecat.adapters.schemas.tools_schema import ToolsSchema
from pipecat.services.google.llm import GoogleLLMContext
from pipecat.services.openai.llm import OpenAILLMContext
from api.utils.template_renderer import render_template
__all__ = [
"get_function_schema",
"update_llm_context",
"render_template",
]
def get_function_schema(
function_name: str,
description: str,
*,
properties: Dict[str, Any] | None = None,
required: List[str] | None = None,
) -> FunctionSchema:
"""Create a FunctionSchema definition that can later be transformed into
the provider-specific format (OpenAI, Gemini, etc.).
The helper keeps the public signature backward-compatible callers that
only pass ``function_name`` and ``description`` continue to work and will
define a parameter-less function.
"""
return FunctionSchema(
name=function_name,
description=description,
properties=properties or {},
required=required or [],
)
def update_llm_context(
context: OpenAILLMContext,
system_message: Dict[str, Any],
functions: List[FunctionSchema],
) -> None:
"""Update *context* with an up-to-date system message and tool list.
This helper removes any previous system messages before inserting the new
*system_message* at the top of the conversation history and then instructs
the LLM which *functions* (a.k.a. tools) are currently available.
"""
# Wrap the provided function schemas in a ToolsSchema so that the adapter
# associated with the current LLM service can convert them to the correct
# provider-specific representation when required.
tools_schema = ToolsSchema(standard_tools=functions)
if isinstance(context, GoogleLLMContext):
context.system_message = system_message["content"]
if functions:
# Lets only call set_tools if we have functions, else Gemini will
# throw an exception
context.set_tools(tools_schema)
if context.messages[-1].role != "user":
# Google expects the last message should end with user message
context.add_message(Content(role="user", parts=[Part(text="...")]))
return
# In case of OpenAILLMContext, replace the system message with incoming system message
previous_interactions = context.messages
# Filter out old system messages but keep user/assistant/function content.
messages: List[Dict[str, Any]] = [system_message]
messages.extend(
interaction
for interaction in previous_interactions
if interaction["role"] != "system"
)
context.set_messages(messages)
if functions:
context.set_tools(tools_schema)

View file

@ -0,0 +1,192 @@
from __future__ import annotations
import json
import os
from typing import TYPE_CHECKING, Any, List
from loguru import logger
from openai import AsyncOpenAI
from opentelemetry import trace
from pipecat.services.openai.llm import OpenAILLMContext
from pipecat.utils.tracing.service_attributes import add_llm_span_attributes
from api.services.pipecat.tracing_config import is_tracing_enabled
from api.services.workflow.dto import ExtractionVariableDTO
if TYPE_CHECKING:
from api.services.workflow.pipecat_engine import PipecatEngine
class VariableExtractionManager:
"""Helper that registers and executes the \"extract_variables\" tool.
The manager is responsible for two things:
1. Registering a callable with the LLM service so that the tool can be
invoked from within the model.
2. Executing the extraction in a background task while maintaining
correct bookkeeping and optional OpenTelemetry tracing.
"""
def __init__(self, engine: "PipecatEngine") -> None: # noqa: F821
# We keep a reference to the engine so we can reuse its context
# and update internal counters / extracted variable state.
self._engine = engine
self._context = engine.context
self._model = "gpt-4o"
# ------------------------------------------------------------------
# Internal helpers
# ------------------------------------------------------------------
async def _perform_extraction(
self,
extraction_variables: List[ExtractionVariableDTO],
parent_ctx: Any,
extraction_prompt: str = "",
) -> dict:
"""Run the actual extraction chat completion and post-process the result."""
# ------------------------------------------------------------------
# Build the prompt that instructs the model to extract the variables.
# ------------------------------------------------------------------
vars_description = "\n".join(
f"- {v.name} ({v.type}): {v.prompt}" for v in extraction_variables
)
# ------------------------------------------------------------------
# Build a normalized representation of the existing conversation so the
# extractor works with both OpenAI-style (dict) messages and Google
# Gemini `Content` objects.
# ------------------------------------------------------------------
def _get_role_and_content(msg: Any) -> tuple[str | None, str | None]:
"""Return a pair of (role, content) for the given message.
The logic supports both OpenAI-style dict messages and Google
`Content` objects that expose ``role`` and ``parts`` attributes.
Only plain textual content is extracted image parts, tool call
placeholders, etc. are ignored for the purpose of variable
extraction.
"""
# --------------------------------------------------------------
# OpenAI format → simple dict with ``role`` and ``content`` keys
# --------------------------------------------------------------
if isinstance(msg, dict):
role = msg.get("role")
content_field = msg.get("content")
# Content can be a str, list of segments, or None.
if isinstance(content_field, str):
content = content_field
elif isinstance(content_field, list):
# Collapse all text parts into a single string.
texts = [
segment.get("text", "")
for segment in content_field
if isinstance(segment, dict) and segment.get("type") == "text"
]
content = " ".join(texts) if texts else None
else:
content = None
return role, content
# --------------------------------------------------------------
# Google Gemini format → ``Content`` object with ``parts`` list
# --------------------------------------------------------------
role_attr = getattr(msg, "role", None)
parts_attr = getattr(msg, "parts", None)
if role_attr is None or parts_attr is None:
return None, None # Unrecognised message format
role = (
"assistant" if role_attr == "model" else role_attr
) # Normalise role name
# Collect textual parts only (ignore images, function calls, etc.)
texts: list[str] = []
for part in parts_attr:
text_val = getattr(part, "text", None)
if text_val:
texts.append(text_val)
content = " ".join(texts) if texts else None
return role, content
conversation_lines: list[str] = []
for msg in self._context.messages:
role, content = _get_role_and_content(msg)
if role in ("assistant", "user") and content:
conversation_lines.append(f"{role}: {content}")
conversation_history = "\n".join(conversation_lines)
system_prompt = (
"You are an assistant tasked with extracting structured data from the conversation. "
"Return ONLY a valid JSON object with the requested variables as top-level keys. Do not wrap the JSON in markdown." # noqa: E501
)
# Use provided extraction_prompt as system prompt, or default
system_prompt = (
system_prompt + "\n\n" + extraction_prompt
if extraction_prompt
else system_prompt
)
user_prompt = (
"\n\nVariables to extract:\n"
f"{vars_description}"
"\n\nConversation history:\n"
f"{conversation_history}"
)
extraction_context = OpenAILLMContext()
extraction_messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
]
extraction_context.set_messages(extraction_messages)
# ------------------------------------------------------------------
# Use independent OpenAI client for LLM call
# ------------------------------------------------------------------
client = AsyncOpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
# Direct API call - no pipeline involvement
response = await client.chat.completions.create(
model=self._model,
messages=extraction_messages,
temperature=0.0,
response_format={"type": "json_object"},
)
llm_response = response.choices[0].message.content
if is_tracing_enabled():
tracer = trace.get_tracer("pipecat")
with tracer.start_as_current_span(
"variable_extraction", context=parent_ctx
) as span:
add_llm_span_attributes(
span,
service_name="OpenAILLMService",
model=self._model,
operation_name="variable_extraction",
messages=json.dumps(extraction_messages),
output=llm_response,
stream=False,
parameters={"temperature": 0.0, "response_format": "json_object"},
)
# ------------------------------------------------------------------
# Parse the assistant output fall back to raw text if it is not valid JSON.
# ------------------------------------------------------------------
try:
extracted = json.loads(llm_response)
except json.JSONDecodeError:
logger.warning(
"Extractor returned invalid JSON; storing raw content instead."
)
extracted = {"raw": llm_response}
logger.debug(f"Extracted variables: {extracted}")
return extracted

View file

@ -0,0 +1,448 @@
from __future__ import annotations
import asyncio
import io
import json
import os
import tempfile
import wave
from typing import TYPE_CHECKING, Optional
from langfuse import get_client
from loguru import logger
from openai import AsyncOpenAI
from opentelemetry import context as otel_context
from pipecat.utils.enums import EndTaskReason
from pipecat.utils.tracing.context_registry import get_current_turn_context
from api.db import db_client
from api.services.pipecat.tracing_config import is_tracing_enabled
from api.tasks.arq import enqueue_job
from api.tasks.function_names import FunctionNames
if TYPE_CHECKING:
from api.services.workflow.pipecat_engine import PipecatEngine
DEFAULT_VOICEMAIL_PROMPT = """
You are analyzing the beginning of a phone call to determine if it's a voicemail greeting.
Common voicemail indicators:
- "You've reached the voicemail of..."
- "Please leave a message after the beep"
- "I'm not available right now"
- "Press 1 to leave a message"
- Robotic or pre-recorded voice quality mentioned
- Background music or hold music references
Transcript: {transcript}
Respond with a JSON object:
{
"is_voicemail": true/false,
"confidence": 0.0-1.0,
"reasoning": "Brief explanation"
}
"""
class VoicemailDetector:
"""
Autonomous voicemail detection system that operates independently of the main pipeline.
"""
def __init__(self, detection_duration: float = 15.0, workflow_run_id: int = None):
self.detection_duration = detection_duration
self.audio_buffer = bytearray()
self.is_detecting = False
self.workflow_run_id = workflow_run_id
self._langfuse_client = get_client()
# We will set the sample rate when we receive the audio packet
self._sample_rate = None
# Task management
self._detection_task: Optional[asyncio.Task] = None
self._is_cancelled = False
self._engine: Optional[PipecatEngine] = None
# Event for audio collection completion
self._audio_collected_event = asyncio.Event()
# ------------------------------------------------------------------
# Utility helpers
# ------------------------------------------------------------------
def _current_duration_seconds(self) -> float:
"""Return the duration (in seconds) of the audio currently in the buffer."""
if self._sample_rate:
return len(self.audio_buffer) / (self._sample_rate * 2)
return 0.0
async def handle_audio_data(
self, processor, pcm: bytes, sample_rate: int, num_channels: int
):
"""Handle incoming audio data without affecting pipeline."""
if not self.is_detecting or self._is_cancelled:
return
# Store the actual sample rate from the first audio packet
if self._sample_rate is None:
self._sample_rate = sample_rate
logger.debug(f"Voicemail detector using sample rate: {sample_rate}")
# Add to buffer without resampling
self.audio_buffer.extend(pcm)
# Check if we've collected enough audio
current_duration = self._current_duration_seconds()
if current_duration >= self.detection_duration:
self._audio_collected_event.set()
async def start_detection(self, engine: PipecatEngine):
"""Start voicemail detection process."""
logger.info("Starting voicemail detection")
self.is_detecting = True
self._is_cancelled = False
self._engine = engine
self._audio_collected_event.clear()
# Start detection in background
self._detection_task = asyncio.create_task(self._run_detection_with_timeout())
async def stop_detection(self):
"""Stop detection immediately (called on disconnect)."""
logger.info("Stopping voicemail detection due to disconnect")
self._is_cancelled = True
self.is_detecting = False
# Set the event to unblock any waiting tasks
self._audio_collected_event.set()
# Cancel ongoing detection task
if self._detection_task and not self._detection_task.done():
self._detection_task.cancel()
# Clear audio buffer
self.audio_buffer.clear()
# Wait for tasks to complete cancellation
if self._detection_task:
try:
await self._detection_task
except asyncio.CancelledError:
pass
async def _run_detection_with_timeout(self):
"""Run detection with proper timeout and cancellation handling."""
try:
# Wait for audio collection or cancellation directly
await self._wait_for_audio_collection()
# Check if cancelled during collection
if self._is_cancelled:
logger.info("Detection cancelled during audio collection")
return
# Process detection
await self._process_detection()
except asyncio.CancelledError:
logger.info("Voicemail detection task cancelled")
except Exception as e:
logger.error(f"Error in voicemail detection: {e}")
finally:
self.is_detecting = False
async def _wait_for_audio_collection(self):
"""Wait for audio buffer to fill or timeout."""
try:
# Wait for either audio collection completion or timeout
await asyncio.wait_for(
self._audio_collected_event.wait(),
timeout=self.detection_duration + 2.0,
)
if not self._is_cancelled:
current_duration = self._current_duration_seconds()
logger.info(
f"Collected {current_duration:.1f}s of audio for voicemail detection (sample rate: {self._sample_rate}Hz)"
)
except asyncio.TimeoutError:
if not self._is_cancelled:
current_duration = self._current_duration_seconds()
logger.warning("Audio collection timeout exceeded")
logger.info(
f"Proceeding with {current_duration:.1f}s of audio (sample rate: {self._sample_rate}Hz)"
)
async def _process_detection(self):
"""Process the collected audio to detect voicemail."""
if not self.audio_buffer or not self._engine:
logger.warning("No audio buffer or engine available for detection")
return
try:
# Convert PCM to WAV once for both transcription and storage
wav_data = self._create_wav_from_pcm(bytes(self.audio_buffer))
# Transcribe audio
logger.info("Transcribing audio for voicemail detection")
transcript = await self._transcribe_audio(wav_data)
if not transcript:
logger.warning("No transcript obtained from audio")
# Still upload the raw recording so data pipeline has it
if self.workflow_run_id:
await self._save_voicemail_audio(wav_data, 0.0, False)
return
logger.info(
f"Voicemail detection transcript obtained: {transcript[:100]}..."
)
# Analyze transcript
result = await self._analyze_transcript(transcript)
# Extract common fields
confidence = result.get("confidence", 0.0)
reasoning = result.get("reasoning", "No reasoning provided")
# Save voicemail audio to S3 once for data pipeline (include duration in filename)
s3_path = None
if self.workflow_run_id:
s3_path = await self._save_voicemail_audio(
wav_data, confidence, result.get("is_voicemail")
)
# Take action based on result
if result.get("is_voicemail", False):
logger.info(
f"Voicemail detected with confidence {confidence}: {reasoning}"
)
# Update workflow run with voicemail tags
if self.workflow_run_id:
# Fetch the workflow run from database
workflow_run = await db_client.get_workflow_run_by_id(
self.workflow_run_id
)
if workflow_run:
call_tags = workflow_run.gathered_context.get("call_tags", [])
call_tags.extend(["voicemail_detected", "not_connected"])
await db_client.update_workflow_run(
run_id=workflow_run.id,
gathered_context={
"call_tags": call_tags,
"voicemail_transcript": transcript,
"voicemail_confidence": confidence,
},
)
# Send end task frame with metadata (including optional S3 path)
await self._engine.send_end_task_frame(
reason=EndTaskReason.VOICEMAIL_DETECTED.value,
additional_metadata={
"voicemail_transcript": transcript,
"voicemail_confidence": confidence,
"voicemail_reasoning": reasoning,
"voicemail_detection_duration": self.detection_duration,
"voicemail_audio_s3_path": s3_path,
},
abort_immediately=True,
)
else:
logger.info("No voicemail detected, continuing normal conversation")
except Exception as e:
logger.error(f"Error processing voicemail detection: {e}")
async def _transcribe_audio(self, wav_data: bytes) -> str:
"""Transcribe audio using OpenAI API directly.
Args:
wav_data: WAV formatted audio data
"""
client = AsyncOpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
# Direct API call - no pipeline involvement
response = await client.audio.transcriptions.create(
file=("audio.wav", wav_data, "audio/wav"),
model="whisper-1", # Using whisper-1 as it's more stable for transcription
language="en",
temperature=0.0,
)
return response.text.strip()
def _create_wav_from_pcm(self, pcm_data: bytes) -> bytes:
"""Convert raw PCM data to WAV format."""
wav_buffer = io.BytesIO()
with wave.open(wav_buffer, "wb") as wav_file:
wav_file.setnchannels(1) # Mono
wav_file.setsampwidth(2) # 16-bit
wav_file.setframerate(self._sample_rate)
wav_file.writeframes(pcm_data)
wav_buffer.seek(0)
return wav_buffer.read()
async def _analyze_transcript(self, transcript: str) -> dict:
"""Analyze transcript using independent OpenAI client."""
# Capture the current turn context for proper span nesting
parent_context = get_current_turn_context()
client = AsyncOpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
langfuse_prompt = None
try:
langfuse_prompt = self._langfuse_client.get_prompt(
"production/voicemail_detection"
)
prompt = langfuse_prompt.compile(transcript=transcript)
except Exception as e:
logger.warning(f"Error getting prompt from Langfuse: {e}")
prompt = DEFAULT_VOICEMAIL_PROMPT.replace("{transcript}", transcript)
messages = [
{
"role": "system",
"content": prompt,
}
]
# When we have a parent OpenTelemetry context, we need to activate it
# so that Langfuse's OTEL tracer will automatically pick it up
if parent_context and is_tracing_enabled():
# Activate the parent context for this scope
token = otel_context.attach(parent_context)
try:
# Start Langfuse generation - it will automatically use the active OTEL context
langfuse_generation = None
try:
langfuse_generation = self._langfuse_client.start_generation(
name="voicemail_detection",
model="gpt-4o",
input=messages,
metadata={
"temperature": 0.0,
"detection_duration": self.detection_duration,
"transcript_length": len(transcript),
},
prompt=langfuse_prompt,
)
except Exception as e:
logger.warning(f"Error starting Langfuse generation: {e}")
# Direct API call
response = await client.chat.completions.create(
model="gpt-4o",
messages=messages,
temperature=0.0,
response_format={"type": "json_object"},
)
llm_response = response.choices[0].message.content
# Update and end Langfuse generation
if langfuse_generation:
try:
langfuse_generation.update(
output=llm_response,
usage_details={
"prompt_tokens": response.usage.prompt_tokens
if response.usage
else 0,
"completion_tokens": response.usage.completion_tokens
if response.usage
else 0,
"total_tokens": response.usage.total_tokens
if response.usage
else 0,
},
)
langfuse_generation.end()
except Exception as e:
logger.warning(f"Error updating Langfuse generation: {e}")
finally:
# Detach the context
otel_context.detach(token)
else:
# No parent context or tracing disabled - just make the API call
response = await client.chat.completions.create(
model="gpt-4o",
messages=messages,
temperature=0.0,
response_format={"type": "json_object"},
)
llm_response = response.choices[0].message.content
# Parse response
try:
return json.loads(llm_response)
except json.JSONDecodeError:
logger.warning("Invalid JSON response from voicemail detection")
return {
"is_voicemail": False,
"confidence": 0.0,
"reasoning": "Invalid response",
}
async def _save_voicemail_audio(
self, wav_data: bytes, confidence: float, is_voicemail: bool
) -> Optional[str]:
"""Save voicemail audio to temp file and enqueue task to upload to S3.
Args:
wav_data: WAV formatted audio data
confidence: Detection confidence score
is_voicemail: Whether it was detected as voicemail
Returns:
The expected S3 object key (bucket path). The actual upload happens asynchronously.
"""
try:
# Create filename with prediction, confidence and duration
duration_seconds = self._current_duration_seconds()
prediction = "voicemail" if is_voicemail else "not_voicemail"
confidence_int = int(confidence * 100)
duration_int = int(duration_seconds)
s3_key = f"voicemail_detections/{self.workflow_run_id}_{prediction}_{confidence_int}_{duration_int}.wav"
# Write WAV data to temp file - DO NOT delete it here, the async task will handle cleanup
with tempfile.NamedTemporaryFile(
suffix=".wav",
delete=False, # Important: don't delete immediately
prefix=f"voicemail_{self.workflow_run_id}_",
) as tmp_file:
tmp_file.write(wav_data)
tmp_file.flush()
temp_file_path = tmp_file.name
logger.info(f"Saved voicemail audio to temp file: {temp_file_path}")
# Enqueue async task to upload to S3
await enqueue_job(
FunctionNames.UPLOAD_VOICEMAIL_AUDIO_TO_S3,
self.workflow_run_id,
temp_file_path,
s3_key,
)
logger.info(f"Enqueued voicemail audio upload task for: {s3_key}")
return s3_key
except Exception as e:
logger.error(f"Failed to save voicemail audio: {e}")
# Clean up temp file if task enqueue failed
if "temp_file_path" in locals() and os.path.exists(temp_file_path):
try:
os.remove(temp_file_path)
except Exception as cleanup_error:
logger.warning(
f"Failed to cleanup temp file after error: {cleanup_error}"
)
return None

View file

View file

@ -0,0 +1,164 @@
{
"nodes": [
{
"id": "915",
"type": "agentNode",
"position": {
"x": 633,
"y": 324
},
"data": {
"prompt": "You are a voice agent whose mode of speaking is voice. Ask the user whether they want to talk to a sales guy or a customer service agent",
"name": "Agent"
},
"measured": {
"width": 300,
"height": 100
},
"selected": false,
"dragging": false
},
{
"id": "7598",
"type": "agentNode",
"position": {
"x": 460.1247806640531,
"y": 610.3714977079578
},
"data": {
"prompt": "You are a customer service agent whose mode of communication with the user is voice. Tell them that someone from our team will reach out to them soon",
"name": "Agent"
},
"measured": {
"width": 300,
"height": 100
},
"selected": false,
"dragging": false
},
{
"id": "6919",
"type": "agentNode",
"position": {
"x": 914.666735413607,
"y": 642.9800281289787
},
"data": {
"prompt": "You are a sales representative whose mode of communication with the user is voice. Tell the user that someone from our team will reach out to you soon",
"name": "Agent"
},
"measured": {
"width": 300,
"height": 100
},
"selected": false,
"dragging": false
},
{
"id": "6581",
"type": "startCall",
"position": {
"x": 648,
"y": 35
},
"data": {
"prompt": "Hello, I am Abhishek from Dograh. ",
"is_static": true,
"name": "Start Call",
"is_start": true
},
"measured": {
"width": 300,
"height": 100
},
"selected": false,
"dragging": false
},
{
"id": "1802",
"type": "endCall",
"position": {
"x": 666.7733431033548,
"y": 987.4345801025363
},
"data": {
"prompt": "Thank you for calling Dograh. Have a great day!",
"is_static": true,
"name": "End Call"
},
"measured": {
"width": 300,
"height": 100
},
"selected": false,
"dragging": false
}
],
"edges": [
{
"animated": true,
"type": "custom",
"source": "915",
"target": "7598",
"id": "xy-edge__915-7598",
"selected": false,
"data": {
"condition": "The customer wants to talk to a customer service agent",
"label": "customer service agent"
}
},
{
"animated": true,
"type": "custom",
"source": "915",
"target": "6919",
"id": "xy-edge__915-6919",
"selected": false,
"data": {
"condition": "customer wants to talk to a sales representative",
"label": "sales representative"
}
},
{
"animated": true,
"type": "custom",
"source": "6581",
"target": "915",
"id": "xy-edge__6581-915",
"selected": false,
"data": {
"condition": "Always take this route",
"label": "Always take this route"
}
},
{
"animated": true,
"type": "custom",
"source": "7598",
"target": "1802",
"id": "xy-edge__7598-1802",
"selected": false,
"data": {
"condition": "end call",
"label": "end call"
}
},
{
"animated": true,
"type": "custom",
"source": "6919",
"target": "1802",
"id": "xy-edge__6919-1802",
"selected": false,
"data": {
"condition": "end call",
"label": "end call"
}
}
],
"viewport": {
"x": 0,
"y": 0,
"zoom": 1
}
}

View file

@ -0,0 +1,192 @@
from unittest.mock import Mock
from api.services.workflow.pipecat_engine_callbacks import (
create_aggregation_correction_callback,
)
def test_aggregation_fixer():
"""Validate the aggregation correction algorithm using a helper that
creates a fresh callback for every (reference, corrupted) pair.
The production callback now needs a PipecatEngine instance with the
`_current_llm_reference_text` set. For test-friendliness we mock a bare
object providing just that attribute for each assertion so the original
two-argument test cases remain unchanged.
"""
def fixer(reference: str, corrupted: str) -> str: # noqa: D401
mock_engine = Mock()
mock_engine._current_llm_reference_text = reference
return create_aggregation_correction_callback(mock_engine)(corrupted)
##### Trailing extra Chars #####
assert (
fixer(
"Good Morning Mr NARGES , My name is Alex and I am calling you from Consumer Services. The reason of my call today, as I can see in our records that you are making your monthly credit card payments on time, but you STILL carry a balance of over 7 thousand dollars, right?",
"My name is Alex and I am calling you from Cons umer Services. The reason of my call today, as I can see in our records that you are making your monthly credit card payments on time, but you STILL carry a balance of over 7 thousand dollars, right?",
)
== "My name is Alex and I am calling you from Consumer Services. The reason of my call today, as I can see in our records that you are making your monthly credit card payments on time, but you STILL carry a balance of over 7 thousand dollars, right?"
), "leading_whole_sentence"
# Whole sentences
assert (
fixer(
"Good Morning Mr NARGES , My name is Alex and I am calling you from Consumer Services. The reason of my call today, as I can see in our records that you are making your monthly credit card payments on time, but you STILL carry a balance of over 7 thousand dollars, right?",
"Good Morning Mr NAR GES , My name is Alex and I am calling you from Cons umer Services. The reason of my call today, as I can see in our records that you are making your monthly credit card payments on time, but you STILL carry a balance of over 7 thousand dollars, right?",
)
== "Good Morning Mr NARGES , My name is Alex and I am calling you from Consumer Services. The reason of my call today, as I can see in our records that you are making your monthly credit card payments on time, but you STILL carry a balance of over 7 thousand dollars, right?"
), "whole_sentences"
# With a period in the end
assert (
fixer(
"Good Morning Mr NARGES , My name is Alex and I am calling you from Consumer Services. The reason of my call today, as I can see in our records that you are making your monthly credit card payments on time, but you STILL carry a balance of over 7 thousand dollars, right?",
"Good Morning Mr NAR GES , My name is Alex and I am calling you from Cons umer Services.",
)
== "Good Morning Mr NARGES , My name is Alex and I am calling you from Consumer Services."
), "period_end"
# without a period in the end
assert (
fixer(
"Good Morning Mr NARGES , My name is Alex and I am calling you from Consumer Services. The reason of my call today, as I can see in our records that you are making your monthly credit card payments on time, but you STILL carry a balance of over 7 thousand dollars, right?",
"Good Morning Mr NAR GES , My name is Alex and I am calling you from Cons umer Services",
)
== "Good Morning Mr NARGES , My name is Alex and I am calling you from Consumer Services"
), "without_period_end"
# Extra space in the end
assert (
fixer(
"Good Morning Mr NARGES , My name is Alex and I am calling you from Consumer Services. The reason of my call today, as I can see in our records that you are making your monthly credit card payments on time, but you STILL carry a balance of over 7 thousand dollars, right?",
"Good Morning Mr NAR GES , My name is Alex and I am calling you from Cons umer Services ",
)
== "Good Morning Mr NARGES , My name is Alex and I am calling you from Consumer Services"
), "extra_space"
# Multiple spaces in corruption
assert (
fixer(
"Good Morning Mr NARGES , My name is Alex and I am calling you from Consumer Services. The reason of my call today, as I can see in our records that you are making your monthly credit card payments on time, but you STILL carry a balance of over 7 thousand dollars, right?",
"Good Morning Mr NAR GES , My name is Alex and I am calling you from Cons umer Servi ces ",
)
== "Good Morning Mr NARGES , My name is Alex and I am calling you from Consumer Services"
), "multiple_space"
# Multiple spaces in corruption ending in a whitespace
assert (
fixer(
"Good Morning Mr NARGES , My name is Alex and I am calling you from Consumer Services. The reason of my call today, as I can see in our records that you are making your monthly credit card payments on time, but you STILL carry a balance of over 7 thousand dollars, right?",
"Good Morning Mr NAR GES , My name is Alex and I am calling you from Cons umer Servi ces. ",
)
== "Good Morning Mr NARGES , My name is Alex and I am calling you from Consumer Services. "
), "multiple_space_end_ws"
##### Leading extra Chars #####
# Whole sentences
assert (
fixer(
"Good Morning Mr NARGES , My name is Alex and I am calling you from Consumer Services. The reason of my call today, as I can see in our records that you are making your monthly credit card payments on time, but you STILL carry a balance of over 7 thousand dollars, right?",
"My name is Alex and I am calling you from Cons umer Services. The reason of my call today, as I can see in our records that you are making your monthly credit card payments on time, but you STILL carry a balance of over 7 thousand dollars, right?",
)
== "My name is Alex and I am calling you from Consumer Services. The reason of my call today, as I can see in our records that you are making your monthly credit card payments on time, but you STILL carry a balance of over 7 thousand dollars, right?"
), "leading_whole_sentence"
# With a period in the end
assert (
fixer(
"Good Morning Mr NARGES , My name is Alex and I am calling you from Consumer Services. The reason of my call today, as I can see in our records that you are making your monthly credit card payments on time, but you STILL carry a balance of over 7 thousand dollars, right?",
"My name is Alex and I am calling you from Cons umer Services.",
)
== "My name is Alex and I am calling you from Consumer Services."
), "leading_period_end"
# without a period in the end
assert (
fixer(
"Good Morning Mr NARGES , My name is Alex and I am calling you from Consumer Services. The reason of my call today, as I can see in our records that you are making your monthly credit card payments on time, but you STILL carry a balance of over 7 thousand dollars, right?",
"My name is Alex and I am calling you from Cons umer Services",
)
== "My name is Alex and I am calling you from Consumer Services"
), "leading_without_period_end"
# Extra space in the end
assert (
fixer(
"Good Morning Mr NARGES , My name is Alex and I am calling you from Consumer Services. The reason of my call today, as I can see in our records that you are making your monthly credit card payments on time, but you STILL carry a balance of over 7 thousand dollars, right?",
"My name is Alex and I am calling you from Cons umer Services ",
)
== "My name is Alex and I am calling you from Consumer Services"
), "leading_extra_space"
# Multiple spaces in corruption
assert (
fixer(
"Good Morning Mr NARGES , My name is Alex and I am calling you from Consumer Services. The reason of my call today, as I can see in our records that you are making your monthly credit card payments on time, but you STILL carry a balance of over 7 thousand dollars, right?",
"My name is Alex and I am calling you from Cons umer Servi ces ",
)
== "My name is Alex and I am calling you from Consumer Services"
), "leading_multiple_space"
# Multiple spaces in corruption ending in a whitespace
assert (
fixer(
"Good Morning Mr NARGES , My name is Alex and I am calling you from Consumer Services. The reason of my call today, as I can see in our records that you are making your monthly credit card payments on time, but you STILL carry a balance of over 7 thousand dollars, right?",
"My name is Alex and I am calling you from Cons umer Servi ces. ",
)
== "My name is Alex and I am calling you from Consumer Services. "
), "leading_multiple_space_end_ws"
# Whitespace
assert fixer("", "") == ""
# Missing reference
assert (
fixer("", "My name is Alex and I am calling you from Cons umer Servi ces.")
== "My name is Alex and I am calling you from Cons umer Servi ces."
), "missing_reference"
# Smaller reference
assert (
fixer(
"My name is Alex",
"My name is Alex and I am calling you from Cons umer Servi ces.",
)
== "My name is Alex and I am calling you from Cons umer Servi ces."
), "smaller_reference"
# Unrelated reference
assert (
fixer(
"Hello Hello",
"My name is Alex and I am calling you from Cons umer Servi ces.",
)
== "My name is Alex and I am calling you from Cons umer Servi ces."
), "unrelated_reference"
def test_create_aggregation_correction_callback():
"""Test the new aggregation correction callback creator."""
# Mock engine with reference text
mock_engine = Mock()
mock_engine._current_llm_reference_text = "Good Morning Mr NARGES, My name is Alex and I am calling you from Consumer Services."
# Create callback
callback = create_aggregation_correction_callback(mock_engine)
# Test correction
corrected = callback(
"Good Morning Mr NAR GES, My name is Alex and I am calling you from Cons umer Services."
)
assert (
corrected
== "Good Morning Mr NARGES, My name is Alex and I am calling you from Consumer Services."
)
# Test with no reference text
mock_engine._current_llm_reference_text = ""
corrected = callback("Some corrupted text")
assert corrected == "Some corrupted text" # Should return as-is when no reference

View file

@ -0,0 +1,128 @@
from unittest.mock import Mock
import pytest
from pipecat.processors.aggregators.llm_response import LLMAssistantAggregatorParams
from pipecat.services.openai.llm import OpenAILLMContext
from api.services.workflow.pipecat_engine import PipecatEngine
from api.services.workflow.pipecat_engine_callbacks import (
create_generation_started_callback,
)
class TestAggregationIntegration:
"""Integration tests for the TTS aggregation correction flow."""
@pytest.mark.asyncio
async def test_engine_reference_text_tracking(self):
"""Test that the engine properly tracks LLM reference text."""
# Create mock dependencies
mock_task = Mock()
mock_llm = Mock()
mock_context = Mock(spec=OpenAILLMContext)
mock_tts = Mock()
mock_workflow = Mock()
mock_workflow.start_node_id = "start"
mock_workflow.nodes = {
"start": Mock(is_start=True, is_static=True, is_end=False, out_edges=[])
}
# Create engine
engine = PipecatEngine(
task=mock_task,
llm=mock_llm,
context=mock_context,
tts=mock_tts,
workflow=mock_workflow,
call_context_vars={},
workflow_run_id=1,
)
# Test initial state
assert engine._current_llm_reference_text == ""
# Test accumulating LLM text
await engine.handle_llm_text_frame("Hello ")
assert engine._current_llm_reference_text == "Hello "
await engine.handle_llm_text_frame("world!")
assert engine._current_llm_reference_text == "Hello world!"
# Test generation started callback clears reference text
callback = create_generation_started_callback(engine)
await callback()
assert engine._current_llm_reference_text == ""
@pytest.mark.asyncio
async def test_aggregation_correction_callback_creation(self):
"""Test creating the aggregation correction callback."""
# Create mock engine
mock_task = Mock()
mock_llm = Mock()
mock_context = Mock(spec=OpenAILLMContext)
mock_workflow = Mock()
engine = PipecatEngine(
task=mock_task,
llm=mock_llm,
context=mock_context,
workflow=mock_workflow,
call_context_vars={},
workflow_run_id=1,
)
# Set reference text
engine._current_llm_reference_text = "Hello, world! How are you?"
# Create correction callback
callback = engine.create_aggregation_correction_callback()
# Test correction - note that trailing punctuation might be stripped if not in corrupted text
corrected = callback("Hello world How are you")
assert corrected == "Hello, world! How are you"
def test_llm_assistant_aggregator_params_with_callback(self):
"""Test that LLMAssistantAggregatorParams accepts correction callback."""
def mock_callback(text: str) -> str:
return text.upper()
params = LLMAssistantAggregatorParams(
expect_stripped_words=True, correct_aggregation_callback=mock_callback
)
assert params.expect_stripped_words is True
assert params.correct_aggregation_callback is not None
assert params.correct_aggregation_callback("hello") == "HELLO"
@pytest.mark.asyncio
async def test_pipeline_callbacks_processor_llm_text_frame(self):
"""Test that PipelineEngineCallbacksProcessor handles LLMTextFrame."""
from pipecat.frames.frames import LLMTextFrame
from pipecat.processors.frame_processor import FrameDirection
from api.services.pipecat.pipeline_engine_callbacks_processor import (
PipelineEngineCallbacksProcessor,
)
# Track callback invocations
callback_invoked = False
callback_text = None
async def mock_llm_text_callback(text: str):
nonlocal callback_invoked, callback_text
callback_invoked = True
callback_text = text
# Create processor with callback
processor = PipelineEngineCallbacksProcessor(
llm_text_frame_callback=mock_llm_text_callback
)
# Process LLMTextFrame
frame = LLMTextFrame(text="Hello world")
await processor.process_frame(frame, FrameDirection.DOWNSTREAM)
# Verify callback was invoked
assert callback_invoked is True
assert callback_text == "Hello world"

View file

@ -0,0 +1,31 @@
from api.services.pricing.cost_calculator import cost_calculator
def test_cost_calculator():
"""Test function to verify cost calculation works"""
sample_usage = {
"llm": {
"OpenAILLMService#0|||gpt-4.1-mini": {
"prompt_tokens": 45380,
"completion_tokens": 496,
"total_tokens": 45876,
"cache_read_input_tokens": 0,
"cache_creation_input_tokens": 0,
}
},
"tts": {"ElevenLabsTTSService#0|||eleven_flash_v2_5": 2399},
"stt": {"DeepgramSTTService#0|||nova-3-general": 177.21536946296692},
"call_duration_seconds": 179,
}
result = cost_calculator.calculate_total_cost(sample_usage)
assert result["llm_cost"] == 45380 * 0.40 / 1_000_000 + 496 * 1.60 / 1_000_000
assert result["tts_cost"] == 2399 * 0.0256 / 1_000
assert result["stt_cost"] == 177.21536946296692 / 60 * 0.0077
assert (
abs(
result["total"]
- (result["llm_cost"] + result["tts_cost"] + result["stt_cost"])
)
< 1e-10
)

View file

@ -0,0 +1,11 @@
import pytest
from api.services.workflow.dto import ReactFlowDTO
@pytest.mark.asyncio
async def test_dto():
# assert no exceptions are raised
with open("services/workflow/test/definitions/rf-1.json", "r") as f:
dto = ReactFlowDTO.model_validate_json(f.read())
assert dto is not None

View file

@ -0,0 +1,159 @@
from unittest.mock import AsyncMock, Mock
import pytest
from pipecat.frames.frames import StartInterruptionFrame
from pipecat.processors.aggregators.llm_response import LLMAssistantAggregatorParams
from pipecat.services.openai.llm import (
OpenAIAssistantContextAggregator,
OpenAILLMContext,
)
class TestInterruptionCorrection:
"""Test that TTS aggregation correction works during interruptions."""
@pytest.mark.asyncio
async def test_openai_interruption_with_correction(self):
"""Test OpenAI assistant context aggregator applies correction during interruption."""
# Create mock context
mock_context = Mock(spec=OpenAILLMContext)
mock_context.get_messages.return_value = []
mock_context.add_message = Mock()
# Create correction callback
def correction_callback(text: str) -> str:
# Simulate fixing corrupted text
if text == "Hello world how are you":
return "Hello world, how are you"
return text
# Create aggregator with correction callback
params = LLMAssistantAggregatorParams(
expect_stripped_words=True, correct_aggregation_callback=correction_callback
)
aggregator = OpenAIAssistantContextAggregator(
context=mock_context, params=params
)
# Set up aggregation state
aggregator._aggregation = "Hello world how are you"
aggregator._current_llm_response_id = "test-id"
aggregator._response_function_messages = {}
aggregator._function_calls_in_progress = {}
aggregator._started = 1
# Mock push_context_frame and reset methods
aggregator.push_context_frame = AsyncMock()
aggregator.reset = AsyncMock()
# Process interruption
interruption_frame = StartInterruptionFrame()
await aggregator._handle_interruptions(interruption_frame)
# Verify the corrected text was added to context
mock_context.add_message.assert_called_once()
added_message = mock_context.add_message.call_args[0][0]
assert added_message["role"] == "assistant"
assert (
added_message["content"]
== "Hello world, how are you <<interrupted_by_user>>"
)
@pytest.mark.asyncio
async def test_google_interruption_with_correction(self):
"""Test Google assistant context aggregator applies correction during interruption."""
from pipecat.services.google.llm import (
Content,
GoogleAssistantContextAggregator,
)
# Create mock context
mock_context = Mock(spec=OpenAILLMContext)
mock_context.get_messages.return_value = []
mock_context.add_message = Mock()
# Create correction callback
def correction_callback(text: str) -> str:
# Simulate fixing corrupted text
if text == "I am here to help":
return "I am here to help"
return text
# Create aggregator with correction callback
params = LLMAssistantAggregatorParams(
expect_stripped_words=True, correct_aggregation_callback=correction_callback
)
aggregator = GoogleAssistantContextAggregator(
context=mock_context, params=params
)
# Set up aggregation state
aggregator._aggregation = "I am here to help"
aggregator._current_llm_response_id = "test-id"
aggregator._response_function_messages = {}
aggregator._function_calls_in_progress = {}
aggregator._started = 1
# Mock push_context_frame and reset methods
aggregator.push_context_frame = AsyncMock()
aggregator.reset = AsyncMock()
# Process interruption
interruption_frame = StartInterruptionFrame()
await aggregator._handle_interruptions(interruption_frame)
# Verify the corrected text was added to context
mock_context.add_message.assert_called_once()
added_content = mock_context.add_message.call_args[0][0]
# Google uses Content objects
assert isinstance(added_content, Content)
assert added_content.role == "model"
assert len(added_content.parts) == 1
assert (
added_content.parts[0].text == "I am here to help <<interrupted_by_user>>"
)
@pytest.mark.asyncio
async def test_interruption_correction_error_handling(self):
"""Test that interruption handling continues even if correction callback fails."""
# Create mock context
mock_context = Mock(spec=OpenAILLMContext)
mock_context.get_messages.return_value = []
mock_context.add_message = Mock()
# Create correction callback that raises error
def failing_callback(text: str) -> str:
raise ValueError("Correction failed")
# Create aggregator with failing callback
params = LLMAssistantAggregatorParams(
expect_stripped_words=True, correct_aggregation_callback=failing_callback
)
aggregator = OpenAIAssistantContextAggregator(
context=mock_context, params=params
)
# Set up aggregation state
aggregator._aggregation = "Some text"
aggregator._current_llm_response_id = "test-id"
aggregator._response_function_messages = {}
aggregator._function_calls_in_progress = {}
aggregator._started = 1
# Mock push_context_frame and reset methods
aggregator.push_context_frame = AsyncMock()
aggregator.reset = AsyncMock()
# Process interruption - should not raise
interruption_frame = StartInterruptionFrame()
await aggregator._handle_interruptions(interruption_frame)
# Verify the original text was still added (fallback behavior)
mock_context.add_message.assert_called_once()
added_message = mock_context.add_message.call_args[0][0]
assert added_message["role"] == "assistant"
assert added_message["content"] == "Some text <<interrupted_by_user>>"

Some files were not shown because too many files have changed in this diff Show more