mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-22 08:38:13 +02:00
Initial Commit 🚀 🚀
This commit is contained in:
commit
4f2a629340
444 changed files with 76863 additions and 0 deletions
0
api/services/auth/__init__.py
Normal file
0
api/services/auth/__init__.py
Normal file
330
api/services/auth/depends.py
Normal file
330
api/services/auth/depends.py
Normal file
|
|
@ -0,0 +1,330 @@
|
|||
from typing import Annotated, Optional
|
||||
|
||||
import httpx
|
||||
from fastapi import Header, HTTPException
|
||||
from loguru import logger
|
||||
from pydantic import ValidationError
|
||||
|
||||
from api.constants import DEPLOYMENT_MODE, DOGRAH_MPS_SECRET_KEY, MPS_API_URL
|
||||
from api.db import db_client
|
||||
from api.db.models import UserModel
|
||||
from api.schemas.user_configuration import UserConfiguration
|
||||
from api.services.auth.stack_auth import stackauth
|
||||
from api.services.configuration.registry import (
|
||||
DograhLLMModel,
|
||||
DograhSTTModel,
|
||||
DograhTTSModel,
|
||||
DograhVoice,
|
||||
ServiceProviders,
|
||||
)
|
||||
|
||||
|
||||
async def get_user(
|
||||
authorization: Annotated[str | None, Header()] = None,
|
||||
) -> UserModel:
|
||||
# ------------------------------------------------------------------
|
||||
# Check if we're in OSS deployment mode
|
||||
# ------------------------------------------------------------------
|
||||
if DEPLOYMENT_MODE == "oss":
|
||||
return await _handle_oss_auth(authorization)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 1. Validate and fetch the authenticated Stack user
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
stack_user = await stackauth.get_user(authorization)
|
||||
if stack_user is None:
|
||||
raise HTTPException(status_code=401, detail="Unauthorized")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 2. Ensure the user has a team (Stack "selected_team_id")
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
selected_team_id: str | None = stack_user.get("selected_team_id")
|
||||
if not selected_team_id and stack_user.get("selected_team"):
|
||||
selected_team_id = stack_user["selected_team"].get("id")
|
||||
|
||||
if not selected_team_id:
|
||||
raise HTTPException(status_code=400, detail="No team selected")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 3. Persist/Fetch the local User model
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
try:
|
||||
user_model = await db_client.get_or_create_user_by_provider_id(stack_user["id"])
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Error while creating user from database {e}"
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 4. Persist Organization (team) and mapping in local database
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
try:
|
||||
(
|
||||
organization,
|
||||
org_was_created,
|
||||
) = await db_client.get_or_create_organization_by_provider_id(
|
||||
org_provider_id=selected_team_id, user_id=user_model.id
|
||||
)
|
||||
|
||||
# Check if user's selected organization differs from the current organization
|
||||
if user_model.selected_organization_id != organization.id:
|
||||
await db_client.add_user_to_organization(user_model.id, organization.id)
|
||||
|
||||
# Update user's selected organization
|
||||
await db_client.update_user_selected_organization(
|
||||
user_model.id, organization.id
|
||||
)
|
||||
|
||||
# Update the user_model object to reflect the change
|
||||
user_model.selected_organization_id = organization.id
|
||||
|
||||
# Only create default configuration if organization was just created
|
||||
# This prevents race conditions where multiple concurrent requests
|
||||
# might try to create configurations
|
||||
if org_was_created:
|
||||
existing_cfg = await db_client.get_user_configurations(user_model.id)
|
||||
if not (existing_cfg.llm or existing_cfg.tts or existing_cfg.stt):
|
||||
mps_config = await create_user_configuration_with_mps_key(
|
||||
user_model.id, organization.id, stack_user["id"]
|
||||
)
|
||||
if mps_config:
|
||||
await db_client.update_user_configuration(
|
||||
user_model.id, mps_config
|
||||
)
|
||||
|
||||
except Exception as exc:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to map user to organization: {exc}",
|
||||
)
|
||||
|
||||
return user_model
|
||||
|
||||
|
||||
async def _handle_oss_auth(authorization: str | None) -> UserModel:
|
||||
"""
|
||||
Handle authentication for OSS deployment mode.
|
||||
Uses the authorization token as provider_id and creates user/org if needed.
|
||||
"""
|
||||
if not authorization:
|
||||
raise HTTPException(status_code=401, detail="Authorization header required")
|
||||
|
||||
# Remove "Bearer " prefix if present
|
||||
token = (
|
||||
authorization.replace("Bearer ", "")
|
||||
if authorization.startswith("Bearer ")
|
||||
else authorization
|
||||
)
|
||||
|
||||
if not token:
|
||||
raise HTTPException(status_code=401, detail="Invalid authorization token")
|
||||
|
||||
try:
|
||||
# Use token as provider_id for OSS mode
|
||||
user_model = await db_client.get_or_create_user_by_provider_id(
|
||||
provider_id=token
|
||||
)
|
||||
|
||||
# Create or get organization for OSS user
|
||||
# Each OSS user gets their own organization using their token as org ID
|
||||
organization = await db_client.get_or_create_organization_by_provider_id(
|
||||
provider_id=f"org_{token}"
|
||||
)
|
||||
|
||||
# Ensure user is mapped to their organization
|
||||
if user_model.selected_organization_id != organization.id:
|
||||
# add_user_to_organization now handles race conditions with ON CONFLICT DO NOTHING
|
||||
await db_client.add_user_to_organization(user_model.id, organization.id)
|
||||
await db_client.update_user_selected_organization(
|
||||
user_model.id, organization.id
|
||||
)
|
||||
user_model.selected_organization_id = organization.id
|
||||
|
||||
return user_model
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Error while handling OSS authentication: {e}"
|
||||
)
|
||||
|
||||
|
||||
async def get_user_optional(
|
||||
authorization: Annotated[str | None, Header()] = None,
|
||||
) -> UserModel | None:
|
||||
"""
|
||||
Same as get_user but returns None instead of raising 401 if unauthorized.
|
||||
Useful for endpoints that need to work both with and without auth.
|
||||
"""
|
||||
try:
|
||||
return await get_user(authorization)
|
||||
except HTTPException as e:
|
||||
if e.status_code == 401:
|
||||
return None
|
||||
raise
|
||||
|
||||
|
||||
async def _handle_oss_auth(authorization: str | None) -> UserModel:
|
||||
"""
|
||||
Handle authentication for OSS deployment mode.
|
||||
Uses the authorization token as provider_id and creates user/org if needed.
|
||||
"""
|
||||
if not authorization:
|
||||
raise HTTPException(status_code=401, detail="Authorization header required")
|
||||
|
||||
# Remove "Bearer " prefix if present
|
||||
token = (
|
||||
authorization.replace("Bearer ", "")
|
||||
if authorization.startswith("Bearer ")
|
||||
else authorization
|
||||
)
|
||||
|
||||
if not token:
|
||||
raise HTTPException(status_code=401, detail="Invalid authorization token")
|
||||
|
||||
try:
|
||||
# Use token as provider_id for OSS mode
|
||||
user_model = await db_client.get_or_create_user_by_provider_id(
|
||||
provider_id=token
|
||||
)
|
||||
|
||||
# Create or get organization for OSS user
|
||||
# Each OSS user gets their own organization using their token as org ID
|
||||
(
|
||||
organization,
|
||||
org_was_created,
|
||||
) = await db_client.get_or_create_organization_by_provider_id(
|
||||
org_provider_id=f"org_{token}", user_id=user_model.id
|
||||
)
|
||||
|
||||
# Ensure user is mapped to their organization
|
||||
if user_model.selected_organization_id != organization.id:
|
||||
# add_user_to_organization now handles race conditions with ON CONFLICT DO NOTHING
|
||||
await db_client.add_user_to_organization(user_model.id, organization.id)
|
||||
await db_client.update_user_selected_organization(
|
||||
user_model.id, organization.id
|
||||
)
|
||||
user_model.selected_organization_id = organization.id
|
||||
|
||||
# Only create default configuration if organization was just created
|
||||
# This prevents race conditions where multiple concurrent requests
|
||||
# might try to create configurations
|
||||
if org_was_created:
|
||||
existing_cfg = await db_client.get_user_configurations(user_model.id)
|
||||
if not (existing_cfg.llm or existing_cfg.tts or existing_cfg.stt):
|
||||
mps_config = await create_user_configuration_with_mps_key(
|
||||
user_model.id, organization.id, token
|
||||
)
|
||||
if mps_config:
|
||||
await db_client.update_user_configuration(
|
||||
user_model.id, mps_config
|
||||
)
|
||||
|
||||
return user_model
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Error while handling OSS authentication: {e}"
|
||||
)
|
||||
|
||||
|
||||
async def create_user_configuration_with_mps_key(
|
||||
user_id: int, organization_id: int, user_provider_id: str
|
||||
) -> Optional[UserConfiguration]:
|
||||
"""Create user configuration using MPS service key.
|
||||
|
||||
Args:
|
||||
user_id: The user's ID
|
||||
organization_id: The organization's ID
|
||||
user_provider_id: The user's provider ID (for created_by field)
|
||||
|
||||
Returns:
|
||||
UserConfiguration with MPS-provided API keys or None if failed
|
||||
"""
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
# Use MPS API URL from constants
|
||||
if DEPLOYMENT_MODE == "oss":
|
||||
# For OSS mode, create a temporary service key without authentication
|
||||
response = await client.post(
|
||||
f"{MPS_API_URL}/api/v1/service-keys/",
|
||||
json={
|
||||
"name": f"Default Dograh Model Service Key",
|
||||
"description": "Auto-generated key for OSS user",
|
||||
"expires_in_days": 7, # Short-lived for OSS
|
||||
"created_by": user_provider_id,
|
||||
},
|
||||
timeout=10.0,
|
||||
)
|
||||
else:
|
||||
# For authenticated mode, use the secret key and organization ID
|
||||
if not DOGRAH_MPS_SECRET_KEY:
|
||||
logger.warning(
|
||||
"Warning: DOGRAH_MPS_SECRET_KEY not set for authenticated mode"
|
||||
)
|
||||
raise ValidationError("Missing DOGRAH_MPS_SECRET_KEY in non oss mode")
|
||||
|
||||
response = await client.post(
|
||||
f"{MPS_API_URL}/api/v1/service-keys/",
|
||||
json={
|
||||
"name": f"Default Dograh Model Service Key",
|
||||
"description": f"Auto-generated key for organization {organization_id}",
|
||||
"organization_id": organization_id,
|
||||
"expires_in_days": 90, # Longer-lived for authenticated users
|
||||
"created_by": user_provider_id,
|
||||
},
|
||||
headers={"X-Secret-Key": DOGRAH_MPS_SECRET_KEY},
|
||||
timeout=10.0,
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
service_key = data.get("service_key")
|
||||
|
||||
if service_key:
|
||||
# Create configuration JSON for storage in database
|
||||
# The service_factory will use this to instantiate actual services
|
||||
configuration = {
|
||||
"llm": {
|
||||
"provider": ServiceProviders.DOGRAH.value,
|
||||
"api_key": service_key,
|
||||
"model": DograhLLMModel.DEFAULT.value, # Default model
|
||||
},
|
||||
"tts": {
|
||||
"provider": ServiceProviders.DOGRAH.value,
|
||||
"api_key": service_key,
|
||||
"model": DograhTTSModel.DEFAULT.value, # Default model
|
||||
"voice": DograhVoice.DEFAULT.value, # Default voice
|
||||
},
|
||||
"stt": {
|
||||
"provider": ServiceProviders.DOGRAH.value,
|
||||
"api_key": service_key,
|
||||
"model": DograhSTTModel.DEFAULT.value, # Default model
|
||||
},
|
||||
}
|
||||
user_config = UserConfiguration(**configuration)
|
||||
return user_config
|
||||
else:
|
||||
logger.warning(
|
||||
f"Failed to get MPS service key: {response.status_code} - {response.text}"
|
||||
)
|
||||
|
||||
|
||||
async def get_superuser(
|
||||
authorization: Annotated[str | None, Header()] = None,
|
||||
) -> UserModel:
|
||||
"""
|
||||
Dependency to check if the authenticated user is a superuser.
|
||||
Raises HTTPException if user is not authenticated or not a superuser.
|
||||
"""
|
||||
user = await get_user(authorization)
|
||||
|
||||
if not user.is_superuser:
|
||||
raise HTTPException(
|
||||
status_code=403, detail="Access denied. Superuser privileges required."
|
||||
)
|
||||
|
||||
return user
|
||||
122
api/services/auth/stack_auth.py
Normal file
122
api/services/auth/stack_auth.py
Normal file
|
|
@ -0,0 +1,122 @@
|
|||
import os
|
||||
|
||||
import aiohttp
|
||||
|
||||
|
||||
class StackAuth:
|
||||
def __init__(self):
|
||||
self.project_id = os.environ.get("STACK_AUTH_PROJECT_ID")
|
||||
self.secret_server_key = os.environ.get("STACK_SECRET_SERVER_KEY")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Internal helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _strip_bearer(self, access_token: str | None) -> str | None:
|
||||
"""Remove the leading "Bearer " prefix from the token if present."""
|
||||
if not access_token:
|
||||
return None
|
||||
if access_token.startswith("Bearer "):
|
||||
return access_token.split(" ", 1)[1]
|
||||
return access_token
|
||||
|
||||
async def get_user(self, access_token: str):
|
||||
if not access_token:
|
||||
return None
|
||||
|
||||
access_token = self._strip_bearer(access_token)
|
||||
|
||||
url = os.environ.get("STACK_AUTH_API_URL") + "/api/v1/users/me"
|
||||
headers = {
|
||||
"x-stack-access-type": "server",
|
||||
"x-stack-project-id": self.project_id,
|
||||
"x-stack-secret-server-key": self.secret_server_key,
|
||||
"x-stack-access-token": access_token,
|
||||
}
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(url, headers=headers) as response:
|
||||
response = await response.json()
|
||||
if "id" in response:
|
||||
return response
|
||||
else:
|
||||
return None
|
||||
|
||||
async def impersonate(self, stack_user_id: str):
|
||||
url = os.environ.get("STACK_AUTH_API_URL") + "/api/v1/auth/sessions"
|
||||
headers = {
|
||||
"x-stack-access-type": "server",
|
||||
"x-stack-project-id": self.project_id,
|
||||
"x-stack-secret-server-key": self.secret_server_key,
|
||||
}
|
||||
|
||||
data = {
|
||||
"user_id": stack_user_id,
|
||||
"expires_in_millis": 3600000,
|
||||
"is_impersonation": True,
|
||||
}
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(url, headers=headers, json=data) as response:
|
||||
response = await response.json()
|
||||
return response
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Team & user management helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
# async def create_team(
|
||||
# self,
|
||||
# access_token: str,
|
||||
# display_name: str,
|
||||
# profile_image_url: str | None = None,
|
||||
# client_metadata: dict | None = None,
|
||||
# ) -> dict:
|
||||
# """Create a new team for the authenticated user and return the API response."""
|
||||
# token = self._strip_bearer(access_token)
|
||||
# if token is None:
|
||||
# raise ValueError("Access token required to create team")
|
||||
|
||||
# url = os.environ.get("STACK_AUTH_API_URL") + "/api/v1/teams"
|
||||
# headers = {
|
||||
# "x-stack-access-type": "server",
|
||||
# "x-stack-project-id": self.project_id,
|
||||
# "x-stack-secret-server-key": self.secret_server_key,
|
||||
# "x-stack-access-token": token,
|
||||
# "Content-Type": "application/json",
|
||||
# }
|
||||
|
||||
# payload: dict = {
|
||||
# "display_name": display_name,
|
||||
# "creator_user_id": "me",
|
||||
# }
|
||||
# if profile_image_url is not None:
|
||||
# payload["profile_image_url"] = profile_image_url
|
||||
# if client_metadata is not None:
|
||||
# payload["client_metadata"] = client_metadata
|
||||
|
||||
# async with aiohttp.ClientSession() as session:
|
||||
# async with session.post(url, headers=headers, json=payload) as response:
|
||||
# return await response.json()
|
||||
|
||||
# async def update_user(self, access_token: str, data: dict) -> dict:
|
||||
# """Patch the current user with supplied data and return the API response."""
|
||||
# token = self._strip_bearer(access_token)
|
||||
# if token is None:
|
||||
# raise ValueError("Access token required to update user")
|
||||
|
||||
# url = os.environ.get("STACK_AUTH_API_URL") + "/api/v1/users/me"
|
||||
# headers = {
|
||||
# "x-stack-access-type": "server",
|
||||
# "x-stack-project-id": self.project_id,
|
||||
# "x-stack-secret-server-key": self.secret_server_key,
|
||||
# "x-stack-access-token": token,
|
||||
# "Content-Type": "application/json",
|
||||
# }
|
||||
|
||||
# async with aiohttp.ClientSession() as session:
|
||||
# async with session.patch(url, headers=headers, json=data) as response:
|
||||
# return await response.json()
|
||||
|
||||
|
||||
stackauth = StackAuth()
|
||||
5
api/services/campaign/__init__.py
Normal file
5
api/services/campaign/__init__.py
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
"""Campaign service package"""
|
||||
|
||||
from .rate_limiter import rate_limiter
|
||||
|
||||
__all__ = ["rate_limiter"]
|
||||
329
api/services/campaign/call_dispatcher.py
Normal file
329
api/services/campaign/call_dispatcher.py
Normal file
|
|
@ -0,0 +1,329 @@
|
|||
import asyncio
|
||||
import time
|
||||
from datetime import UTC, datetime
|
||||
from typing import Optional
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from api.db import db_client
|
||||
from api.db.models import QueuedRunModel, WorkflowRunModel
|
||||
from api.enums import OrganizationConfigurationKey, WorkflowRunMode
|
||||
from api.services.campaign.rate_limiter import rate_limiter
|
||||
from api.services.telephony.twilio import TwilioService
|
||||
|
||||
|
||||
class CampaignCallDispatcher:
|
||||
"""Manages rate-limited and concurrent-limited call dispatching"""
|
||||
|
||||
def __init__(self):
|
||||
self._twilio_service = None
|
||||
self.default_concurrent_limit = 20
|
||||
|
||||
@property
|
||||
def twilio_service(self):
|
||||
"""Lazy initialization of TwilioService"""
|
||||
if self._twilio_service is None:
|
||||
self._twilio_service = TwilioService()
|
||||
return self._twilio_service
|
||||
|
||||
async def get_org_concurrent_limit(self, organization_id: int) -> int:
|
||||
"""Get the concurrent call limit for an organization."""
|
||||
try:
|
||||
config = await db_client.get_configuration(
|
||||
organization_id,
|
||||
OrganizationConfigurationKey.CONCURRENT_CALL_LIMIT.value,
|
||||
)
|
||||
if config and config.value:
|
||||
return int(config.value["value"])
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Error getting concurrent limit for org {organization_id}: {e}"
|
||||
)
|
||||
return self.default_concurrent_limit
|
||||
|
||||
async def process_batch(self, campaign_id: int, batch_size: int = 10) -> int:
|
||||
"""
|
||||
Processes a batch of queued runs with priority for scheduled retries
|
||||
Returns: number of processed runs
|
||||
"""
|
||||
# Get campaign details
|
||||
campaign = await db_client.get_campaign_by_id(campaign_id)
|
||||
if not campaign:
|
||||
raise ValueError(f"Campaign {campaign_id} not found")
|
||||
|
||||
# Check if campaign is in running state
|
||||
if campaign.state != "running":
|
||||
logger.info(
|
||||
f"Campaign {campaign_id} is not in running state: {campaign.state}"
|
||||
)
|
||||
return 0
|
||||
|
||||
# First, get any scheduled retries that are due
|
||||
scheduled_runs = await db_client.get_scheduled_queued_runs(
|
||||
campaign_id=campaign_id,
|
||||
scheduled_before=datetime.now(UTC),
|
||||
limit=batch_size,
|
||||
)
|
||||
|
||||
remaining_slots = batch_size - len(scheduled_runs)
|
||||
|
||||
# Then get regular queued runs
|
||||
regular_runs = []
|
||||
if remaining_slots > 0:
|
||||
regular_runs = await db_client.get_queued_runs(
|
||||
campaign_id=campaign_id,
|
||||
state="queued",
|
||||
scheduled_for=False, # Exclude scheduled runs
|
||||
limit=remaining_slots,
|
||||
)
|
||||
|
||||
queued_runs = scheduled_runs + regular_runs
|
||||
|
||||
if not queued_runs:
|
||||
logger.info(f"No more queued runs for campaign {campaign_id}")
|
||||
return 0
|
||||
|
||||
processed_count = 0
|
||||
for queued_run in queued_runs:
|
||||
try:
|
||||
# Apply rate limiting
|
||||
await self.apply_rate_limit(
|
||||
campaign.organization_id, campaign.rate_limit_per_second
|
||||
)
|
||||
|
||||
# Dispatch the call
|
||||
workflow_run = await self.dispatch_call(queued_run, campaign)
|
||||
|
||||
# Update queued run as processed
|
||||
await db_client.update_queued_run(
|
||||
queued_run_id=queued_run.id,
|
||||
state="processed",
|
||||
workflow_run_id=workflow_run.id,
|
||||
processed_at=datetime.now(UTC),
|
||||
)
|
||||
|
||||
processed_count += 1
|
||||
|
||||
# Update campaign processed count
|
||||
await db_client.update_campaign(
|
||||
campaign_id=campaign_id, processed_rows=campaign.processed_rows + 1
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error processing queued run {queued_run.id}: {e}")
|
||||
|
||||
# Mark the queued run as failed to prevent infinite retry loops
|
||||
try:
|
||||
await db_client.update_queued_run(
|
||||
queued_run_id=queued_run.id,
|
||||
state="failed",
|
||||
processed_at=datetime.now(UTC),
|
||||
)
|
||||
logger.info(
|
||||
f"Marked queued run {queued_run.id} as failed due to error: {e}"
|
||||
)
|
||||
except Exception as update_error:
|
||||
logger.error(
|
||||
f"Failed to mark queued run {queued_run.id} as failed: {update_error}"
|
||||
)
|
||||
|
||||
return processed_count
|
||||
|
||||
async def dispatch_call(
|
||||
self, queued_run: QueuedRunModel, campaign: any
|
||||
) -> Optional[WorkflowRunModel]:
|
||||
"""Creates workflow run and initiates call with concurrent limiting"""
|
||||
# Get concurrent limit for organization
|
||||
max_concurrent = await self.get_org_concurrent_limit(campaign.organization_id)
|
||||
|
||||
# Track wait time for alerting
|
||||
wait_start = time.time()
|
||||
slot_id = None
|
||||
|
||||
# Wait until we can acquire a concurrent slot
|
||||
while True:
|
||||
slot_id = await rate_limiter.try_acquire_concurrent_slot(
|
||||
campaign.organization_id, max_concurrent
|
||||
)
|
||||
if slot_id:
|
||||
break
|
||||
|
||||
# Check if we've been waiting too long
|
||||
wait_time = time.time() - wait_start
|
||||
if wait_time > 600: # 10 minutes
|
||||
logger.error(
|
||||
f"Waiting for concurrent slot for {wait_time:.1f}s, "
|
||||
f"org: {campaign.organization_id}, campaign: {campaign.id}"
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"Attempting to get a slot for {campaign.organization_id} {campaign.id}"
|
||||
)
|
||||
|
||||
# Wait before retrying
|
||||
await asyncio.sleep(1)
|
||||
|
||||
# Get workflow details
|
||||
workflow = await db_client.get_workflow_by_id(campaign.workflow_id)
|
||||
if not workflow:
|
||||
# Release slot before raising
|
||||
await rate_limiter.release_concurrent_slot(
|
||||
campaign.organization_id, slot_id
|
||||
)
|
||||
raise ValueError(f"Workflow {campaign.workflow_id} not found")
|
||||
|
||||
# Merge context variables (queued_run context already includes retry info if applicable)
|
||||
initial_context = {
|
||||
**workflow.template_context_variables,
|
||||
**queued_run.context_variables,
|
||||
"campaign_id": campaign.id,
|
||||
}
|
||||
|
||||
# Extract phone number
|
||||
phone_number = queued_run.context_variables.get("phone_number")
|
||||
if not phone_number:
|
||||
# Release slot before raising
|
||||
await rate_limiter.release_concurrent_slot(
|
||||
campaign.organization_id, slot_id
|
||||
)
|
||||
raise ValueError(f"No phone number in queued run {queued_run.id}")
|
||||
|
||||
# Create workflow run with queued_run_id tracking
|
||||
workflow_run_name = f"WR-CAMPAIGN-{campaign.id}-{queued_run.id}"
|
||||
|
||||
try:
|
||||
workflow_run = await db_client.create_workflow_run(
|
||||
name=workflow_run_name,
|
||||
workflow_id=campaign.workflow_id,
|
||||
mode=WorkflowRunMode.TWILIO.value,
|
||||
user_id=campaign.created_by,
|
||||
initial_context=initial_context,
|
||||
campaign_id=campaign.id,
|
||||
queued_run_id=queued_run.id, # Link to queued run for retry tracking
|
||||
)
|
||||
|
||||
# Store slot_id mapping in Redis for cleanup later
|
||||
await rate_limiter.store_workflow_slot_mapping(
|
||||
workflow_run.id, campaign.organization_id, slot_id
|
||||
)
|
||||
except Exception as e:
|
||||
# Release slot on error
|
||||
await rate_limiter.release_concurrent_slot(
|
||||
campaign.organization_id, slot_id
|
||||
)
|
||||
raise
|
||||
|
||||
# Add "retry" tag if this is a retry call
|
||||
if queued_run.context_variables.get("is_retry"):
|
||||
retry_reason = queued_run.context_variables.get("retry_reason", "unknown")
|
||||
await db_client.update_workflow_run(
|
||||
run_id=workflow_run.id,
|
||||
gathered_context={
|
||||
"call_tags": ["retry", f"retry_reason_{retry_reason}"]
|
||||
},
|
||||
)
|
||||
|
||||
# Initiate call via Twilio
|
||||
try:
|
||||
call_result = await self.twilio_service.initiate_call(
|
||||
to_number=phone_number,
|
||||
workflow_run_id=workflow_run.id,
|
||||
organization_id=campaign.organization_id,
|
||||
url_args={
|
||||
"workflow_id": campaign.workflow_id,
|
||||
"user_id": campaign.created_by,
|
||||
"workflow_run_id": workflow_run.id,
|
||||
"campaign_id": campaign.id,
|
||||
},
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Call initiated for workflow run {workflow_run.id}, SID: {call_result.get('sid')}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to initiate call for workflow run {workflow_run.id}: {e}"
|
||||
)
|
||||
|
||||
# Update workflow run as failed
|
||||
twilio_callback_logs = workflow_run.logs.get("twilio_status_callbacks", [])
|
||||
twilio_callback_log = {
|
||||
"status": "failed",
|
||||
"timestamp": datetime.now(UTC).isoformat(),
|
||||
"data": {"error": str(e)},
|
||||
}
|
||||
twilio_callback_logs.append(twilio_callback_log)
|
||||
await db_client.update_workflow_run(
|
||||
run_id=workflow_run.id,
|
||||
is_completed=True,
|
||||
gathered_context={
|
||||
"error": str(e),
|
||||
},
|
||||
logs={
|
||||
"twilio_status_callbacks": twilio_callback_logs,
|
||||
},
|
||||
)
|
||||
|
||||
# Release concurrent slot on failure
|
||||
mapping = await rate_limiter.get_workflow_slot_mapping(workflow_run.id)
|
||||
if mapping:
|
||||
org_id, slot_id = mapping
|
||||
await rate_limiter.release_concurrent_slot(org_id, slot_id)
|
||||
await rate_limiter.delete_workflow_slot_mapping(workflow_run.id)
|
||||
|
||||
raise
|
||||
|
||||
return workflow_run
|
||||
|
||||
async def apply_rate_limit(self, organization_id: int, rate_limit: int) -> None:
|
||||
"""
|
||||
Enforces rate limiting - waits if necessary to comply with rate limit
|
||||
|
||||
Example usage:
|
||||
```
|
||||
# This will wait up to 1 second if needed to respect rate limit
|
||||
await self.apply_rate_limit(org_id, 1) # 1 call per second
|
||||
await twilio.initiate_call(...) # Now safe to call
|
||||
```
|
||||
"""
|
||||
max_wait = 1.0 # Maximum time to wait for a slot
|
||||
start_time = time.time()
|
||||
|
||||
while True:
|
||||
# Try to acquire token
|
||||
if await rate_limiter.acquire_token(organization_id, rate_limit):
|
||||
return # Got permission to proceed
|
||||
|
||||
# Check how long to wait
|
||||
wait_time = await rate_limiter.get_next_available_slot(
|
||||
organization_id, rate_limit
|
||||
)
|
||||
|
||||
# Don't wait forever
|
||||
if time.time() - start_time + wait_time > max_wait:
|
||||
raise TimeoutError("Rate limit timeout - try again later")
|
||||
|
||||
# Wait for next available slot
|
||||
await asyncio.sleep(wait_time)
|
||||
|
||||
async def release_call_slot(self, workflow_run_id: int) -> bool:
|
||||
"""
|
||||
Release concurrent slot when a call completes.
|
||||
Called by Twilio webhooks or workflow completion handlers.
|
||||
"""
|
||||
mapping = await rate_limiter.get_workflow_slot_mapping(workflow_run_id)
|
||||
if mapping:
|
||||
org_id, slot_id = mapping
|
||||
success = await rate_limiter.release_concurrent_slot(org_id, slot_id)
|
||||
if success:
|
||||
await rate_limiter.delete_workflow_slot_mapping(workflow_run_id)
|
||||
logger.info(
|
||||
f"Released concurrent slot for workflow run {workflow_run_id}"
|
||||
)
|
||||
return success
|
||||
return False
|
||||
|
||||
|
||||
# Global instance
|
||||
campaign_call_dispatcher = CampaignCallDispatcher()
|
||||
258
api/services/campaign/campaign_event_protocol.py
Normal file
258
api/services/campaign/campaign_event_protocol.py
Normal file
|
|
@ -0,0 +1,258 @@
|
|||
"""Campaign event protocol for orchestrator communication.
|
||||
|
||||
Defines message formats and helpers for campaign event publishing and handling.
|
||||
"""
|
||||
|
||||
import json
|
||||
from dataclasses import asdict, dataclass
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
|
||||
class CampaignEventType(str, Enum):
|
||||
"""Types of campaign events."""
|
||||
|
||||
# Batch processing events
|
||||
BATCH_COMPLETED = "batch_completed"
|
||||
BATCH_FAILED = "batch_failed"
|
||||
|
||||
# Sync events
|
||||
SYNC_STARTED = "sync_started"
|
||||
SYNC_COMPLETED = "sync_completed"
|
||||
SYNC_FAILED = "sync_failed"
|
||||
|
||||
# Campaign lifecycle events
|
||||
CAMPAIGN_STARTED = "campaign_started"
|
||||
CAMPAIGN_PAUSED = "campaign_paused"
|
||||
CAMPAIGN_RESUMED = "campaign_resumed"
|
||||
CAMPAIGN_COMPLETED = "campaign_completed"
|
||||
CAMPAIGN_FAILED = "campaign_failed"
|
||||
|
||||
# Retry events
|
||||
RETRY_NEEDED = "retry_needed"
|
||||
RETRY_SCHEDULED = "retry_scheduled"
|
||||
RETRY_FAILED = "retry_failed"
|
||||
|
||||
|
||||
class RetryReason(str, Enum):
|
||||
"""Reasons for retry."""
|
||||
|
||||
BUSY = "busy"
|
||||
NO_ANSWER = "no_answer"
|
||||
VOICEMAIL = "voicemail"
|
||||
FAILED = "failed"
|
||||
ERROR = "error"
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseCampaignEvent:
|
||||
"""Base class for all campaign events."""
|
||||
|
||||
type: str
|
||||
campaign_id: int = 0
|
||||
timestamp: Optional[str] = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.timestamp is None:
|
||||
from datetime import UTC, datetime
|
||||
|
||||
self.timestamp = datetime.now(UTC).isoformat()
|
||||
|
||||
def to_json(self) -> str:
|
||||
return json.dumps(asdict(self))
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, data: str):
|
||||
return cls(**json.loads(data))
|
||||
|
||||
|
||||
@dataclass
|
||||
class BatchCompletedEvent(BaseCampaignEvent):
|
||||
"""Event sent when a batch processing completes."""
|
||||
|
||||
type: str = CampaignEventType.BATCH_COMPLETED
|
||||
processed_count: int = 0
|
||||
failed_count: int = 0
|
||||
batch_size: int = 0
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
if self.metadata is None:
|
||||
self.metadata = {}
|
||||
|
||||
|
||||
@dataclass
|
||||
class BatchFailedEvent(BaseCampaignEvent):
|
||||
"""Event sent when a batch processing fails."""
|
||||
|
||||
type: str = CampaignEventType.BATCH_FAILED
|
||||
error: str = ""
|
||||
processed_count: int = 0
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
if self.metadata is None:
|
||||
self.metadata = {}
|
||||
|
||||
|
||||
@dataclass
|
||||
class SyncStartedEvent(BaseCampaignEvent):
|
||||
"""Event sent when campaign source sync starts."""
|
||||
|
||||
type: str = CampaignEventType.SYNC_STARTED
|
||||
source_type: str = ""
|
||||
source_id: str = ""
|
||||
|
||||
|
||||
@dataclass
|
||||
class SyncCompletedEvent(BaseCampaignEvent):
|
||||
"""Event sent when campaign source sync completes."""
|
||||
|
||||
type: str = CampaignEventType.SYNC_COMPLETED
|
||||
total_rows: int = 0
|
||||
source_type: str = ""
|
||||
source_id: str = ""
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
if self.metadata is None:
|
||||
self.metadata = {}
|
||||
|
||||
|
||||
@dataclass
|
||||
class SyncFailedEvent(BaseCampaignEvent):
|
||||
"""Event sent when campaign source sync fails."""
|
||||
|
||||
type: str = CampaignEventType.SYNC_FAILED
|
||||
error: str = ""
|
||||
source_type: str = ""
|
||||
source_id: str = ""
|
||||
|
||||
|
||||
@dataclass
|
||||
class CampaignStartedEvent(BaseCampaignEvent):
|
||||
"""Event sent when a campaign starts."""
|
||||
|
||||
type: str = CampaignEventType.CAMPAIGN_STARTED
|
||||
workflow_id: int = 0
|
||||
total_rows: Optional[int] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class CampaignPausedEvent(BaseCampaignEvent):
|
||||
"""Event sent when a campaign is paused."""
|
||||
|
||||
type: str = CampaignEventType.CAMPAIGN_PAUSED
|
||||
processed_rows: int = 0
|
||||
failed_rows: int = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class CampaignResumedEvent(BaseCampaignEvent):
|
||||
"""Event sent when a campaign is resumed."""
|
||||
|
||||
type: str = CampaignEventType.CAMPAIGN_RESUMED
|
||||
processed_rows: int = 0
|
||||
failed_rows: int = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class CampaignCompletedEvent(BaseCampaignEvent):
|
||||
"""Event sent when a campaign completes."""
|
||||
|
||||
type: str = CampaignEventType.CAMPAIGN_COMPLETED
|
||||
total_rows: int = 0
|
||||
processed_rows: int = 0
|
||||
failed_rows: int = 0
|
||||
duration_seconds: Optional[float] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class CampaignFailedEvent(BaseCampaignEvent):
|
||||
"""Event sent when a campaign fails."""
|
||||
|
||||
type: str = CampaignEventType.CAMPAIGN_FAILED
|
||||
error: str = ""
|
||||
processed_rows: int = 0
|
||||
failed_rows: int = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class RetryNeededEvent(BaseCampaignEvent):
|
||||
"""Event sent when a call needs retry."""
|
||||
|
||||
type: str = CampaignEventType.RETRY_NEEDED
|
||||
workflow_run_id: int = 0
|
||||
queued_run_id: int = 0
|
||||
reason: str = "" # RetryReason value
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
if self.metadata is None:
|
||||
self.metadata = {}
|
||||
|
||||
|
||||
@dataclass
|
||||
class RetryScheduledEvent(BaseCampaignEvent):
|
||||
"""Event sent when a retry is scheduled."""
|
||||
|
||||
type: str = CampaignEventType.RETRY_SCHEDULED
|
||||
queued_run_id: int = 0
|
||||
retry_run_id: int = 0
|
||||
retry_count: int = 0
|
||||
scheduled_for: str = "" # ISO timestamp
|
||||
reason: str = "" # RetryReason value
|
||||
|
||||
|
||||
@dataclass
|
||||
class RetryFailedEvent(BaseCampaignEvent):
|
||||
"""Event sent when max retries reached."""
|
||||
|
||||
type: str = CampaignEventType.RETRY_FAILED
|
||||
queued_run_id: int = 0
|
||||
retry_count: int = 0
|
||||
last_reason: str = "" # RetryReason value
|
||||
|
||||
|
||||
def parse_campaign_event(data: str) -> Any:
|
||||
"""Parse a campaign event message."""
|
||||
try:
|
||||
parsed = json.loads(data)
|
||||
event_type = parsed.get("type")
|
||||
|
||||
# Map event types to their classes
|
||||
event_class_map = {
|
||||
CampaignEventType.BATCH_COMPLETED: BatchCompletedEvent,
|
||||
CampaignEventType.BATCH_FAILED: BatchFailedEvent,
|
||||
CampaignEventType.SYNC_STARTED: SyncStartedEvent,
|
||||
CampaignEventType.SYNC_COMPLETED: SyncCompletedEvent,
|
||||
CampaignEventType.SYNC_FAILED: SyncFailedEvent,
|
||||
CampaignEventType.CAMPAIGN_STARTED: CampaignStartedEvent,
|
||||
CampaignEventType.CAMPAIGN_PAUSED: CampaignPausedEvent,
|
||||
CampaignEventType.CAMPAIGN_RESUMED: CampaignResumedEvent,
|
||||
CampaignEventType.CAMPAIGN_COMPLETED: CampaignCompletedEvent,
|
||||
CampaignEventType.CAMPAIGN_FAILED: CampaignFailedEvent,
|
||||
CampaignEventType.RETRY_NEEDED: RetryNeededEvent,
|
||||
CampaignEventType.RETRY_SCHEDULED: RetryScheduledEvent,
|
||||
CampaignEventType.RETRY_FAILED: RetryFailedEvent,
|
||||
}
|
||||
|
||||
event_class = event_class_map.get(event_type)
|
||||
if event_class:
|
||||
return event_class(**parsed)
|
||||
|
||||
# Unknown event type
|
||||
from loguru import logger
|
||||
|
||||
logger.warning(f"Unknown campaign event type: {event_type}")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
from loguru import logger
|
||||
|
||||
logger.error(f"Failed to parse campaign event: {e}, data: {data}")
|
||||
return None
|
||||
121
api/services/campaign/campaign_event_publisher.py
Normal file
121
api/services/campaign/campaign_event_publisher.py
Normal file
|
|
@ -0,0 +1,121 @@
|
|||
"""Campaign event publisher for orchestrator communication.
|
||||
|
||||
Handles publishing of campaign events to Redis pub/sub channels.
|
||||
"""
|
||||
|
||||
from typing import Dict, Optional
|
||||
|
||||
import redis.asyncio as aioredis
|
||||
from loguru import logger
|
||||
|
||||
from api.constants import REDIS_URL
|
||||
from api.enums import RedisChannel
|
||||
from api.services.campaign.campaign_event_protocol import (
|
||||
BatchCompletedEvent,
|
||||
CampaignCompletedEvent,
|
||||
RetryNeededEvent,
|
||||
SyncCompletedEvent,
|
||||
)
|
||||
|
||||
|
||||
class CampaignEventPublisher:
|
||||
"""Helper class for publishing campaign events."""
|
||||
|
||||
def __init__(self, redis_client):
|
||||
self.redis = redis_client
|
||||
|
||||
async def publish_batch_completed(
|
||||
self,
|
||||
campaign_id: int,
|
||||
processed_count: int,
|
||||
failed_count: int = 0,
|
||||
batch_size: int = 0,
|
||||
metadata: Optional[Dict] = None,
|
||||
):
|
||||
"""Publish batch completed event."""
|
||||
event = BatchCompletedEvent(
|
||||
campaign_id=campaign_id,
|
||||
processed_count=processed_count,
|
||||
failed_count=failed_count,
|
||||
batch_size=batch_size,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
await self.redis.publish(RedisChannel.CAMPAIGN_EVENTS.value, event.to_json())
|
||||
|
||||
async def publish_sync_completed(
|
||||
self,
|
||||
campaign_id: int,
|
||||
total_rows: int,
|
||||
source_type: str = "",
|
||||
source_id: str = "",
|
||||
metadata: Optional[Dict] = None,
|
||||
):
|
||||
"""Publish sync completed event."""
|
||||
event = SyncCompletedEvent(
|
||||
campaign_id=campaign_id,
|
||||
total_rows=total_rows,
|
||||
source_type=source_type,
|
||||
source_id=source_id,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
await self.redis.publish(RedisChannel.CAMPAIGN_EVENTS.value, event.to_json())
|
||||
|
||||
async def publish_retry_needed(
|
||||
self,
|
||||
workflow_run_id: int,
|
||||
reason: str,
|
||||
campaign_id: Optional[int] = None,
|
||||
queued_run_id: Optional[int] = None,
|
||||
metadata: Optional[Dict] = None,
|
||||
):
|
||||
"""Publish retry needed event."""
|
||||
event = RetryNeededEvent(
|
||||
campaign_id=campaign_id or 0,
|
||||
workflow_run_id=workflow_run_id,
|
||||
queued_run_id=queued_run_id or 0,
|
||||
reason=reason,
|
||||
metadata=metadata or {},
|
||||
)
|
||||
|
||||
await self.redis.publish(RedisChannel.CAMPAIGN_EVENTS.value, event.to_json())
|
||||
|
||||
logger.info(
|
||||
f"Published retry event for workflow_run {workflow_run_id}, "
|
||||
f"reason: {reason}, campaign: {campaign_id}"
|
||||
)
|
||||
|
||||
async def publish_campaign_completed(
|
||||
self,
|
||||
campaign_id: int,
|
||||
total_rows: int,
|
||||
processed_rows: int,
|
||||
failed_rows: int,
|
||||
duration_seconds: Optional[float] = None,
|
||||
):
|
||||
"""Publish campaign completed event."""
|
||||
event = CampaignCompletedEvent(
|
||||
campaign_id=campaign_id,
|
||||
total_rows=total_rows,
|
||||
processed_rows=processed_rows,
|
||||
failed_rows=failed_rows,
|
||||
duration_seconds=duration_seconds,
|
||||
)
|
||||
|
||||
await self.redis.publish(RedisChannel.CAMPAIGN_EVENTS.value, event.to_json())
|
||||
|
||||
|
||||
# Global publisher instance with lazy Redis connection
|
||||
async def get_campaign_event_publisher() -> CampaignEventPublisher:
|
||||
"""Get or create the campaign event publisher."""
|
||||
global _campaign_publisher
|
||||
global _campaign_redis_client
|
||||
|
||||
if "_campaign_publisher" not in globals():
|
||||
_campaign_redis_client = await aioredis.from_url(
|
||||
REDIS_URL, decode_responses=True
|
||||
)
|
||||
_campaign_publisher = CampaignEventPublisher(_campaign_redis_client)
|
||||
|
||||
return _campaign_publisher
|
||||
563
api/services/campaign/campaign_orchestrator.py
Normal file
563
api/services/campaign/campaign_orchestrator.py
Normal file
|
|
@ -0,0 +1,563 @@
|
|||
"""Campaign Orchestrator Service.
|
||||
|
||||
This service ensures continuous campaign processing by listening to events
|
||||
and scheduling batches immediately upon completion. It also monitors campaigns
|
||||
for final completion after 1 hour of inactivity and handles retry events.
|
||||
"""
|
||||
|
||||
from api.logging_config import setup_logging
|
||||
|
||||
logging_queue_listener = setup_logging()
|
||||
|
||||
|
||||
import asyncio
|
||||
import signal
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import Dict
|
||||
|
||||
import redis.asyncio as aioredis
|
||||
from loguru import logger
|
||||
|
||||
from api.constants import REDIS_URL
|
||||
from api.db import db_client
|
||||
from api.db.models import CampaignModel, QueuedRunModel
|
||||
from api.enums import RedisChannel
|
||||
from api.services.campaign.campaign_event_protocol import (
|
||||
CampaignCompletedEvent,
|
||||
CampaignEventType,
|
||||
RetryNeededEvent,
|
||||
parse_campaign_event,
|
||||
)
|
||||
from api.tasks.arq import enqueue_job
|
||||
from api.tasks.function_names import FunctionNames
|
||||
|
||||
|
||||
class CampaignOrchestrator:
|
||||
"""Orchestrates campaign processing, retry handling, and completion detection."""
|
||||
|
||||
def __init__(self, redis_client: aioredis.Redis):
|
||||
self.redis = redis_client
|
||||
self.completion_check_interval = 60 # 1 minute
|
||||
self.completion_timeout = 3600 # 1 hour
|
||||
self._processing_locks: Dict[int, datetime] = {} # prevent duplicate scheduling
|
||||
self._last_activity: Dict[
|
||||
int, datetime
|
||||
] = {} # track last activity per campaign
|
||||
self._batch_in_progress: Dict[
|
||||
int, datetime
|
||||
] = {} # track batches that have been scheduled but not completed
|
||||
self._running = False
|
||||
self._pubsub = None
|
||||
|
||||
async def run(self):
|
||||
"""Main service with two concurrent tasks."""
|
||||
self._running = True
|
||||
logger.info("Campaign Orchestrator starting...")
|
||||
|
||||
try:
|
||||
# Task 1: Listen for events and react immediately
|
||||
event_task = asyncio.create_task(self._listen_for_events())
|
||||
|
||||
# Task 2: Periodically check for stale campaigns
|
||||
completion_task = asyncio.create_task(self._monitor_completion())
|
||||
|
||||
# Wait for both tasks
|
||||
await asyncio.gather(event_task, completion_task)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.info("Campaign Orchestrator cancelled")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Campaign Orchestrator error: {e}")
|
||||
raise
|
||||
finally:
|
||||
await self.shutdown()
|
||||
|
||||
async def _listen_for_events(self):
|
||||
"""Listen for campaign events and react immediately."""
|
||||
self._pubsub = self.redis.pubsub()
|
||||
await self._pubsub.subscribe(RedisChannel.CAMPAIGN_EVENTS.value)
|
||||
logger.info(f"Subscribed to {RedisChannel.CAMPAIGN_EVENTS.value} channel")
|
||||
|
||||
async for message in self._pubsub.listen():
|
||||
if not self._running:
|
||||
break
|
||||
|
||||
if message["type"] == "message":
|
||||
try:
|
||||
event = parse_campaign_event(message["data"])
|
||||
if event:
|
||||
await self._handle_event(event)
|
||||
else:
|
||||
logger.error(
|
||||
f"Failed to parse campaign event: {message['data']}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error handling campaign event: {e}")
|
||||
|
||||
async def _handle_event(self, event):
|
||||
"""Handle campaign events including retry events."""
|
||||
# Handle RetryNeededEvent
|
||||
if isinstance(event, RetryNeededEvent):
|
||||
await self._handle_retry_event(event)
|
||||
return
|
||||
|
||||
# All events should have campaign_id
|
||||
if not hasattr(event, "campaign_id") or not event.campaign_id:
|
||||
logger.warning(f"Event missing campaign_id: {type(event).__name__}")
|
||||
return
|
||||
|
||||
campaign_id = event.campaign_id
|
||||
event_type = event.type
|
||||
|
||||
logger.debug(f"campaign_id: {campaign_id} - Received event: {event_type}")
|
||||
|
||||
if event_type == CampaignEventType.BATCH_COMPLETED:
|
||||
# Clear the batch in progress flag
|
||||
if campaign_id in self._batch_in_progress:
|
||||
del self._batch_in_progress[campaign_id]
|
||||
logger.debug(
|
||||
f"campaign_id: {campaign_id} - Batch completed, cleared in-progress flag"
|
||||
)
|
||||
|
||||
# Immediately schedule next batch
|
||||
await self._schedule_next_batch(campaign_id)
|
||||
self._last_activity[campaign_id] = datetime.now(UTC)
|
||||
|
||||
elif event_type == CampaignEventType.SYNC_COMPLETED:
|
||||
# Start processing after sync
|
||||
logger.info(
|
||||
f"campaign_id: {campaign_id} - Sync completed, starting processing"
|
||||
)
|
||||
await self._schedule_next_batch(campaign_id)
|
||||
self._last_activity[campaign_id] = datetime.now(UTC)
|
||||
|
||||
async def _handle_retry_event(self, event: RetryNeededEvent):
|
||||
"""Process retry event and schedule if eligible (from campaign_retry_manager)."""
|
||||
|
||||
# Check retry eligibility
|
||||
campaign_id = event.campaign_id
|
||||
if not campaign_id:
|
||||
logger.debug("Skipping non-campaign retry event")
|
||||
return
|
||||
|
||||
# Get campaign configuration
|
||||
campaign = await db_client.get_campaign_by_id(campaign_id)
|
||||
if not campaign:
|
||||
logger.error(f"campaign_id: {campaign_id} - Campaign not found")
|
||||
return
|
||||
|
||||
retry_config = campaign.retry_config or {}
|
||||
if not retry_config.get("enabled", True):
|
||||
logger.info(f"campaign_id: {campaign_id} - Retry disabled")
|
||||
return
|
||||
|
||||
# Check if this reason should be retried
|
||||
reason = event.reason
|
||||
if reason == "busy" and not retry_config.get("retry_on_busy", True):
|
||||
logger.info(f"campaign_id: {campaign_id} - Skipping retry for busy signal")
|
||||
return
|
||||
if reason == "no_answer" and not retry_config.get("retry_on_no_answer", True):
|
||||
logger.info(f"campaign_id: {campaign_id} - Skipping retry for no-answer")
|
||||
return
|
||||
if reason == "voicemail" and not retry_config.get("retry_on_voicemail", True):
|
||||
logger.info(f"campaign_id: {campaign_id} - Skipping retry for voicemail")
|
||||
return
|
||||
|
||||
# Get the original queued run
|
||||
queued_run = await db_client.get_queued_run_by_id(event.queued_run_id)
|
||||
if not queued_run:
|
||||
logger.error(
|
||||
f"campaign_id: {campaign_id} - Queued run {event.queued_run_id} not found"
|
||||
)
|
||||
return
|
||||
|
||||
max_retries = retry_config.get("max_retries", 1)
|
||||
|
||||
if queued_run.retry_count >= max_retries:
|
||||
await self._mark_final_failure(queued_run, reason)
|
||||
logger.info(
|
||||
f"campaign_id: {campaign_id} - Max retries ({max_retries}) reached for queued run {queued_run.id}"
|
||||
)
|
||||
return
|
||||
|
||||
# Create scheduled retry entry
|
||||
retry_delay = retry_config.get("retry_delay_seconds", 120)
|
||||
await self._schedule_retry(queued_run, reason, retry_delay)
|
||||
|
||||
# Update last activity
|
||||
self._last_activity[campaign_id] = datetime.now(UTC)
|
||||
|
||||
async def _schedule_retry(
|
||||
self, original_run: QueuedRunModel, reason: str, delay_seconds: int
|
||||
):
|
||||
"""Create a new queued run for retry."""
|
||||
|
||||
campaign_id = original_run.campaign_id
|
||||
|
||||
# Create retry context
|
||||
retry_context = {
|
||||
**original_run.context_variables,
|
||||
"is_retry": True,
|
||||
"retry_attempt": original_run.retry_count + 1,
|
||||
"retry_reason": reason,
|
||||
}
|
||||
|
||||
logger.debug(
|
||||
f"campaign_id: {campaign_id} - Scheduling retry for {reason} in {delay_seconds}s, "
|
||||
f"retry attempt {original_run.retry_count + 1}"
|
||||
)
|
||||
|
||||
# Create retry entry with unique source_uuid
|
||||
retry_run = await db_client.create_queued_run(
|
||||
campaign_id=campaign_id,
|
||||
source_uuid=f"{original_run.source_uuid}_retry_{original_run.retry_count + 1}",
|
||||
context_variables=retry_context,
|
||||
state="queued",
|
||||
retry_count=original_run.retry_count + 1,
|
||||
parent_queued_run_id=original_run.id,
|
||||
scheduled_for=datetime.now(UTC) + timedelta(seconds=delay_seconds),
|
||||
retry_reason=reason,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"campaign_id: {campaign_id} - Scheduled retry {retry_run.id} for {reason} in {delay_seconds}s, "
|
||||
f"retry attempt {retry_run.retry_count}"
|
||||
)
|
||||
|
||||
async def _mark_final_failure(self, queued_run: QueuedRunModel, reason: str):
|
||||
"""Mark a queued run as finally failed after max retries."""
|
||||
campaign_id = queued_run.campaign_id
|
||||
|
||||
# Update the campaign's failed_rows counter
|
||||
campaign = await db_client.get_campaign_by_id(campaign_id)
|
||||
if campaign:
|
||||
await db_client.update_campaign(
|
||||
campaign_id=campaign_id, failed_rows=campaign.failed_rows + 1
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"campaign_id: {campaign_id} - Queued run {queued_run.id} finally failed after max retries, "
|
||||
f"last reason: {reason}"
|
||||
)
|
||||
|
||||
async def _schedule_next_batch(self, campaign_id: int):
|
||||
"""Schedule next batch immediately if work available."""
|
||||
|
||||
# Prevent duplicate scheduling with in-memory lock
|
||||
if campaign_id in self._processing_locks:
|
||||
lock_time = self._processing_locks[campaign_id]
|
||||
if (datetime.now(UTC) - lock_time).total_seconds() < 5:
|
||||
logger.debug(
|
||||
f"campaign_id: {campaign_id} - Batch already scheduled recently"
|
||||
)
|
||||
return
|
||||
|
||||
# Set lock
|
||||
self._processing_locks[campaign_id] = datetime.now(UTC)
|
||||
|
||||
try:
|
||||
# Check campaign status
|
||||
campaign = await db_client.get_campaign_by_id(campaign_id)
|
||||
if not campaign:
|
||||
logger.error(f"campaign_id: {campaign_id} - Campaign not found")
|
||||
return
|
||||
|
||||
if campaign.state not in ["running", "syncing"]:
|
||||
logger.info(
|
||||
f"campaign_id: {campaign_id} - Campaign not in running state: {campaign.state}"
|
||||
)
|
||||
return
|
||||
|
||||
# Check for available work (queued runs + due retries)
|
||||
has_work = await self._has_pending_work(campaign_id)
|
||||
|
||||
if has_work:
|
||||
# Schedule batch immediately
|
||||
await enqueue_job(
|
||||
FunctionNames.PROCESS_CAMPAIGN_BATCH,
|
||||
campaign_id,
|
||||
10, # batch_size
|
||||
)
|
||||
logger.info(f"campaign_id: {campaign_id} - Scheduled next batch")
|
||||
|
||||
# Set batch in progress flag
|
||||
self._batch_in_progress[campaign_id] = datetime.now(UTC)
|
||||
|
||||
# Update database
|
||||
await db_client.update_campaign(
|
||||
campaign_id=campaign_id,
|
||||
last_batch_scheduled_at=datetime.now(UTC),
|
||||
last_activity_at=datetime.now(UTC),
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
f"campaign_id: {campaign_id} - No pending work to process, "
|
||||
f"campaign may complete or wait for retries"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"campaign_id: {campaign_id} - Error scheduling batch: {e}")
|
||||
finally:
|
||||
# Release lock after a short delay
|
||||
asyncio.create_task(self._release_lock_after_delay(campaign_id, 5))
|
||||
|
||||
async def _release_lock_after_delay(self, campaign_id: int, delay: int):
|
||||
"""Release processing lock after delay."""
|
||||
await asyncio.sleep(delay)
|
||||
if campaign_id in self._processing_locks:
|
||||
del self._processing_locks[campaign_id]
|
||||
logger.debug(f"campaign_id: {campaign_id} - Released processing lock")
|
||||
|
||||
async def _monitor_completion(self):
|
||||
"""Periodically check for campaigns that should be marked complete."""
|
||||
while self._running:
|
||||
try:
|
||||
await self._check_stale_campaigns()
|
||||
except Exception as e:
|
||||
logger.error(f"Completion monitoring failed: {e}")
|
||||
|
||||
await asyncio.sleep(self.completion_check_interval)
|
||||
|
||||
async def _check_stale_campaigns(self):
|
||||
"""Check all running campaigns for completion or orphaned work."""
|
||||
logger.debug("Checking for stale campaigns...")
|
||||
|
||||
campaigns = await db_client.get_campaigns_by_status(statuses=["running"])
|
||||
|
||||
for campaign in campaigns:
|
||||
try:
|
||||
campaign_id = campaign.id
|
||||
|
||||
# Check if batch is stuck (initiated > 5 minutes ago but no completion)
|
||||
if campaign_id in self._batch_in_progress:
|
||||
batch_start_time = self._batch_in_progress[campaign_id]
|
||||
time_since_batch_start = (
|
||||
datetime.now(UTC) - batch_start_time
|
||||
).total_seconds()
|
||||
|
||||
if time_since_batch_start > 300: # 5 minutes
|
||||
logger.warning(
|
||||
f"campaign_id: {campaign_id} - Batch stuck for {time_since_batch_start:.0f}s, "
|
||||
f"clearing flag and checking for more work"
|
||||
)
|
||||
del self._batch_in_progress[campaign_id]
|
||||
|
||||
# Check if there's work to be done
|
||||
if await self._has_pending_work(campaign_id):
|
||||
logger.info(
|
||||
f"campaign_id: {campaign_id} - Found pending work after stuck batch, "
|
||||
f"scheduling new batch"
|
||||
)
|
||||
await self._schedule_next_batch(campaign_id)
|
||||
continue
|
||||
|
||||
# Check for orphaned work (e.g., newly created retries with no batch in progress)
|
||||
if campaign_id not in self._batch_in_progress:
|
||||
has_work = await self._has_pending_work(campaign_id)
|
||||
if has_work:
|
||||
logger.info(
|
||||
f"campaign_id: {campaign_id} - Found orphaned work (likely new retries), "
|
||||
f"scheduling batch to process"
|
||||
)
|
||||
await self._schedule_next_batch(campaign_id)
|
||||
continue
|
||||
|
||||
# Check if campaign should be marked complete
|
||||
if await self._should_mark_complete(campaign):
|
||||
await self._complete_campaign(campaign)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"campaign_id: {campaign.id} - Completion check failed: {e}"
|
||||
)
|
||||
|
||||
async def _should_mark_complete(self, campaign: CampaignModel) -> bool:
|
||||
"""Check if campaign has no activity for 1 hour."""
|
||||
campaign_id = campaign.id
|
||||
|
||||
# Don't mark complete if batch is in progress
|
||||
if campaign_id in self._batch_in_progress:
|
||||
logger.debug(
|
||||
f"campaign_id: {campaign_id} - Batch in progress, not marking complete"
|
||||
)
|
||||
return False
|
||||
|
||||
# Check for any pending work
|
||||
has_work = await self._has_pending_work(campaign_id)
|
||||
if has_work:
|
||||
return False
|
||||
|
||||
# Check in-memory last activity
|
||||
last_activity = self._last_activity.get(campaign_id)
|
||||
if not last_activity:
|
||||
# Fall back to database
|
||||
last_activity = campaign.last_activity_at
|
||||
|
||||
if not last_activity:
|
||||
# No activity recorded, use last batch scheduled time
|
||||
last_activity = campaign.last_batch_scheduled_at
|
||||
|
||||
if not last_activity:
|
||||
# No activity at all, use started_at
|
||||
last_activity = campaign.started_at
|
||||
|
||||
if last_activity:
|
||||
time_since = datetime.now(UTC) - last_activity
|
||||
if time_since.total_seconds() < self.completion_timeout:
|
||||
return False
|
||||
|
||||
logger.info(
|
||||
f"campaign_id: {campaign_id} - No activity for {self.completion_timeout}s, "
|
||||
f"marking complete"
|
||||
)
|
||||
return True
|
||||
|
||||
async def _has_pending_work(self, campaign_id: int) -> bool:
|
||||
"""Check if campaign has any work to do."""
|
||||
# Check queued runs
|
||||
queued_count = await db_client.get_queued_runs_count(
|
||||
campaign_id=campaign_id, states=["queued"]
|
||||
)
|
||||
|
||||
if queued_count > 0:
|
||||
logger.debug(f"campaign_id: {campaign_id} - Has {queued_count} queued runs")
|
||||
return True
|
||||
|
||||
# Check scheduled retries that are due
|
||||
scheduled_count = await db_client.get_scheduled_runs_count(
|
||||
campaign_id=campaign_id, scheduled_before=datetime.now(UTC)
|
||||
)
|
||||
|
||||
if scheduled_count > 0:
|
||||
logger.debug(
|
||||
f"campaign_id: {campaign_id} - Has {scheduled_count} scheduled retries due"
|
||||
)
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
async def _complete_campaign(self, campaign: CampaignModel):
|
||||
"""Mark campaign as complete."""
|
||||
campaign_id = campaign.id
|
||||
|
||||
try:
|
||||
# Double-check no pending work
|
||||
if await self._has_pending_work(campaign_id):
|
||||
logger.info(
|
||||
f"campaign_id: {campaign_id} - Found pending work, not completing"
|
||||
)
|
||||
return
|
||||
|
||||
# Update campaign status
|
||||
await db_client.update_campaign(
|
||||
campaign_id=campaign_id,
|
||||
state="completed",
|
||||
completed_at=datetime.now(UTC),
|
||||
)
|
||||
|
||||
logger.info(f"campaign_id: {campaign_id} - Campaign marked as completed")
|
||||
|
||||
# Publish completion event using typed event
|
||||
completion_event = CampaignCompletedEvent(
|
||||
campaign_id=campaign_id,
|
||||
total_rows=campaign.total_rows or 0,
|
||||
processed_rows=campaign.processed_rows,
|
||||
failed_rows=campaign.failed_rows,
|
||||
)
|
||||
|
||||
# Calculate duration if started_at is available
|
||||
if campaign.started_at:
|
||||
duration = (datetime.now(UTC) - campaign.started_at).total_seconds()
|
||||
completion_event.duration_seconds = duration
|
||||
|
||||
await self.redis.publish(
|
||||
RedisChannel.CAMPAIGN_EVENTS.value, completion_event.to_json()
|
||||
)
|
||||
|
||||
# Clean up in-memory state
|
||||
if campaign_id in self._last_activity:
|
||||
del self._last_activity[campaign_id]
|
||||
if campaign_id in self._processing_locks:
|
||||
del self._processing_locks[campaign_id]
|
||||
if campaign_id in self._batch_in_progress:
|
||||
del self._batch_in_progress[campaign_id]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"campaign_id: {campaign_id} - Failed to complete campaign: {e}"
|
||||
)
|
||||
|
||||
async def shutdown(self):
|
||||
"""Clean shutdown of the orchestrator."""
|
||||
logger.info("Campaign Orchestrator shutting down...")
|
||||
self._running = False
|
||||
|
||||
if self._pubsub:
|
||||
try:
|
||||
await self._pubsub.unsubscribe(RedisChannel.CAMPAIGN_EVENTS.value)
|
||||
await self._pubsub.close()
|
||||
except Exception as e:
|
||||
logger.error(f"Error closing pubsub: {e}")
|
||||
|
||||
logger.info("Campaign Orchestrator shutdown complete")
|
||||
|
||||
|
||||
async def main():
|
||||
"""Main entry point for Campaign Orchestrator service."""
|
||||
|
||||
# Setup Redis connection
|
||||
redis = await aioredis.from_url(REDIS_URL, decode_responses=True)
|
||||
|
||||
# Create and run orchestrator
|
||||
orchestrator = CampaignOrchestrator(redis)
|
||||
|
||||
# Create a shutdown event for clean coordination
|
||||
shutdown_event = asyncio.Event()
|
||||
|
||||
# Setup signal handlers
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
def signal_handler(signum):
|
||||
logger.info(f"Received shutdown signal {signum}")
|
||||
shutdown_event.set()
|
||||
|
||||
for sig in (signal.SIGTERM, signal.SIGINT):
|
||||
loop.add_signal_handler(sig, lambda s=sig: signal_handler(s))
|
||||
|
||||
# Run orchestrator with shutdown monitoring
|
||||
orchestrator_task = asyncio.create_task(orchestrator.run())
|
||||
shutdown_task = asyncio.create_task(shutdown_event.wait())
|
||||
|
||||
try:
|
||||
# Wait for either orchestrator to complete or shutdown signal
|
||||
done, _ = await asyncio.wait(
|
||||
[orchestrator_task, shutdown_task], return_when=asyncio.FIRST_COMPLETED
|
||||
)
|
||||
|
||||
# If shutdown was triggered, stop the orchestrator
|
||||
if shutdown_task in done:
|
||||
logger.info("Shutdown signal received, stopping orchestrator...")
|
||||
orchestrator._running = False
|
||||
# Wait for orchestrator to finish gracefully
|
||||
try:
|
||||
await asyncio.wait_for(orchestrator_task, timeout=5.0)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning("Orchestrator shutdown timeout, cancelling...")
|
||||
orchestrator_task.cancel()
|
||||
try:
|
||||
await orchestrator_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Keyboard interrupt received")
|
||||
finally:
|
||||
# Ensure clean shutdown
|
||||
await orchestrator.shutdown()
|
||||
await redis.aclose()
|
||||
|
||||
logger.info("Campaign Orchestrator service stopped")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
256
api/services/campaign/rate_limiter.py
Normal file
256
api/services/campaign/rate_limiter.py
Normal file
|
|
@ -0,0 +1,256 @@
|
|||
import time
|
||||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
import redis.asyncio as aioredis
|
||||
from loguru import logger
|
||||
|
||||
from api.constants import REDIS_URL
|
||||
|
||||
|
||||
class RateLimiter:
|
||||
"""Sliding window rate limiter to enforce strict per-second limits and concurrent call limits"""
|
||||
|
||||
def __init__(self):
|
||||
self.redis_client: Optional[aioredis.Redis] = None
|
||||
self.stale_call_timeout = 1800 # 30 minutes in seconds
|
||||
|
||||
async def _get_redis(self) -> aioredis.Redis:
|
||||
"""Get or create Redis connection"""
|
||||
if self.redis_client is None:
|
||||
self.redis_client = await aioredis.from_url(
|
||||
REDIS_URL, decode_responses=True
|
||||
)
|
||||
return self.redis_client
|
||||
|
||||
async def acquire_token(self, organization_id: int, rate_limit: int = 1) -> bool:
|
||||
"""
|
||||
Enforces strict rate limit: max N calls per rolling second window
|
||||
Returns True if allowed, False if rate limited
|
||||
"""
|
||||
redis_client = await self._get_redis()
|
||||
|
||||
key = f"rate_limit:{organization_id}"
|
||||
now = time.time()
|
||||
window_start = now - 1.0 # 1 second sliding window
|
||||
|
||||
# Lua script for atomic sliding window operation
|
||||
lua_script = """
|
||||
local key = KEYS[1]
|
||||
local now = tonumber(ARGV[1])
|
||||
local window_start = tonumber(ARGV[2])
|
||||
local max_requests = tonumber(ARGV[3])
|
||||
|
||||
-- Remove timestamps older than window
|
||||
redis.call('ZREMRANGEBYSCORE', key, 0, window_start)
|
||||
|
||||
-- Count requests in current window
|
||||
local current_requests = redis.call('ZCARD', key)
|
||||
|
||||
if current_requests < max_requests then
|
||||
-- Add current timestamp
|
||||
redis.call('ZADD', key, now, now)
|
||||
redis.call('EXPIRE', key, 2) -- Expire after 2 seconds
|
||||
return 1
|
||||
else
|
||||
return 0
|
||||
end
|
||||
"""
|
||||
|
||||
try:
|
||||
result = await redis_client.eval(
|
||||
lua_script, 1, key, now, window_start, rate_limit
|
||||
)
|
||||
return bool(result)
|
||||
except Exception as e:
|
||||
logger.error(f"Rate limiter error: {e}")
|
||||
# On error, be conservative and deny
|
||||
return False
|
||||
|
||||
async def get_next_available_slot(
|
||||
self, organization_id: int, rate_limit: int = 1
|
||||
) -> float:
|
||||
"""
|
||||
Returns seconds until next available slot
|
||||
Useful for implementing retry with backoff
|
||||
"""
|
||||
redis_client = await self._get_redis()
|
||||
|
||||
key = f"rate_limit:{organization_id}"
|
||||
|
||||
try:
|
||||
# Get oldest timestamp in current window
|
||||
oldest = await redis_client.zrange(key, 0, 0, withscores=True)
|
||||
if not oldest:
|
||||
return 0.0 # Can call immediately
|
||||
|
||||
oldest_time = oldest[0][1]
|
||||
next_available = oldest_time + 1.0 # 1 second after oldest
|
||||
wait_time = max(0, next_available - time.time())
|
||||
|
||||
return wait_time
|
||||
except Exception as e:
|
||||
logger.error(f"Rate limiter get_next_available_slot error: {e}")
|
||||
return 1.0 # Default wait time on error
|
||||
|
||||
async def try_acquire_concurrent_slot(
|
||||
self, organization_id: int, max_concurrent: int = 20
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Try to acquire a concurrent call slot.
|
||||
Returns a unique slot_id if successful, None if limit reached.
|
||||
"""
|
||||
redis_client = await self._get_redis()
|
||||
|
||||
concurrent_key = f"concurrent_calls:{organization_id}"
|
||||
now = time.time()
|
||||
stale_cutoff = now - self.stale_call_timeout
|
||||
|
||||
# Lua script for atomic operation
|
||||
lua_script = """
|
||||
local key = KEYS[1]
|
||||
local now = tonumber(ARGV[1])
|
||||
local max_concurrent = tonumber(ARGV[2])
|
||||
local stale_cutoff = tonumber(ARGV[3])
|
||||
local slot_id = ARGV[4]
|
||||
|
||||
-- Remove stale entries (older than 30 minutes)
|
||||
redis.call('ZREMRANGEBYSCORE', key, 0, stale_cutoff)
|
||||
|
||||
-- Get current count
|
||||
local current_count = redis.call('ZCARD', key)
|
||||
|
||||
if current_count < max_concurrent then
|
||||
-- Add new slot
|
||||
redis.call('ZADD', key, now, slot_id)
|
||||
redis.call('EXPIRE', key, 3600) -- Expire after 1 hour
|
||||
return slot_id
|
||||
else
|
||||
return nil
|
||||
end
|
||||
"""
|
||||
|
||||
# Generate unique slot ID (timestamp + random component)
|
||||
slot_id = f"{int(now * 1000)}_{uuid.uuid4().hex[:8]}"
|
||||
|
||||
try:
|
||||
result = await redis_client.eval(
|
||||
lua_script,
|
||||
1,
|
||||
concurrent_key,
|
||||
now,
|
||||
max_concurrent,
|
||||
stale_cutoff,
|
||||
slot_id,
|
||||
)
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"Concurrent limiter error: {e}")
|
||||
return None
|
||||
|
||||
async def release_concurrent_slot(self, organization_id: int, slot_id: str) -> bool:
|
||||
"""
|
||||
Release a concurrent call slot.
|
||||
Returns True if slot was released, False otherwise.
|
||||
"""
|
||||
if not slot_id:
|
||||
return False
|
||||
|
||||
redis_client = await self._get_redis()
|
||||
concurrent_key = f"concurrent_calls:{organization_id}"
|
||||
|
||||
try:
|
||||
removed = await redis_client.zrem(concurrent_key, slot_id)
|
||||
if removed:
|
||||
logger.debug(
|
||||
f"Released concurrent slot {slot_id} for org {organization_id}"
|
||||
)
|
||||
return bool(removed)
|
||||
except Exception as e:
|
||||
logger.error(f"Error releasing concurrent slot: {e}")
|
||||
return False
|
||||
|
||||
async def get_concurrent_count(self, organization_id: int) -> int:
|
||||
"""
|
||||
Get current number of active concurrent calls for an organization.
|
||||
Automatically cleans up stale entries.
|
||||
"""
|
||||
redis_client = await self._get_redis()
|
||||
concurrent_key = f"concurrent_calls:{organization_id}"
|
||||
|
||||
try:
|
||||
# Clean up stale entries first
|
||||
stale_cutoff = time.time() - self.stale_call_timeout
|
||||
await redis_client.zremrangebyscore(concurrent_key, 0, stale_cutoff)
|
||||
|
||||
# Get current count
|
||||
count = await redis_client.zcard(concurrent_key)
|
||||
return count
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting concurrent count: {e}")
|
||||
return 0
|
||||
|
||||
async def store_workflow_slot_mapping(
|
||||
self, workflow_run_id: int, organization_id: int, slot_id: str
|
||||
) -> bool:
|
||||
"""
|
||||
Store the mapping between workflow_run_id and its concurrent slot.
|
||||
Used for cleanup when calls complete.
|
||||
"""
|
||||
redis_client = await self._get_redis()
|
||||
mapping_key = f"workflow_slot_mapping:{workflow_run_id}"
|
||||
|
||||
try:
|
||||
# Store as a hash with TTL
|
||||
await redis_client.hset(
|
||||
mapping_key, mapping={"org_id": organization_id, "slot_id": slot_id}
|
||||
)
|
||||
# Set expiry to match stale timeout
|
||||
await redis_client.expire(mapping_key, self.stale_call_timeout)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Error storing workflow slot mapping: {e}")
|
||||
return False
|
||||
|
||||
async def get_workflow_slot_mapping(
|
||||
self, workflow_run_id: int
|
||||
) -> Optional[tuple[int, str]]:
|
||||
"""
|
||||
Get the concurrent slot mapping for a workflow run.
|
||||
Returns (organization_id, slot_id) tuple or None if not found.
|
||||
"""
|
||||
redis_client = await self._get_redis()
|
||||
mapping_key = f"workflow_slot_mapping:{workflow_run_id}"
|
||||
|
||||
try:
|
||||
mapping = await redis_client.hgetall(mapping_key)
|
||||
if mapping and "org_id" in mapping and "slot_id" in mapping:
|
||||
return (int(mapping["org_id"]), mapping["slot_id"])
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting workflow slot mapping: {e}")
|
||||
return None
|
||||
|
||||
async def delete_workflow_slot_mapping(self, workflow_run_id: int) -> bool:
|
||||
"""
|
||||
Delete the workflow slot mapping after releasing the slot.
|
||||
"""
|
||||
redis_client = await self._get_redis()
|
||||
mapping_key = f"workflow_slot_mapping:{workflow_run_id}"
|
||||
|
||||
try:
|
||||
deleted = await redis_client.delete(mapping_key)
|
||||
return bool(deleted)
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting workflow slot mapping: {e}")
|
||||
return False
|
||||
|
||||
async def close(self):
|
||||
"""Close Redis connection"""
|
||||
if self.redis_client:
|
||||
await self.redis_client.close()
|
||||
self.redis_client = None
|
||||
|
||||
|
||||
# Global rate limiter instance
|
||||
rate_limiter = RateLimiter()
|
||||
122
api/services/campaign/runner.py
Normal file
122
api/services/campaign/runner.py
Normal file
|
|
@ -0,0 +1,122 @@
|
|||
from datetime import UTC, datetime
|
||||
from typing import Any, Dict
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from api.db import db_client
|
||||
from api.tasks.arq import enqueue_job
|
||||
from api.tasks.function_names import FunctionNames
|
||||
|
||||
|
||||
class CampaignRunnerService:
|
||||
"""Orchestrates campaign execution"""
|
||||
|
||||
async def start_campaign(self, campaign_id: int) -> None:
|
||||
"""Entry point - updates state to 'syncing' and enqueues sync task"""
|
||||
# Get campaign
|
||||
campaign = await db_client.get_campaign_by_id(campaign_id)
|
||||
if not campaign:
|
||||
raise ValueError(f"Campaign {campaign_id} not found")
|
||||
|
||||
if campaign.state != "created":
|
||||
raise ValueError(
|
||||
f"Campaign must be in 'created' state to start, current state: {campaign.state}"
|
||||
)
|
||||
|
||||
# Update campaign state to syncing
|
||||
await db_client.update_campaign(
|
||||
campaign_id=campaign_id,
|
||||
state="syncing",
|
||||
started_at=datetime.now(UTC),
|
||||
source_sync_status="in_progress",
|
||||
)
|
||||
|
||||
# Enqueue the sync task
|
||||
await enqueue_job(FunctionNames.SYNC_CAMPAIGN_SOURCE, campaign_id)
|
||||
|
||||
logger.info(f"Campaign {campaign_id} started, syncing source data")
|
||||
|
||||
async def pause_campaign(self, campaign_id: int) -> None:
|
||||
"""Pauses active campaign processing"""
|
||||
campaign = await db_client.get_campaign_by_id(campaign_id)
|
||||
if not campaign:
|
||||
raise ValueError(f"Campaign {campaign_id} not found")
|
||||
|
||||
if campaign.state not in ["running", "syncing"]:
|
||||
raise ValueError(
|
||||
f"Campaign must be in 'running' or 'syncing' state to pause, current state: {campaign.state}"
|
||||
)
|
||||
|
||||
# Update state to paused
|
||||
await db_client.update_campaign(campaign_id=campaign_id, state="paused")
|
||||
|
||||
logger.info(f"Campaign {campaign_id} paused")
|
||||
|
||||
async def resume_campaign(self, campaign_id: int) -> None:
|
||||
"""Resumes paused campaign"""
|
||||
campaign = await db_client.get_campaign_by_id(campaign_id)
|
||||
if not campaign:
|
||||
raise ValueError(f"Campaign {campaign_id} not found")
|
||||
|
||||
if campaign.state != "paused":
|
||||
raise ValueError(
|
||||
f"Campaign must be in 'paused' state to resume, current state: {campaign.state}"
|
||||
)
|
||||
|
||||
# Update state to running
|
||||
await db_client.update_campaign(campaign_id=campaign_id, state="running")
|
||||
|
||||
# Enqueue process batch task to continue processing
|
||||
await enqueue_job(FunctionNames.PROCESS_CAMPAIGN_BATCH, campaign_id)
|
||||
|
||||
logger.info(f"Campaign {campaign_id} resumed")
|
||||
|
||||
async def get_campaign_status(self, campaign_id: int) -> Dict[str, Any]:
|
||||
"""Returns detailed campaign status"""
|
||||
campaign = await db_client.get_campaign_by_id(campaign_id)
|
||||
if not campaign:
|
||||
raise ValueError(f"Campaign {campaign_id} not found")
|
||||
|
||||
# Count failed calls from workflow runs
|
||||
failed_calls = await self._count_failed_campaign_calls(campaign_id)
|
||||
|
||||
return {
|
||||
"campaign_id": campaign_id,
|
||||
"state": campaign.state,
|
||||
"total_rows": campaign.total_rows or 0,
|
||||
"processed_rows": campaign.processed_rows,
|
||||
"failed_calls": failed_calls,
|
||||
"progress_percentage": (
|
||||
(campaign.processed_rows / campaign.total_rows * 100)
|
||||
if campaign.total_rows and campaign.total_rows > 0
|
||||
else 0
|
||||
),
|
||||
"source_sync": {
|
||||
"status": campaign.source_sync_status,
|
||||
"last_synced_at": campaign.source_last_synced_at,
|
||||
"error": campaign.source_sync_error,
|
||||
},
|
||||
"rate_limit": campaign.rate_limit_per_second,
|
||||
"started_at": campaign.started_at,
|
||||
"completed_at": campaign.completed_at,
|
||||
}
|
||||
|
||||
async def _count_failed_campaign_calls(self, campaign_id: int) -> int:
|
||||
"""Count failed calls by examining workflow_run Twilio callbacks"""
|
||||
# Get all workflow runs for this campaign
|
||||
workflow_runs = await db_client.get_workflow_runs_by_campaign(campaign_id)
|
||||
|
||||
failed_count = 0
|
||||
for run in workflow_runs:
|
||||
callbacks = run.logs.get("twilio_status_callbacks", [])
|
||||
if callbacks:
|
||||
# Check final status
|
||||
final_status = callbacks[-1].get("status", "").lower()
|
||||
if final_status in ["failed", "busy", "no-answer"]:
|
||||
failed_count += 1
|
||||
|
||||
return failed_count
|
||||
|
||||
|
||||
# Global instance
|
||||
campaign_runner_service = CampaignRunnerService()
|
||||
49
api/services/campaign/source_sync.py
Normal file
49
api/services/campaign/source_sync.py
Normal file
|
|
@ -0,0 +1,49 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict
|
||||
|
||||
from loguru import logger
|
||||
|
||||
|
||||
class CampaignSourceSyncService(ABC):
|
||||
"""Base class for campaign data source synchronization"""
|
||||
|
||||
@abstractmethod
|
||||
async def sync_source_data(self, campaign_id: int) -> int:
|
||||
"""
|
||||
Fetches data from source and creates queued_runs
|
||||
Each record gets a unique source_uuid based on source type
|
||||
Returns: number of records synced
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def validate_source_schema(self, source_config: Dict[str, Any]) -> bool:
|
||||
"""Validates required fields exist in source"""
|
||||
pass
|
||||
|
||||
async def get_source_credentials(
|
||||
self, organization_id: int, source_type: str
|
||||
) -> Dict[str, Any]:
|
||||
"""Gets OAuth tokens or API credentials via Nango"""
|
||||
# This would be implemented to work with Nango service
|
||||
# For now, returning placeholder
|
||||
logger.info(
|
||||
f"Getting credentials for org {organization_id}, source {source_type}"
|
||||
)
|
||||
return {}
|
||||
|
||||
|
||||
def get_sync_service(source_type: str) -> CampaignSourceSyncService:
|
||||
"""Returns appropriate sync service based on source type"""
|
||||
from .sources.google_sheets import GoogleSheetsSyncService
|
||||
|
||||
services = {
|
||||
"google-sheet": GoogleSheetsSyncService,
|
||||
# Add more as needed: "hubspot": HubSpotSyncService,
|
||||
}
|
||||
|
||||
service_class = services.get(source_type)
|
||||
if not service_class:
|
||||
raise ValueError(f"Unknown source type: {source_type}")
|
||||
|
||||
return service_class()
|
||||
5
api/services/campaign/sources/__init__.py
Normal file
5
api/services/campaign/sources/__init__.py
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
"""Campaign source sync services"""
|
||||
|
||||
from .google_sheets import GoogleSheetsSyncService
|
||||
|
||||
__all__ = ["GoogleSheetsSyncService"]
|
||||
180
api/services/campaign/sources/google_sheets.py
Normal file
180
api/services/campaign/sources/google_sheets.py
Normal file
|
|
@ -0,0 +1,180 @@
|
|||
import re
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import httpx
|
||||
from loguru import logger
|
||||
|
||||
from api.db import db_client
|
||||
from api.services.campaign.source_sync import CampaignSourceSyncService
|
||||
from api.services.integrations.nango import NangoService
|
||||
|
||||
|
||||
class GoogleSheetsSyncService(CampaignSourceSyncService):
|
||||
"""Implementation for Google Sheets synchronization"""
|
||||
|
||||
def __init__(self):
|
||||
self.nango_service = NangoService()
|
||||
self.sheets_api_base = "https://sheets.googleapis.com/v4/spreadsheets"
|
||||
|
||||
async def sync_source_data(self, campaign_id: int) -> int:
|
||||
"""
|
||||
Fetches data from Google Sheets and creates queued_runs
|
||||
"""
|
||||
# Get campaign
|
||||
campaign = await db_client.get_campaign_by_id(campaign_id)
|
||||
if not campaign:
|
||||
raise ValueError(f"Campaign {campaign_id} not found")
|
||||
|
||||
# 1. Get Google Sheets integration for the organization
|
||||
integrations = await db_client.get_integrations_by_organization_id(
|
||||
campaign.organization_id
|
||||
)
|
||||
integration = None
|
||||
for intg in integrations:
|
||||
if intg.provider == "google-sheet" and intg.is_active:
|
||||
integration = intg
|
||||
break
|
||||
|
||||
if not integration:
|
||||
raise ValueError("Google Sheets integration not found or inactive")
|
||||
|
||||
# 2. Get OAuth token via Nango using the integration_id (which is the Nango connection ID)
|
||||
token_data = await self.nango_service.get_access_token(
|
||||
connection_id=integration.integration_id, provider_config_key="google-sheet"
|
||||
)
|
||||
access_token = token_data["credentials"]["access_token"]
|
||||
|
||||
# 3. Extract sheet ID from URL
|
||||
sheet_id = self._extract_sheet_id(campaign.source_id)
|
||||
|
||||
# 4. Get sheet metadata (to find data range)
|
||||
metadata = await self._get_sheet_metadata(sheet_id, access_token)
|
||||
if not metadata.get("sheets"):
|
||||
raise ValueError("No sheets found in the spreadsheet")
|
||||
|
||||
sheet_name = metadata["sheets"][0]["properties"]["title"]
|
||||
|
||||
# 5. Fetch all data from sheet
|
||||
sheet_data = await self._fetch_sheet_data(
|
||||
sheet_id,
|
||||
f"{sheet_name}!A:Z", # Get all columns A-Z
|
||||
access_token,
|
||||
)
|
||||
|
||||
# 6. Convert to queued_runs
|
||||
if not sheet_data or len(sheet_data) < 2:
|
||||
logger.warning(f"No data found in sheet for campaign {campaign_id}")
|
||||
return 0
|
||||
|
||||
headers = sheet_data[0] # First row is headers
|
||||
rows = sheet_data[1:] # Rest is data
|
||||
|
||||
queued_runs = []
|
||||
for idx, row_values in enumerate(rows, 1):
|
||||
# Pad row to match headers length
|
||||
padded_row = row_values + [""] * (len(headers) - len(row_values))
|
||||
|
||||
# Create context variables dict
|
||||
context_vars = dict(zip(headers, padded_row))
|
||||
|
||||
# Skip if no phone number
|
||||
if not context_vars.get("phone_number"):
|
||||
logger.debug(f"Skipping row {idx}: no phone_number")
|
||||
continue
|
||||
|
||||
# Generate unique source UUID
|
||||
source_uuid = f"sheet_{sheet_id}_row_{idx}"
|
||||
|
||||
queued_runs.append(
|
||||
{
|
||||
"campaign_id": campaign_id,
|
||||
"source_uuid": source_uuid,
|
||||
"context_variables": context_vars,
|
||||
"state": "queued",
|
||||
}
|
||||
)
|
||||
|
||||
# 7. Bulk insert
|
||||
if queued_runs:
|
||||
await db_client.bulk_create_queued_runs(queued_runs)
|
||||
logger.info(
|
||||
f"Created {len(queued_runs)} queued runs for campaign {campaign_id}"
|
||||
)
|
||||
|
||||
# 8. Update campaign total_rows
|
||||
await db_client.update_campaign(
|
||||
campaign_id=campaign_id,
|
||||
total_rows=len(queued_runs),
|
||||
source_sync_status="completed",
|
||||
)
|
||||
|
||||
return len(queued_runs)
|
||||
|
||||
async def _fetch_sheet_data(
|
||||
self, sheet_id: str, range: str, access_token: str
|
||||
) -> List[List[str]]:
|
||||
"""Fetch data from Google Sheets API"""
|
||||
url = f"{self.sheets_api_base}/{sheet_id}/values/{range}"
|
||||
headers = {"Authorization": f"Bearer {access_token}"}
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(url, headers=headers)
|
||||
response.raise_for_status()
|
||||
|
||||
data = response.json()
|
||||
return data.get("values", [])
|
||||
|
||||
async def _get_sheet_metadata(
|
||||
self, sheet_id: str, access_token: str
|
||||
) -> Dict[str, Any]:
|
||||
"""Get sheet metadata including sheet names"""
|
||||
url = f"{self.sheets_api_base}/{sheet_id}"
|
||||
headers = {"Authorization": f"Bearer {access_token}"}
|
||||
|
||||
logger.debug(f"Fetching sheet metadata from URL: {url}")
|
||||
logger.debug(f"Using sheet_id: {sheet_id}")
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
try:
|
||||
response = await client.get(url, headers=headers)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.error(f"HTTP error {e.response.status_code} for URL: {url}")
|
||||
logger.error(f"Response body: {e.response.text}")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching sheet metadata: {e}")
|
||||
raise
|
||||
|
||||
def _extract_sheet_id(self, sheet_url: str) -> str:
|
||||
"""
|
||||
Extract sheet ID from various Google Sheets URL formats:
|
||||
- https://docs.google.com/spreadsheets/d/{id}/edit
|
||||
- https://docs.google.com/spreadsheets/d/{id}/edit#gid=0
|
||||
"""
|
||||
pattern = r"/spreadsheets/d/([a-zA-Z0-9-_]+)"
|
||||
match = re.search(pattern, sheet_url)
|
||||
if match:
|
||||
return match.group(1)
|
||||
raise ValueError(f"Invalid Google Sheets URL: {sheet_url}")
|
||||
|
||||
async def validate_source_schema(self, source_config: Dict[str, Any]) -> bool:
|
||||
"""Validate that required columns exist"""
|
||||
required_columns = ["phone_number", "first_name", "last_name"]
|
||||
|
||||
# Fetch just the header row
|
||||
sheet_id = self._extract_sheet_id(source_config["source_id"])
|
||||
access_token = source_config["access_token"]
|
||||
|
||||
headers = await self._fetch_sheet_data(
|
||||
sheet_id,
|
||||
"A1:Z1", # Just first row
|
||||
access_token,
|
||||
)
|
||||
|
||||
if not headers:
|
||||
return False
|
||||
|
||||
header_row = headers[0]
|
||||
return all(col in header_row for col in required_columns)
|
||||
0
api/services/configuration/__init__.py
Normal file
0
api/services/configuration/__init__.py
Normal file
153
api/services/configuration/check_validity.py
Normal file
153
api/services/configuration/check_validity.py
Normal file
|
|
@ -0,0 +1,153 @@
|
|||
from typing import Dict, Optional, TypedDict
|
||||
|
||||
import openai
|
||||
from deepgram import (
|
||||
DeepgramClient,
|
||||
LiveOptions,
|
||||
)
|
||||
from groq import Groq
|
||||
|
||||
# try:
|
||||
# from pyneuphonic import Neuphonic
|
||||
# except ImportError:
|
||||
# Neuphonic = None
|
||||
from api.schemas.user_configuration import (
|
||||
UserConfiguration,
|
||||
)
|
||||
from api.services.configuration.registry import ServiceConfig, ServiceProviders
|
||||
|
||||
|
||||
class APIKeyStatus(TypedDict):
|
||||
model: str
|
||||
message: str
|
||||
|
||||
|
||||
class APIKeyStatusResponse(TypedDict):
|
||||
status: list[APIKeyStatus]
|
||||
|
||||
|
||||
class UserConfigurationValidator:
|
||||
def __init__(self):
|
||||
self._provider_api_key_validity_status: Dict[str, bool] = {}
|
||||
self._validator_map = {
|
||||
ServiceProviders.OPENAI.value: self._check_openai_api_key,
|
||||
ServiceProviders.DEEPGRAM.value: self._check_deepgram_api_key,
|
||||
ServiceProviders.GROQ.value: self._check_groq_api_key,
|
||||
ServiceProviders.ELEVENLABS.value: self._validate_elevenlabs_api_key,
|
||||
ServiceProviders.GOOGLE.value: self._check_google_api_key,
|
||||
ServiceProviders.AZURE.value: self._check_azure_api_key,
|
||||
ServiceProviders.CARTESIA.value: self._check_cartesia_api_key,
|
||||
ServiceProviders.DOGRAH.value: self._check_dograh_api_key,
|
||||
}
|
||||
|
||||
async def validate(self, configuration: UserConfiguration) -> APIKeyStatusResponse:
|
||||
status_list = []
|
||||
|
||||
status_list.extend(self._validate_service(configuration.llm, "llm"))
|
||||
status_list.extend(self._validate_service(configuration.stt, "stt"))
|
||||
status_list.extend(self._validate_service(configuration.tts, "tts"))
|
||||
|
||||
if status_list:
|
||||
raise ValueError(status_list)
|
||||
|
||||
return {"status": [{"model": "all", "message": "ok"}]}
|
||||
|
||||
def _validate_service(
|
||||
self, service_config: Optional[ServiceConfig], service_name: str
|
||||
) -> list[APIKeyStatus]:
|
||||
"""Validate a service configuration and return any error statuses."""
|
||||
if not service_config:
|
||||
return [{"model": service_name, "message": "API key is missing"}]
|
||||
|
||||
provider = service_config.provider
|
||||
api_key = service_config.api_key
|
||||
|
||||
if not self._check_api_key(provider, api_key):
|
||||
return [{"model": service_name, "message": f"Invalid {provider} API key"}]
|
||||
|
||||
return []
|
||||
|
||||
def _check_api_key(self, provider: str, api_key: str) -> bool:
|
||||
"""Check if an API key for a provider is valid."""
|
||||
validator = self._validator_map.get(provider)
|
||||
if not validator:
|
||||
return False
|
||||
|
||||
return validator(provider, api_key)
|
||||
|
||||
def _check_openai_api_key(self, model: str, api_key: str) -> bool:
|
||||
if model in self._provider_api_key_validity_status:
|
||||
return self._provider_api_key_validity_status[model]
|
||||
|
||||
client = openai.OpenAI(api_key=api_key)
|
||||
try:
|
||||
client.models.list()
|
||||
self._provider_api_key_validity_status[model] = True
|
||||
except openai.AuthenticationError:
|
||||
self._provider_api_key_validity_status[model] = False
|
||||
return self._provider_api_key_validity_status[model]
|
||||
|
||||
def _check_deepgram_api_key(self, model: str, api_key: str) -> bool:
|
||||
if model in self._provider_api_key_validity_status:
|
||||
return self._provider_api_key_validity_status[model]
|
||||
|
||||
deepgram = DeepgramClient(api_key)
|
||||
dg_connection = deepgram.listen.websocket.v("1")
|
||||
|
||||
try:
|
||||
options = LiveOptions(
|
||||
model="nova-2",
|
||||
language="en-US",
|
||||
smart_format=True,
|
||||
)
|
||||
|
||||
connected = dg_connection.start(options)
|
||||
self._provider_api_key_validity_status[model] = connected
|
||||
finally:
|
||||
dg_connection.finish()
|
||||
return self._provider_api_key_validity_status[model]
|
||||
|
||||
def _check_groq_api_key(self, model: str, api_key: str) -> bool:
|
||||
if model in self._provider_api_key_validity_status:
|
||||
return self._provider_api_key_validity_status[model]
|
||||
|
||||
client = Groq(api_key=api_key)
|
||||
try:
|
||||
client.models.list()
|
||||
self._provider_api_key_validity_status[model] = True
|
||||
except Exception:
|
||||
self._provider_api_key_validity_status[model] = False
|
||||
return self._provider_api_key_validity_status[model]
|
||||
|
||||
def _validate_elevenlabs_api_key(self, model: str, api_key: str) -> bool:
|
||||
return True
|
||||
|
||||
def _check_google_api_key(self, model: str, api_key: str) -> bool:
|
||||
return True
|
||||
|
||||
def _check_azure_api_key(self, model: str, api_key: str) -> bool:
|
||||
return True
|
||||
|
||||
def _check_cartesia_api_key(self, model: str, api_key: str) -> bool:
|
||||
return True
|
||||
|
||||
def _check_dograh_api_key(self, model: str, api_key: str) -> bool:
|
||||
return True
|
||||
|
||||
# def _check_neuphonic_api_key(self, model: str, api_key: str) -> bool:
|
||||
# if not Neuphonic:
|
||||
# self._provider_api_key_validity_status[model] = False
|
||||
# return self._provider_api_key_validity_status[model]
|
||||
|
||||
# if model in self._provider_api_key_validity_status:
|
||||
# return self._provider_api_key_validity_status[model]
|
||||
|
||||
# client = Neuphonic(api_key=api_key)
|
||||
# try:
|
||||
# response = client.voices.list() # get's all available voices
|
||||
# voices = response.data["voices"]
|
||||
# self._provider_api_key_validity_status[model] = True
|
||||
# except Exception:
|
||||
# self._provider_api_key_validity_status[model] = False
|
||||
|
||||
# return self._provider_api_key_validity_status[model]
|
||||
34
api/services/configuration/defaults.py
Normal file
34
api/services/configuration/defaults.py
Normal file
|
|
@ -0,0 +1,34 @@
|
|||
from __future__ import annotations
|
||||
|
||||
"""Utilities for building default service configurations for a new user.
|
||||
|
||||
The defaults follow the same provider choices exposed by `/user/configurations/defaults`.
|
||||
Values for `api_key` are pulled from environment variables named *{PROVIDER}_API_KEY*.
|
||||
|
||||
If an environment variable is missing, that particular provider configuration is
|
||||
left as ``None``.
|
||||
"""
|
||||
|
||||
|
||||
from api.services.configuration.registry import (
|
||||
DeepgramSTTConfiguration,
|
||||
ElevenlabsTTSConfiguration,
|
||||
OpenAILLMService,
|
||||
ServiceProviders,
|
||||
)
|
||||
|
||||
# Mapping of service to (provider enum, configuration class)
|
||||
_DEFAULTS = {
|
||||
"llm": (ServiceProviders.OPENAI, OpenAILLMService),
|
||||
"tts": (ServiceProviders.ELEVENLABS, ElevenlabsTTSConfiguration),
|
||||
"stt": (ServiceProviders.DEEPGRAM, DeepgramSTTConfiguration),
|
||||
}
|
||||
|
||||
# Public mapping of service name -> default provider
|
||||
DEFAULT_SERVICE_PROVIDERS = {
|
||||
field: provider for field, (provider, _) in _DEFAULTS.items()
|
||||
}
|
||||
|
||||
__all__ = [
|
||||
"DEFAULT_SERVICE_PROVIDERS",
|
||||
]
|
||||
69
api/services/configuration/masking.py
Normal file
69
api/services/configuration/masking.py
Normal file
|
|
@ -0,0 +1,69 @@
|
|||
from __future__ import annotations
|
||||
|
||||
"""Utilities for masking API keys before they are sent to the client.
|
||||
|
||||
The rules are simple:
|
||||
1. Only expose the last *visible* characters (default 4) of a key.
|
||||
2. Incoming masked keys are considered a placeholder – if they equal the mask of
|
||||
the already-stored key, we treat them as *unchanged* and keep the real value
|
||||
in storage.
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from api.schemas.user_configuration import UserConfiguration
|
||||
from api.services.configuration.registry import ServiceConfig
|
||||
|
||||
VISIBLE_CHARS = 4 # number of trailing characters to reveal
|
||||
MASK_CHAR = "*"
|
||||
|
||||
|
||||
def mask_key(real_key: str, visible: int = VISIBLE_CHARS) -> str:
|
||||
"""Return a masked representation of *real_key*.
|
||||
|
||||
Example:
|
||||
>>> mask_key("sk-1234567890abcdef")
|
||||
'****************cdef'
|
||||
"""
|
||||
if real_key is None:
|
||||
return ""
|
||||
|
||||
if visible <= 0 or visible >= len(real_key):
|
||||
# mask entire key or nothing to mask – edge-cases
|
||||
return MASK_CHAR * len(real_key)
|
||||
|
||||
masked_part = MASK_CHAR * (len(real_key) - visible)
|
||||
return f"{masked_part}{real_key[-visible:]}"
|
||||
|
||||
|
||||
def is_mask_of(masked: str, real_key: str) -> bool:
|
||||
"""Return *True* if *masked* equals the mask of *real_key* under the current rules."""
|
||||
return mask_key(real_key) == masked
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# High-level helpers for UserConfiguration objects
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _mask_service(service_cfg: Optional[ServiceConfig]) -> Optional[Dict[str, Any]]:
|
||||
if service_cfg is None:
|
||||
return None
|
||||
|
||||
# Work on a dict copy so we don't mutate original models
|
||||
data = service_cfg.model_dump()
|
||||
if "api_key" in data and data["api_key"]:
|
||||
data["api_key"] = mask_key(data["api_key"])
|
||||
return data
|
||||
|
||||
|
||||
def mask_user_config(config: UserConfiguration) -> Dict[str, Any]:
|
||||
"""Return a JSON-serialisable dict of *config* with every api_key masked."""
|
||||
|
||||
return {
|
||||
"llm": _mask_service(config.llm),
|
||||
"tts": _mask_service(config.tts),
|
||||
"stt": _mask_service(config.stt),
|
||||
"test_phone_number": config.test_phone_number,
|
||||
"timezone": config.timezone,
|
||||
}
|
||||
75
api/services/configuration/merge.py
Normal file
75
api/services/configuration/merge.py
Normal file
|
|
@ -0,0 +1,75 @@
|
|||
from __future__ import annotations
|
||||
|
||||
"""Helpers for merging incoming user-configuration updates with what is already
|
||||
stored, while honouring masked API keys.
|
||||
"""
|
||||
|
||||
from typing import Dict
|
||||
|
||||
from api.schemas.user_configuration import UserConfiguration
|
||||
from api.services.configuration.masking import is_mask_of
|
||||
|
||||
SERVICE_FIELDS = ("llm", "tts", "stt")
|
||||
|
||||
|
||||
def merge_user_configurations(
|
||||
existing: UserConfiguration, incoming_partial: Dict[str, dict]
|
||||
) -> UserConfiguration:
|
||||
"""Merge *incoming_partial* onto *existing* and return a new UserConfiguration.
|
||||
|
||||
*incoming_partial* is the body of the PUT request (already `model_dump()`ed or
|
||||
extracted via Pydantic `model_dump`).
|
||||
|
||||
Rules:
|
||||
1. If a service block is absent in the request, keep the existing one.
|
||||
2. If provider unchanged and the api_key field is either missing or equal to
|
||||
the masked placeholder, preserve the existing real key.
|
||||
3. If provider changes, the incoming api_key is used verbatim (validation
|
||||
will fail later if it is missing).
|
||||
4. Non-service top-level fields (e.g. `test_phone_number`) are overwritten
|
||||
when supplied.
|
||||
"""
|
||||
|
||||
merged = existing.model_dump(exclude_none=True)
|
||||
|
||||
def _merge_service_block(service_name: str):
|
||||
incoming_cfg = incoming_partial.get(service_name)
|
||||
if incoming_cfg is None:
|
||||
return # nothing to do
|
||||
|
||||
old_cfg = merged.get(service_name, {})
|
||||
|
||||
provider_changed = (
|
||||
old_cfg.get("provider") is not None
|
||||
and incoming_cfg.get("provider") is not None
|
||||
and incoming_cfg.get("provider") != old_cfg.get("provider")
|
||||
)
|
||||
|
||||
incoming_api_key = incoming_cfg.get("api_key")
|
||||
|
||||
if not provider_changed:
|
||||
# conditional preservation of api_key
|
||||
if incoming_api_key is not None:
|
||||
if (
|
||||
old_cfg
|
||||
and "api_key" in old_cfg
|
||||
and is_mask_of(incoming_api_key, old_cfg["api_key"])
|
||||
):
|
||||
incoming_cfg["api_key"] = old_cfg["api_key"]
|
||||
else:
|
||||
if "api_key" in old_cfg:
|
||||
incoming_cfg["api_key"] = old_cfg["api_key"]
|
||||
|
||||
merged[service_name] = incoming_cfg
|
||||
|
||||
for service in SERVICE_FIELDS:
|
||||
_merge_service_block(service)
|
||||
|
||||
# other simple fields
|
||||
if "test_phone_number" in incoming_partial:
|
||||
merged["test_phone_number"] = incoming_partial["test_phone_number"]
|
||||
|
||||
if "timezone" in incoming_partial:
|
||||
merged["timezone"] = incoming_partial["timezone"]
|
||||
|
||||
return UserConfiguration.model_validate(merged)
|
||||
356
api/services/configuration/registry.py
Normal file
356
api/services/configuration/registry.py
Normal file
|
|
@ -0,0 +1,356 @@
|
|||
from enum import Enum, auto
|
||||
from typing import Annotated, Dict, Literal, Type, TypeVar, Union
|
||||
|
||||
from pydantic import BaseModel, Field, computed_field
|
||||
|
||||
|
||||
class ServiceType(Enum):
|
||||
LLM = auto()
|
||||
TTS = auto()
|
||||
STT = auto()
|
||||
|
||||
|
||||
class ServiceProviders(str, Enum):
|
||||
OPENAI = "openai"
|
||||
DEEPGRAM = "deepgram"
|
||||
GROQ = "groq"
|
||||
CARTESIA = "cartesia"
|
||||
# NEUPHONIC = "neuphonic"
|
||||
ELEVENLABS = "elevenlabs"
|
||||
GOOGLE = "google"
|
||||
AZURE = "azure"
|
||||
DOGRAH = "dograh"
|
||||
|
||||
|
||||
class BaseServiceConfiguration(BaseModel):
|
||||
provider: Literal[
|
||||
ServiceProviders.OPENAI,
|
||||
ServiceProviders.DEEPGRAM,
|
||||
ServiceProviders.GROQ,
|
||||
ServiceProviders.ELEVENLABS,
|
||||
ServiceProviders.GOOGLE,
|
||||
ServiceProviders.AZURE,
|
||||
ServiceProviders.DOGRAH,
|
||||
]
|
||||
api_key: str
|
||||
|
||||
|
||||
class BaseLLMConfiguration(BaseServiceConfiguration):
|
||||
model: str
|
||||
|
||||
|
||||
class BaseTTSConfiguration(BaseServiceConfiguration):
|
||||
model: str
|
||||
|
||||
|
||||
class BaseSTTConfiguration(BaseServiceConfiguration):
|
||||
model: str
|
||||
|
||||
|
||||
# Unified registry for all service types
|
||||
REGISTRY: Dict[ServiceType, Dict[str, Type[BaseServiceConfiguration]]] = {
|
||||
ServiceType.LLM: {},
|
||||
ServiceType.TTS: {},
|
||||
ServiceType.STT: {},
|
||||
}
|
||||
|
||||
T = TypeVar("T", bound=BaseServiceConfiguration)
|
||||
|
||||
|
||||
def register_service(service_type: ServiceType):
|
||||
"""Generic decorator for registering service configurations"""
|
||||
|
||||
def decorator(cls: Type[T]) -> Type[T]:
|
||||
# Get provider from class attributes or field defaults
|
||||
provider = getattr(cls, "provider", None)
|
||||
if provider is None:
|
||||
# Try to get from model fields
|
||||
provider = cls.model_fields.get("provider", None)
|
||||
if provider is not None:
|
||||
provider = provider.default
|
||||
if provider is None:
|
||||
raise ValueError(f"Provider not specified for {cls.__name__}")
|
||||
|
||||
REGISTRY[service_type][provider] = cls
|
||||
return cls
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
# Convenience decorators
|
||||
def register_llm(cls: Type[BaseLLMConfiguration]):
|
||||
return register_service(ServiceType.LLM)(cls)
|
||||
|
||||
|
||||
def register_tts(cls: Type[BaseTTSConfiguration]):
|
||||
return register_service(ServiceType.TTS)(cls)
|
||||
|
||||
|
||||
def register_stt(cls: Type[BaseSTTConfiguration]):
|
||||
return register_service(ServiceType.STT)(cls)
|
||||
|
||||
|
||||
###################################################### LLM ########################################################################
|
||||
|
||||
|
||||
class OpenAIModel(str, Enum):
|
||||
GPT3_5_TURBO = "gpt-3.5-turbo"
|
||||
GPT4_1 = "gpt-4.1"
|
||||
GPT4_1_MINI = "gpt-4.1-mini"
|
||||
GPT4_1_NANO = "gpt-4.1-nano"
|
||||
GPT5 = "gpt-5"
|
||||
GPT5_MINI = "gpt-5-mini"
|
||||
GPT5_NANO = "gpt-5-nano"
|
||||
|
||||
|
||||
@register_llm
|
||||
class OpenAILLMService(BaseLLMConfiguration):
|
||||
provider: Literal[ServiceProviders.OPENAI] = ServiceProviders.OPENAI
|
||||
model: OpenAIModel = OpenAIModel.GPT4_1
|
||||
api_key: str
|
||||
|
||||
|
||||
class GoogleModel(str, Enum):
|
||||
GEMINI_2_0_FLASH = "gemini-2.0-flash"
|
||||
GEMINI_2_0_FLASH_LITE = "gemini-2.0-flash-lite"
|
||||
GEMINI_2_5_FLASH = "gemini-2.5-flash"
|
||||
GEMINI_2_5_FLASH_LITE = "gemini-2.5-flash-lite"
|
||||
|
||||
|
||||
@register_llm
|
||||
class GoogleLLMService(BaseLLMConfiguration):
|
||||
provider: Literal[ServiceProviders.GOOGLE] = ServiceProviders.GOOGLE
|
||||
model: GoogleModel = GoogleModel.GEMINI_2_0_FLASH
|
||||
api_key: str
|
||||
|
||||
|
||||
class GroqModel(str, Enum):
|
||||
LLAMA_3_3_70B = "llama-3.3-70b-versatile"
|
||||
DEEPSEEK_R1_DISTILL_LLAMA_70B = "deepseek-r1-distill-llama-70b"
|
||||
QUEN_QWQ_32B = "qwen-qwq-32b"
|
||||
LLAMA_4_SCOUT_17B_16E_INSTRUCT = "meta-llama/llama-4-scout-17b-16e-instruct"
|
||||
LLAMA_4_MAVERICK_17B_128E_INSTRUCT = "meta-llama/llama-4-maverick-17b-128e-instruct"
|
||||
GEMMA2_9B_IT = "gemma2-9b-it"
|
||||
LLAMA_3_1_8B_INSTANT = "llama-3.1-8b-instant"
|
||||
OPENAI_GPT_OSS_120B = "openai/gpt-oss-120b"
|
||||
|
||||
|
||||
@register_llm
|
||||
class GroqLLMService(BaseLLMConfiguration):
|
||||
provider: Literal[ServiceProviders.GROQ] = ServiceProviders.GROQ
|
||||
model: GroqModel = GroqModel.LLAMA_3_3_70B
|
||||
api_key: str
|
||||
|
||||
|
||||
class AzureModel(str, Enum):
|
||||
GPT4_1_MINI = "gpt-4.1-mini"
|
||||
|
||||
|
||||
@register_llm
|
||||
class AzureLLMService(BaseLLMConfiguration):
|
||||
provider: Literal[ServiceProviders.AZURE] = ServiceProviders.AZURE
|
||||
model: AzureModel = AzureModel.GPT4_1_MINI
|
||||
api_key: str
|
||||
endpoint: str
|
||||
|
||||
|
||||
# Dograh LLM Service
|
||||
class DograhLLMModel(str, Enum):
|
||||
DEFAULT = "default"
|
||||
|
||||
|
||||
@register_llm
|
||||
class DograhLLMService(BaseLLMConfiguration):
|
||||
provider: Literal[ServiceProviders.DOGRAH] = ServiceProviders.DOGRAH
|
||||
model: DograhLLMModel = DograhLLMModel.DEFAULT
|
||||
api_key: str
|
||||
|
||||
|
||||
LLMConfig = Annotated[
|
||||
Union[
|
||||
OpenAILLMService,
|
||||
GroqLLMService,
|
||||
GoogleLLMService,
|
||||
AzureLLMService,
|
||||
DograhLLMService,
|
||||
],
|
||||
Field(discriminator="provider"),
|
||||
]
|
||||
|
||||
###################################################### TTS ########################################################################
|
||||
|
||||
|
||||
class DeepgramVoice(str, Enum):
|
||||
HELENA = "aura-2-helena-en"
|
||||
THALIA = "aura-2-thalia-en"
|
||||
|
||||
|
||||
@register_tts
|
||||
class DeepgramTTSConfiguration(BaseServiceConfiguration):
|
||||
provider: Literal[ServiceProviders.DEEPGRAM] = ServiceProviders.DEEPGRAM
|
||||
voice: DeepgramVoice = DeepgramVoice.HELENA
|
||||
api_key: str
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def model(self) -> str:
|
||||
# Deepgram model's name is inferred using the voice name.
|
||||
# It can either contain aura-2 or aura-1
|
||||
if "aura-2" in self.voice:
|
||||
return "aura-2"
|
||||
elif "aura-1" in self.voice:
|
||||
return "aura-1"
|
||||
else:
|
||||
# Default fallback
|
||||
return "aura-2"
|
||||
|
||||
|
||||
class ElevenlabsVoice(str, Enum):
|
||||
ALEXANDRA = "Alexandra - 3dzJXoCYueSQiptQ6euE"
|
||||
AMY = "Amy - oGn4Ha2pe2vSJkmIJgLQ"
|
||||
ANGELA = "Angela - FUfBrNit0NNZAwb58KWH"
|
||||
ARIA = "Aria - 9BWtsMINqrJLrRacOk9x"
|
||||
CHELSEA = "Chelsea - NHRgOEwqx5WZNClv5sat"
|
||||
CHRISTINA = "Christina - X03mvPuTfprif8QBAVeJ"
|
||||
CLARA = "Clara - ZIlrSGI4jZqobxRKprJz"
|
||||
CLYDE = "Clyde - 2EiwWnXFnvU5JabPnv8n"
|
||||
DAVE = "Dave - CYw3kZ02Hs0563khs1Fj"
|
||||
DOMI = "Domi - AZnzlk1XvdvUeBnXmlld"
|
||||
DREW = "Drew - 29vD33N1CtxCmqQRPOHJ"
|
||||
EVE = "Eve - BZgkqPqms7Kj9ulSkVzn"
|
||||
FIN = "Fin - D38z5RcWu1voky8WS1ja"
|
||||
HOPE_BESTIE = "Hope_Bestie - uYXf8XasLslADfZ2MB4u"
|
||||
HOPE_NATURAL = "Hope_Natural - OYTbf65OHHFELVut7v2H"
|
||||
JARNATHAN = "Jarnathan - c6SfcYrb2t09NHXiT80T"
|
||||
JENNA = "Jenna - C2BkQxlGNzBn7WD2bqfR"
|
||||
JESSICA = "Jessica - cgSgspJ2msm6clMCkdW9"
|
||||
JUNIPER = "Juniper - aMSt68OGf4xUZAnLpTU8"
|
||||
LAUREN = "Lauren - 3liN8q8YoeB9Hk6AboKe"
|
||||
LINA = "Lina - oWjuL7HSoaEJRMDMP3HD"
|
||||
OLIVIA = "Olivia - 1rviaVF7GGGkTU36HNpz"
|
||||
PAUL = "Paul - 5Q0t7uMcjvnagumLfvZi"
|
||||
RACHEL = "Rachel - 21m00Tcm4TlvDq8ikWAM"
|
||||
ROGER = "Roger - CwhRBWXzGAHq8TQ4Fs17"
|
||||
SAMI_REAL = "Sami_Real - O4cGUVdAocn0z4EpQ9yF"
|
||||
SARAH = "Sarah - EXAVITQu4vr4xnSDxMaL"
|
||||
|
||||
|
||||
class ElevenlabsModel(str, Enum):
|
||||
FLASH_2 = "eleven_flash_v2_5"
|
||||
|
||||
|
||||
@register_tts
|
||||
class ElevenlabsTTSConfiguration(BaseServiceConfiguration):
|
||||
provider: Literal[ServiceProviders.ELEVENLABS] = ServiceProviders.ELEVENLABS
|
||||
voice: ElevenlabsVoice = ElevenlabsVoice.RACHEL
|
||||
speed: float = Field(default=1.0, ge=0.1, le=2.0, description="Speed of the voice")
|
||||
model: ElevenlabsModel = ElevenlabsModel.FLASH_2
|
||||
api_key: str
|
||||
|
||||
|
||||
class OpenAIVoice(str, Enum):
|
||||
ALLY = "alloy"
|
||||
|
||||
|
||||
class OpenAITTSModel(str, Enum):
|
||||
GPT_4o_MINI = "gpt-4o-mini-tts"
|
||||
|
||||
|
||||
@register_tts
|
||||
class OpenAITTSService(BaseTTSConfiguration):
|
||||
provider: Literal[ServiceProviders.OPENAI] = ServiceProviders.OPENAI
|
||||
model: OpenAITTSModel = OpenAITTSModel.GPT_4o_MINI
|
||||
voice: OpenAIVoice = OpenAIVoice.ALLY
|
||||
api_key: str
|
||||
|
||||
|
||||
# class NeuphonicVoice(str, Enum):
|
||||
# EMILY = "Emily - fc854436-2dac-4d21-aa69-ae17b54e98eb"
|
||||
|
||||
|
||||
# @register_tts
|
||||
# class NeuphonicTTSService(BaseTTSConfiguration):
|
||||
# provider: Literal[ServiceProviders.NEUPHONIC] = ServiceProviders.NEUPHONIC
|
||||
# voice: NeuphonicVoice = NeuphonicVoice.EMILY
|
||||
# model: str = "NA"
|
||||
# api_key: str
|
||||
|
||||
|
||||
# Dograh TTS Service
|
||||
class DograhVoice(str, Enum):
|
||||
DEFAULT = "default"
|
||||
|
||||
|
||||
class DograhTTSModel(str, Enum):
|
||||
DEFAULT = "default"
|
||||
|
||||
|
||||
@register_tts
|
||||
class DograhTTSService(BaseTTSConfiguration):
|
||||
provider: Literal[ServiceProviders.DOGRAH] = ServiceProviders.DOGRAH
|
||||
model: DograhTTSModel = DograhTTSModel.DEFAULT
|
||||
voice: DograhVoice = DograhVoice.DEFAULT
|
||||
api_key: str
|
||||
|
||||
|
||||
TTSConfig = Annotated[
|
||||
Union[
|
||||
DeepgramTTSConfiguration,
|
||||
OpenAITTSService,
|
||||
ElevenlabsTTSConfiguration,
|
||||
DograhTTSService,
|
||||
],
|
||||
Field(discriminator="provider"),
|
||||
]
|
||||
|
||||
###################################################### STT ########################################################################
|
||||
|
||||
|
||||
class DeepgramSTTModel(str, Enum):
|
||||
NOVA_3_GENERAL = "nova-3-general"
|
||||
|
||||
|
||||
@register_stt
|
||||
class DeepgramSTTConfiguration(BaseSTTConfiguration):
|
||||
provider: Literal[ServiceProviders.DEEPGRAM] = ServiceProviders.DEEPGRAM
|
||||
model: DeepgramSTTModel = DeepgramSTTModel.NOVA_3_GENERAL
|
||||
api_key: str
|
||||
|
||||
|
||||
@register_stt
|
||||
class CartesiaSTTConfiguration(BaseSTTConfiguration):
|
||||
provider: Literal[ServiceProviders.CARTESIA] = ServiceProviders.CARTESIA
|
||||
api_key: str
|
||||
|
||||
|
||||
class OpenAISTTModel(str, Enum):
|
||||
GPT_4o_TRANSCRIBE = "gpt-4o-transcribe"
|
||||
|
||||
|
||||
@register_stt
|
||||
class OpenAISTTConfiguration(BaseSTTConfiguration):
|
||||
provider: Literal[ServiceProviders.OPENAI] = ServiceProviders.OPENAI
|
||||
model: OpenAISTTModel = OpenAISTTModel.GPT_4o_TRANSCRIBE
|
||||
api_key: str
|
||||
|
||||
|
||||
# Dograh STT Service
|
||||
class DograhSTTModel(str, Enum):
|
||||
DEFAULT = "default"
|
||||
|
||||
|
||||
@register_stt
|
||||
class DograhSTTService(BaseSTTConfiguration):
|
||||
provider: Literal[ServiceProviders.DOGRAH] = ServiceProviders.DOGRAH
|
||||
model: DograhSTTModel = DograhSTTModel.DEFAULT
|
||||
api_key: str
|
||||
|
||||
|
||||
STTConfig = Annotated[
|
||||
Union[DeepgramSTTConfiguration, OpenAISTTConfiguration, DograhSTTService],
|
||||
Field(discriminator="provider"),
|
||||
]
|
||||
|
||||
ServiceConfig = Annotated[
|
||||
Union[LLMConfig, TTSConfig, STTConfig], Field(discriminator="provider")
|
||||
]
|
||||
9
api/services/filesystem/__init__.py
Normal file
9
api/services/filesystem/__init__.py
Normal file
|
|
@ -0,0 +1,9 @@
|
|||
from .base import BaseFileSystem
|
||||
from .minio import MinioFileSystem
|
||||
from .s3 import S3FileSystem
|
||||
|
||||
__all__ = [
|
||||
"BaseFileSystem",
|
||||
"S3FileSystem",
|
||||
"MinioFileSystem",
|
||||
]
|
||||
60
api/services/filesystem/base.py
Normal file
60
api/services/filesystem/base.py
Normal file
|
|
@ -0,0 +1,60 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from typing import Any, BinaryIO, Dict, Optional
|
||||
|
||||
|
||||
class BaseFileSystem(ABC):
|
||||
"""Abstract base class for filesystem operations."""
|
||||
|
||||
@abstractmethod
|
||||
async def acreate_file(self, file_path: str, content: BinaryIO) -> bool:
|
||||
"""Create a new file with the given content.
|
||||
|
||||
Args:
|
||||
file_path: Path where the file should be created
|
||||
content: File content as a binary stream
|
||||
|
||||
Returns:
|
||||
bool: True if file was created successfully, False otherwise
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def aupload_file(self, local_path: str, destination_path: str) -> bool:
|
||||
"""Upload a file from local path to destination.
|
||||
|
||||
Args:
|
||||
local_path: Path to the local file
|
||||
destination_path: Path where the file should be uploaded
|
||||
|
||||
Returns:
|
||||
bool: True if file was uploaded successfully, False otherwise
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def aget_signed_url(
|
||||
self, file_path: str, expiration: int = 3600
|
||||
) -> Optional[str]:
|
||||
"""Generate a signed URL for temporary access to a file.
|
||||
|
||||
Args:
|
||||
file_path: Path to the file
|
||||
expiration: URL expiration time in seconds (default: 1 hour)
|
||||
|
||||
Returns:
|
||||
Optional[str]: Signed URL if successful, None otherwise
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def aget_file_metadata(self, file_path: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get metadata for a file.
|
||||
|
||||
Args:
|
||||
file_path: Path to the file
|
||||
|
||||
Returns:
|
||||
Optional[Dict[str, Any]]: File metadata if successful, None otherwise
|
||||
Contains: size, created_at, modified_at, etag, etc.
|
||||
"""
|
||||
pass
|
||||
95
api/services/filesystem/local.py
Normal file
95
api/services/filesystem/local.py
Normal file
|
|
@ -0,0 +1,95 @@
|
|||
import asyncio
|
||||
import os
|
||||
from datetime import datetime
|
||||
from typing import BinaryIO, Optional
|
||||
|
||||
import aiofiles
|
||||
|
||||
from .base import BaseFileSystem
|
||||
|
||||
|
||||
class LocalFileSystem(BaseFileSystem):
|
||||
"""Local filesystem implementation."""
|
||||
|
||||
def __init__(self, base_path: str):
|
||||
"""Initialize local filesystem.
|
||||
|
||||
Args:
|
||||
base_path: Base directory path for file operations
|
||||
"""
|
||||
self.base_path = base_path
|
||||
os.makedirs(base_path, exist_ok=True)
|
||||
|
||||
def _get_full_path(self, file_path: str) -> str:
|
||||
"""Get the full path by joining with base path."""
|
||||
return os.path.join(self.base_path, file_path)
|
||||
|
||||
async def acreate_file(self, file_path: str, content: BinaryIO) -> bool:
|
||||
try:
|
||||
full_path = self._get_full_path(file_path)
|
||||
os.makedirs(os.path.dirname(full_path), exist_ok=True)
|
||||
|
||||
async with aiofiles.open(full_path, "wb") as f:
|
||||
await f.write(await content.read())
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
async def create_temp_file(self, file_path: str) -> bool:
|
||||
try:
|
||||
full_path = self._get_full_path(file_path)
|
||||
os.makedirs(os.path.dirname(full_path), exist_ok=True)
|
||||
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
async def aupload_file(self, local_path: str, destination_path: str) -> bool:
|
||||
try:
|
||||
full_dest_path = self._get_full_path(destination_path)
|
||||
os.makedirs(os.path.dirname(full_dest_path), exist_ok=True)
|
||||
|
||||
async with (
|
||||
aiofiles.open(local_path, "rb") as src,
|
||||
aiofiles.open(full_dest_path, "wb") as dst,
|
||||
):
|
||||
await dst.write(await src.read())
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
async def aget_signed_url(
|
||||
self, file_path: str, expiration: int = 3600
|
||||
) -> Optional[str]:
|
||||
# For local filesystem, we'll create a temporary symlink with expiration
|
||||
try:
|
||||
full_path = self._get_full_path(file_path)
|
||||
if not os.path.exists(full_path):
|
||||
return None
|
||||
|
||||
# Create a temporary directory for symlinks
|
||||
temp_dir = os.path.join(self.base_path, ".temp_links")
|
||||
os.makedirs(temp_dir, exist_ok=True)
|
||||
|
||||
# Generate a unique temporary filename
|
||||
temp_filename = (
|
||||
f"{datetime.now().timestamp()}_{os.path.basename(file_path)}"
|
||||
)
|
||||
temp_path = os.path.join(temp_dir, temp_filename)
|
||||
|
||||
# Create symlink
|
||||
os.symlink(full_path, temp_path)
|
||||
|
||||
# Schedule deletion after expiration
|
||||
async def delete_after_expiration():
|
||||
await asyncio.sleep(expiration)
|
||||
try:
|
||||
os.remove(temp_path)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
asyncio.create_task(delete_after_expiration())
|
||||
|
||||
return f"/files/{temp_filename}"
|
||||
except Exception:
|
||||
return None
|
||||
137
api/services/filesystem/minio.py
Normal file
137
api/services/filesystem/minio.py
Normal file
|
|
@ -0,0 +1,137 @@
|
|||
import asyncio
|
||||
from datetime import timedelta
|
||||
from typing import Any, BinaryIO, Dict, Optional
|
||||
|
||||
from minio import Minio
|
||||
from minio.error import S3Error
|
||||
|
||||
from .base import BaseFileSystem
|
||||
|
||||
|
||||
class MinioFileSystem(BaseFileSystem):
|
||||
"""MinIO implementation of the filesystem interface for OSS users.
|
||||
|
||||
Handles both internal (container-to-container) and external (browser) access:
|
||||
- endpoint: Used for API operations (uploads, downloads from code)
|
||||
- public_endpoint: Used for generating browser-accessible presigned URLs
|
||||
|
||||
Auto-detection logic:
|
||||
1. If MINIO_PUBLIC_ENDPOINT env var is set, use it (for production/custom domains)
|
||||
2. If endpoint is "minio:9000" (Docker internal), auto-use "localhost:9000" for browser
|
||||
3. Otherwise, endpoint works for both (e.g., "localhost:9000" in local non-Docker setup)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
endpoint: str = "localhost:9000",
|
||||
access_key: str = "minioadmin",
|
||||
secret_key: str = "minioadmin",
|
||||
bucket_name: str = "voice-audio",
|
||||
secure: bool = False,
|
||||
public_endpoint: Optional[str] = None,
|
||||
):
|
||||
self.bucket_name = bucket_name
|
||||
self.endpoint = endpoint
|
||||
self.public_endpoint = public_endpoint or endpoint
|
||||
self.secure = secure
|
||||
self.access_key = access_key
|
||||
self.secret_key = secret_key
|
||||
|
||||
# Client for internal operations (uploads, etc.)
|
||||
self.client = Minio(
|
||||
endpoint, access_key=access_key, secret_key=secret_key, secure=secure
|
||||
)
|
||||
|
||||
# Ensure bucket exists (using internal client)
|
||||
try:
|
||||
if not self.client.bucket_exists(self.bucket_name):
|
||||
self.client.make_bucket(self.bucket_name)
|
||||
except Exception as e:
|
||||
# Bucket might already exist or we might be in a restricted environment
|
||||
pass
|
||||
|
||||
async def acreate_file(self, file_path: str, content: BinaryIO) -> bool:
|
||||
try:
|
||||
data = await content.read()
|
||||
|
||||
def _put():
|
||||
self.client.put_object(
|
||||
self.bucket_name,
|
||||
file_path,
|
||||
data=bytes(data),
|
||||
length=len(data),
|
||||
)
|
||||
|
||||
await asyncio.to_thread(_put)
|
||||
return True
|
||||
except S3Error:
|
||||
return False
|
||||
|
||||
async def aupload_file(self, local_path: str, destination_path: str) -> bool:
|
||||
try:
|
||||
|
||||
def _fput():
|
||||
self.client.fput_object(self.bucket_name, destination_path, local_path)
|
||||
|
||||
await asyncio.to_thread(_fput)
|
||||
return True
|
||||
except S3Error:
|
||||
return False
|
||||
|
||||
async def aget_signed_url(
|
||||
self, file_path: str, expiration: int = 3600, force_inline: bool = False
|
||||
) -> Optional[str]:
|
||||
try:
|
||||
|
||||
def _presign():
|
||||
response_headers = None
|
||||
if force_inline and file_path.endswith(".txt"):
|
||||
response_headers = {
|
||||
"response-content-type": "text/plain",
|
||||
"response-content-disposition": "inline",
|
||||
}
|
||||
|
||||
# Generate URL with the main client
|
||||
url = self.client.presigned_get_object(
|
||||
self.bucket_name,
|
||||
file_path,
|
||||
expires=timedelta(seconds=expiration),
|
||||
response_headers=response_headers,
|
||||
)
|
||||
|
||||
# If we have different public endpoint, replace it in the URL
|
||||
if self.endpoint != self.public_endpoint:
|
||||
# Simple string replacement since presigned URLs are just strings
|
||||
# Replace the endpoint in the URL
|
||||
url = url.replace(
|
||||
f"://{self.endpoint}/", f"://{self.public_endpoint}/"
|
||||
)
|
||||
url = url.replace(
|
||||
f"Host={self.endpoint}", f"Host={self.public_endpoint}"
|
||||
)
|
||||
|
||||
return url
|
||||
|
||||
url = await asyncio.to_thread(_presign)
|
||||
return url
|
||||
except S3Error:
|
||||
return None
|
||||
|
||||
async def aget_file_metadata(self, file_path: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get MinIO object metadata."""
|
||||
try:
|
||||
|
||||
def _stat():
|
||||
return self.client.stat_object(self.bucket_name, file_path)
|
||||
|
||||
stat = await asyncio.to_thread(_stat)
|
||||
return {
|
||||
"size": stat.size,
|
||||
"created_at": stat.last_modified,
|
||||
"modified_at": stat.last_modified,
|
||||
"etag": stat.etag.strip('"') if stat.etag else None,
|
||||
"content_type": stat.content_type,
|
||||
"storage_class": None, # MinIO doesn't have storage classes like S3
|
||||
}
|
||||
except S3Error:
|
||||
return None
|
||||
99
api/services/filesystem/s3.py
Normal file
99
api/services/filesystem/s3.py
Normal file
|
|
@ -0,0 +1,99 @@
|
|||
from typing import Any, BinaryIO, Dict, Optional
|
||||
|
||||
import aioboto3
|
||||
from botocore.exceptions import ClientError
|
||||
|
||||
from .base import BaseFileSystem
|
||||
|
||||
|
||||
class S3FileSystem(BaseFileSystem):
|
||||
"""S3 implementation of the filesystem interface."""
|
||||
|
||||
def __init__(self, bucket_name: str, region_name: str = "us-east-1"):
|
||||
"""Initialize S3 filesystem.
|
||||
|
||||
Args:
|
||||
bucket_name: Name of the S3 bucket
|
||||
region_name: AWS region name
|
||||
"""
|
||||
self.bucket_name = bucket_name
|
||||
self.region_name = region_name
|
||||
self.session = aioboto3.Session()
|
||||
|
||||
async def acreate_file(self, file_path: str, content: BinaryIO) -> bool:
|
||||
try:
|
||||
async with self.session.client(
|
||||
"s3", region_name=self.region_name
|
||||
) as s3_client:
|
||||
await s3_client.put_object(
|
||||
Bucket=self.bucket_name, Key=file_path, Body=await content.read()
|
||||
)
|
||||
return True
|
||||
except ClientError:
|
||||
return False
|
||||
|
||||
async def aupload_file(self, local_path: str, destination_path: str) -> bool:
|
||||
try:
|
||||
async with self.session.client(
|
||||
"s3", region_name=self.region_name
|
||||
) as s3_client:
|
||||
await s3_client.upload_file(
|
||||
local_path, self.bucket_name, destination_path
|
||||
)
|
||||
return True
|
||||
except ClientError:
|
||||
return False
|
||||
|
||||
async def aget_signed_url(
|
||||
self, file_path: str, expiration: int = 3600, force_inline: bool = False
|
||||
) -> Optional[str]:
|
||||
"""Generate a presigned GET url for the given object.
|
||||
|
||||
For transcript text files we force the response headers so that the
|
||||
browser renders the content **inline** instead of triggering a file
|
||||
download. We do this by asking S3 to override the content type &
|
||||
disposition on the response.
|
||||
"""
|
||||
try:
|
||||
async with self.session.client(
|
||||
"s3", region_name=self.region_name
|
||||
) as s3_client:
|
||||
params = {"Bucket": self.bucket_name, "Key": file_path}
|
||||
|
||||
# Make transcripts viewable inline in the browser when requested
|
||||
if force_inline and file_path.endswith(".txt"):
|
||||
params.update(
|
||||
{
|
||||
"ResponseContentType": "text/plain",
|
||||
"ResponseContentDisposition": "inline",
|
||||
}
|
||||
)
|
||||
|
||||
url = await s3_client.generate_presigned_url(
|
||||
"get_object",
|
||||
Params=params,
|
||||
ExpiresIn=expiration,
|
||||
)
|
||||
return url
|
||||
except ClientError:
|
||||
return None
|
||||
|
||||
async def aget_file_metadata(self, file_path: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get S3 object metadata."""
|
||||
try:
|
||||
async with self.session.client(
|
||||
"s3", region_name=self.region_name
|
||||
) as s3_client:
|
||||
response = await s3_client.head_object(
|
||||
Bucket=self.bucket_name, Key=file_path
|
||||
)
|
||||
return {
|
||||
"size": response.get("ContentLength"),
|
||||
"created_at": response.get("LastModified"),
|
||||
"modified_at": response.get("LastModified"),
|
||||
"etag": response.get("ETag", "").strip('"'),
|
||||
"content_type": response.get("ContentType"),
|
||||
"storage_class": response.get("StorageClass"),
|
||||
}
|
||||
except ClientError:
|
||||
return None
|
||||
219
api/services/gender/README.md
Normal file
219
api/services/gender/README.md
Normal file
|
|
@ -0,0 +1,219 @@
|
|||
# Gender Prediction Service
|
||||
|
||||
An internal service for predicting gender from first names using SSA (Social Security Administration) baby names data with GenderAPI fallback for uncertain predictions.
|
||||
|
||||
## Overview
|
||||
|
||||
This service provides gender prediction with:
|
||||
- **Local model** built from 145 years of SSA data (1880-2024)
|
||||
- **104,819 unique names** with confidence scores
|
||||
- **Compressed storage** (2.21 MB model file)
|
||||
- **GenderAPI fallback** for unknown or low-confidence names
|
||||
|
||||
## Data Source
|
||||
|
||||
The SSA baby names data is already downloaded in the `names/` directory from:
|
||||
https://catalog.data.gov/dataset/baby-names-from-social-security-card-applications-national-data
|
||||
|
||||
## Building the Model
|
||||
|
||||
### Prerequisites
|
||||
- Python 3.11+
|
||||
- SSA data files in `names/` directory (already included)
|
||||
|
||||
### Build Steps
|
||||
|
||||
```bash
|
||||
# Navigate to the gender service directory
|
||||
cd dograh/api/services/gender/
|
||||
|
||||
# Run the model builder
|
||||
python build_model.py
|
||||
```
|
||||
|
||||
This will:
|
||||
1. Process all 145 year files (yob1880.txt to yob2024.txt)
|
||||
2. Aggregate name counts across all years
|
||||
3. Calculate confidence scores based on gender ratios
|
||||
4. Generate a compressed `model.txt` file (~2.21 MB)
|
||||
|
||||
### Model Output
|
||||
|
||||
The builder generates `model.txt` with:
|
||||
- **Version**: Model version number
|
||||
- **Metadata**: Build date, statistics, thresholds
|
||||
- **Names**: Compressed array format `[male_count, female_count, confidence]`
|
||||
|
||||
Example output:
|
||||
```
|
||||
Model saved to: .../services/gender/model.txt
|
||||
File size: 2.21 MB
|
||||
|
||||
Model statistics:
|
||||
Total names: 104,819
|
||||
High confidence names (≥0.85): 1,711
|
||||
Confidence percentage: 1.6%
|
||||
```
|
||||
|
||||
## Using the Service
|
||||
|
||||
### Basic Usage
|
||||
|
||||
```python
|
||||
from services.gender.gender_service import GenderService
|
||||
|
||||
# Initialize the service
|
||||
service = GenderService()
|
||||
|
||||
# Predict gender for a single name
|
||||
result = await service.predict("John")
|
||||
print(f"Gender: {result.gender}") # "male"
|
||||
print(f"Confidence: {result.confidence}") # 0.996
|
||||
print(f"Source: {result.source}") # "model"
|
||||
|
||||
# Get salutation for a name
|
||||
greeting = await service.get_salutation("John")
|
||||
print(f"Salutation: {greeting}") # "Mr."
|
||||
|
||||
greeting = await service.get_salutation("Mary")
|
||||
print(f"Salutation: {greeting}") # "Ms."
|
||||
|
||||
greeting = await service.get_salutation("Unknown")
|
||||
print(f"Salutation: {greeting}") # "Dear"
|
||||
|
||||
# Clean up
|
||||
await service.close()
|
||||
```
|
||||
|
||||
### Configuration Options
|
||||
|
||||
```python
|
||||
# Custom configuration
|
||||
service = GenderService(
|
||||
model_path="custom/path/to/model.txt", # Default: ./model.txt
|
||||
confidence_threshold=0.85, # Default: 0.85
|
||||
gender_api_key="your-api-key", # Default: from GENDER_API_KEY env
|
||||
gender_api_url="https://..." # Default: GenderAPI v2 endpoint
|
||||
)
|
||||
```
|
||||
|
||||
### Salutation Generation
|
||||
|
||||
```python
|
||||
# Get appropriate salutation based on gender
|
||||
salutation = await service.get_salutation("John") # "Mr."
|
||||
salutation = await service.get_salutation("Mary") # "Ms."
|
||||
salutation = await service.get_salutation("Unknown") # "Dear"
|
||||
|
||||
# Custom confidence threshold for salutation
|
||||
salutation = await service.get_salutation(
|
||||
"Taylor", # Ambiguous name
|
||||
confidence_threshold=0.9 # Higher threshold
|
||||
) # Returns "Dear" due to low confidence
|
||||
|
||||
# Salutation logic:
|
||||
# - "Mr." for male with confidence >= threshold
|
||||
# - "Ms." for female with confidence >= threshold
|
||||
# - "Dear" for unknown gender or low confidence
|
||||
```
|
||||
|
||||
### Batch Predictions
|
||||
|
||||
```python
|
||||
# Predict multiple names at once
|
||||
names = ["Alice", "Bob", "Charlie", "Diana"]
|
||||
results = await service.batch_predict(names)
|
||||
|
||||
for name, result in zip(names, results):
|
||||
print(f"{name}: {result.gender} ({result.confidence:.2f})")
|
||||
```
|
||||
|
||||
### Response Format
|
||||
|
||||
```python
|
||||
class GenderPrediction:
|
||||
gender: "male" | "female" | "unknown" # Predicted gender
|
||||
confidence: float # 0.0 to 1.0
|
||||
source: "model" | "genderapi" # Prediction source
|
||||
```
|
||||
|
||||
### Service Statistics
|
||||
|
||||
```python
|
||||
# Get service statistics
|
||||
stats = await service.get_stats()
|
||||
print(f"Total names: {stats['model']['total_names']:,}")
|
||||
print(f"High confidence: {stats['model']['high_confidence_names']:,}")
|
||||
print(f"Cached names in Redis: {stats['cache']['cached_names']}")
|
||||
print(f"Cache TTL: {stats['cache']['ttl_seconds']} seconds")
|
||||
print(f"API enabled: {stats['api']['enabled']}")
|
||||
```
|
||||
|
||||
### Cache Management
|
||||
|
||||
```python
|
||||
# Clear Redis cache
|
||||
await service.clear_cache()
|
||||
```
|
||||
|
||||
## Environment Variables
|
||||
|
||||
```bash
|
||||
# Required: Redis connection URL
|
||||
export REDIS_URL=redis://localhost:6379
|
||||
|
||||
# Optional: Set GenderAPI key for fallback
|
||||
export GENDERAPI_API_KEY=your-api-key-here
|
||||
|
||||
# Optional: Override confidence threshold (default: 0.85)
|
||||
export CONFIDENCE_THRESHOLD=0.85
|
||||
```
|
||||
|
||||
## How It Works
|
||||
|
||||
1. **Name normalization**: Converts to lowercase, strips whitespace
|
||||
2. **Local model check**: Looks up name in pre-built model
|
||||
3. **Confidence evaluation**: If confidence ≥ 0.85, returns local prediction
|
||||
4. **Redis cache check**: Checks Redis for previously fetched API results
|
||||
5. **API fallback**: For unknown/low-confidence names, calls GenderAPI
|
||||
6. **Redis caching**: Stores API responses in Redis with 30-day TTL
|
||||
|
||||
## Testing
|
||||
|
||||
Run the test suite to verify the service:
|
||||
|
||||
```bash
|
||||
python test_service.py
|
||||
```
|
||||
|
||||
This tests:
|
||||
- High-confidence predictions
|
||||
- Ambiguous names
|
||||
- Edge cases (empty strings, special characters)
|
||||
- International names (with API key)
|
||||
- Batch predictions
|
||||
|
||||
## Model Updates
|
||||
|
||||
The model should be rebuilt annually when new SSA data is released:
|
||||
|
||||
1. Download new year file (e.g., yob2025.txt) to `names/` directory
|
||||
2. Run `python build_model.py` to rebuild
|
||||
3. Test with `python test_service.py`
|
||||
4. Commit the updated `model.txt`
|
||||
|
||||
## Performance
|
||||
|
||||
- **Model size**: 2.21 MB (compressed JSON)
|
||||
- **Load time**: < 100ms
|
||||
- **Prediction time**: < 1ms (local), < 5ms (Redis cache), < 500ms (API)
|
||||
- **Memory usage**: ~10 MB for model in memory
|
||||
- **Cache**: Redis-based with 30-day TTL
|
||||
- **Scalability**: Shared cache across all service instances
|
||||
|
||||
## Limitations
|
||||
|
||||
- Based on US SSA data (may not work well for non-US names)
|
||||
- Historical bias in older data
|
||||
- Unisex names have lower confidence
|
||||
- Requires GenderAPI key for comprehensive coverage
|
||||
0
api/services/gender/__init__.py
Normal file
0
api/services/gender/__init__.py
Normal file
164
api/services/gender/build_model.py
Normal file
164
api/services/gender/build_model.py
Normal file
|
|
@ -0,0 +1,164 @@
|
|||
"""
|
||||
Build gender prediction model from SSA baby names data.
|
||||
Generates a compressed JSON model file.
|
||||
"""
|
||||
|
||||
import json
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
from math import log10
|
||||
from pathlib import Path
|
||||
|
||||
from api.services.gender.constants import CONFIDENCE_THRESHOLD
|
||||
|
||||
|
||||
def calculate_confidence(male_count: int, female_count: int) -> float:
|
||||
"""Calculate confidence score for gender prediction."""
|
||||
total = male_count + female_count
|
||||
|
||||
# Minimum sample size requirement
|
||||
if total < 100:
|
||||
return 0.0
|
||||
|
||||
# Calculate gender ratio
|
||||
ratio = max(male_count, female_count) / total
|
||||
|
||||
# Apply logarithmic scaling for sample size
|
||||
# Max confidence at 100,000 occurrences
|
||||
sample_weight = min(1.0, log10(total) / 5)
|
||||
|
||||
# Final confidence
|
||||
return round(ratio * sample_weight, 4)
|
||||
|
||||
|
||||
def build_model():
|
||||
"""Build gender prediction model from SSA data."""
|
||||
# Initialize counters
|
||||
name_stats = defaultdict(lambda: {"M": 0, "F": 0})
|
||||
|
||||
# Get the path to names directory
|
||||
names_dir = Path(__file__).parent / "names"
|
||||
|
||||
if not names_dir.exists():
|
||||
raise FileNotFoundError(f"Names directory not found: {names_dir}")
|
||||
|
||||
file_count = 0
|
||||
# Process all year files
|
||||
for year_file in sorted(names_dir.glob("yob*.txt")):
|
||||
file_count += 1
|
||||
with open(year_file, "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
|
||||
parts = line.split(",")
|
||||
if len(parts) != 3:
|
||||
continue
|
||||
|
||||
name, gender, count = parts
|
||||
name = name.lower()
|
||||
count = int(count)
|
||||
|
||||
# Simple aggregation - no year weighting
|
||||
name_stats[name][gender] += count
|
||||
|
||||
print(f"Processed {file_count} year files")
|
||||
|
||||
# Build compressed model format
|
||||
names_data = {}
|
||||
high_confidence_count = 0
|
||||
|
||||
for name, stats in name_stats.items():
|
||||
male_count = stats["M"]
|
||||
female_count = stats["F"]
|
||||
|
||||
if male_count == 0 and female_count == 0:
|
||||
continue
|
||||
|
||||
# Calculate confidence
|
||||
confidence = calculate_confidence(male_count, female_count)
|
||||
|
||||
# Store as compact array: [male_count, female_count, confidence]
|
||||
names_data[name] = [male_count, female_count, confidence]
|
||||
|
||||
if confidence >= CONFIDENCE_THRESHOLD:
|
||||
high_confidence_count += 1
|
||||
|
||||
# Create final model structure with metadata
|
||||
model = {
|
||||
"version": "1.0",
|
||||
"metadata": {
|
||||
"confidence_threshold": CONFIDENCE_THRESHOLD,
|
||||
"total_names": len(names_data),
|
||||
"high_confidence_names": high_confidence_count,
|
||||
"build_date": datetime.now().isoformat(),
|
||||
"source_files": file_count,
|
||||
},
|
||||
"names": names_data,
|
||||
}
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def save_model(model, output_path="model.txt"):
|
||||
"""Save model to compressed JSON file."""
|
||||
output_file = Path(__file__).parent / output_path
|
||||
|
||||
# Write compressed JSON (no indentation for smaller size)
|
||||
with open(output_file, "w", encoding="utf-8") as f:
|
||||
json.dump(model, f, separators=(",", ":"))
|
||||
|
||||
# Calculate file size
|
||||
file_size_mb = output_file.stat().st_size / (1024 * 1024)
|
||||
|
||||
print(f"\nModel saved to: {output_file}")
|
||||
print(f"File size: {file_size_mb:.2f} MB")
|
||||
print(f"\nModel statistics:")
|
||||
print(f" Total names: {model['metadata']['total_names']:,}")
|
||||
print(
|
||||
f" High confidence names (≥{CONFIDENCE_THRESHOLD}): {model['metadata']['high_confidence_names']:,}"
|
||||
)
|
||||
print(
|
||||
f" Confidence percentage: {model['metadata']['high_confidence_names'] / model['metadata']['total_names'] * 100:.1f}%"
|
||||
)
|
||||
|
||||
|
||||
def test_model(model):
|
||||
"""Test model with sample names."""
|
||||
print("\nSample predictions:")
|
||||
test_names = [
|
||||
"john",
|
||||
"mary",
|
||||
"alex",
|
||||
"taylor",
|
||||
"michael",
|
||||
"sarah",
|
||||
"jordan",
|
||||
"casey",
|
||||
]
|
||||
|
||||
for name in test_names:
|
||||
if name in model["names"]:
|
||||
male_count, female_count, confidence = model["names"][name]
|
||||
gender = "male" if male_count > female_count else "female"
|
||||
print(
|
||||
f" {name.capitalize():10} -> {gender:6} (confidence: {confidence:.3f}, M:{male_count:,} F:{female_count:,})"
|
||||
)
|
||||
else:
|
||||
print(f" {name.capitalize():10} -> not found in dataset")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("Building gender prediction model from SSA data...")
|
||||
print("=" * 50)
|
||||
|
||||
try:
|
||||
model = build_model()
|
||||
save_model(model)
|
||||
test_model(model)
|
||||
print("\n✓ Model build complete!")
|
||||
|
||||
except Exception as e:
|
||||
print(f"\n✗ Error building model: {e}")
|
||||
raise
|
||||
10
api/services/gender/constants.py
Normal file
10
api/services/gender/constants.py
Normal file
|
|
@ -0,0 +1,10 @@
|
|||
"""
|
||||
Hyperparameters and configuration for gender prediction service.
|
||||
"""
|
||||
|
||||
# Confidence threshold for using local model predictions
|
||||
CONFIDENCE_THRESHOLD = 0.85
|
||||
|
||||
# Redis cache configuration
|
||||
REDIS_CACHE_TTL = 86400 * 30 # 30 days in seconds
|
||||
REDIS_KEY_PREFIX = "genderservice:"
|
||||
391
api/services/gender/gender_service.py
Normal file
391
api/services/gender/gender_service.py
Normal file
|
|
@ -0,0 +1,391 @@
|
|||
"""
|
||||
Gender prediction service with local model and GenderAPI fallback.
|
||||
Internal service for use within Dograh platform.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Literal, Optional
|
||||
|
||||
import httpx
|
||||
import redis.asyncio as aioredis
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from api.constants import REDIS_URL
|
||||
from api.services.gender.constants import (
|
||||
CONFIDENCE_THRESHOLD,
|
||||
REDIS_CACHE_TTL,
|
||||
REDIS_KEY_PREFIX,
|
||||
)
|
||||
|
||||
|
||||
class GenderPrediction(BaseModel):
|
||||
"""Gender prediction result."""
|
||||
|
||||
gender: Literal["male", "female", "unknown"] = Field(
|
||||
..., description="Predicted gender"
|
||||
)
|
||||
confidence: float = Field(..., ge=0, le=1, description="Confidence score (0-1)")
|
||||
source: Literal["model", "genderapi"] = Field(
|
||||
..., description="Source of prediction"
|
||||
)
|
||||
|
||||
|
||||
class GenderService:
|
||||
"""
|
||||
Internal service for predicting gender from names.
|
||||
Uses local SSA-based model with GenderAPI fallback.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_path: Optional[str] = None,
|
||||
confidence_threshold: float = CONFIDENCE_THRESHOLD,
|
||||
gender_api_key: Optional[str] = None,
|
||||
gender_api_url: str = "https://gender-api.com/v2/gender",
|
||||
):
|
||||
"""
|
||||
Initialize the gender service.
|
||||
|
||||
Args:
|
||||
model_path: Path to the model file (default: ./model.txt)
|
||||
confidence_threshold: Minimum confidence to use local model
|
||||
gender_api_key: API key for GenderAPI (falls back to env var)
|
||||
gender_api_url: GenderAPI endpoint URL
|
||||
"""
|
||||
self.confidence_threshold = confidence_threshold
|
||||
self.gender_api_key = gender_api_key or os.getenv("GENDERAPI_API_KEY")
|
||||
self.gender_api_url = gender_api_url
|
||||
|
||||
# Load model
|
||||
if model_path is None:
|
||||
model_path = Path(__file__).parent / "model.txt"
|
||||
else:
|
||||
model_path = Path(model_path)
|
||||
|
||||
self.model = self._load_model(model_path)
|
||||
self._http_client = None
|
||||
self._redis_client: Optional[aioredis.Redis] = None
|
||||
|
||||
def _load_model(self, model_path: Path) -> dict:
|
||||
"""Load the compressed gender prediction model."""
|
||||
if not model_path.exists():
|
||||
logger.warning(f"Warning: Model file not found at {model_path}")
|
||||
return {"metadata": {}, "names": {}}
|
||||
|
||||
try:
|
||||
with open(model_path, "r", encoding="utf-8") as f:
|
||||
model = json.load(f)
|
||||
|
||||
# Validate model structure
|
||||
if "names" not in model or "metadata" not in model:
|
||||
raise ValueError("Invalid model format")
|
||||
|
||||
logger.debug(
|
||||
f"Loaded gender prediction model with {model['metadata'].get('total_names', 0):,} names"
|
||||
)
|
||||
|
||||
return model
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading gender prediction model: {e}")
|
||||
return {"metadata": {}, "names": {}}
|
||||
|
||||
@property
|
||||
def http_client(self) -> httpx.AsyncClient:
|
||||
"""Get or create HTTP client for API calls."""
|
||||
if self._http_client is None:
|
||||
self._http_client = httpx.AsyncClient(
|
||||
timeout=httpx.Timeout(30.0),
|
||||
limits=httpx.Limits(max_keepalive_connections=5),
|
||||
)
|
||||
return self._http_client
|
||||
|
||||
async def _get_redis(self) -> aioredis.Redis:
|
||||
"""Get or create Redis connection."""
|
||||
if self._redis_client is None:
|
||||
self._redis_client = await aioredis.from_url(
|
||||
REDIS_URL, decode_responses=True
|
||||
)
|
||||
return self._redis_client
|
||||
|
||||
async def predict(
|
||||
self, first_name: str, last_name: Optional[str] = None
|
||||
) -> GenderPrediction:
|
||||
"""
|
||||
Predict gender for a given name.
|
||||
|
||||
Args:
|
||||
first_name: First name to predict gender for
|
||||
last_name: Last name (optional, not used in v1.0)
|
||||
|
||||
Returns:
|
||||
GenderPrediction with gender, confidence, and source
|
||||
"""
|
||||
if not first_name:
|
||||
return GenderPrediction(gender="unknown", confidence=0.0, source="model")
|
||||
|
||||
# Normalize name for lookup
|
||||
normalized_name = first_name.lower().strip()
|
||||
|
||||
# Step 1: Check local model
|
||||
if normalized_name in self.model["names"]:
|
||||
male_count, female_count, confidence = self.model["names"][normalized_name]
|
||||
|
||||
# Use local model if confidence meets threshold
|
||||
if confidence >= self.confidence_threshold:
|
||||
gender = "male" if male_count > female_count else "female"
|
||||
logger.debug(
|
||||
f"GenderService: Local Prediction {first_name} - {gender} with confidence: {confidence}"
|
||||
)
|
||||
return GenderPrediction(
|
||||
gender=gender, confidence=confidence, source="model"
|
||||
)
|
||||
else:
|
||||
logger.debug(
|
||||
f"GenderService: Low Confidence Local Prediction {first_name} - with confidence: {confidence}"
|
||||
)
|
||||
|
||||
# Step 2: Check Redis cache for previous API responses
|
||||
try:
|
||||
redis_client = await self._get_redis()
|
||||
cache_key = f"{REDIS_KEY_PREFIX}{normalized_name}"
|
||||
cached_data = await redis_client.get(cache_key)
|
||||
|
||||
if cached_data:
|
||||
cached_result = json.loads(cached_data)
|
||||
logger.debug(
|
||||
f"GenderService: Redis Cache Hit {first_name} - {cached_result['gender']} with confidence: {cached_result['confidence']}"
|
||||
)
|
||||
return GenderPrediction(**cached_result, source="genderapi")
|
||||
except Exception as e:
|
||||
logger.warning(f"Redis cache check failed: {e}")
|
||||
|
||||
# Step 3: Fallback to GenderAPI
|
||||
if self.gender_api_key:
|
||||
try:
|
||||
result = await self._call_gender_api(first_name)
|
||||
|
||||
# Cache the result in Redis
|
||||
try:
|
||||
redis_client = await self._get_redis()
|
||||
cache_key = f"{REDIS_KEY_PREFIX}{normalized_name}"
|
||||
cache_data = json.dumps(
|
||||
{"gender": result.gender, "confidence": result.confidence}
|
||||
)
|
||||
await redis_client.setex(cache_key, REDIS_CACHE_TTL, cache_data)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to cache result in Redis: {e}")
|
||||
|
||||
# No need for additional debug log here as _call_gender_api logs with timing
|
||||
return result
|
||||
except Exception as e:
|
||||
# Error already logged in _call_gender_api with timing
|
||||
pass
|
||||
|
||||
# Step 4: Return best guess from model or unknown
|
||||
if normalized_name in self.model["names"]:
|
||||
male_count, female_count, confidence = self.model["names"][normalized_name]
|
||||
gender = "male" if male_count > female_count else "female"
|
||||
return GenderPrediction(
|
||||
gender=gender, confidence=confidence, source="model"
|
||||
)
|
||||
|
||||
# Final fallback: unknown
|
||||
return GenderPrediction(gender="unknown", confidence=0.0, source="model")
|
||||
|
||||
async def _call_gender_api(self, first_name: str) -> GenderPrediction:
|
||||
"""
|
||||
Call GenderAPI for gender prediction.
|
||||
|
||||
Args:
|
||||
first_name: First name to predict
|
||||
|
||||
Returns:
|
||||
GenderPrediction from API response
|
||||
"""
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.gender_api_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
payload = {"first_name": first_name}
|
||||
|
||||
try:
|
||||
# Track API call timing
|
||||
start_time = time.perf_counter()
|
||||
|
||||
response = await self.http_client.post(
|
||||
self.gender_api_url, headers=headers, json=payload
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
# Calculate elapsed time
|
||||
elapsed_time = (
|
||||
time.perf_counter() - start_time
|
||||
) * 1000 # Convert to milliseconds
|
||||
|
||||
data = response.json()
|
||||
|
||||
# Map GenderAPI response format
|
||||
gender = data.get("gender", "unknown").lower()
|
||||
if gender not in ["male", "female"]:
|
||||
gender = "unknown"
|
||||
|
||||
# GenderAPI returns accuracy as probability
|
||||
confidence = data.get("probability", 0)
|
||||
|
||||
# Log the API call with timing
|
||||
logger.info(
|
||||
f"GenderAPI call for '{first_name}': {gender} with confidence {confidence:.2f} "
|
||||
f"(took {elapsed_time:.2f}ms)"
|
||||
)
|
||||
|
||||
return GenderPrediction(
|
||||
gender=gender, confidence=confidence, source="genderapi"
|
||||
)
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
# Log error with timing if we got a response
|
||||
elapsed_time = (
|
||||
(time.perf_counter() - start_time) * 1000
|
||||
if "start_time" in locals()
|
||||
else 0
|
||||
)
|
||||
logger.error(
|
||||
f"GenderAPI HTTP error for '{first_name}': {e.response.status_code} "
|
||||
f"(took {elapsed_time:.2f}ms)"
|
||||
)
|
||||
|
||||
if e.response.status_code == 401:
|
||||
raise ValueError("Invalid GenderAPI key")
|
||||
elif e.response.status_code == 429:
|
||||
raise ValueError("GenderAPI rate limit exceeded")
|
||||
else:
|
||||
raise ValueError(f"GenderAPI HTTP error: {e.response.status_code}")
|
||||
|
||||
except httpx.TimeoutException as e:
|
||||
elapsed_time = (
|
||||
(time.perf_counter() - start_time) * 1000
|
||||
if "start_time" in locals()
|
||||
else 0
|
||||
)
|
||||
logger.error(
|
||||
f"GenderAPI timeout for '{first_name}' after {elapsed_time:.2f}ms"
|
||||
)
|
||||
raise ValueError(f"GenderAPI request timed out")
|
||||
|
||||
except Exception as e:
|
||||
elapsed_time = (
|
||||
(time.perf_counter() - start_time) * 1000
|
||||
if "start_time" in locals()
|
||||
else 0
|
||||
)
|
||||
logger.error(
|
||||
f"GenderAPI unexpected error for '{first_name}': {str(e)} "
|
||||
f"(took {elapsed_time:.2f}ms)"
|
||||
)
|
||||
raise
|
||||
|
||||
async def get_salutation(
|
||||
self,
|
||||
first_name: str,
|
||||
last_name: Optional[str] = None,
|
||||
confidence_threshold: Optional[float] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Get appropriate salutation based on gender prediction.
|
||||
|
||||
Args:
|
||||
first_name: First name to predict gender for
|
||||
last_name: Last name (optional, not used in v1.0)
|
||||
confidence_threshold: Optional override for confidence threshold
|
||||
|
||||
Returns:
|
||||
"Mr." for male, "Ms." for female, "Dear" for unknown/low confidence
|
||||
"""
|
||||
if not first_name:
|
||||
return "Dear"
|
||||
|
||||
# Get gender prediction
|
||||
prediction = await self.predict(first_name, last_name)
|
||||
|
||||
# Return salutation based on gender and confidence
|
||||
if prediction.gender == "unknown":
|
||||
return "Dear"
|
||||
elif prediction.gender == "male":
|
||||
return "Mr."
|
||||
else: # female
|
||||
return "Ms."
|
||||
|
||||
async def batch_predict(self, names: list[str]) -> list[GenderPrediction]:
|
||||
"""
|
||||
Predict gender for multiple names.
|
||||
|
||||
Args:
|
||||
names: List of first names
|
||||
|
||||
Returns:
|
||||
List of GenderPrediction results
|
||||
"""
|
||||
results = []
|
||||
for name in names:
|
||||
result = await self.predict(name)
|
||||
results.append(result)
|
||||
return results
|
||||
|
||||
async def close(self):
|
||||
"""Close HTTP and Redis clients and cleanup resources."""
|
||||
if self._http_client:
|
||||
await self._http_client.aclose()
|
||||
self._http_client = None
|
||||
if self._redis_client:
|
||||
await self._redis_client.close()
|
||||
self._redis_client = None
|
||||
|
||||
async def get_stats(self) -> dict:
|
||||
"""Get statistics about the service and model."""
|
||||
metadata = self.model.get("metadata", {})
|
||||
|
||||
# Get Redis cache stats
|
||||
cache_stats = {}
|
||||
try:
|
||||
redis_client = await self._get_redis()
|
||||
# Count keys matching our prefix pattern
|
||||
keys = await redis_client.keys(f"{REDIS_KEY_PREFIX}*")
|
||||
cache_stats = {
|
||||
"cached_names": len(keys),
|
||||
"cache_type": "redis",
|
||||
"ttl_seconds": REDIS_CACHE_TTL,
|
||||
}
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get Redis stats: {e}")
|
||||
cache_stats = {"cached_names": 0, "cache_type": "redis", "error": str(e)}
|
||||
|
||||
return {
|
||||
"model": {
|
||||
"version": self.model.get("version", "unknown"),
|
||||
"total_names": metadata.get("total_names", 0),
|
||||
"high_confidence_names": metadata.get("high_confidence_names", 0),
|
||||
"confidence_threshold": self.confidence_threshold,
|
||||
"build_date": metadata.get("build_date", "unknown"),
|
||||
},
|
||||
"cache": cache_stats,
|
||||
"api": {"enabled": bool(self.gender_api_key), "url": self.gender_api_url},
|
||||
}
|
||||
|
||||
async def clear_cache(self):
|
||||
"""Clear the Redis cache for gender predictions."""
|
||||
try:
|
||||
redis_client = await self._get_redis()
|
||||
keys = await redis_client.keys(f"{REDIS_KEY_PREFIX}*")
|
||||
if keys:
|
||||
await redis_client.delete(*keys)
|
||||
logger.info(f"Cleared {len(keys)} entries from Redis cache")
|
||||
else:
|
||||
logger.debug("No cache entries to clear")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to clear Redis cache: {e}")
|
||||
1
api/services/gender/model.txt
Normal file
1
api/services/gender/model.txt
Normal file
File diff suppressed because one or more lines are too long
248
api/services/gender/test_service.py
Normal file
248
api/services/gender/test_service.py
Normal file
|
|
@ -0,0 +1,248 @@
|
|||
"""
|
||||
Test script for the gender prediction service.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
|
||||
from api.services.gender.gender_service import GenderService
|
||||
|
||||
|
||||
async def test_local_model():
|
||||
"""Test predictions using local model only."""
|
||||
print("\n" + "=" * 60)
|
||||
print("Testing Local Model Predictions")
|
||||
print("=" * 60)
|
||||
|
||||
# Initialize service without API key (local model only)
|
||||
service = GenderService()
|
||||
|
||||
# Test high-confidence names
|
||||
high_confidence_names = [
|
||||
"John",
|
||||
"Mary",
|
||||
"Michael",
|
||||
"Sarah",
|
||||
"Robert",
|
||||
"Lisa",
|
||||
"William",
|
||||
"Jennifer",
|
||||
"David",
|
||||
"Patricia",
|
||||
]
|
||||
|
||||
print("\nHigh-confidence predictions (should use local model):")
|
||||
print("-" * 50)
|
||||
for name in high_confidence_names:
|
||||
result = await service.predict(name)
|
||||
print(
|
||||
f" {name:12} -> {result.gender:6} (conf: {result.confidence:.3f}, source: {result.source})"
|
||||
)
|
||||
|
||||
# Test ambiguous names
|
||||
ambiguous_names = [
|
||||
"Taylor",
|
||||
"Jordan",
|
||||
"Casey",
|
||||
"Alex",
|
||||
"Morgan",
|
||||
"Blake",
|
||||
"Avery",
|
||||
"Riley",
|
||||
"Quinn",
|
||||
"Sage",
|
||||
]
|
||||
|
||||
print("\nAmbiguous names (lower confidence):")
|
||||
print("-" * 50)
|
||||
for name in ambiguous_names:
|
||||
result = await service.predict(name)
|
||||
status = "✓" if result.source == "model" and result.confidence >= 0.85 else "○"
|
||||
print(
|
||||
f" {status} {name:12} -> {result.gender:6} (conf: {result.confidence:.3f}, source: {result.source})"
|
||||
)
|
||||
|
||||
# Test unknown names
|
||||
unknown_names = ["Xyzabc", "Qwerty", "Abcdef"]
|
||||
|
||||
print("\nUnknown names (not in dataset):")
|
||||
print("-" * 50)
|
||||
for name in unknown_names:
|
||||
result = await service.predict(name)
|
||||
print(
|
||||
f" {name:12} -> {result.gender:7} (conf: {result.confidence:.3f}, source: {result.source})"
|
||||
)
|
||||
|
||||
# Get service statistics
|
||||
stats = await service.get_stats()
|
||||
print("\nService Statistics:")
|
||||
print("-" * 50)
|
||||
print(f" Model version: {stats['model']['version']}")
|
||||
print(f" Total names: {stats['model']['total_names']:,}")
|
||||
print(f" High confidence: {stats['model']['high_confidence_names']:,}")
|
||||
print(f" Threshold: {stats['model']['confidence_threshold']}")
|
||||
print(f" Cache type: {stats['cache'].get('cache_type', 'unknown')}")
|
||||
print(f" Cached names: {stats['cache'].get('cached_names', 0)}")
|
||||
print(f" API enabled: {stats['api']['enabled']}")
|
||||
|
||||
await service.close()
|
||||
|
||||
|
||||
async def test_with_api():
|
||||
"""Test with GenderAPI integration (requires API key)."""
|
||||
print("\n" + "=" * 60)
|
||||
print("Testing with GenderAPI Integration")
|
||||
print("=" * 60)
|
||||
|
||||
# Check if API key is available
|
||||
import os
|
||||
|
||||
api_key = os.getenv("GENDER_API_KEY")
|
||||
|
||||
if not api_key:
|
||||
print("\n⚠️ No GENDER_API_KEY found in environment")
|
||||
print(" Skipping API integration tests")
|
||||
print(" To test API fallback, set: export GENDER_API_KEY=your_key")
|
||||
return
|
||||
|
||||
service = GenderService(gender_api_key=api_key)
|
||||
|
||||
# Test names that might need API fallback
|
||||
test_names = [
|
||||
"Priya", # Indian name, might not be in SSA data
|
||||
"Hiroshi", # Japanese name
|
||||
"Giovanni", # Italian name
|
||||
"Olga", # Russian name
|
||||
"Chen", # Chinese name
|
||||
]
|
||||
|
||||
print("\nInternational names (may use API fallback):")
|
||||
print("-" * 50)
|
||||
for name in test_names:
|
||||
result = await service.predict(name)
|
||||
print(
|
||||
f" {name:12} -> {result.gender:6} (conf: {result.confidence:.3f}, source: {result.source})"
|
||||
)
|
||||
|
||||
# Test batch prediction
|
||||
print("\nBatch prediction test:")
|
||||
print("-" * 50)
|
||||
batch_names = ["Alice", "Bob", "Charlie", "Diana", "Eve"]
|
||||
results = await service.batch_predict(batch_names)
|
||||
for name, result in zip(batch_names, results):
|
||||
print(f" {name:12} -> {result.gender:6} (conf: {result.confidence:.3f})")
|
||||
|
||||
await service.close()
|
||||
|
||||
|
||||
async def test_salutation():
|
||||
"""Test salutation generation."""
|
||||
print("\n" + "=" * 60)
|
||||
print("Testing Salutation Generation")
|
||||
print("=" * 60)
|
||||
|
||||
service = GenderService()
|
||||
|
||||
# Test high-confidence names
|
||||
test_cases = [
|
||||
("John", "Mr."),
|
||||
("Mary", "Ms."),
|
||||
("Michael", "Mr."),
|
||||
("Sarah", "Ms."),
|
||||
("Robert", "Mr."),
|
||||
("Jennifer", "Ms."),
|
||||
]
|
||||
|
||||
print("\nHigh-confidence salutations:")
|
||||
print("-" * 50)
|
||||
for name, expected in test_cases:
|
||||
salutation = await service.get_salutation(name)
|
||||
status = "✓" if salutation == expected else "✗"
|
||||
print(f" {status} {name:12} -> {salutation:4} (expected: {expected})")
|
||||
|
||||
# Test ambiguous/unknown names
|
||||
ambiguous_cases = [
|
||||
"Xyzabc", # Unknown name
|
||||
"Qwerty", # Unknown name
|
||||
"", # Empty string
|
||||
" ", # Whitespace
|
||||
"123", # Numbers
|
||||
]
|
||||
|
||||
print("\nUnknown/ambiguous names (should return 'Dear'):")
|
||||
print("-" * 50)
|
||||
for name in ambiguous_cases:
|
||||
salutation = await service.get_salutation(name)
|
||||
display_name = f"'{name}'" if name else "(empty)"
|
||||
status = "✓" if salutation == "Dear" else "✗"
|
||||
print(f" {status} {display_name:12} -> {salutation}")
|
||||
|
||||
# Test with custom confidence threshold
|
||||
print("\nCustom confidence threshold test:")
|
||||
print("-" * 50)
|
||||
# Taylor has confidence ~0.744, should be "Dear" with high threshold
|
||||
salutation_default = await service.get_salutation("Taylor")
|
||||
salutation_high = await service.get_salutation("Taylor", confidence_threshold=0.9)
|
||||
print(f" Taylor (default threshold): {salutation_default}")
|
||||
print(f" Taylor (0.9 threshold): {salutation_high}")
|
||||
|
||||
await service.close()
|
||||
|
||||
|
||||
async def test_edge_cases():
|
||||
"""Test edge cases and error handling."""
|
||||
print("\n" + "=" * 60)
|
||||
print("Testing Edge Cases")
|
||||
print("=" * 60)
|
||||
|
||||
service = GenderService()
|
||||
|
||||
# Test empty/invalid inputs
|
||||
edge_cases = [
|
||||
"", # Empty string
|
||||
" ", # Whitespace
|
||||
"123", # Numbers
|
||||
"John-Paul", # Hyphenated
|
||||
"Mary Ann", # Space in name
|
||||
"O'Brien", # Apostrophe
|
||||
"José", # Accented
|
||||
]
|
||||
|
||||
print("\nEdge case inputs:")
|
||||
print("-" * 50)
|
||||
for name in edge_cases:
|
||||
result = await service.predict(name)
|
||||
display_name = f"'{name}'" if name else "(empty)"
|
||||
print(
|
||||
f" {display_name:12} -> {result.gender:7} (conf: {result.confidence:.3f})"
|
||||
)
|
||||
|
||||
# Test case insensitivity
|
||||
print("\nCase insensitivity test:")
|
||||
print("-" * 50)
|
||||
case_variants = ["john", "JOHN", "John", "JoHn"]
|
||||
for name in case_variants:
|
||||
result = await service.predict(name)
|
||||
print(f" {name:12} -> {result.gender:6} (conf: {result.confidence:.3f})")
|
||||
|
||||
await service.close()
|
||||
|
||||
|
||||
async def main():
|
||||
"""Run all tests."""
|
||||
print("\n" + "=" * 60)
|
||||
print("Gender Prediction Service Test Suite")
|
||||
print("=" * 60)
|
||||
|
||||
# Run tests
|
||||
await test_local_model()
|
||||
await test_salutation()
|
||||
await test_edge_cases()
|
||||
await test_with_api()
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("✓ All tests completed!")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
0
api/services/integrations/__init__.py
Normal file
0
api/services/integrations/__init__.py
Normal file
253
api/services/integrations/nango.py
Normal file
253
api/services/integrations/nango.py
Normal file
|
|
@ -0,0 +1,253 @@
|
|||
import hashlib
|
||||
import json
|
||||
import os
|
||||
from typing import Any, Dict
|
||||
|
||||
import httpx
|
||||
from fastapi import HTTPException
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel
|
||||
|
||||
from api.db import db_client
|
||||
|
||||
NANGO_ALLOWED_INTEGRATIONS = [
|
||||
i.strip() for i in os.environ.get("NANGO_ALLOWED_INTEGRATIONS", "slack").split(",")
|
||||
]
|
||||
|
||||
|
||||
class NangoWebhookRequest(BaseModel):
|
||||
type: str
|
||||
connectionId: str
|
||||
providerConfigKey: str
|
||||
authMode: str
|
||||
provider: str
|
||||
environment: str
|
||||
operation: str
|
||||
endUser: dict # Contains endUserId and organizationId
|
||||
success: bool
|
||||
|
||||
|
||||
class NangoService:
|
||||
def __init__(self):
|
||||
self.base_url = "https://api.nango.dev"
|
||||
self.secret_key = os.getenv("NANGO_API_KEY")
|
||||
|
||||
def _verify_webhook_signature(
|
||||
self, request_body: str, signature: str = None
|
||||
) -> bool:
|
||||
"""
|
||||
Verify the webhook signature using SHA256 hash.
|
||||
|
||||
Args:
|
||||
request_body: The raw request body as string
|
||||
signature: The signature from request headers (optional for now)
|
||||
|
||||
Returns:
|
||||
True if signature is valid
|
||||
"""
|
||||
expected_signature = self.secret_key + request_body
|
||||
expected_hash = hashlib.sha256(expected_signature.encode("utf-8")).hexdigest()
|
||||
return expected_hash == signature
|
||||
|
||||
async def create_session(
|
||||
self, user_id: str, organization_id: int
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Create a Nango session for the given user and organization.
|
||||
|
||||
Args:
|
||||
user_id: The end user ID
|
||||
organization_id: The organization ID
|
||||
|
||||
Returns:
|
||||
Response from Nango API
|
||||
"""
|
||||
if not self.secret_key:
|
||||
raise ValueError("NANGO_SECRET_KEY environment variable is not set")
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.secret_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
payload = {
|
||||
"end_user": {"id": user_id},
|
||||
"organization": {"id": str(organization_id)},
|
||||
"allowed_integrations": NANGO_ALLOWED_INTEGRATIONS,
|
||||
}
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
f"{self.base_url}/connect/sessions", headers=headers, json=payload
|
||||
)
|
||||
|
||||
if response.status_code != 201:
|
||||
raise httpx.HTTPStatusError(
|
||||
f"Nango API error: {response.status_code}",
|
||||
request=response.request,
|
||||
response=response,
|
||||
)
|
||||
|
||||
return response.json()
|
||||
|
||||
async def process_webhook(
|
||||
self, raw_body: bytes, signature: str = None
|
||||
) -> Dict[str, str]:
|
||||
"""
|
||||
Process incoming Nango webhook request.
|
||||
|
||||
Args:
|
||||
raw_body: The raw request body as bytes
|
||||
signature: Optional signature from request headers
|
||||
|
||||
Returns:
|
||||
Dict with status and message
|
||||
"""
|
||||
# Decode and parse the request body
|
||||
try:
|
||||
body_text = raw_body.decode("utf-8")
|
||||
webhook_json = json.loads(body_text) if body_text else {}
|
||||
logger.debug(f"received webhook from nango: {webhook_json}")
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"JSON decode error: {e} body_text: {body_text}")
|
||||
raise HTTPException(status_code=400, detail=f"Invalid JSON: {str(e)}")
|
||||
|
||||
# Verify webhook signature
|
||||
if not self._verify_webhook_signature(body_text, signature):
|
||||
raise HTTPException(status_code=401, detail="Invalid webhook signature")
|
||||
|
||||
# Parse webhook data
|
||||
try:
|
||||
webhook_data = NangoWebhookRequest(**webhook_json)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to parse webhook data: {e}")
|
||||
raise HTTPException(
|
||||
status_code=400, detail=f"Invalid webhook format: {str(e)}"
|
||||
)
|
||||
|
||||
# Extract user and organization IDs from the webhook payload
|
||||
end_user = webhook_data.endUser
|
||||
if (
|
||||
not end_user
|
||||
or "endUserId" not in end_user
|
||||
or "organizationId" not in end_user
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Missing endUser information in webhook payload"
|
||||
)
|
||||
|
||||
user_id = int(end_user["endUserId"])
|
||||
organization_id = int(end_user["organizationId"])
|
||||
|
||||
# Use the connectionId as the integration_id since it's unique per integration
|
||||
integration_id = webhook_data.connectionId
|
||||
|
||||
# Initialize connection_details
|
||||
connection_details = {}
|
||||
|
||||
# Fetch connection details if type is auth and provider is slack
|
||||
if webhook_data.type == "auth":
|
||||
connection_details = await self._fetch_connection_details(
|
||||
integration_id, webhook_data.provider
|
||||
)
|
||||
|
||||
# Create the integration in the database
|
||||
integration = await db_client.create_integration(
|
||||
integration_id=integration_id,
|
||||
organisation_id=organization_id,
|
||||
provider=webhook_data.provider,
|
||||
created_by=user_id,
|
||||
is_active=True,
|
||||
connection_details=connection_details,
|
||||
)
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"message": f"Integration created successfully with ID: {integration.id}",
|
||||
}
|
||||
|
||||
async def _fetch_connection_details(
|
||||
self, connection_id: str, provider_key: str
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Fetch connection details from Nango API for a given connection ID.
|
||||
|
||||
Args:
|
||||
connection_id: The connection ID from the webhook
|
||||
|
||||
Returns:
|
||||
Connection details as a dictionary
|
||||
"""
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.secret_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
url = f"{self.base_url}/connection/{connection_id}/?provider_config_key={provider_key}"
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
try:
|
||||
response = await client.get(url, headers=headers)
|
||||
|
||||
if response.status_code != 200:
|
||||
logger.error(
|
||||
f"Failed to fetch connection details: {response.status_code} - {response.text}"
|
||||
)
|
||||
raise httpx.HTTPStatusError(
|
||||
f"Nango API error while fetching connection: {response.status_code}",
|
||||
request=response.request,
|
||||
response=response,
|
||||
)
|
||||
|
||||
connection_details = response.json()
|
||||
return connection_details
|
||||
|
||||
except httpx.HTTPError as e:
|
||||
logger.error(f"HTTP error while fetching connection details: {e}")
|
||||
# Return empty dict if API call fails, but log the error
|
||||
return {}
|
||||
|
||||
async def get_access_token(
|
||||
self, connection_id: str, provider_config_key: str
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Get the latest access token for a connection from Nango.
|
||||
|
||||
Args:
|
||||
connection_id: The connection ID
|
||||
provider_config_key: The provider config key (e.g., 'google-sheet')
|
||||
|
||||
Returns:
|
||||
Dict containing access token and other connection details
|
||||
"""
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.secret_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
url = f"{self.base_url}/connection/{connection_id}?provider_config_key={provider_config_key}"
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
try:
|
||||
response = await client.get(url, headers=headers)
|
||||
|
||||
if response.status_code != 200:
|
||||
logger.error(
|
||||
f"Failed to get access token: {response.status_code} - {response.text}"
|
||||
)
|
||||
raise httpx.HTTPStatusError(
|
||||
f"Nango API error: {response.status_code}",
|
||||
request=response.request,
|
||||
response=response,
|
||||
)
|
||||
|
||||
return response.json()
|
||||
|
||||
except httpx.HTTPError as e:
|
||||
logger.error(f"HTTP error while getting access token: {e}")
|
||||
raise
|
||||
|
||||
|
||||
# Create a singleton instance
|
||||
nango_service = NangoService()
|
||||
3
api/services/looptalk/__init__.py
Normal file
3
api/services/looptalk/__init__.py
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
from .orchestrator import LoopTalkTestOrchestrator
|
||||
|
||||
__all__ = ["LoopTalkTestOrchestrator"]
|
||||
220
api/services/looptalk/audio_streamer.py
Normal file
220
api/services/looptalk/audio_streamer.py
Normal file
|
|
@ -0,0 +1,220 @@
|
|||
"""
|
||||
Audio streaming processor for LoopTalk real-time audio monitoring.
|
||||
|
||||
This processor captures audio from both actor and adversary agents and streams
|
||||
it to connected WebRTC clients for real-time monitoring.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import Dict, Set
|
||||
|
||||
from loguru import logger
|
||||
from pipecat.audio.utils import mix_audio
|
||||
from pipecat.frames.frames import (
|
||||
Frame,
|
||||
InputAudioRawFrame,
|
||||
OutputAudioRawFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
|
||||
|
||||
class LoopTalkAudioStreamer(FrameProcessor):
|
||||
"""
|
||||
Processes audio frames from LoopTalk conversations and streams to WebRTC clients.
|
||||
|
||||
This processor sits in the pipeline and captures all audio frames, then
|
||||
forwards them to connected WebRTC clients for real-time monitoring.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
test_session_id: str,
|
||||
role: str, # "actor" or "adversary"
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self._test_session_id = test_session_id
|
||||
self._role = role
|
||||
self._listeners: Set[asyncio.Queue] = set()
|
||||
self._sample_rate = 16000 # Default sample rate
|
||||
self._num_channels = 1
|
||||
|
||||
def add_listener(self, queue: asyncio.Queue):
|
||||
"""Add a listener queue for streaming audio."""
|
||||
self._listeners.add(queue)
|
||||
|
||||
def remove_listener(self, queue: asyncio.Queue):
|
||||
"""Remove a listener queue."""
|
||||
self._listeners.discard(queue)
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
"""Process audio frames and stream to listeners."""
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
# Capture both input and output audio
|
||||
if isinstance(frame, (InputAudioRawFrame, OutputAudioRawFrame)):
|
||||
# Extract audio data
|
||||
audio_data = frame.audio
|
||||
sample_rate = frame.sample_rate
|
||||
num_channels = frame.num_channels
|
||||
|
||||
# Store sample rate for reference
|
||||
if sample_rate:
|
||||
self._sample_rate = sample_rate
|
||||
if num_channels:
|
||||
self._num_channels = num_channels
|
||||
|
||||
# Stream to all listeners
|
||||
if self._listeners and audio_data:
|
||||
# Create a packet with metadata
|
||||
packet = {
|
||||
"test_session_id": self._test_session_id,
|
||||
"role": self._role,
|
||||
"audio": audio_data,
|
||||
"sample_rate": sample_rate,
|
||||
"num_channels": num_channels,
|
||||
"is_input": isinstance(frame, InputAudioRawFrame),
|
||||
}
|
||||
|
||||
# Send to all listeners without blocking
|
||||
for queue in list(self._listeners):
|
||||
try:
|
||||
queue.put_nowait(packet)
|
||||
except asyncio.QueueFull:
|
||||
logger.warning(
|
||||
f"Audio queue full for session {self._test_session_id}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error streaming audio: {e}")
|
||||
self._listeners.discard(queue)
|
||||
elif self._listeners and not audio_data:
|
||||
logger.warning(
|
||||
f"Audio streamer {self._role} received frame with no audio data"
|
||||
)
|
||||
elif audio_data and not self._listeners:
|
||||
# This is expected early in the session before WebSocket connects
|
||||
pass
|
||||
|
||||
# Always forward the frame
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
|
||||
class LoopTalkAudioMixer:
|
||||
"""
|
||||
Mixes audio from actor and adversary streams for combined playback.
|
||||
|
||||
This class manages the mixing of two audio streams (actor and adversary)
|
||||
to create a combined audio stream for monitoring.
|
||||
"""
|
||||
|
||||
def __init__(self, test_session_id: str):
|
||||
self._test_session_id = test_session_id
|
||||
self._actor_buffer = bytearray()
|
||||
self._adversary_buffer = bytearray()
|
||||
self._listeners: Set[asyncio.Queue] = set()
|
||||
self._sample_rate = 16000
|
||||
self._num_channels = 1
|
||||
self._buffer_size = 8000 # 0.5 seconds at 16kHz
|
||||
|
||||
def add_listener(self, queue: asyncio.Queue):
|
||||
"""Add a listener for mixed audio."""
|
||||
self._listeners.add(queue)
|
||||
|
||||
def remove_listener(self, queue: asyncio.Queue):
|
||||
"""Remove a listener."""
|
||||
self._listeners.discard(queue)
|
||||
|
||||
async def add_audio(
|
||||
self, role: str, audio_data: bytes, sample_rate: int, num_channels: int
|
||||
):
|
||||
"""Add audio data from actor or adversary."""
|
||||
if role == "actor":
|
||||
self._actor_buffer.extend(audio_data)
|
||||
elif role == "adversary":
|
||||
self._adversary_buffer.extend(audio_data)
|
||||
|
||||
# Update audio parameters
|
||||
self._sample_rate = sample_rate
|
||||
self._num_channels = num_channels
|
||||
|
||||
# Check if we have enough data to mix
|
||||
await self._check_and_mix()
|
||||
|
||||
async def _check_and_mix(self):
|
||||
"""Check buffers and mix audio when enough data is available."""
|
||||
# Mix when we have at least buffer_size in both buffers
|
||||
while (
|
||||
len(self._actor_buffer) >= self._buffer_size
|
||||
and len(self._adversary_buffer) >= self._buffer_size
|
||||
):
|
||||
# Extract chunks
|
||||
actor_chunk = bytes(self._actor_buffer[: self._buffer_size])
|
||||
adversary_chunk = bytes(self._adversary_buffer[: self._buffer_size])
|
||||
|
||||
# Remove from buffers
|
||||
del self._actor_buffer[: self._buffer_size]
|
||||
del self._adversary_buffer[: self._buffer_size]
|
||||
|
||||
# Mix audio
|
||||
mixed_audio = mix_audio(actor_chunk, adversary_chunk)
|
||||
|
||||
# Stream to listeners
|
||||
if self._listeners and mixed_audio:
|
||||
packet = {
|
||||
"test_session_id": self._test_session_id,
|
||||
"role": "mixed",
|
||||
"audio": mixed_audio,
|
||||
"sample_rate": self._sample_rate,
|
||||
"num_channels": self._num_channels,
|
||||
"is_input": False,
|
||||
}
|
||||
|
||||
for queue in list(self._listeners):
|
||||
try:
|
||||
queue.put_nowait(packet)
|
||||
except asyncio.QueueFull:
|
||||
logger.warning(
|
||||
f"Mixed audio queue full for session {self._test_session_id}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error streaming mixed audio: {e}")
|
||||
self._listeners.discard(queue)
|
||||
|
||||
|
||||
# Global registry for audio streamers and mixers
|
||||
_audio_streamers: Dict[str, Dict[str, LoopTalkAudioStreamer]] = {}
|
||||
_audio_mixers: Dict[str, LoopTalkAudioMixer] = {}
|
||||
|
||||
|
||||
def get_or_create_audio_streamer(
|
||||
test_session_id: str, role: str
|
||||
) -> LoopTalkAudioStreamer:
|
||||
"""Get or create an audio streamer for a test session and role."""
|
||||
if test_session_id not in _audio_streamers:
|
||||
_audio_streamers[test_session_id] = {}
|
||||
|
||||
if role not in _audio_streamers[test_session_id]:
|
||||
_audio_streamers[test_session_id][role] = LoopTalkAudioStreamer(
|
||||
test_session_id=test_session_id, role=role
|
||||
)
|
||||
|
||||
return _audio_streamers[test_session_id][role]
|
||||
|
||||
|
||||
def get_or_create_audio_mixer(test_session_id: str) -> LoopTalkAudioMixer:
|
||||
"""Get or create an audio mixer for a test session."""
|
||||
if test_session_id not in _audio_mixers:
|
||||
_audio_mixers[test_session_id] = LoopTalkAudioMixer(test_session_id)
|
||||
|
||||
return _audio_mixers[test_session_id]
|
||||
|
||||
|
||||
def cleanup_audio_streamers(test_session_id: str):
|
||||
"""Clean up audio streamers and mixers for a test session."""
|
||||
if test_session_id in _audio_streamers:
|
||||
del _audio_streamers[test_session_id]
|
||||
|
||||
if test_session_id in _audio_mixers:
|
||||
del _audio_mixers[test_session_id]
|
||||
|
||||
logger.info(f"Cleaned up audio streamers for test session {test_session_id}")
|
||||
1
api/services/looptalk/core/__init__.py
Normal file
1
api/services/looptalk/core/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
|||
"""Core modules for LoopTalk orchestration."""
|
||||
167
api/services/looptalk/core/pipeline_builder.py
Normal file
167
api/services/looptalk/core/pipeline_builder.py
Normal file
|
|
@ -0,0 +1,167 @@
|
|||
"""Pipeline building logic for LoopTalk agents."""
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
from loguru import logger
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.processors.filters.stt_mute_filter import (
|
||||
STTMuteConfig,
|
||||
STTMuteFilter,
|
||||
STTMuteStrategy,
|
||||
)
|
||||
from pipecat.transports import InternalTransport
|
||||
|
||||
from api.db.db_client import DBClient
|
||||
from api.services.looptalk.audio_streamer import get_or_create_audio_streamer
|
||||
from api.services.pipecat.audio_config import AudioConfig
|
||||
from api.services.pipecat.pipeline_builder import (
|
||||
create_pipeline_components,
|
||||
create_pipeline_task,
|
||||
)
|
||||
from api.services.pipecat.pipeline_engine_callbacks_processor import (
|
||||
PipelineEngineCallbacksProcessor,
|
||||
)
|
||||
from api.services.pipecat.service_factory import (
|
||||
create_llm_service,
|
||||
create_stt_service,
|
||||
create_tts_service,
|
||||
)
|
||||
from api.services.workflow.dto import ReactFlowDTO
|
||||
from api.services.workflow.pipecat_engine import PipecatEngine
|
||||
from api.services.workflow.workflow import WorkflowGraph
|
||||
|
||||
|
||||
class LoopTalkPipelineBuilder:
|
||||
"""Builds pipelines for LoopTalk agents."""
|
||||
|
||||
def __init__(self, db_client: DBClient):
|
||||
"""Initialize the pipeline builder.
|
||||
|
||||
Args:
|
||||
db_client: Database client for fetching user configurations
|
||||
"""
|
||||
self.db_client = db_client
|
||||
|
||||
async def create_agent_pipeline(
|
||||
self,
|
||||
transport: InternalTransport,
|
||||
workflow: Any,
|
||||
test_session_id: int,
|
||||
agent_id: str,
|
||||
role: str,
|
||||
) -> Dict[str, Any]:
|
||||
"""Create a pipeline for an agent (actor or adversary).
|
||||
|
||||
Args:
|
||||
transport: Internal transport for the agent
|
||||
workflow: Workflow model from database
|
||||
test_session_id: ID of the test session
|
||||
agent_id: Unique identifier for the agent
|
||||
role: Either "actor" or "adversary"
|
||||
|
||||
Returns:
|
||||
Dictionary containing pipeline task, engine, and components
|
||||
"""
|
||||
# Get user configuration from database
|
||||
user_config = await self.db_client.get_user_configurations(workflow.user_id)
|
||||
|
||||
# Create pipeline components
|
||||
audio_config = AudioConfig(
|
||||
transport_in_sample_rate=16000,
|
||||
transport_out_sample_rate=16000,
|
||||
vad_sample_rate=16000,
|
||||
pipeline_sample_rate=16000,
|
||||
)
|
||||
|
||||
# Create services
|
||||
stt = create_stt_service(user_config)
|
||||
llm = create_llm_service(user_config)
|
||||
tts = create_tts_service(user_config, audio_config)
|
||||
|
||||
logger.debug(f"Created services for {role}: STT={stt}, LLM={llm}, TTS={tts}")
|
||||
|
||||
audio_buffer, audio_synchronizer, transcript, context = (
|
||||
create_pipeline_components(audio_config)
|
||||
)
|
||||
context_aggregator = llm.create_context_aggregator(context)
|
||||
|
||||
# Get workflow graph
|
||||
workflow_graph = WorkflowGraph(
|
||||
ReactFlowDTO.model_validate(workflow.workflow_definition_with_fallback)
|
||||
)
|
||||
|
||||
# Create engine
|
||||
engine = PipecatEngine(
|
||||
task=None, # Will be set after creating the task
|
||||
llm=llm,
|
||||
context=context,
|
||||
tts=tts,
|
||||
workflow=workflow_graph,
|
||||
call_context_vars={},
|
||||
audio_buffer=audio_buffer,
|
||||
workflow_run_id=None, # LoopTalk doesn't have workflow runs
|
||||
)
|
||||
|
||||
# Create STT mute filter
|
||||
stt_mute_filter = STTMuteFilter(
|
||||
config=STTMuteConfig(
|
||||
strategies={STTMuteStrategy.FIRST_SPEECH},
|
||||
)
|
||||
)
|
||||
|
||||
# Create pipeline engine callback processor
|
||||
pipeline_engine_callback_processor = PipelineEngineCallbacksProcessor(
|
||||
max_call_duration_seconds=300,
|
||||
max_duration_end_task_callback=engine.create_max_duration_callback(),
|
||||
llm_generated_text_callback=engine.create_llm_generated_text_callback(),
|
||||
generation_started_callback=engine.create_generation_started_callback(),
|
||||
)
|
||||
|
||||
# Get aggregators
|
||||
user_context_aggregator = context_aggregator.user()
|
||||
assistant_context_aggregator = context_aggregator.assistant()
|
||||
|
||||
# Register processors with synchronizer for merged audio
|
||||
audio_synchronizer.register_processors(
|
||||
audio_buffer.input(), audio_buffer.output()
|
||||
)
|
||||
|
||||
# Get audio streamer for real-time streaming
|
||||
audio_streamer = get_or_create_audio_streamer(str(test_session_id), role)
|
||||
|
||||
# Create pipeline
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
transport.input(),
|
||||
audio_buffer.input(), # Record input audio
|
||||
audio_streamer, # Stream audio to connected clients
|
||||
stt_mute_filter,
|
||||
stt,
|
||||
transcript.user(),
|
||||
user_context_aggregator,
|
||||
llm,
|
||||
pipeline_engine_callback_processor,
|
||||
tts,
|
||||
transport.output(),
|
||||
audio_buffer.output(), # Record output audio
|
||||
transcript.assistant(),
|
||||
assistant_context_aggregator,
|
||||
]
|
||||
)
|
||||
|
||||
# Create pipeline task with unique conversation ID for tracing
|
||||
conversation_id = f"{test_session_id}-{role}-{agent_id}"
|
||||
task = create_pipeline_task(pipeline, conversation_id, audio_config)
|
||||
|
||||
# Set the task on the engine
|
||||
engine.task = task
|
||||
|
||||
return {
|
||||
"task": task,
|
||||
"engine": engine,
|
||||
"audio_buffer": audio_buffer,
|
||||
"audio_synchronizer": audio_synchronizer,
|
||||
"transcript": transcript,
|
||||
"assistant_context_aggregator": assistant_context_aggregator,
|
||||
"audio_streamer": audio_streamer,
|
||||
}
|
||||
216
api/services/looptalk/core/recording_manager.py
Normal file
216
api/services/looptalk/core/recording_manager.py
Normal file
|
|
@ -0,0 +1,216 @@
|
|||
"""Recording management for LoopTalk sessions."""
|
||||
|
||||
import wave
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional, Tuple
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from api.enums import StorageBackend
|
||||
from api.services.storage import storage_fs
|
||||
|
||||
|
||||
class RecordingManager:
|
||||
"""Manages audio recording and transcript files for LoopTalk sessions."""
|
||||
|
||||
def __init__(self, base_dir: Path):
|
||||
"""Initialize the recording manager.
|
||||
|
||||
Args:
|
||||
base_dir: Base directory for temporary recordings
|
||||
"""
|
||||
self.base_dir = base_dir
|
||||
self.base_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def get_recording_paths(self, test_session_id: int, role: str) -> Dict[str, Path]:
|
||||
"""Get file paths for recordings.
|
||||
|
||||
Args:
|
||||
test_session_id: ID of the test session
|
||||
role: Either "actor" or "adversary"
|
||||
|
||||
Returns:
|
||||
Dictionary with paths for audio, transcript, and temp audio files
|
||||
"""
|
||||
session_dir = self.base_dir / f"session_{test_session_id}"
|
||||
session_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
return {
|
||||
"audio": session_dir / f"{role}_audio.wav",
|
||||
"transcript": session_dir / f"{role}_transcript.txt",
|
||||
"temp_audio": session_dir / f"{role}_audio_temp.pcm",
|
||||
}
|
||||
|
||||
def convert_pcm_to_wav(
|
||||
self,
|
||||
test_session_id: int,
|
||||
role: str,
|
||||
sample_rate: int = 16000,
|
||||
num_channels: int = 1,
|
||||
) -> Optional[Path]:
|
||||
"""Convert PCM audio file to WAV format.
|
||||
|
||||
Args:
|
||||
test_session_id: ID of the test session
|
||||
role: Either "actor" or "adversary"
|
||||
sample_rate: Sample rate of the audio
|
||||
num_channels: Number of audio channels
|
||||
|
||||
Returns:
|
||||
Path to the WAV file if successful, None otherwise
|
||||
"""
|
||||
paths = self.get_recording_paths(test_session_id, role)
|
||||
|
||||
# Check if PCM file exists
|
||||
if not paths["temp_audio"].exists():
|
||||
logger.warning(f"No audio recorded for {role} in session {test_session_id}")
|
||||
return None
|
||||
|
||||
try:
|
||||
# Read PCM data
|
||||
with open(paths["temp_audio"], "rb") as f:
|
||||
pcm_data = f.read()
|
||||
|
||||
# Write WAV file
|
||||
with wave.open(str(paths["audio"]), "wb") as wav_file:
|
||||
wav_file.setnchannels(num_channels)
|
||||
wav_file.setsampwidth(2) # 16-bit audio
|
||||
wav_file.setframerate(sample_rate)
|
||||
wav_file.writeframes(pcm_data)
|
||||
|
||||
# Remove temporary PCM file
|
||||
paths["temp_audio"].unlink()
|
||||
|
||||
logger.info(
|
||||
f"Converted audio to WAV for {role} in session {test_session_id}: {paths['audio']}"
|
||||
)
|
||||
return paths["audio"]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to convert audio to WAV for {role} in session {test_session_id}: {e}"
|
||||
)
|
||||
return None
|
||||
|
||||
async def upload_recording_to_s3(
|
||||
self, test_session_id: int, role: str
|
||||
) -> Tuple[Optional[str], Optional[str]]:
|
||||
"""Upload recording and transcript to S3.
|
||||
|
||||
Args:
|
||||
test_session_id: ID of the test session
|
||||
role: Either "actor" or "adversary"
|
||||
|
||||
Returns:
|
||||
Tuple of (audio_url, transcript_url) or (None, None) if failed
|
||||
"""
|
||||
paths = self.get_recording_paths(test_session_id, role)
|
||||
audio_url = None
|
||||
transcript_url = None
|
||||
|
||||
# Import here to avoid circular imports
|
||||
|
||||
current_backend = StorageBackend.get_current_backend()
|
||||
logger.info(
|
||||
f"LOOPTALK UPLOAD: Using {current_backend.label} (code: {current_backend.code}) for session {test_session_id}, role: {role}"
|
||||
)
|
||||
|
||||
# Upload audio if exists
|
||||
if paths["audio"].exists():
|
||||
audio_key = f"looptalk/recordings/{test_session_id}/{role}_audio.wav"
|
||||
try:
|
||||
success = await storage_fs.aupload_file(str(paths["audio"]), audio_key)
|
||||
if success:
|
||||
audio_url = audio_key
|
||||
logger.info(
|
||||
f"Uploaded {role} audio to {current_backend.label}: {audio_key}"
|
||||
)
|
||||
else:
|
||||
logger.error(
|
||||
f"Failed to upload {role} audio to {current_backend.label}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error uploading {role} audio to {current_backend.label}: {e}"
|
||||
)
|
||||
|
||||
# Upload transcript if exists
|
||||
if paths["transcript"].exists():
|
||||
transcript_key = (
|
||||
f"looptalk/transcripts/{test_session_id}/{role}_transcript.txt"
|
||||
)
|
||||
try:
|
||||
success = await storage_fs.aupload_file(
|
||||
str(paths["transcript"]), transcript_key
|
||||
)
|
||||
if success:
|
||||
transcript_url = transcript_key
|
||||
logger.info(
|
||||
f"Uploaded {role} transcript to {current_backend.label}: {transcript_key}"
|
||||
)
|
||||
else:
|
||||
logger.error(
|
||||
f"Failed to upload {role} transcript to {current_backend.label}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error uploading {role} transcript to {current_backend.label}: {e}"
|
||||
)
|
||||
|
||||
return audio_url, transcript_url
|
||||
|
||||
def cleanup_session_files(self, test_session_id: int):
|
||||
"""Clean up local files for a session.
|
||||
|
||||
Args:
|
||||
test_session_id: ID of the test session
|
||||
"""
|
||||
session_dir = self.base_dir / f"session_{test_session_id}"
|
||||
if session_dir.exists():
|
||||
try:
|
||||
import shutil
|
||||
|
||||
shutil.rmtree(session_dir)
|
||||
logger.debug(f"Cleaned up local files for session {test_session_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to clean up session files: {e}")
|
||||
|
||||
def get_recording_info(self, test_session_id: int) -> Dict[str, any]:
|
||||
"""Get information about recordings for a test session.
|
||||
|
||||
Args:
|
||||
test_session_id: ID of the test session
|
||||
|
||||
Returns:
|
||||
Dictionary with recording information
|
||||
"""
|
||||
session_dir = self.base_dir / f"session_{test_session_id}"
|
||||
|
||||
info = {
|
||||
"test_session_id": test_session_id,
|
||||
"recording_dir": str(session_dir),
|
||||
"files": {},
|
||||
}
|
||||
|
||||
for role in ["actor", "adversary"]:
|
||||
paths = self.get_recording_paths(test_session_id, role)
|
||||
role_info = {}
|
||||
|
||||
# Check audio file
|
||||
if paths["audio"].exists():
|
||||
role_info["audio"] = {
|
||||
"path": str(paths["audio"]),
|
||||
"size_bytes": paths["audio"].stat().st_size,
|
||||
}
|
||||
|
||||
# Check transcript file
|
||||
if paths["transcript"].exists():
|
||||
role_info["transcript"] = {
|
||||
"path": str(paths["transcript"]),
|
||||
"size_bytes": paths["transcript"].stat().st_size,
|
||||
}
|
||||
|
||||
if role_info:
|
||||
info["files"][role] = role_info
|
||||
|
||||
return info
|
||||
184
api/services/looptalk/core/session_manager.py
Normal file
184
api/services/looptalk/core/session_manager.py
Normal file
|
|
@ -0,0 +1,184 @@
|
|||
"""Session management for LoopTalk test sessions."""
|
||||
|
||||
import asyncio
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from loguru import logger
|
||||
|
||||
|
||||
class SessionManager:
|
||||
"""Manages running LoopTalk test sessions."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the session manager."""
|
||||
self._running_sessions: Dict[int, Dict[str, Any]] = {}
|
||||
self._disconnect_handlers: Dict[int, asyncio.Task] = {}
|
||||
|
||||
def add_session(self, test_session_id: int, session_info: Dict[str, Any]):
|
||||
"""Add a new session to the manager.
|
||||
|
||||
Args:
|
||||
test_session_id: ID of the test session
|
||||
session_info: Dictionary containing session information
|
||||
"""
|
||||
self._running_sessions[test_session_id] = session_info
|
||||
|
||||
def get_session(self, test_session_id: int) -> Optional[Dict[str, Any]]:
|
||||
"""Get session information.
|
||||
|
||||
Args:
|
||||
test_session_id: ID of the test session
|
||||
|
||||
Returns:
|
||||
Session information dictionary or None if not found
|
||||
"""
|
||||
return self._running_sessions.get(test_session_id)
|
||||
|
||||
def remove_session(self, test_session_id: int):
|
||||
"""Remove a session from the manager.
|
||||
|
||||
Args:
|
||||
test_session_id: ID of the test session
|
||||
"""
|
||||
if test_session_id in self._running_sessions:
|
||||
del self._running_sessions[test_session_id]
|
||||
|
||||
# Cancel any disconnect handler for this session
|
||||
if test_session_id in self._disconnect_handlers:
|
||||
handler = self._disconnect_handlers.pop(test_session_id)
|
||||
if not handler.done():
|
||||
handler.cancel()
|
||||
|
||||
def get_active_count(self) -> int:
|
||||
"""Get the number of currently active sessions."""
|
||||
return len(self._running_sessions)
|
||||
|
||||
def get_active_info(self) -> Dict[str, Any]:
|
||||
"""Get information about all active sessions."""
|
||||
return {
|
||||
"count": len(self._running_sessions),
|
||||
"sessions": [
|
||||
{
|
||||
"test_session_id": session_id,
|
||||
"conversation_id": info["conversation"].id,
|
||||
"start_time": info["start_time"],
|
||||
"duration_seconds": int(
|
||||
(datetime.now(UTC) - info["start_time"]).total_seconds()
|
||||
),
|
||||
}
|
||||
for session_id, info in self._running_sessions.items()
|
||||
],
|
||||
}
|
||||
|
||||
async def handle_agent_disconnect(
|
||||
self, test_session_id: int, disconnected_role: str, stop_callback: callable
|
||||
):
|
||||
"""Handle when one agent disconnects.
|
||||
|
||||
This will cancel the other agent as well to ensure clean shutdown.
|
||||
|
||||
Args:
|
||||
test_session_id: ID of the test session
|
||||
disconnected_role: Role that disconnected ("actor" or "adversary")
|
||||
stop_callback: Callback to stop the session
|
||||
"""
|
||||
logger.info(
|
||||
f"Handling {disconnected_role} disconnect for session {test_session_id}"
|
||||
)
|
||||
|
||||
# Check if we already have a disconnect handler running
|
||||
if test_session_id in self._disconnect_handlers:
|
||||
logger.debug(
|
||||
f"Disconnect handler already running for session {test_session_id}"
|
||||
)
|
||||
return
|
||||
|
||||
# Create a task to handle the disconnect
|
||||
async def _handle_disconnect():
|
||||
try:
|
||||
# Wait a short time to avoid race conditions
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
# Check if session still exists
|
||||
session_info = self.get_session(test_session_id)
|
||||
if not session_info:
|
||||
logger.debug(f"Session {test_session_id} already stopped")
|
||||
return
|
||||
|
||||
# Stop the session (which will cancel both agents)
|
||||
logger.info(
|
||||
f"Stopping session {test_session_id} due to {disconnected_role} disconnect"
|
||||
)
|
||||
await stop_callback(test_session_id)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.debug(
|
||||
f"Disconnect handler cancelled for session {test_session_id}"
|
||||
)
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error handling disconnect for session {test_session_id}: {e}"
|
||||
)
|
||||
|
||||
# Store the task so we can cancel it if needed
|
||||
self._disconnect_handlers[test_session_id] = asyncio.create_task(
|
||||
_handle_disconnect()
|
||||
)
|
||||
|
||||
def update_audio_metadata(
|
||||
self,
|
||||
test_session_id: int,
|
||||
role: str,
|
||||
sample_rate: Optional[int] = None,
|
||||
num_channels: Optional[int] = None,
|
||||
):
|
||||
"""Update audio metadata for a role in a session.
|
||||
|
||||
Args:
|
||||
test_session_id: ID of the test session
|
||||
role: Either "actor" or "adversary"
|
||||
sample_rate: Sample rate of the audio
|
||||
num_channels: Number of audio channels
|
||||
"""
|
||||
if test_session_id not in self._running_sessions:
|
||||
return
|
||||
|
||||
if "audio_metadata" not in self._running_sessions[test_session_id]:
|
||||
self._running_sessions[test_session_id]["audio_metadata"] = {}
|
||||
|
||||
if role not in self._running_sessions[test_session_id]["audio_metadata"]:
|
||||
self._running_sessions[test_session_id]["audio_metadata"][role] = {}
|
||||
|
||||
metadata = self._running_sessions[test_session_id]["audio_metadata"][role]
|
||||
if sample_rate is not None:
|
||||
metadata["sample_rate"] = sample_rate
|
||||
if num_channels is not None:
|
||||
metadata["num_channels"] = num_channels
|
||||
|
||||
def get_audio_metadata(self, test_session_id: int, role: str) -> Dict[str, Any]:
|
||||
"""Get audio metadata for a role in a session.
|
||||
|
||||
Args:
|
||||
test_session_id: ID of the test session
|
||||
role: Either "actor" or "adversary"
|
||||
|
||||
Returns:
|
||||
Dictionary with sample_rate and num_channels
|
||||
"""
|
||||
default = {"sample_rate": 16000, "num_channels": 1}
|
||||
|
||||
if test_session_id not in self._running_sessions:
|
||||
return default
|
||||
|
||||
metadata = (
|
||||
self._running_sessions.get(test_session_id, {})
|
||||
.get("audio_metadata", {})
|
||||
.get(role, {})
|
||||
)
|
||||
|
||||
return {
|
||||
"sample_rate": metadata.get("sample_rate", 16000),
|
||||
"num_channels": metadata.get("num_channels", 1),
|
||||
}
|
||||
553
api/services/looptalk/orchestrator.py
Normal file
553
api/services/looptalk/orchestrator.py
Normal file
|
|
@ -0,0 +1,553 @@
|
|||
import asyncio
|
||||
import os
|
||||
import uuid
|
||||
from datetime import UTC, datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from loguru import logger
|
||||
from pipecat.pipeline.task import PipelineTask
|
||||
from pipecat.transports import (
|
||||
InternalTransport,
|
||||
InternalTransportManager,
|
||||
)
|
||||
from pipecat.utils.context import set_current_run_id
|
||||
|
||||
from api.db.db_client import DBClient
|
||||
from api.services.pipecat.transport_setup import create_internal_transport
|
||||
|
||||
from .core.pipeline_builder import LoopTalkPipelineBuilder
|
||||
from .core.recording_manager import RecordingManager
|
||||
from .core.session_manager import SessionManager
|
||||
|
||||
|
||||
class LoopTalkTestOrchestrator:
|
||||
"""Orchestrates LoopTalk testing sessions with agent-to-agent conversations."""
|
||||
|
||||
def __init__(
|
||||
self, db_client: DBClient, network_latency_seconds: Optional[float] = None
|
||||
):
|
||||
self.db_client = db_client
|
||||
self.transport_manager = InternalTransportManager()
|
||||
self.session_manager = SessionManager()
|
||||
self.pipeline_builder = LoopTalkPipelineBuilder(db_client)
|
||||
self.recording_manager = RecordingManager(Path("/tmp/looptalk_recordings"))
|
||||
|
||||
# Default network latency (can be overridden per session)
|
||||
# Priority: constructor param > env var > default (100ms)
|
||||
if network_latency_seconds is not None:
|
||||
self._default_network_latency = network_latency_seconds
|
||||
else:
|
||||
env_latency = os.environ.get("LOOPTALK_NETWORK_LATENCY_MS")
|
||||
if env_latency:
|
||||
try:
|
||||
self._default_network_latency = (
|
||||
float(env_latency) / 1000.0
|
||||
) # Convert ms to seconds
|
||||
except ValueError:
|
||||
logger.warning(
|
||||
f"Invalid LOOPTALK_NETWORK_LATENCY_MS value: {env_latency}, using default 100ms"
|
||||
)
|
||||
self._default_network_latency = 0.1
|
||||
else:
|
||||
self._default_network_latency = 0.1 # 100ms default
|
||||
|
||||
async def start_test_session(
|
||||
self,
|
||||
test_session_id: int,
|
||||
organization_id: int,
|
||||
network_latency_seconds: Optional[float] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Start a LoopTalk test session."""
|
||||
|
||||
# Get test session details
|
||||
test_session = await self.db_client.get_test_session(
|
||||
test_session_id=test_session_id, organization_id=organization_id
|
||||
)
|
||||
|
||||
if not test_session:
|
||||
raise ValueError(f"Test session {test_session_id} not found")
|
||||
|
||||
if test_session.status != "pending":
|
||||
raise ValueError(f"Test session {test_session_id} is not in pending state")
|
||||
|
||||
try:
|
||||
# Update status to running
|
||||
await self.db_client.update_test_session_status(
|
||||
test_session_id=test_session_id, status="running"
|
||||
)
|
||||
|
||||
# Create conversation record
|
||||
conversation = await self.db_client.create_conversation(
|
||||
test_session_id=test_session_id
|
||||
)
|
||||
|
||||
# Create audio configuration for LoopTalk
|
||||
from api.services.pipecat.audio_config import AudioConfig
|
||||
|
||||
audio_config = AudioConfig(
|
||||
transport_in_sample_rate=16000,
|
||||
transport_out_sample_rate=16000,
|
||||
pipeline_sample_rate=16000,
|
||||
)
|
||||
|
||||
# Use provided latency or fall back to default
|
||||
latency = (
|
||||
network_latency_seconds
|
||||
if network_latency_seconds is not None
|
||||
else self._default_network_latency
|
||||
)
|
||||
logger.info(
|
||||
f"Using network latency of {latency}s for test session {test_session_id}"
|
||||
)
|
||||
|
||||
# Generate unique workflow run IDs for each agent
|
||||
actor_workflow_run_id = int(str(test_session_id) + "1")
|
||||
adversary_workflow_run_id = int(str(test_session_id) + "2")
|
||||
|
||||
# Create transports using the new method with turn analyzer
|
||||
actor_transport = create_internal_transport(
|
||||
workflow_run_id=actor_workflow_run_id,
|
||||
audio_config=audio_config,
|
||||
latency_seconds=latency,
|
||||
)
|
||||
adversary_transport = create_internal_transport(
|
||||
workflow_run_id=adversary_workflow_run_id,
|
||||
audio_config=audio_config,
|
||||
latency_seconds=latency,
|
||||
)
|
||||
|
||||
# Connect the transports
|
||||
actor_transport.connect_partner(adversary_transport)
|
||||
|
||||
# Store the transport pair in the manager
|
||||
self.transport_manager._transport_pairs[str(test_session_id)] = (
|
||||
actor_transport,
|
||||
adversary_transport,
|
||||
)
|
||||
|
||||
# Generate unique identifiers for actor and adversary
|
||||
actor_id = f"actor_{test_session_id}_{str(uuid.uuid4())[:8]}"
|
||||
adversary_id = f"adversary_{test_session_id}_{str(uuid.uuid4())[:8]}"
|
||||
|
||||
# Create pipelines for both agents
|
||||
actor_pipeline_info = await self.pipeline_builder.create_agent_pipeline(
|
||||
transport=actor_transport,
|
||||
workflow=test_session.actor_workflow,
|
||||
test_session_id=test_session_id,
|
||||
agent_id=actor_id,
|
||||
role="actor",
|
||||
)
|
||||
actor_pipeline_task = actor_pipeline_info["task"]
|
||||
|
||||
adversary_pipeline_info = await self.pipeline_builder.create_agent_pipeline(
|
||||
transport=adversary_transport,
|
||||
workflow=test_session.adversary_workflow,
|
||||
test_session_id=test_session_id,
|
||||
agent_id=adversary_id,
|
||||
role="adversary",
|
||||
)
|
||||
|
||||
adversary_pipeline_task = adversary_pipeline_info["task"]
|
||||
|
||||
# Register event handlers for both pipelines
|
||||
await self._register_transport_handlers(
|
||||
actor_transport, actor_pipeline_info, test_session_id, "actor"
|
||||
)
|
||||
await self._register_transport_handlers(
|
||||
adversary_transport,
|
||||
adversary_pipeline_info,
|
||||
test_session_id,
|
||||
"adversary",
|
||||
)
|
||||
|
||||
# Store session info
|
||||
session_info = {
|
||||
"test_session": test_session,
|
||||
"conversation": conversation,
|
||||
"actor_task": actor_pipeline_task,
|
||||
"adversary_task": adversary_pipeline_task,
|
||||
"actor_transport": actor_transport,
|
||||
"adversary_transport": adversary_transport,
|
||||
"start_time": datetime.now(UTC),
|
||||
}
|
||||
self.session_manager.add_session(test_session_id, session_info)
|
||||
|
||||
# Start both pipelines in background tasks
|
||||
from pipecat.pipeline.base_task import PipelineTaskParams
|
||||
|
||||
params = PipelineTaskParams(loop=asyncio.get_event_loop())
|
||||
|
||||
# Start the pipelines - this will trigger initialization through the normal pipeline start process
|
||||
# The workflow engines will be initialized when the pipeline starts
|
||||
|
||||
# Create conversation IDs for tracing
|
||||
actor_conversation_id = f"{test_session_id}-actor-{actor_id}"
|
||||
adversary_conversation_id = f"{test_session_id}-adversary-{adversary_id}"
|
||||
|
||||
# Create tasks but don't await them - they'll run in the background
|
||||
logger.debug(f"Running actor task with ID: {actor_id}")
|
||||
actor_task_future = asyncio.create_task(
|
||||
self._run_pipeline_with_context(
|
||||
actor_pipeline_task,
|
||||
params,
|
||||
actor_id,
|
||||
actor_conversation_id,
|
||||
"actor",
|
||||
)
|
||||
)
|
||||
|
||||
logger.debug(f"Running adversary task with ID: {adversary_id}")
|
||||
adversary_task_future = asyncio.create_task(
|
||||
self._run_pipeline_with_context(
|
||||
adversary_pipeline_task,
|
||||
params,
|
||||
adversary_id,
|
||||
adversary_conversation_id,
|
||||
"adversary",
|
||||
)
|
||||
)
|
||||
|
||||
# Store the futures so we can monitor them
|
||||
session_info["actor_task_future"] = actor_task_future
|
||||
session_info["adversary_task_future"] = adversary_task_future
|
||||
|
||||
logger.info(f"Started LoopTalk test session {test_session_id}")
|
||||
|
||||
return {
|
||||
"test_session_id": test_session_id,
|
||||
"conversation_id": conversation.id,
|
||||
"status": "running",
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to start test session {test_session_id}: {e}")
|
||||
await self.db_client.update_test_session_status(
|
||||
test_session_id=test_session_id, status="failed", error=str(e)
|
||||
)
|
||||
raise
|
||||
|
||||
async def _register_transport_handlers(
|
||||
self,
|
||||
transport: InternalTransport,
|
||||
pipeline_info: Dict[str, Any],
|
||||
test_session_id: int,
|
||||
role: str,
|
||||
):
|
||||
"""Register transport event handlers for a pipeline.
|
||||
|
||||
Args:
|
||||
transport: The transport to register handlers on
|
||||
pipeline_info: Dictionary containing pipeline components
|
||||
test_session_id: ID of the test session
|
||||
role: Either "actor" or "adversary"
|
||||
"""
|
||||
engine = pipeline_info["engine"]
|
||||
task = pipeline_info["task"]
|
||||
audio_buffer = pipeline_info["audio_buffer"]
|
||||
audio_synchronizer = pipeline_info["audio_synchronizer"]
|
||||
transcript = pipeline_info["transcript"]
|
||||
assistant_context_aggregator = pipeline_info["assistant_context_aggregator"]
|
||||
|
||||
# Register transport event handlers
|
||||
@transport.event_handler("on_client_connected")
|
||||
async def on_client_connected(transport, participant):
|
||||
logger.debug(f"LoopTalk {role} client connected - initializing workflow")
|
||||
# Start audio recording
|
||||
await audio_buffer.start_recording()
|
||||
await audio_synchronizer.start_recording()
|
||||
await engine.initialize()
|
||||
|
||||
@transport.event_handler("on_client_disconnected")
|
||||
async def on_client_disconnected(transport, participant):
|
||||
logger.debug(f"LoopTalk {role} client disconnected")
|
||||
# Stop audio recording
|
||||
await audio_buffer.stop_recording()
|
||||
await audio_synchronizer.stop_recording()
|
||||
|
||||
# Handle disconnect propagation - stop the other agent too
|
||||
await self.session_manager.handle_agent_disconnect(
|
||||
test_session_id, role, self.stop_test_session
|
||||
)
|
||||
|
||||
await task.cancel()
|
||||
|
||||
# Connect the context aggregator events to engine
|
||||
@assistant_context_aggregator.event_handler("on_push_aggregation")
|
||||
async def on_assistant_aggregator_push_context(_aggregator):
|
||||
logger.debug(
|
||||
"Assistant aggregator push context – flushing pending transitions"
|
||||
)
|
||||
await engine.flush_pending_transitions()
|
||||
|
||||
# Register custom audio and transcript handlers for LoopTalk
|
||||
await self._register_looptalk_handlers(
|
||||
audio_synchronizer, transcript, test_session_id, role
|
||||
)
|
||||
|
||||
async def _register_looptalk_handlers(
|
||||
self, audio_synchronizer, transcript, test_session_id: int, role: str
|
||||
):
|
||||
"""Register LoopTalk-specific handlers for audio and transcript recording"""
|
||||
|
||||
paths = self.recording_manager.get_recording_paths(test_session_id, role)
|
||||
|
||||
# Store audio metadata for later WAV conversion
|
||||
audio_metadata = {"sample_rate": None, "num_channels": None}
|
||||
|
||||
# Audio handler - writes directly to PCM file
|
||||
@audio_synchronizer.event_handler("on_merged_audio")
|
||||
async def on_merged_audio(_, pcm, sample_rate, num_channels):
|
||||
if not pcm:
|
||||
return
|
||||
|
||||
# Store metadata on first write
|
||||
if audio_metadata["sample_rate"] is None:
|
||||
audio_metadata["sample_rate"] = sample_rate
|
||||
audio_metadata["num_channels"] = num_channels
|
||||
|
||||
# Append PCM data to temporary file
|
||||
try:
|
||||
with open(paths["temp_audio"], "ab") as f:
|
||||
f.write(pcm)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to write audio for {role} in session {test_session_id}: {e}"
|
||||
)
|
||||
|
||||
# Transcript handler - writes directly to text file
|
||||
@transcript.event_handler("on_transcript_update")
|
||||
async def on_transcript_update(processor, frame):
|
||||
transcript_text = ""
|
||||
for msg in frame.messages:
|
||||
timestamp = f"[{msg.timestamp}] " if msg.timestamp else ""
|
||||
line = f"{timestamp}{msg.role}: {msg.content}\n"
|
||||
transcript_text += line
|
||||
|
||||
# Append transcript to file
|
||||
try:
|
||||
with open(paths["transcript"], "a") as f:
|
||||
f.write(transcript_text)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to write transcript for {role} in session {test_session_id}: {e}"
|
||||
)
|
||||
|
||||
# Store metadata in session info for later WAV conversion
|
||||
# Set default values if not yet captured
|
||||
if audio_metadata["sample_rate"] is None:
|
||||
audio_metadata["sample_rate"] = 16000 # Default sample rate
|
||||
audio_metadata["num_channels"] = 1 # Default channels
|
||||
|
||||
self.session_manager.update_audio_metadata(
|
||||
test_session_id,
|
||||
role,
|
||||
sample_rate=audio_metadata["sample_rate"],
|
||||
num_channels=audio_metadata["num_channels"],
|
||||
)
|
||||
|
||||
async def _run_pipeline_with_context(
|
||||
self,
|
||||
pipeline_task: PipelineTask,
|
||||
params,
|
||||
agent_id: str,
|
||||
conversation_id: str,
|
||||
role: str,
|
||||
):
|
||||
"""Run a pipeline task with the agent_id set in context"""
|
||||
set_current_run_id(agent_id)
|
||||
return await pipeline_task.run(params)
|
||||
|
||||
async def stop_test_session(self, test_session_id: int) -> Dict[str, Any]:
|
||||
"""Stop a running test session."""
|
||||
|
||||
session_info = self.session_manager.get_session(test_session_id)
|
||||
if not session_info:
|
||||
raise ValueError(f"Test session {test_session_id} is not running")
|
||||
|
||||
try:
|
||||
# Cancel both pipeline tasks
|
||||
await session_info["actor_task"].cancel()
|
||||
await session_info["adversary_task"].cancel()
|
||||
|
||||
# Also cancel the task futures if they exist
|
||||
if "actor_task_future" in session_info:
|
||||
session_info["actor_task_future"].cancel()
|
||||
if "adversary_task_future" in session_info:
|
||||
session_info["adversary_task_future"].cancel()
|
||||
|
||||
# Calculate duration
|
||||
duration_seconds = int(
|
||||
(datetime.now(UTC) - session_info["start_time"]).total_seconds()
|
||||
)
|
||||
|
||||
# Update conversation
|
||||
await self.db_client.update_conversation(
|
||||
conversation_id=session_info["conversation"].id,
|
||||
duration_seconds=duration_seconds,
|
||||
ended_at=datetime.now(UTC),
|
||||
)
|
||||
|
||||
# Update test session status
|
||||
await self.db_client.update_test_session_status(
|
||||
test_session_id=test_session_id,
|
||||
status="completed",
|
||||
results={
|
||||
"duration_seconds": duration_seconds,
|
||||
"conversation_id": session_info["conversation"].id,
|
||||
},
|
||||
)
|
||||
|
||||
# Finalize recordings for both actor and adversary
|
||||
# Convert PCM files to WAV
|
||||
actor_metadata = self.session_manager.get_audio_metadata(
|
||||
test_session_id, "actor"
|
||||
)
|
||||
adversary_metadata = self.session_manager.get_audio_metadata(
|
||||
test_session_id, "adversary"
|
||||
)
|
||||
|
||||
self.recording_manager.convert_pcm_to_wav(
|
||||
test_session_id,
|
||||
"actor",
|
||||
sample_rate=actor_metadata["sample_rate"],
|
||||
num_channels=actor_metadata["num_channels"],
|
||||
)
|
||||
self.recording_manager.convert_pcm_to_wav(
|
||||
test_session_id,
|
||||
"adversary",
|
||||
sample_rate=adversary_metadata["sample_rate"],
|
||||
num_channels=adversary_metadata["num_channels"],
|
||||
)
|
||||
|
||||
# Upload recordings to S3 (synchronously for load testing)
|
||||
(
|
||||
actor_audio_url,
|
||||
actor_transcript_url,
|
||||
) = await self.recording_manager.upload_recording_to_s3(
|
||||
test_session_id, "actor"
|
||||
)
|
||||
(
|
||||
adversary_audio_url,
|
||||
adversary_transcript_url,
|
||||
) = await self.recording_manager.upload_recording_to_s3(
|
||||
test_session_id, "adversary"
|
||||
)
|
||||
|
||||
# Update conversation with recording URLs
|
||||
await self.db_client.update_conversation(
|
||||
conversation_id=session_info["conversation"].id,
|
||||
actor_recording_url=actor_audio_url,
|
||||
adversary_recording_url=adversary_audio_url,
|
||||
transcript={
|
||||
"actor_transcript_url": actor_transcript_url,
|
||||
"adversary_transcript_url": adversary_transcript_url,
|
||||
},
|
||||
)
|
||||
|
||||
# Log recording locations
|
||||
logger.info(f"LoopTalk recordings uploaded to S3:")
|
||||
if actor_audio_url:
|
||||
logger.info(f" - Actor audio: {actor_audio_url}")
|
||||
if actor_transcript_url:
|
||||
logger.info(f" - Actor transcript: {actor_transcript_url}")
|
||||
if adversary_audio_url:
|
||||
logger.info(f" - Adversary audio: {adversary_audio_url}")
|
||||
if adversary_transcript_url:
|
||||
logger.info(f" - Adversary transcript: {adversary_transcript_url}")
|
||||
|
||||
# Clean up local files after successful upload
|
||||
self.recording_manager.cleanup_session_files(test_session_id)
|
||||
|
||||
# Clean up
|
||||
self.transport_manager.remove_transport_pair(str(test_session_id))
|
||||
self.session_manager.remove_session(test_session_id)
|
||||
|
||||
# Clean up audio streamers
|
||||
from api.services.looptalk.audio_streamer import cleanup_audio_streamers
|
||||
|
||||
cleanup_audio_streamers(str(test_session_id))
|
||||
|
||||
logger.info(f"Stopped LoopTalk test session {test_session_id}")
|
||||
|
||||
return {
|
||||
"test_session_id": test_session_id,
|
||||
"status": "completed",
|
||||
"duration_seconds": duration_seconds,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to stop test session {test_session_id}: {e}")
|
||||
await self.db_client.update_test_session_status(
|
||||
test_session_id=test_session_id, status="failed", error=str(e)
|
||||
)
|
||||
raise
|
||||
|
||||
async def start_load_test(
|
||||
self,
|
||||
organization_id: int,
|
||||
name_prefix: str,
|
||||
actor_workflow_id: int,
|
||||
adversary_workflow_id: int,
|
||||
config: Dict[str, Any],
|
||||
test_count: int,
|
||||
) -> Dict[str, Any]:
|
||||
"""Start a load test with multiple concurrent test sessions."""
|
||||
|
||||
# Validate test count
|
||||
if test_count < 1 or test_count > 10:
|
||||
raise ValueError("Test count must be between 1 and 10")
|
||||
|
||||
# Create test sessions
|
||||
test_sessions = await self.db_client.create_load_test_group(
|
||||
organization_id=organization_id,
|
||||
name_prefix=name_prefix,
|
||||
actor_workflow_id=actor_workflow_id,
|
||||
adversary_workflow_id=adversary_workflow_id,
|
||||
config=config,
|
||||
test_count=test_count,
|
||||
)
|
||||
|
||||
# Start all test sessions concurrently
|
||||
tasks = []
|
||||
for test_session in test_sessions:
|
||||
task = asyncio.create_task(
|
||||
self.start_test_session(
|
||||
test_session_id=test_session.id, organization_id=organization_id
|
||||
)
|
||||
)
|
||||
tasks.append(task)
|
||||
|
||||
# Wait for all to start
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# Count successes and failures
|
||||
started = sum(1 for r in results if not isinstance(r, Exception))
|
||||
failed = sum(1 for r in results if isinstance(r, Exception))
|
||||
|
||||
load_test_group_id = test_sessions[0].load_test_group_id
|
||||
|
||||
logger.info(
|
||||
f"Started load test {load_test_group_id}: "
|
||||
f"{started} started, {failed} failed out of {test_count}"
|
||||
)
|
||||
|
||||
return {
|
||||
"load_test_group_id": load_test_group_id,
|
||||
"total": test_count,
|
||||
"started": started,
|
||||
"failed": failed,
|
||||
"test_session_ids": [ts.id for ts in test_sessions],
|
||||
}
|
||||
|
||||
def get_active_test_count(self) -> int:
|
||||
"""Get the number of currently active test sessions."""
|
||||
return self.session_manager.get_active_count()
|
||||
|
||||
def get_active_test_info(self) -> Dict[str, Any]:
|
||||
"""Get information about all active test sessions."""
|
||||
return self.session_manager.get_active_info()
|
||||
|
||||
def get_recording_info(self, test_session_id: int) -> Dict[str, Any]:
|
||||
"""Get information about recordings for a test session"""
|
||||
return self.recording_manager.get_recording_info(test_session_id)
|
||||
283
api/services/mps_service_key_client.py
Normal file
283
api/services/mps_service_key_client.py
Normal file
|
|
@ -0,0 +1,283 @@
|
|||
"""
|
||||
MPS Service Key HTTP Client
|
||||
This client communicates with the Model Proxy Service (MPS) for service key management.
|
||||
Service keys are stored and managed entirely in MPS, not in the local database.
|
||||
"""
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
import httpx
|
||||
from loguru import logger
|
||||
|
||||
from api.constants import DEPLOYMENT_MODE, DOGRAH_MPS_SECRET_KEY, MPS_API_URL
|
||||
|
||||
|
||||
class MPSServiceKeyClient:
|
||||
"""HTTP client for managing service keys via MPS API."""
|
||||
|
||||
def __init__(self):
|
||||
self.base_url = MPS_API_URL
|
||||
self.timeout = httpx.Timeout(10.0)
|
||||
|
||||
def _get_headers(self) -> dict:
|
||||
"""Get headers for MPS API requests."""
|
||||
headers = {"Content-Type": "application/json"}
|
||||
|
||||
# Add authentication for non-OSS mode
|
||||
if DEPLOYMENT_MODE != "oss" and DOGRAH_MPS_SECRET_KEY:
|
||||
headers["X-Secret-Key"] = DOGRAH_MPS_SECRET_KEY
|
||||
|
||||
return headers
|
||||
|
||||
async def create_service_key(
|
||||
self,
|
||||
name: str,
|
||||
organization_id: Optional[int] = None,
|
||||
created_by: str = None,
|
||||
expires_in_days: int = 90,
|
||||
description: Optional[str] = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Create a new service key via MPS API.
|
||||
|
||||
For OSS mode: organization_id should be None
|
||||
For authenticated mode: organization_id should be provided
|
||||
"""
|
||||
async with httpx.AsyncClient(timeout=self.timeout) as client:
|
||||
request_body = {
|
||||
"name": name,
|
||||
"description": description or f"Service key: {name}",
|
||||
"expires_in_days": expires_in_days,
|
||||
"created_by": created_by,
|
||||
}
|
||||
|
||||
# Only add organization_id for non-OSS mode
|
||||
if DEPLOYMENT_MODE != "oss" and organization_id:
|
||||
request_body["organization_id"] = organization_id
|
||||
|
||||
response = await client.post(
|
||||
f"{self.base_url}/api/v1/service-keys/",
|
||||
json=request_body,
|
||||
headers=self._get_headers(),
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
# Transform the response to match our expected format
|
||||
return {
|
||||
"id": data.get("id"),
|
||||
"name": data.get("name"),
|
||||
"service_key": data.get("service_key"), # Only returned on creation
|
||||
"key_prefix": data.get("service_key", "")[:8]
|
||||
if data.get("service_key")
|
||||
else "",
|
||||
"expires_at": data.get("expires_at"),
|
||||
"created_at": data.get("created_at"),
|
||||
"is_active": data.get("is_active", True),
|
||||
"created_by": data.get("created_by"),
|
||||
}
|
||||
else:
|
||||
raise httpx.HTTPStatusError(
|
||||
f"Failed to create service key: {response.text}",
|
||||
request=response.request,
|
||||
response=response,
|
||||
)
|
||||
|
||||
async def get_service_keys(
|
||||
self,
|
||||
organization_id: Optional[int] = None,
|
||||
created_by: Optional[str] = None,
|
||||
include_archived: bool = False,
|
||||
) -> List[dict]:
|
||||
"""
|
||||
Get service keys from MPS.
|
||||
|
||||
For OSS mode: Use created_by to filter keys
|
||||
For authenticated mode: Use organization_id to filter keys
|
||||
"""
|
||||
async with httpx.AsyncClient(timeout=self.timeout) as client:
|
||||
params = {}
|
||||
|
||||
if DEPLOYMENT_MODE == "oss":
|
||||
# In OSS mode, filter by created_by
|
||||
if created_by:
|
||||
params["created_by"] = created_by
|
||||
else:
|
||||
# In authenticated mode, filter by organization_id
|
||||
if organization_id:
|
||||
params["organization_id"] = organization_id
|
||||
|
||||
if include_archived:
|
||||
params["include_archived"] = "true"
|
||||
|
||||
response = await client.get(
|
||||
f"{self.base_url}/api/v1/service-keys/",
|
||||
params=params,
|
||||
headers=self._get_headers(),
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
keys = response.json()
|
||||
# Transform the response to match our expected format
|
||||
return [
|
||||
{
|
||||
"id": key.get("id"),
|
||||
"name": key.get("name"),
|
||||
"key_prefix": key.get("key_prefix", ""),
|
||||
"is_active": key.get("is_active", True),
|
||||
"created_at": key.get("created_at"),
|
||||
"last_used_at": key.get("last_used_at"),
|
||||
"expires_at": key.get("expires_at"),
|
||||
"archived_at": key.get("archived_at"),
|
||||
"created_by": key.get("created_by"),
|
||||
}
|
||||
for key in keys
|
||||
]
|
||||
else:
|
||||
logger.error(
|
||||
f"Failed to get service keys: {response.status_code} - {response.text}"
|
||||
)
|
||||
return []
|
||||
|
||||
async def get_service_key_by_id(
|
||||
self,
|
||||
key_id: int,
|
||||
organization_id: Optional[int] = None,
|
||||
created_by: Optional[str] = None,
|
||||
) -> Optional[dict]:
|
||||
"""Get a specific service key by ID."""
|
||||
async with httpx.AsyncClient(timeout=self.timeout) as client:
|
||||
response = await client.get(
|
||||
f"{self.base_url}/api/v1/service-keys/{key_id}",
|
||||
headers=self._get_headers(),
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
key = response.json()
|
||||
|
||||
# Validate ownership for OSS mode
|
||||
if DEPLOYMENT_MODE == "oss" and created_by:
|
||||
if key.get("created_by") != created_by:
|
||||
logger.warning(
|
||||
f"Access denied: User {created_by} tried to access key created by {key.get('created_by')}"
|
||||
)
|
||||
return None
|
||||
|
||||
# Validate organization for authenticated mode
|
||||
if DEPLOYMENT_MODE != "oss" and organization_id:
|
||||
if key.get("organization_id") != organization_id:
|
||||
logger.warning(
|
||||
f"Access denied: Org {organization_id} tried to access key for org {key.get('organization_id')}"
|
||||
)
|
||||
return None
|
||||
|
||||
return {
|
||||
"id": key.get("id"),
|
||||
"name": key.get("name"),
|
||||
"key_prefix": key.get("key_prefix", ""),
|
||||
"is_active": key.get("is_active", True),
|
||||
"created_at": key.get("created_at"),
|
||||
"last_used_at": key.get("last_used_at"),
|
||||
"expires_at": key.get("expires_at"),
|
||||
"archived_at": key.get("archived_at"),
|
||||
"created_by": key.get("created_by"),
|
||||
}
|
||||
else:
|
||||
return None
|
||||
|
||||
async def archive_service_key(
|
||||
self,
|
||||
key_id: int,
|
||||
organization_id: Optional[int] = None,
|
||||
created_by: Optional[str] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Archive (soft delete) a service key.
|
||||
|
||||
For OSS mode: Validates that created_by matches the key creator
|
||||
For authenticated mode: Validates organization_id matches
|
||||
"""
|
||||
# First, verify ownership
|
||||
key = await self.get_service_key_by_id(key_id, organization_id, created_by)
|
||||
if not key:
|
||||
logger.error(f"Service key {key_id} not found or access denied")
|
||||
return False
|
||||
|
||||
async with httpx.AsyncClient(timeout=self.timeout) as client:
|
||||
response = await client.delete(
|
||||
f"{self.base_url}/api/v1/service-keys/{key_id}",
|
||||
headers=self._get_headers(),
|
||||
)
|
||||
|
||||
if response.status_code in [200, 204]:
|
||||
return True
|
||||
else:
|
||||
logger.error(
|
||||
f"Failed to archive service key: {response.status_code} - {response.text}"
|
||||
)
|
||||
return False
|
||||
|
||||
async def call_workflow_api(
|
||||
self,
|
||||
call_type: str,
|
||||
use_case: str,
|
||||
activity_description: str,
|
||||
organization_id: Optional[int] = None,
|
||||
created_by: Optional[str] = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Call the MPS workflow creation API using secret key authentication.
|
||||
|
||||
For OSS mode: Pass created_by in headers
|
||||
For authenticated mode: Pass organization_id in headers
|
||||
|
||||
Args:
|
||||
call_type: INBOUND or OUTBOUND
|
||||
use_case: Description of the use case
|
||||
activity_description: Description of what the agent should do
|
||||
organization_id: Organization ID (for authenticated mode)
|
||||
created_by: User provider ID (for OSS mode)
|
||||
|
||||
Returns:
|
||||
Workflow data from MPS API
|
||||
|
||||
Raises:
|
||||
HTTPException: If the API call fails
|
||||
"""
|
||||
headers = {"Content-Type": "application/json"}
|
||||
|
||||
# Add secret key authentication
|
||||
if DEPLOYMENT_MODE != "oss" and DOGRAH_MPS_SECRET_KEY:
|
||||
headers["X-Secret-Key"] = DOGRAH_MPS_SECRET_KEY
|
||||
if organization_id:
|
||||
headers["X-Organization-Id"] = str(organization_id)
|
||||
elif DEPLOYMENT_MODE == "oss":
|
||||
if created_by:
|
||||
headers["X-Created-By"] = created_by
|
||||
|
||||
async with httpx.AsyncClient(timeout=httpx.Timeout(30.0)) as client:
|
||||
response = await client.post(
|
||||
f"{self.base_url}/api/v1/workflow/create-workflow",
|
||||
json={
|
||||
"call_type": call_type,
|
||||
"use_case": use_case,
|
||||
"activity_description": activity_description,
|
||||
},
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
return response.json()
|
||||
else:
|
||||
logger.error(
|
||||
f"Failed to create workflow: {response.status_code} - {response.text}"
|
||||
)
|
||||
raise httpx.HTTPStatusError(
|
||||
f"Failed to create workflow: {response.text}",
|
||||
request=response.request,
|
||||
response=response,
|
||||
)
|
||||
|
||||
|
||||
# Create a singleton instance
|
||||
mps_service_key_client = MPSServiceKeyClient()
|
||||
0
api/services/pipecat/__init__.py
Normal file
0
api/services/pipecat/__init__.py
Normal file
120
api/services/pipecat/audio_config.py
Normal file
120
api/services/pipecat/audio_config.py
Normal file
|
|
@ -0,0 +1,120 @@
|
|||
"""
|
||||
Audio configuration for pipeline components.
|
||||
|
||||
This module provides centralized audio configuration to ensure consistent
|
||||
sample rates across all pipeline components and proper coordination between
|
||||
transport serializers, VAD, and audio buffers.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from api.enums import WorkflowRunMode
|
||||
|
||||
|
||||
@dataclass
|
||||
class AudioConfig:
|
||||
"""Centralized audio configuration for the pipeline.
|
||||
|
||||
Note: Pipeline is limited to 16kHz maximum to support VAD.
|
||||
Transports handle resampling from/to higher rates (24kHz, 48kHz).
|
||||
|
||||
Attributes:
|
||||
transport_in_sample_rate: Sample rate of incoming audio from transport (after resampling)
|
||||
transport_out_sample_rate: Sample rate of outgoing audio to transport (before resampling)
|
||||
vad_sample_rate: Sample rate for VAD processing (8000 or 16000)
|
||||
pipeline_sample_rate: Internal pipeline processing sample rate (max 16000)
|
||||
buffer_size_seconds: Audio buffer size in seconds
|
||||
"""
|
||||
|
||||
transport_in_sample_rate: int
|
||||
transport_out_sample_rate: int
|
||||
vad_sample_rate: int = 16000 # VAD typically resamples internally
|
||||
pipeline_sample_rate: Optional[int] = None # If None, uses transport rates
|
||||
buffer_size_seconds: float = 1.0 # This is how frequenly we will call merge_auido
|
||||
|
||||
def __post_init__(self):
|
||||
# Validate VAD sample rate
|
||||
if self.vad_sample_rate not in [8000, 16000]:
|
||||
raise ValueError(
|
||||
f"VAD sample rate must be 8000 or 16000, got {self.vad_sample_rate}"
|
||||
)
|
||||
|
||||
# Set pipeline sample rate to transport out rate if not specified
|
||||
if self.pipeline_sample_rate is None:
|
||||
self.pipeline_sample_rate = min(self.transport_out_sample_rate, 16000)
|
||||
|
||||
# Ensure pipeline sample rate doesn't exceed 16kHz (VAD limitation)
|
||||
if self.pipeline_sample_rate > 16000:
|
||||
logger.warning(
|
||||
f"Pipeline sample rate {self.pipeline_sample_rate} exceeds 16kHz limit, "
|
||||
f"capping at 16kHz. Transport will handle resampling."
|
||||
)
|
||||
self.pipeline_sample_rate = 16000
|
||||
|
||||
# Log configuration for auditing
|
||||
logger.info(
|
||||
f"AudioConfig initialized: "
|
||||
f"transport_in={self.transport_in_sample_rate}Hz, "
|
||||
f"transport_out={self.transport_out_sample_rate}Hz, "
|
||||
f"vad={self.vad_sample_rate}Hz, "
|
||||
f"pipeline={self.pipeline_sample_rate}Hz, "
|
||||
f"buffer={self.buffer_size_seconds}s"
|
||||
)
|
||||
|
||||
@property
|
||||
def buffer_size_bytes(self) -> int:
|
||||
"""Calculate buffer size in bytes based on pipeline sample rate."""
|
||||
# 2 bytes per sample (16-bit PCM)
|
||||
return int(self.pipeline_sample_rate * 2 * self.buffer_size_seconds)
|
||||
|
||||
@property
|
||||
def buffer_size_samples(self) -> int:
|
||||
"""Calculate buffer size in samples based on pipeline sample rate."""
|
||||
return int(self.pipeline_sample_rate * self.buffer_size_seconds)
|
||||
|
||||
|
||||
def create_audio_config(transport_type: str) -> AudioConfig:
|
||||
"""Create audio configuration based on transport type.
|
||||
|
||||
Args:
|
||||
transport_type: Type of transport ("webrtc", "twilio", "stasis")
|
||||
|
||||
Returns:
|
||||
AudioConfig instance with appropriate settings
|
||||
"""
|
||||
if transport_type in (WorkflowRunMode.STASIS.value, WorkflowRunMode.TWILIO.value):
|
||||
return AudioConfig(
|
||||
transport_in_sample_rate=8000,
|
||||
transport_out_sample_rate=8000,
|
||||
vad_sample_rate=8000, # Use matching VAD rate
|
||||
pipeline_sample_rate=8000, # Keep at 8kHz to avoid resampling
|
||||
buffer_size_seconds=1.0,
|
||||
)
|
||||
elif transport_type in [
|
||||
WorkflowRunMode.WEBRTC.value,
|
||||
WorkflowRunMode.SMALLWEBRTC.value,
|
||||
]:
|
||||
# WebRTC typically uses 24kHz or 48kHz, but we limit pipeline to 16kHz
|
||||
# The transport will handle resampling between 24kHz and 16kHz
|
||||
return AudioConfig(
|
||||
transport_in_sample_rate=16000, # Transport will resample from 24kHz
|
||||
transport_out_sample_rate=16000, # Transport will resample to 24kHz
|
||||
vad_sample_rate=16000, # VAD native rate
|
||||
pipeline_sample_rate=16000, # Keep pipeline at 16kHz
|
||||
buffer_size_seconds=1.0,
|
||||
)
|
||||
else:
|
||||
# Default configuration
|
||||
logger.warning(
|
||||
f"Unknown transport type: {transport_type}, using default config"
|
||||
)
|
||||
return AudioConfig(
|
||||
transport_in_sample_rate=16000,
|
||||
transport_out_sample_rate=16000,
|
||||
vad_sample_rate=16000,
|
||||
pipeline_sample_rate=16000,
|
||||
buffer_size_seconds=1.0,
|
||||
)
|
||||
122
api/services/pipecat/audio_transcript_buffers.py
Normal file
122
api/services/pipecat/audio_transcript_buffers.py
Normal file
|
|
@ -0,0 +1,122 @@
|
|||
import asyncio
|
||||
import re
|
||||
import tempfile
|
||||
import wave
|
||||
from typing import List
|
||||
|
||||
from loguru import logger
|
||||
|
||||
|
||||
class InMemoryAudioBuffer:
|
||||
"""Buffer audio data in memory during a call, then write to temp file on disconnect."""
|
||||
|
||||
def __init__(self, workflow_run_id: int, sample_rate: int, num_channels: int = 1):
|
||||
self._workflow_run_id = workflow_run_id
|
||||
self._sample_rate = sample_rate
|
||||
self._num_channels = num_channels
|
||||
self._chunks: List[bytes] = []
|
||||
self._lock = asyncio.Lock()
|
||||
self._total_size = 0
|
||||
self._max_size = 100 * 1024 * 1024 # 100MB limit
|
||||
|
||||
async def append(self, pcm_data: bytes):
|
||||
"""Append PCM audio data to the buffer."""
|
||||
async with self._lock:
|
||||
if self._total_size + len(pcm_data) > self._max_size:
|
||||
logger.error(
|
||||
f"Audio buffer size limit exceeded for workflow {self._workflow_run_id}. "
|
||||
f"Current: {self._total_size}, Attempted to add: {len(pcm_data)}"
|
||||
)
|
||||
raise MemoryError("Audio buffer size limit exceeded")
|
||||
self._chunks.append(pcm_data)
|
||||
self._total_size += len(pcm_data)
|
||||
logger.trace(
|
||||
f"Appended {len(pcm_data)} bytes to audio buffer. Total size: {self._total_size}"
|
||||
)
|
||||
|
||||
async def write_to_temp_file(self) -> str:
|
||||
"""Write audio data to a temporary WAV file and return the path."""
|
||||
async with self._lock:
|
||||
temp_file = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
|
||||
logger.debug(
|
||||
f"Writing audio buffer to temp file {temp_file.name} for workflow {self._workflow_run_id}"
|
||||
)
|
||||
|
||||
# Write WAV header and PCM data
|
||||
with wave.open(temp_file.name, "wb") as wf:
|
||||
wf.setnchannels(self._num_channels)
|
||||
wf.setsampwidth(2) # 16-bit audio
|
||||
wf.setframerate(self._sample_rate)
|
||||
|
||||
# Concatenate all chunks
|
||||
for chunk in self._chunks:
|
||||
wf.writeframes(chunk)
|
||||
|
||||
logger.info(
|
||||
f"Successfully wrote {self._total_size} bytes of audio to {temp_file.name}"
|
||||
)
|
||||
return temp_file.name
|
||||
|
||||
@property
|
||||
def is_empty(self) -> bool:
|
||||
"""Check if the buffer is empty."""
|
||||
return len(self._chunks) == 0
|
||||
|
||||
@property
|
||||
def size(self) -> int:
|
||||
"""Get the total size of buffered data."""
|
||||
return self._total_size
|
||||
|
||||
|
||||
class InMemoryTranscriptBuffer:
|
||||
"""Buffer transcript data in memory during a call, then write to temp file on disconnect."""
|
||||
|
||||
# Compiled regex to identify user speech lines, e.g.
|
||||
# [2025-06-29T12:34:56.789+00:00] user: hello
|
||||
_USER_SPEECH_RE: re.Pattern[str] = re.compile(
|
||||
r"^\[\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}\.\d{3}\+\d{2}:\d{2}\] user: .+"
|
||||
)
|
||||
|
||||
def __init__(self, workflow_run_id: int):
|
||||
self._workflow_run_id = workflow_run_id
|
||||
self._lines: List[str] = []
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
async def append(self, transcript: str):
|
||||
"""Append transcript text to the buffer."""
|
||||
async with self._lock:
|
||||
self._lines.append(transcript)
|
||||
logger.trace(
|
||||
f"Appended transcript line to buffer for workflow {self._workflow_run_id}"
|
||||
)
|
||||
|
||||
async def write_to_temp_file(self) -> str:
|
||||
"""Write transcript to a temporary text file and return the path."""
|
||||
async with self._lock:
|
||||
temp_file = tempfile.NamedTemporaryFile(
|
||||
mode="w", suffix=".txt", delete=False
|
||||
)
|
||||
logger.debug(
|
||||
f"Writing transcript buffer to temp file {temp_file.name} for workflow {self._workflow_run_id}"
|
||||
)
|
||||
|
||||
content = "".join(self._lines)
|
||||
temp_file.write(content)
|
||||
temp_file.close()
|
||||
|
||||
logger.info(
|
||||
f"Successfully wrote {len(content)} chars of transcript to {temp_file.name}"
|
||||
)
|
||||
return temp_file.name
|
||||
|
||||
@property
|
||||
def is_empty(self) -> bool:
|
||||
"""Check if the buffer is empty."""
|
||||
return len(self._lines) == 0
|
||||
|
||||
def contains_user_speech(self) -> bool:
|
||||
"""Return True if any buffered transcript line matches the user speech pattern."""
|
||||
for line in self._lines:
|
||||
if self._USER_SPEECH_RE.match(line):
|
||||
return True
|
||||
return False
|
||||
69
api/services/pipecat/engine_pre_aggregator_processor.py
Normal file
69
api/services/pipecat/engine_pre_aggregator_processor.py
Normal file
|
|
@ -0,0 +1,69 @@
|
|||
"""Engine Pre-Aggregator Processor
|
||||
|
||||
This processor sits before the user context aggregator in the pipeline and handles
|
||||
engine-specific callbacks for frames that need to be processed before aggregation.
|
||||
This ensures the engine can update context before the aggregator generates LLM frames.
|
||||
"""
|
||||
|
||||
from typing import Awaitable, Callable, Optional
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from api.services.pipecat.exceptions import VoicemailDetectedException
|
||||
from pipecat.frames.frames import (
|
||||
Frame,
|
||||
UserStartedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
|
||||
|
||||
class EnginePreAggregatorProcessor(FrameProcessor):
|
||||
"""
|
||||
Processor that handles engine callbacks before user context aggregation.
|
||||
|
||||
This processor is positioned before the user context aggregator to ensure
|
||||
the engine can update LLM context before aggregation occurs.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
user_started_speaking_callback: Optional[Callable[[], Awaitable[None]]] = None,
|
||||
user_stopped_speaking_callback: Optional[Callable[[], Awaitable[None]]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self._user_started_speaking_callback = user_started_speaking_callback
|
||||
self._user_stopped_speaking_callback = user_stopped_speaking_callback
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
# Handle frames that need engine processing before aggregation
|
||||
if isinstance(frame, UserStartedSpeakingFrame):
|
||||
await self._handle_user_started_speaking()
|
||||
elif isinstance(frame, UserStoppedSpeakingFrame):
|
||||
try:
|
||||
await self._handle_user_stopped_speaking()
|
||||
except VoicemailDetectedException:
|
||||
# We have detected voicemail, lets not
|
||||
# forward the UserStoppedSpeakingFrame, so that
|
||||
# we don't issue an llm call from user context
|
||||
# aggregator
|
||||
logger.debug("Voicemail detected, not pushing UserStoppedSpeakingFrame")
|
||||
return
|
||||
|
||||
# Always push the frame downstream
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
async def _handle_user_started_speaking(self):
|
||||
"""Handle UserStartedSpeakingFrame before aggregation."""
|
||||
if self._user_started_speaking_callback:
|
||||
# logger.debug("Engine pre-aggregator: User started speaking")
|
||||
await self._user_started_speaking_callback()
|
||||
|
||||
async def _handle_user_stopped_speaking(self):
|
||||
"""Handle UserStoppedSpeakingFrame before aggregation."""
|
||||
if self._user_stopped_speaking_callback:
|
||||
# logger.debug("Engine pre-aggregator: User stopped speaking")
|
||||
await self._user_stopped_speaking_callback()
|
||||
249
api/services/pipecat/event_handlers.py
Normal file
249
api/services/pipecat/event_handlers.py
Normal file
|
|
@ -0,0 +1,249 @@
|
|||
from typing import Optional
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from api.db import db_client
|
||||
from api.services.campaign.call_dispatcher import campaign_call_dispatcher
|
||||
from api.services.pipecat.audio_transcript_buffers import (
|
||||
InMemoryAudioBuffer,
|
||||
InMemoryTranscriptBuffer,
|
||||
)
|
||||
from api.services.pipecat.pipeline_metrics_aggregator import PipelineMetricsAggregator
|
||||
from api.services.workflow.disposition_mapper import (
|
||||
apply_disposition_mapping,
|
||||
get_organization_id_from_workflow_run,
|
||||
)
|
||||
from api.services.workflow.pipecat_engine import PipecatEngine
|
||||
from api.tasks.arq import enqueue_job
|
||||
from api.tasks.function_names import FunctionNames
|
||||
from pipecat.pipeline.task import PipelineTask
|
||||
from pipecat.transports.base_transport import BaseTransport
|
||||
from pipecat.utils.enums import EndTaskReason
|
||||
|
||||
|
||||
def register_transport_event_handlers(
|
||||
transport,
|
||||
workflow_run_id,
|
||||
audio_buffer,
|
||||
task: PipelineTask,
|
||||
engine: PipecatEngine,
|
||||
usage_metrics_aggregator: PipelineMetricsAggregator,
|
||||
audio_synchronizer=None,
|
||||
audio_config=None,
|
||||
):
|
||||
"""Register event handlers for transport events"""
|
||||
|
||||
# Initialize in-memory buffers with proper audio configuration
|
||||
sample_rate = audio_config.pipeline_sample_rate if audio_config else 16000
|
||||
num_channels = 1 # Pipeline audio is always mono
|
||||
|
||||
logger.debug(
|
||||
f"Initializing audio buffer for workflow {workflow_run_id} "
|
||||
f"with sample_rate={sample_rate}Hz, channels={num_channels}"
|
||||
)
|
||||
|
||||
in_memory_audio_buffer = InMemoryAudioBuffer(
|
||||
workflow_run_id=workflow_run_id,
|
||||
sample_rate=sample_rate,
|
||||
num_channels=num_channels,
|
||||
)
|
||||
in_memory_transcript_buffer = InMemoryTranscriptBuffer(workflow_run_id)
|
||||
|
||||
@transport.event_handler("on_client_connected")
|
||||
async def on_client_connected(transport, participant):
|
||||
logger.debug("In on_client_connected callback handler - initializing workflow")
|
||||
await audio_buffer.start_recording()
|
||||
if audio_synchronizer:
|
||||
await audio_synchronizer.start_recording()
|
||||
await engine.initialize()
|
||||
|
||||
@transport.event_handler("on_client_disconnected")
|
||||
async def on_client_disconnected(
|
||||
transport: BaseTransport,
|
||||
participant,
|
||||
transport_disconnect_reason: Optional[str] = None,
|
||||
):
|
||||
logger.debug(
|
||||
f"In on_client_disconnected callback handler, disconnect_reason: {transport_disconnect_reason}"
|
||||
)
|
||||
|
||||
workflow_run = await db_client.get_workflow_run_by_id(workflow_run_id)
|
||||
|
||||
# First priority: Check if engine has a disconnect reason (local disconnect)
|
||||
engine_call_disposition = engine.get_call_disposition()
|
||||
gathered_context = engine.get_gathered_context()
|
||||
|
||||
# also consider existing gathered context in workflow_run
|
||||
gathered_context = {**gathered_context, **workflow_run.gathered_context}
|
||||
|
||||
if engine_call_disposition:
|
||||
# Engine has set a disconnect reason - this takes priority
|
||||
call_disposition = engine_call_disposition
|
||||
logger.debug(f"Engine disposition detected, code: {call_disposition}")
|
||||
elif transport_disconnect_reason:
|
||||
# TODO: Make this more generic using some DSL or equivalent. This is currently
|
||||
# configured to work for Kapil's bot
|
||||
call_duration = usage_metrics_aggregator.get_call_duration()
|
||||
if transport_disconnect_reason == EndTaskReason.USER_HANGUP.value:
|
||||
if call_duration < 10:
|
||||
call_disposition = "HU"
|
||||
else:
|
||||
call_disposition = "NIBP"
|
||||
else:
|
||||
# Transport provided a disconnect reason (remote hangup)
|
||||
call_disposition = transport_disconnect_reason
|
||||
logger.debug(
|
||||
f"Remote disconnect detected, reason: {call_disposition} duration: {call_duration}"
|
||||
)
|
||||
else:
|
||||
# No reason provided - assume user hangup
|
||||
call_disposition = EndTaskReason.UNKNOWN.value
|
||||
logger.debug("No disposition found from either engine or transport")
|
||||
|
||||
# Cancel task only when no engine disconnect reason (remote disconnect)
|
||||
if not engine_call_disposition:
|
||||
await task.cancel()
|
||||
|
||||
organization_id = await get_organization_id_from_workflow_run(workflow_run_id)
|
||||
mapped_call_disposition = await apply_disposition_mapping(
|
||||
call_disposition, organization_id
|
||||
)
|
||||
|
||||
gathered_context.update({"mapped_call_disposition": mapped_call_disposition})
|
||||
|
||||
if in_memory_transcript_buffer:
|
||||
call_tags = gathered_context.get("call_tags", [])
|
||||
|
||||
try:
|
||||
has_user_speech = in_memory_transcript_buffer.contains_user_speech()
|
||||
except Exception:
|
||||
has_user_speech = False
|
||||
|
||||
if has_user_speech and "user_speech" not in call_tags:
|
||||
call_tags.append("user_speech")
|
||||
|
||||
# Append any keys from gathered_context that start with 'tag_' to call_tags
|
||||
for key in gathered_context:
|
||||
if key.startswith("tag_") and key not in call_tags:
|
||||
call_tags.append(gathered_context[key])
|
||||
|
||||
gathered_context["call_tags"] = call_tags
|
||||
|
||||
# Clean up engine resources (including voicemail detector)
|
||||
await engine.cleanup()
|
||||
|
||||
await audio_buffer.stop_recording()
|
||||
if audio_synchronizer:
|
||||
await audio_synchronizer.stop_recording()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Close Smart-Turn WebSocket if the transport's analyzer supports it
|
||||
# ------------------------------------------------------------------
|
||||
try:
|
||||
turn_analyzer = None
|
||||
|
||||
# Most transports store their params (with turn_analyzer) directly.
|
||||
if hasattr(transport, "_params") and transport._params:
|
||||
turn_analyzer = getattr(transport._params, "turn_analyzer", None)
|
||||
|
||||
# Fallback: some transports expose params through input() instance.
|
||||
if turn_analyzer is None and hasattr(transport, "input"):
|
||||
try:
|
||||
input_transport = transport.input()
|
||||
if input_transport and hasattr(input_transport, "_params"):
|
||||
turn_analyzer = getattr(
|
||||
input_transport._params, "turn_analyzer", None
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if turn_analyzer and hasattr(turn_analyzer, "close"):
|
||||
await turn_analyzer.close()
|
||||
logger.debug("Closed turn analyzer websocket")
|
||||
except Exception as exc:
|
||||
logger.warning(f"Failed to close Smart-Turn analyzer gracefully: {exc}")
|
||||
|
||||
usage_info = usage_metrics_aggregator.get_all_usage_metrics_serialized()
|
||||
|
||||
logger.debug(f"Usage metrics: {usage_info}")
|
||||
|
||||
await db_client.update_workflow_run(
|
||||
run_id=workflow_run_id,
|
||||
usage_info=usage_info,
|
||||
gathered_context=gathered_context,
|
||||
is_completed=True,
|
||||
)
|
||||
|
||||
# Release concurrent slot for campaign calls
|
||||
if workflow_run and workflow_run.campaign_id:
|
||||
await campaign_call_dispatcher.release_call_slot(workflow_run_id)
|
||||
|
||||
# Write buffers to temp files and enqueue S3 upload
|
||||
try:
|
||||
# Only upload if buffers have content
|
||||
if not in_memory_audio_buffer.is_empty:
|
||||
audio_temp_path = await in_memory_audio_buffer.write_to_temp_file()
|
||||
await enqueue_job(
|
||||
FunctionNames.UPLOAD_AUDIO_TO_S3, workflow_run_id, audio_temp_path
|
||||
)
|
||||
else:
|
||||
logger.debug("Audio buffer is empty, skipping upload")
|
||||
|
||||
if not in_memory_transcript_buffer.is_empty:
|
||||
transcript_temp_path = (
|
||||
await in_memory_transcript_buffer.write_to_temp_file()
|
||||
)
|
||||
await enqueue_job(
|
||||
FunctionNames.UPLOAD_TRANSCRIPT_TO_S3,
|
||||
workflow_run_id,
|
||||
transcript_temp_path,
|
||||
)
|
||||
else:
|
||||
logger.debug("Transcript buffer is empty, skipping upload")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error preparing buffers for S3 upload: {e}", exc_info=True)
|
||||
|
||||
await enqueue_job(FunctionNames.CALCULATE_WORKFLOW_RUN_COST, workflow_run_id)
|
||||
await enqueue_job(
|
||||
FunctionNames.RUN_INTEGRATIONS_POST_WORKFLOW_RUN, workflow_run_id
|
||||
)
|
||||
|
||||
# Return the buffers so they can be passed to other handlers
|
||||
return in_memory_audio_buffer, in_memory_transcript_buffer
|
||||
|
||||
|
||||
def register_audio_data_handler(
|
||||
audio_synchronizer, workflow_run_id, in_memory_buffer: InMemoryAudioBuffer
|
||||
):
|
||||
"""Register event handler for audio data"""
|
||||
logger.info(f"Registering audio data handler for workflow run {workflow_run_id}")
|
||||
|
||||
@audio_synchronizer.event_handler("on_merged_audio")
|
||||
async def on_merged_audio(_, pcm, sample_rate, num_channels):
|
||||
if not pcm:
|
||||
return
|
||||
|
||||
# Use in-memory buffer
|
||||
try:
|
||||
await in_memory_buffer.append(pcm)
|
||||
except MemoryError as e:
|
||||
logger.error(f"Memory buffer full: {e}")
|
||||
# Could implement overflow to disk here if needed
|
||||
|
||||
|
||||
def register_transcript_handler(
|
||||
transcript, workflow_run_id, in_memory_buffer: InMemoryTranscriptBuffer
|
||||
):
|
||||
"""Register event handler for transcript updates"""
|
||||
|
||||
@transcript.event_handler("on_transcript_update")
|
||||
async def on_transcript_update(processor, frame):
|
||||
transcript_text = ""
|
||||
for msg in frame.messages:
|
||||
timestamp = f"[{msg.timestamp}] " if msg.timestamp else ""
|
||||
line = f"{timestamp}{msg.role}: {msg.content}\n"
|
||||
transcript_text += line
|
||||
|
||||
# Use in-memory buffer
|
||||
await in_memory_buffer.append(transcript_text)
|
||||
6
api/services/pipecat/exceptions.py
Normal file
6
api/services/pipecat/exceptions.py
Normal file
|
|
@ -0,0 +1,6 @@
|
|||
class VoicemailDetectedException(Exception):
|
||||
"""
|
||||
Exception raised when voicemail is detected.
|
||||
"""
|
||||
|
||||
pass
|
||||
147
api/services/pipecat/pipeline_builder.py
Normal file
147
api/services/pipecat/pipeline_builder.py
Normal file
|
|
@ -0,0 +1,147 @@
|
|||
import os
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from api.constants import (
|
||||
ENABLE_TRACING,
|
||||
)
|
||||
from api.services.pipecat.audio_config import AudioConfig
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||
from pipecat.processors.audio.audio_buffer_processor import AudioBuffer
|
||||
from pipecat.processors.audio.audio_synchronizer import AudioSynchronizer
|
||||
from pipecat.processors.transcript_processor import TranscriptProcessor
|
||||
from pipecat.utils.context import turn_var
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from api.services.workflow.pipecat_engine import PipecatEngine
|
||||
|
||||
|
||||
def create_pipeline_components(audio_config: AudioConfig, engine: "PipecatEngine"):
|
||||
"""Create and return the main pipeline components with proper audio configuration"""
|
||||
logger.info(f"Creating pipeline components with audio config: {audio_config}")
|
||||
|
||||
# Use new split audio buffer for better performance
|
||||
audio_buffer = AudioBuffer(
|
||||
sample_rate=audio_config.pipeline_sample_rate,
|
||||
buffer_size=audio_config.buffer_size_bytes,
|
||||
)
|
||||
|
||||
# Create synchronizer for merged audio (outside pipeline)
|
||||
audio_synchronizer = AudioSynchronizer(
|
||||
sample_rate=audio_config.pipeline_sample_rate,
|
||||
buffer_size=audio_config.buffer_size_bytes,
|
||||
)
|
||||
|
||||
transcript = TranscriptProcessor(
|
||||
assistant_correct_aggregation_callback=engine.create_aggregation_correction_callback()
|
||||
)
|
||||
|
||||
context = OpenAILLMContext()
|
||||
|
||||
return audio_buffer, audio_synchronizer, transcript, context
|
||||
|
||||
|
||||
def build_pipeline(
|
||||
transport,
|
||||
stt,
|
||||
transcript,
|
||||
audio_buffer,
|
||||
audio_synchronizer,
|
||||
llm,
|
||||
tts,
|
||||
user_context_aggregator,
|
||||
assistant_context_aggregator,
|
||||
pipeline_engine_callback_processor,
|
||||
stt_mute_filter,
|
||||
pipeline_metrics_aggregator,
|
||||
user_idle_disconnect,
|
||||
engine_pre_aggregator_processor=None,
|
||||
):
|
||||
"""Build the main pipeline with all components"""
|
||||
# Register processors with synchronizer for merged audio
|
||||
logger.info("Registering audio buffer processors with synchronizer")
|
||||
audio_synchronizer.register_processors(audio_buffer.input(), audio_buffer.output())
|
||||
|
||||
# Build processors list with optional context controller
|
||||
processors = [
|
||||
transport.input(), # Transport user input
|
||||
audio_buffer.input(), # Record input audio (only processes InputAudioRawFrame)
|
||||
stt_mute_filter,
|
||||
stt, # STT can now have audio_passthrough=False
|
||||
user_idle_disconnect,
|
||||
transcript.user(),
|
||||
]
|
||||
|
||||
# Insert engine pre-aggregator processor if provided (before user aggregator)
|
||||
if engine_pre_aggregator_processor:
|
||||
processors.append(engine_pre_aggregator_processor)
|
||||
|
||||
processors.extend(
|
||||
[
|
||||
user_context_aggregator,
|
||||
llm, # LLM
|
||||
pipeline_engine_callback_processor,
|
||||
tts, # TTS
|
||||
transport.output(), # Transport bot output
|
||||
audio_buffer.output(), # Record output audio (only processes OutputAudioRawFrame)
|
||||
transcript.assistant(),
|
||||
assistant_context_aggregator, # Assistant spoken responses
|
||||
pipeline_metrics_aggregator,
|
||||
]
|
||||
)
|
||||
|
||||
return Pipeline(processors)
|
||||
|
||||
|
||||
def create_pipeline_task(pipeline, workflow_run_id, audio_config: AudioConfig = None):
|
||||
"""Create a pipeline task with appropriate parameters"""
|
||||
# Set up pipeline params with audio configuration if provided
|
||||
pipeline_params = PipelineParams(
|
||||
allow_interruptions=True,
|
||||
enable_metrics=True,
|
||||
enable_usage_metrics=True,
|
||||
send_initial_empty_metrics=False,
|
||||
enable_heartbeats=True,
|
||||
)
|
||||
|
||||
# If audio_config is provided, set the audio sample rates
|
||||
if audio_config:
|
||||
pipeline_params.audio_in_sample_rate = audio_config.transport_in_sample_rate
|
||||
pipeline_params.audio_out_sample_rate = audio_config.transport_out_sample_rate
|
||||
logger.debug(
|
||||
f"Setting pipeline audio params - in: {audio_config.transport_in_sample_rate}Hz, "
|
||||
f"out: {audio_config.transport_out_sample_rate}Hz"
|
||||
)
|
||||
|
||||
task = PipelineTask(
|
||||
pipeline,
|
||||
params=pipeline_params,
|
||||
enable_tracing=ENABLE_TRACING,
|
||||
conversation_id=f"{workflow_run_id}",
|
||||
)
|
||||
|
||||
# Check if turn logging is enabled
|
||||
enable_turn_logging = os.getenv("ENABLE_TURN_LOGGING", "false").lower() == "true"
|
||||
|
||||
if enable_turn_logging:
|
||||
# Attach event handlers to propagate turn information into the logging context
|
||||
turn_observer = task.turn_tracking_observer
|
||||
|
||||
if turn_observer is not None:
|
||||
# Import turn context manager only if needed
|
||||
from api.services.pipecat.turn_context import get_turn_context_manager
|
||||
|
||||
async def _on_turn_started(observer, turn_number: int):
|
||||
"""Set the current turn number into the context variable."""
|
||||
# Set in both contextvar and turn context manager
|
||||
turn_var.set(turn_number)
|
||||
turn_manager = get_turn_context_manager()
|
||||
turn_manager.set_turn(turn_number)
|
||||
|
||||
# Register the handlers with the observer
|
||||
turn_observer.add_event_handler("on_turn_started", _on_turn_started)
|
||||
|
||||
return task
|
||||
84
api/services/pipecat/pipeline_engine_callbacks_processor.py
Normal file
84
api/services/pipecat/pipeline_engine_callbacks_processor.py
Normal file
|
|
@ -0,0 +1,84 @@
|
|||
import time
|
||||
from typing import Awaitable, Callable, Optional
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
Frame,
|
||||
HeartbeatFrame,
|
||||
LLMFullResponseStartFrame,
|
||||
LLMGeneratedTextFrame,
|
||||
LLMTextFrame,
|
||||
StartFrame,
|
||||
TTSSpeakFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
|
||||
|
||||
class PipelineEngineCallbacksProcessor(FrameProcessor):
|
||||
"""
|
||||
Custom PipelineEngineCallbacksProcessor that accepts callbacks for various
|
||||
use cases, like ending tasks when max call duration is exceeded, or informing
|
||||
the engine that the bot is done speaking.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_call_duration_seconds: int = 300,
|
||||
max_duration_end_task_callback: Optional[Callable[[], Awaitable[None]]] = None,
|
||||
llm_generated_text_callback: Optional[Callable[[], Awaitable[None]]] = None,
|
||||
generation_started_callback: Optional[Callable[[], Awaitable[None]]] = None,
|
||||
llm_text_frame_callback: Optional[Callable[[str], Awaitable[None]]] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self._start_time = None
|
||||
self._max_call_duration_seconds = max_call_duration_seconds
|
||||
self._max_duration_end_task_callback = max_duration_end_task_callback
|
||||
self._llm_generated_text_callback = llm_generated_text_callback
|
||||
self._generation_started_callback = generation_started_callback
|
||||
self._llm_text_frame_callback = llm_text_frame_callback
|
||||
self._end_task_frame_pushed = False
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, StartFrame):
|
||||
await self._start(frame)
|
||||
elif isinstance(frame, HeartbeatFrame):
|
||||
await self._check_call_duration()
|
||||
elif isinstance(frame, LLMGeneratedTextFrame):
|
||||
await self._generated_text_frame(frame)
|
||||
elif isinstance(frame, LLMFullResponseStartFrame):
|
||||
await self._generation_started()
|
||||
elif (
|
||||
isinstance(frame, (LLMTextFrame, TTSSpeakFrame))
|
||||
and self._llm_text_frame_callback
|
||||
):
|
||||
# Include TTSSpeakFrame here since for static nodes, we send TTSSpeakFrame
|
||||
# which can act as reference while fixing the aggregated trascript
|
||||
await self._llm_text_frame_callback(frame.text)
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
async def _start(self, _: StartFrame):
|
||||
self._start_time = time.time()
|
||||
|
||||
async def _check_call_duration(self):
|
||||
if self._start_time is not None:
|
||||
if time.time() - self._start_time > self._max_call_duration_seconds:
|
||||
if not self._end_task_frame_pushed:
|
||||
if self._max_duration_end_task_callback:
|
||||
await self._max_duration_end_task_callback()
|
||||
self._end_task_frame_pushed = True
|
||||
else:
|
||||
logger.debug(
|
||||
"Max call duration exceeded. Skipping EndTaskFrame since already sent"
|
||||
)
|
||||
|
||||
async def _generated_text_frame(self, _: LLMGeneratedTextFrame):
|
||||
"""Handle LLMGeneratedTextFrame."""
|
||||
if self._llm_generated_text_callback is not None:
|
||||
await self._llm_generated_text_callback()
|
||||
|
||||
async def _generation_started(self):
|
||||
if self._generation_started_callback:
|
||||
await self._generation_started_callback()
|
||||
162
api/services/pipecat/pipeline_metrics_aggregator.py
Normal file
162
api/services/pipecat/pipeline_metrics_aggregator.py
Normal file
|
|
@ -0,0 +1,162 @@
|
|||
import time
|
||||
from collections import defaultdict
|
||||
from typing import Dict, Optional
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
CancelFrame,
|
||||
EndFrame,
|
||||
Frame,
|
||||
MetricsFrame,
|
||||
StartFrame,
|
||||
)
|
||||
from pipecat.metrics.metrics import (
|
||||
LLMTokenUsage,
|
||||
LLMUsageMetricsData,
|
||||
STTUsageMetricsData,
|
||||
TTSUsageMetricsData,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
|
||||
|
||||
class PipelineMetricsAggregator(FrameProcessor):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
# Structure: {f"{processor}|||{model}": aggregated_metrics}
|
||||
# For LLM: aggregated_metrics is LLMTokenUsage
|
||||
# For TTS: aggregated_metrics is int (total characters)
|
||||
# For STT: aggregated_metrics is float (total seconds)
|
||||
|
||||
self._start_time: Optional[float] = None
|
||||
self._stop_time: Optional[float] = None
|
||||
self._llm_usage_metrics: Dict[str, LLMTokenUsage] = {}
|
||||
self._tts_usage_metrics: Dict[str, int] = defaultdict(int)
|
||||
self._stt_usage_metrics: Dict[str, float] = defaultdict(float)
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, StartFrame):
|
||||
await self._start(frame)
|
||||
elif isinstance(frame, EndFrame):
|
||||
await self._stop(frame)
|
||||
elif isinstance(frame, CancelFrame):
|
||||
await self._cancel(frame)
|
||||
elif isinstance(frame, MetricsFrame):
|
||||
for data in frame.data:
|
||||
if isinstance(data, LLMUsageMetricsData):
|
||||
await self._handle_llm_usage_metrics(data)
|
||||
elif isinstance(data, TTSUsageMetricsData):
|
||||
await self._handle_tts_usage_metrics(data)
|
||||
elif isinstance(data, STTUsageMetricsData):
|
||||
await self._handle_stt_usage_metrics(data)
|
||||
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
async def _start(self, _: StartFrame):
|
||||
"""Start tracking call duration."""
|
||||
self._start_time = time.time()
|
||||
self._stop_time = None
|
||||
|
||||
async def _stop(self, _: EndFrame):
|
||||
"""Stop tracking call duration."""
|
||||
if self._start_time is not None and self._stop_time is None:
|
||||
self._stop_time = time.time()
|
||||
|
||||
async def _cancel(self, _: CancelFrame):
|
||||
"""Handle call cancellation - also stop tracking duration."""
|
||||
if self._start_time is not None and self._stop_time is None:
|
||||
self._stop_time = time.time()
|
||||
|
||||
async def _handle_llm_usage_metrics(self, data: LLMUsageMetricsData):
|
||||
key = f"{data.processor}|||{data.model}"
|
||||
new_usage = data.value
|
||||
|
||||
if key in self._llm_usage_metrics:
|
||||
# Aggregate with existing metrics
|
||||
existing = self._llm_usage_metrics[key]
|
||||
aggregated = LLMTokenUsage(
|
||||
prompt_tokens=existing.prompt_tokens + new_usage.prompt_tokens,
|
||||
completion_tokens=existing.completion_tokens
|
||||
+ new_usage.completion_tokens,
|
||||
total_tokens=existing.total_tokens + new_usage.total_tokens,
|
||||
cache_read_input_tokens=(existing.cache_read_input_tokens or 0)
|
||||
+ (new_usage.cache_read_input_tokens or 0),
|
||||
cache_creation_input_tokens=(existing.cache_creation_input_tokens or 0)
|
||||
+ (new_usage.cache_creation_input_tokens or 0),
|
||||
)
|
||||
self._llm_usage_metrics[key] = aggregated
|
||||
else:
|
||||
# First occurrence for this processor+model combination
|
||||
self._llm_usage_metrics[key] = LLMTokenUsage(
|
||||
prompt_tokens=new_usage.prompt_tokens,
|
||||
completion_tokens=new_usage.completion_tokens,
|
||||
total_tokens=new_usage.total_tokens,
|
||||
cache_read_input_tokens=new_usage.cache_read_input_tokens,
|
||||
cache_creation_input_tokens=new_usage.cache_creation_input_tokens,
|
||||
)
|
||||
|
||||
logger.debug(f"LLM usage metrics: {self._llm_usage_metrics}")
|
||||
|
||||
async def _handle_tts_usage_metrics(self, data: TTSUsageMetricsData):
|
||||
key = f"{data.processor}|||{data.model}"
|
||||
self._tts_usage_metrics[key] += data.value
|
||||
# logger.debug(f"TTS usage metrics: {self._tts_usage_metrics}")
|
||||
|
||||
async def _handle_stt_usage_metrics(self, data: STTUsageMetricsData):
|
||||
key = f"{data.processor}|||{data.model}"
|
||||
self._stt_usage_metrics[key] += data.value
|
||||
logger.debug(f"STT usage metrics: {self._stt_usage_metrics}")
|
||||
|
||||
def get_llm_usage_metrics(self) -> Dict[str, LLMTokenUsage]:
|
||||
"""Get the aggregated LLM usage metrics grouped by processor|||model."""
|
||||
return self._llm_usage_metrics
|
||||
|
||||
def get_tts_usage_metrics(self) -> Dict[str, int]:
|
||||
"""Get the aggregated TTS usage metrics grouped by processor|||model."""
|
||||
return self._tts_usage_metrics
|
||||
|
||||
def get_stt_usage_metrics(self) -> Dict[str, float]:
|
||||
"""Get the aggregated STT usage metrics grouped by processor|||model."""
|
||||
return self._stt_usage_metrics
|
||||
|
||||
def get_call_duration(self) -> float:
|
||||
"""Get call duration"""
|
||||
if self._start_time is None:
|
||||
return 0.0
|
||||
|
||||
if self._stop_time is None:
|
||||
call_duration = time.time() - self._start_time
|
||||
else:
|
||||
call_duration = self._stop_time - self._start_time
|
||||
|
||||
# Lets return a rounded integer
|
||||
return int(round(call_duration))
|
||||
|
||||
def get_all_usage_metrics_serialized(self) -> Dict[str, Dict[str, any]]:
|
||||
"""Get all aggregated usage metrics in JSON-serializable format."""
|
||||
serialized_llm = {}
|
||||
for key, usage in self._llm_usage_metrics.items():
|
||||
serialized_llm[key] = {
|
||||
"prompt_tokens": usage.prompt_tokens,
|
||||
"completion_tokens": usage.completion_tokens,
|
||||
"total_tokens": usage.total_tokens,
|
||||
"cache_read_input_tokens": usage.cache_read_input_tokens,
|
||||
"cache_creation_input_tokens": usage.cache_creation_input_tokens,
|
||||
}
|
||||
|
||||
return {
|
||||
"llm": serialized_llm,
|
||||
"tts": dict(self._tts_usage_metrics),
|
||||
"stt": dict(self._stt_usage_metrics),
|
||||
"call_duration_seconds": self.get_call_duration(),
|
||||
}
|
||||
|
||||
def reset_metrics(self):
|
||||
"""Reset all aggregated metrics."""
|
||||
self._llm_usage_metrics.clear()
|
||||
self._tts_usage_metrics.clear()
|
||||
self._stt_usage_metrics.clear()
|
||||
self._start_time = None
|
||||
self._stop_time = None
|
||||
388
api/services/pipecat/run_pipeline.py
Normal file
388
api/services/pipecat/run_pipeline.py
Normal file
|
|
@ -0,0 +1,388 @@
|
|||
from typing import Optional
|
||||
|
||||
from fastapi import HTTPException, WebSocket
|
||||
from loguru import logger
|
||||
|
||||
from api.db import db_client
|
||||
from api.enums import WorkflowRunMode
|
||||
from api.services.pipecat.audio_config import AudioConfig, create_audio_config
|
||||
from api.services.pipecat.engine_pre_aggregator_processor import (
|
||||
EnginePreAggregatorProcessor,
|
||||
)
|
||||
from api.services.pipecat.event_handlers import (
|
||||
register_audio_data_handler,
|
||||
register_transcript_handler,
|
||||
register_transport_event_handlers,
|
||||
)
|
||||
from api.services.pipecat.pipeline_builder import (
|
||||
build_pipeline,
|
||||
create_pipeline_components,
|
||||
create_pipeline_task,
|
||||
)
|
||||
from api.services.pipecat.pipeline_engine_callbacks_processor import (
|
||||
PipelineEngineCallbacksProcessor,
|
||||
)
|
||||
from api.services.pipecat.pipeline_metrics_aggregator import PipelineMetricsAggregator
|
||||
from api.services.pipecat.service_factory import (
|
||||
create_llm_service,
|
||||
create_stt_service,
|
||||
create_tts_service,
|
||||
)
|
||||
from api.services.pipecat.tracing_config import setup_pipeline_tracing
|
||||
from api.services.pipecat.transport_setup import (
|
||||
create_stasis_transport,
|
||||
create_twilio_transport,
|
||||
create_webrtc_transport,
|
||||
)
|
||||
from api.services.telephony.stasis_rtp_connection import StasisRTPConnection
|
||||
from api.services.workflow.dto import ReactFlowDTO
|
||||
from api.services.workflow.pipecat_engine import PipecatEngine
|
||||
from api.services.workflow.workflow import WorkflowGraph
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.processors.aggregators.llm_response import LLMAssistantAggregatorParams
|
||||
from pipecat.processors.filters.stt_mute_filter import (
|
||||
STTMuteConfig,
|
||||
STTMuteFilter,
|
||||
STTMuteStrategy,
|
||||
)
|
||||
from pipecat.processors.user_idle_processor import UserIdleProcessor
|
||||
from pipecat.transports.network.webrtc_connection import SmallWebRTCConnection
|
||||
from pipecat.utils.context import set_current_run_id
|
||||
from pipecat.utils.tracing.context_registry import ContextProviderRegistry
|
||||
|
||||
# Setup tracing if enabled
|
||||
setup_pipeline_tracing()
|
||||
|
||||
|
||||
async def run_pipeline_twilio(
|
||||
websocket_client: WebSocket,
|
||||
stream_sid: str,
|
||||
call_sid: str,
|
||||
workflow_id: int,
|
||||
workflow_run_id: int,
|
||||
user_id: int,
|
||||
) -> None:
|
||||
"""Run pipeline for Twilio connections"""
|
||||
logger.debug(
|
||||
f"Running pipeline for Twilio connection with workflow_id: {workflow_id} and workflow_run_id: {workflow_run_id}"
|
||||
)
|
||||
set_current_run_id(workflow_run_id)
|
||||
|
||||
# Store Twilio call SID in cost_info for later cost calculation
|
||||
cost_info = {"twilio_call_sid": call_sid}
|
||||
await db_client.update_workflow_run(workflow_run_id, cost_info=cost_info)
|
||||
|
||||
# Get workflow to extract all pipeline configurations
|
||||
workflow = await db_client.get_workflow(workflow_id, user_id)
|
||||
vad_config = None
|
||||
ambient_noise_config = None
|
||||
if workflow and workflow.workflow_configurations:
|
||||
if "vad_configuration" in workflow.workflow_configurations:
|
||||
vad_config = workflow.workflow_configurations["vad_configuration"]
|
||||
if "ambient_noise_configuration" in workflow.workflow_configurations:
|
||||
ambient_noise_config = workflow.workflow_configurations[
|
||||
"ambient_noise_configuration"
|
||||
]
|
||||
|
||||
# Create audio configuration for Twilio
|
||||
audio_config = create_audio_config(WorkflowRunMode.TWILIO.value)
|
||||
|
||||
transport = create_twilio_transport(
|
||||
websocket_client,
|
||||
stream_sid,
|
||||
call_sid,
|
||||
workflow_run_id,
|
||||
audio_config,
|
||||
vad_config,
|
||||
ambient_noise_config,
|
||||
)
|
||||
await _run_pipeline(
|
||||
transport,
|
||||
workflow_id,
|
||||
workflow_run_id,
|
||||
user_id,
|
||||
audio_config=audio_config,
|
||||
)
|
||||
|
||||
|
||||
async def run_pipeline_smallwebrtc(
|
||||
webrtc_connection: SmallWebRTCConnection,
|
||||
workflow_id: int,
|
||||
workflow_run_id: int,
|
||||
user_id: int,
|
||||
call_context_vars: dict = {},
|
||||
) -> None:
|
||||
"""Run pipeline for WebRTC connections"""
|
||||
logger.debug(
|
||||
f"Running pipeline for WebRTC connection with workflow_id: {workflow_id} and workflow_run_id: {workflow_run_id}"
|
||||
)
|
||||
set_current_run_id(workflow_run_id)
|
||||
|
||||
# Get workflow to extract all pipeline configurations
|
||||
workflow = await db_client.get_workflow(workflow_id, user_id)
|
||||
vad_config = None
|
||||
ambient_noise_config = None
|
||||
if workflow and workflow.workflow_configurations:
|
||||
if "vad_configuration" in workflow.workflow_configurations:
|
||||
vad_config = workflow.workflow_configurations["vad_configuration"]
|
||||
if "ambient_noise_configuration" in workflow.workflow_configurations:
|
||||
ambient_noise_config = workflow.workflow_configurations[
|
||||
"ambient_noise_configuration"
|
||||
]
|
||||
|
||||
# Create audio configuration for WebRTC
|
||||
audio_config = create_audio_config(WorkflowRunMode.SMALLWEBRTC.value)
|
||||
|
||||
transport = create_webrtc_transport(
|
||||
webrtc_connection,
|
||||
workflow_run_id,
|
||||
audio_config,
|
||||
vad_config,
|
||||
ambient_noise_config,
|
||||
)
|
||||
await _run_pipeline(
|
||||
transport,
|
||||
workflow_id,
|
||||
workflow_run_id,
|
||||
user_id,
|
||||
call_context_vars=call_context_vars,
|
||||
audio_config=audio_config,
|
||||
)
|
||||
|
||||
|
||||
async def run_pipeline_ari_stasis(
|
||||
stasis_connection: StasisRTPConnection,
|
||||
workflow_id: int,
|
||||
workflow_run_id: int,
|
||||
user_id: int,
|
||||
call_context_vars: dict,
|
||||
) -> None:
|
||||
"""Run pipeline for ARI connections"""
|
||||
logger.debug(
|
||||
f"Running pipeline for ARI connection with workflow_id: {workflow_id} and workflow_run_id: {workflow_run_id}"
|
||||
)
|
||||
set_current_run_id(workflow_run_id)
|
||||
|
||||
# Get workflow to extract all pipeline configurations
|
||||
workflow = await db_client.get_workflow(workflow_id, user_id)
|
||||
vad_config = None
|
||||
ambient_noise_config = None
|
||||
if workflow and workflow.workflow_configurations:
|
||||
if "vad_configuration" in workflow.workflow_configurations:
|
||||
vad_config = workflow.workflow_configurations["vad_configuration"]
|
||||
if "ambient_noise_configuration" in workflow.workflow_configurations:
|
||||
ambient_noise_config = workflow.workflow_configurations[
|
||||
"ambient_noise_configuration"
|
||||
]
|
||||
|
||||
# Create audio configuration for Stasis
|
||||
audio_config = create_audio_config(WorkflowRunMode.STASIS.value)
|
||||
|
||||
transport = create_stasis_transport(
|
||||
stasis_connection,
|
||||
workflow_run_id,
|
||||
audio_config,
|
||||
vad_config,
|
||||
ambient_noise_config,
|
||||
)
|
||||
await _run_pipeline(
|
||||
transport,
|
||||
workflow_id,
|
||||
workflow_run_id,
|
||||
user_id,
|
||||
call_context_vars=call_context_vars,
|
||||
audio_config=audio_config,
|
||||
stasis_connection=stasis_connection, # Pass connection for immediate transfers
|
||||
)
|
||||
|
||||
|
||||
async def _run_pipeline(
|
||||
transport,
|
||||
workflow_id: int,
|
||||
workflow_run_id: int,
|
||||
user_id: int,
|
||||
call_context_vars: dict = {},
|
||||
audio_config: AudioConfig = None,
|
||||
stasis_connection: Optional[StasisRTPConnection] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Run the pipeline with the given transport and configuration
|
||||
|
||||
Args:
|
||||
transport: The transport to use for the pipeline
|
||||
workflow_id: The ID of the workflow
|
||||
workflow_run_id: The ID of the workflow run
|
||||
user_id: The ID of the user
|
||||
mode: The mode of the pipeline (twilio or smallwebrtc)
|
||||
"""
|
||||
workflow_run = await db_client.get_workflow_run(workflow_run_id, user_id)
|
||||
|
||||
# If the workflow run is already completed, we don't need to run it again
|
||||
if workflow_run.is_completed:
|
||||
raise HTTPException(status_code=400, detail="Workflow run already completed")
|
||||
|
||||
merged_call_context_vars = workflow_run.initial_context
|
||||
# If there is some extra call_context_vars, update them
|
||||
if call_context_vars:
|
||||
merged_call_context_vars = {**merged_call_context_vars, **call_context_vars}
|
||||
await db_client.update_workflow_run(
|
||||
workflow_run_id, initial_context=merged_call_context_vars
|
||||
)
|
||||
|
||||
# Get user configuration
|
||||
user_config = await db_client.get_user_configurations(user_id)
|
||||
|
||||
# Create services based on user configuration
|
||||
stt = create_stt_service(user_config)
|
||||
tts = create_tts_service(user_config, audio_config)
|
||||
llm = create_llm_service(user_config)
|
||||
|
||||
# Get workflow first so we can create engine before pipeline components
|
||||
workflow = await db_client.get_workflow(workflow_id, user_id)
|
||||
if not workflow:
|
||||
raise HTTPException(status_code=404, detail="Workflow not found")
|
||||
|
||||
# Extract configurations from workflow configurations
|
||||
max_call_duration_seconds = 300 # Default 5 minutes
|
||||
max_user_idle_timeout = 10.0 # Default 10 seconds
|
||||
|
||||
if workflow.workflow_configurations:
|
||||
# Use workflow-specific max call duration if provided
|
||||
if "max_call_duration" in workflow.workflow_configurations:
|
||||
max_call_duration_seconds = workflow.workflow_configurations[
|
||||
"max_call_duration"
|
||||
]
|
||||
|
||||
# Use workflow-specific max user idle timeout if provided
|
||||
if "max_user_idle_timeout" in workflow.workflow_configurations:
|
||||
max_user_idle_timeout = workflow.workflow_configurations[
|
||||
"max_user_idle_timeout"
|
||||
]
|
||||
|
||||
workflow_graph = WorkflowGraph(
|
||||
ReactFlowDTO.model_validate(workflow.workflow_definition_with_fallback)
|
||||
)
|
||||
|
||||
engine = PipecatEngine(
|
||||
llm=llm,
|
||||
tts=tts,
|
||||
workflow=workflow_graph,
|
||||
call_context_vars=merged_call_context_vars,
|
||||
workflow_run_id=workflow_run_id,
|
||||
)
|
||||
|
||||
# Create pipeline components with audio configuration and engine
|
||||
audio_buffer, audio_synchronizer, transcript, context = create_pipeline_components(
|
||||
audio_config, engine
|
||||
)
|
||||
|
||||
# Set the context and audio_buffer after creation
|
||||
engine.set_context(context)
|
||||
engine.set_audio_buffer(audio_buffer)
|
||||
|
||||
# Set Stasis connection for immediate transfers (if available)
|
||||
if stasis_connection:
|
||||
engine.set_stasis_connection(stasis_connection)
|
||||
|
||||
assistant_params = LLMAssistantAggregatorParams(
|
||||
expect_stripped_words=True,
|
||||
correct_aggregation_callback=engine.create_aggregation_correction_callback(),
|
||||
)
|
||||
context_aggregator = llm.create_context_aggregator(
|
||||
context, assistant_params=assistant_params
|
||||
)
|
||||
|
||||
# Create engine pre-aggregator processor for speaking events
|
||||
engine_pre_aggregator_processor = EnginePreAggregatorProcessor(
|
||||
user_started_speaking_callback=engine.create_user_started_speaking_callback(),
|
||||
user_stopped_speaking_callback=engine.create_user_stopped_speaking_callback(),
|
||||
)
|
||||
|
||||
# Create usage metrics aggregator with engine's callback
|
||||
pipeline_engine_callback_processor = PipelineEngineCallbacksProcessor(
|
||||
max_call_duration_seconds=max_call_duration_seconds,
|
||||
max_duration_end_task_callback=engine.create_max_duration_callback(),
|
||||
llm_generated_text_callback=engine.create_llm_generated_text_callback(),
|
||||
generation_started_callback=engine.create_generation_started_callback(),
|
||||
llm_text_frame_callback=engine.handle_llm_text_frame,
|
||||
# Note: speaking event callbacks are now handled by pre-aggregator processor
|
||||
)
|
||||
|
||||
pipeline_metrics_aggregator = PipelineMetricsAggregator()
|
||||
|
||||
# Create STT mute filter using the selected strategies and the engine's callback
|
||||
stt_mute_filter = STTMuteFilter(
|
||||
config=STTMuteConfig(
|
||||
strategies={
|
||||
STTMuteStrategy.MUTE_UNTIL_FIRST_BOT_COMPLETE,
|
||||
STTMuteStrategy.CUSTOM,
|
||||
},
|
||||
should_mute_callback=engine.create_should_mute_callback(),
|
||||
)
|
||||
)
|
||||
|
||||
# Use engine's user idle callback with configured timeout
|
||||
user_idle_disconnect = UserIdleProcessor(
|
||||
callback=engine.create_user_idle_callback(), timeout=max_user_idle_timeout
|
||||
)
|
||||
|
||||
user_context_aggregator = context_aggregator.user()
|
||||
assistant_context_aggregator = context_aggregator.assistant()
|
||||
|
||||
@assistant_context_aggregator.event_handler("on_push_aggregation")
|
||||
async def on_assistant_aggregator_push_context(_aggregator):
|
||||
logger.debug("Assistant aggregator push context – flushing pending transitions")
|
||||
await engine.flush_pending_transitions(source="context_push")
|
||||
|
||||
# Build the pipeline with the STT mute filter and context controller
|
||||
pipeline = build_pipeline(
|
||||
transport,
|
||||
stt,
|
||||
transcript,
|
||||
audio_buffer,
|
||||
audio_synchronizer,
|
||||
llm,
|
||||
tts,
|
||||
user_context_aggregator,
|
||||
assistant_context_aggregator,
|
||||
pipeline_engine_callback_processor,
|
||||
stt_mute_filter,
|
||||
pipeline_metrics_aggregator,
|
||||
user_idle_disconnect,
|
||||
engine_pre_aggregator_processor=engine_pre_aggregator_processor,
|
||||
)
|
||||
|
||||
# Create pipeline task with audio configuration
|
||||
task = create_pipeline_task(pipeline, workflow_run_id, audio_config)
|
||||
|
||||
# Now set the task on the engine
|
||||
engine.set_task(task)
|
||||
|
||||
# Register event handlers
|
||||
in_memory_audio_buffer, in_memory_transcript_buffer = (
|
||||
register_transport_event_handlers(
|
||||
transport,
|
||||
workflow_run_id,
|
||||
audio_buffer,
|
||||
task,
|
||||
engine=engine,
|
||||
usage_metrics_aggregator=pipeline_metrics_aggregator,
|
||||
audio_synchronizer=audio_synchronizer,
|
||||
audio_config=audio_config,
|
||||
)
|
||||
)
|
||||
register_audio_data_handler(
|
||||
audio_synchronizer, workflow_run_id, in_memory_audio_buffer
|
||||
)
|
||||
register_transcript_handler(
|
||||
transcript, workflow_run_id, in_memory_transcript_buffer
|
||||
)
|
||||
|
||||
try:
|
||||
# Run the pipeline
|
||||
runner = PipelineRunner()
|
||||
await runner.run(task)
|
||||
logger.info(f"Pipeline runner completed for run {workflow_run_id}")
|
||||
finally:
|
||||
ContextProviderRegistry.remove_providers(str(workflow_run_id))
|
||||
logger.debug(f"Cleaned up context providers for workflow run {workflow_run_id}")
|
||||
150
api/services/pipecat/service_factory.py
Normal file
150
api/services/pipecat/service_factory.py
Normal file
|
|
@ -0,0 +1,150 @@
|
|||
from typing import TYPE_CHECKING
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
from api.constants import MPS_API_URL
|
||||
from api.services.configuration.registry import ServiceProviders
|
||||
from pipecat.services.azure.llm import AzureLLMService
|
||||
from pipecat.services.cartesia.stt import CartesiaSTTService
|
||||
from pipecat.services.deepgram.stt import DeepgramSTTService
|
||||
from pipecat.services.deepgram.tts import DeepgramTTSService
|
||||
from pipecat.services.dograh.llm import DograhLLMService
|
||||
from pipecat.services.dograh.stt import DograhSTTService
|
||||
from pipecat.services.dograh.tts import DograhTTSService
|
||||
from pipecat.services.elevenlabs.tts import ElevenLabsTTSService
|
||||
from pipecat.services.google.llm import GoogleLLMService
|
||||
from pipecat.services.groq.llm import GroqLLMService
|
||||
from pipecat.services.openai.llm import OpenAILLMService
|
||||
from pipecat.services.openai.stt import OpenAISTTService
|
||||
from pipecat.services.openai.tts import OpenAITTSService
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from api.services.pipecat.audio_config import AudioConfig
|
||||
|
||||
|
||||
def create_stt_service(user_config):
|
||||
"""Create and return appropriate STT service based on user configuration"""
|
||||
if user_config.stt.provider == ServiceProviders.DEEPGRAM.value:
|
||||
return DeepgramSTTService(
|
||||
api_key=user_config.stt.api_key,
|
||||
audio_passthrough=False, # Disable passthrough since audio is buffered separately
|
||||
)
|
||||
elif user_config.stt.provider == ServiceProviders.OPENAI.value:
|
||||
return OpenAISTTService(
|
||||
api_key=user_config.stt.api_key,
|
||||
model=user_config.stt.model.value,
|
||||
audio_passthrough=False, # Disable passthrough since audio is buffered separately
|
||||
)
|
||||
elif user_config.stt.provider == ServiceProviders.CARTESIA.value:
|
||||
return CartesiaSTTService(
|
||||
api_key=user_config.stt.api_key,
|
||||
audio_passthrough=False, # Disable passthrough since audio is buffered separately
|
||||
)
|
||||
elif user_config.stt.provider == ServiceProviders.DOGRAH.value:
|
||||
base_url = MPS_API_URL.replace("http://", "ws://").replace("https://", "wss://")
|
||||
return DograhSTTService(
|
||||
base_url=base_url,
|
||||
api_key=user_config.stt.api_key,
|
||||
model=user_config.stt.model.value,
|
||||
audio_passthrough=False, # Disable passthrough since audio is buffered separately
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=400, detail=f"Invalid STT provider {user_config.stt.provider}"
|
||||
)
|
||||
|
||||
|
||||
def create_tts_service(user_config, audio_config: "AudioConfig"):
|
||||
"""Create and return appropriate TTS service based on user configuration
|
||||
|
||||
Args:
|
||||
user_config: User configuration containing TTS settings
|
||||
transport_type: Type of transport (e.g., 'stasis', 'twilio', 'webrtc')
|
||||
"""
|
||||
if user_config.tts.provider == ServiceProviders.DEEPGRAM.value:
|
||||
return DeepgramTTSService(
|
||||
api_key=user_config.tts.api_key,
|
||||
voice=user_config.tts.voice.value,
|
||||
sample_rate=24000,
|
||||
)
|
||||
elif user_config.tts.provider == ServiceProviders.OPENAI.value:
|
||||
return OpenAITTSService(
|
||||
api_key=user_config.tts.api_key, model=user_config.tts.model.value
|
||||
)
|
||||
elif user_config.tts.provider == ServiceProviders.ELEVENLABS.value:
|
||||
voice_id = user_config.tts.voice.split(" - ")[1]
|
||||
return ElevenLabsTTSService(
|
||||
reconnect_on_error=False,
|
||||
api_key=user_config.tts.api_key,
|
||||
voice_id=voice_id,
|
||||
model=user_config.tts.model.value,
|
||||
params=ElevenLabsTTSService.InputParams(
|
||||
stability=0.8, speed=user_config.tts.speed, similarity_boost=0.75
|
||||
),
|
||||
)
|
||||
elif user_config.tts.provider == ServiceProviders.DOGRAH.value:
|
||||
# Convert HTTP URL to WebSocket URL for TTS
|
||||
base_url = MPS_API_URL.replace("http://", "ws://").replace("https://", "wss://")
|
||||
# Handle both enum and string values for model and voice
|
||||
return DograhTTSService(
|
||||
base_url=base_url,
|
||||
api_key=user_config.tts.api_key,
|
||||
model=user_config.tts.model.value,
|
||||
voice=user_config.tts.voice.value,
|
||||
sample_rate=24000,
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=400, detail=f"Invalid TTS provider {user_config.tts.provider}"
|
||||
)
|
||||
|
||||
|
||||
def create_llm_service(user_config):
|
||||
"""Create and return appropriate LLM service based on user configuration"""
|
||||
if user_config.llm.provider == ServiceProviders.OPENAI.value:
|
||||
if "gpt-5" in user_config.llm.model.value:
|
||||
return OpenAILLMService(
|
||||
api_key=user_config.llm.api_key,
|
||||
model=user_config.llm.model.value,
|
||||
params=OpenAILLMService.InputParams(
|
||||
reasoning_effort="minimal", verbosity="low"
|
||||
),
|
||||
)
|
||||
else:
|
||||
return OpenAILLMService(
|
||||
api_key=user_config.llm.api_key,
|
||||
model=user_config.llm.model.value,
|
||||
params=OpenAILLMService.InputParams(temperature=0.1),
|
||||
)
|
||||
elif user_config.llm.provider == ServiceProviders.GROQ.value:
|
||||
print(
|
||||
f"Creating Groq LLM service with API key: {user_config.llm.api_key} and model: {user_config.llm.model.value}"
|
||||
)
|
||||
return GroqLLMService(
|
||||
api_key=user_config.llm.api_key,
|
||||
model=user_config.llm.model.value,
|
||||
params=OpenAILLMService.InputParams(temperature=0.1),
|
||||
)
|
||||
elif user_config.llm.provider == ServiceProviders.GOOGLE.value:
|
||||
# Use the correct InputParams class for Google to avoid propagating OpenAI-specific
|
||||
# NOT_GIVEN sentinels that break Pydantic validation in GoogleLLMService.
|
||||
return GoogleLLMService(
|
||||
api_key=user_config.llm.api_key,
|
||||
model=user_config.llm.model.value,
|
||||
params=GoogleLLMService.InputParams(temperature=0.1),
|
||||
)
|
||||
elif user_config.llm.provider == ServiceProviders.AZURE.value:
|
||||
return AzureLLMService(
|
||||
api_key=user_config.llm.api_key,
|
||||
endpoint=user_config.llm.endpoint,
|
||||
model=user_config.llm.model.value, # Azure uses deployment name as model
|
||||
params=AzureLLMService.InputParams(temperature=0.1),
|
||||
)
|
||||
elif user_config.llm.provider == ServiceProviders.DOGRAH.value:
|
||||
return DograhLLMService(
|
||||
base_url=f"{MPS_API_URL}/api/v1/llm",
|
||||
api_key=user_config.llm.api_key,
|
||||
model=user_config.llm.model.value,
|
||||
)
|
||||
else:
|
||||
raise HTTPException(status_code=400, detail="Invalid LLM provider")
|
||||
44
api/services/pipecat/tracing_config.py
Normal file
44
api/services/pipecat/tracing_config.py
Normal file
|
|
@ -0,0 +1,44 @@
|
|||
import base64
|
||||
import os
|
||||
|
||||
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
|
||||
|
||||
from api.constants import ENABLE_TRACING
|
||||
from pipecat.utils.tracing.setup import setup_tracing
|
||||
|
||||
|
||||
def is_tracing_enabled():
|
||||
"""Check if tracing should be enabled based on ENABLE_TRACING flag."""
|
||||
# Tracing is only enabled when ENABLE_TRACING is explicitly set to true
|
||||
# This makes the system OSS-friendly by default (no external dependencies required)
|
||||
return ENABLE_TRACING
|
||||
|
||||
|
||||
def setup_pipeline_tracing():
|
||||
"""Setup tracing for the pipeline if enabled"""
|
||||
if is_tracing_enabled():
|
||||
# Only set up Langfuse if credentials are provided
|
||||
langfuse_host = os.environ.get("LANGFUSE_HOST")
|
||||
langfuse_public_key = os.environ.get("LANGFUSE_PUBLIC_KEY")
|
||||
langfuse_secret_key = os.environ.get("LANGFUSE_SECRET_KEY")
|
||||
|
||||
if not all([langfuse_host, langfuse_public_key, langfuse_secret_key]):
|
||||
print(
|
||||
"Warning: ENABLE_TRACING is true but Langfuse credentials are not configured. Tracing disabled."
|
||||
)
|
||||
return
|
||||
|
||||
LANGFUSE_AUTH = base64.b64encode(
|
||||
f"{langfuse_public_key}:{langfuse_secret_key}".encode()
|
||||
).decode()
|
||||
|
||||
os.environ["OTEL_EXPORTER_OTLP_ENDPOINT"] = f"{langfuse_host}/api/public/otel"
|
||||
os.environ["OTEL_EXPORTER_OTLP_HEADERS"] = (
|
||||
f"Authorization=Basic {LANGFUSE_AUTH}"
|
||||
)
|
||||
|
||||
otlp_exporter = OTLPSpanExporter()
|
||||
setup_tracing(service_name="dograh-pipeline", exporter=otlp_exporter)
|
||||
print("Langfuse tracing enabled")
|
||||
else:
|
||||
print("Tracing disabled (ENABLE_TRACING=false)")
|
||||
299
api/services/pipecat/transport_setup.py
Normal file
299
api/services/pipecat/transport_setup.py
Normal file
|
|
@ -0,0 +1,299 @@
|
|||
import os
|
||||
|
||||
from fastapi import WebSocket
|
||||
|
||||
from api.constants import APP_ROOT_DIR, ENABLE_RNNOISE, ENABLE_SMART_TURN
|
||||
from api.services.pipecat.audio_config import AudioConfig
|
||||
from api.services.smart_turn.websocket_smart_turn import (
|
||||
WebSocketSmartTurnAnalyzer,
|
||||
)
|
||||
from api.services.telephony.stasis_rtp_connection import StasisRTPConnection
|
||||
from api.services.telephony.stasis_rtp_serializer import StasisRTPFrameSerializer
|
||||
from api.services.telephony.stasis_rtp_transport import (
|
||||
StasisRTPTransport,
|
||||
StasisRTPTransportParams,
|
||||
)
|
||||
from pipecat.audio.filters.rnnoise_filter import RNNoiseFilter
|
||||
from pipecat.audio.mixers.silence_audio_mixer import SilenceAudioMixer
|
||||
from pipecat.audio.mixers.soundfile_mixer import SoundfileMixer
|
||||
from pipecat.audio.turn.smart_turn.base_smart_turn import SmartTurnParams
|
||||
from pipecat.audio.vad.silero import SileroVADAnalyzer, VADParams
|
||||
from pipecat.serializers.twilio import TwilioFrameSerializer
|
||||
from pipecat.transports import InternalTransport
|
||||
from pipecat.transports.base_transport import TransportParams
|
||||
from pipecat.transports.network.fastapi_websocket import (
|
||||
FastAPIWebsocketParams,
|
||||
FastAPIWebsocketTransport,
|
||||
)
|
||||
from pipecat.transports.network.small_webrtc import SmallWebRTCTransport
|
||||
from pipecat.transports.network.webrtc_connection import SmallWebRTCConnection
|
||||
|
||||
librnnoise_path = os.path.normpath(
|
||||
str(APP_ROOT_DIR / "native" / "rnnoise" / "librnnoise.so")
|
||||
)
|
||||
|
||||
|
||||
def create_turn_analyzer(workflow_run_id: int, audio_config: AudioConfig):
|
||||
"""Create a turn analyzer backed by the local Smart Turn HTTP service.
|
||||
|
||||
Args:
|
||||
workflow_run_id: ID of the workflow run for turn analyzer context
|
||||
audio_config: Audio configuration containing pipeline sample rate
|
||||
"""
|
||||
if ENABLE_SMART_TURN:
|
||||
service_url = os.getenv(
|
||||
"SMART_TURN_WS_SERVICE_ENDPOINT", "ws://localhost:8010/ws"
|
||||
)
|
||||
|
||||
# Prepare optional authentication headers for Smart Turn service
|
||||
secret_key = os.getenv("SMART_TURN_HTTP_SERVICE_KEY")
|
||||
headers = {"X-API-Key": secret_key} if secret_key else None
|
||||
|
||||
return WebSocketSmartTurnAnalyzer(
|
||||
url=service_url,
|
||||
headers=headers,
|
||||
sample_rate=audio_config.pipeline_sample_rate,
|
||||
params=SmartTurnParams(
|
||||
stop_secs=1.5, # send turn complete if silent for stop_secs seconds
|
||||
pre_speech_ms=0, # send speech segments before speech was detected by VAD
|
||||
max_duration_secs=5, # max duration of speech to be sent to the end of turn analyzer
|
||||
# we don't want to _clear except when we have end of turn prediction as 1 from last run
|
||||
# else if we have speaking -> queit -> trigger end of turn -> clear() and then
|
||||
# we have speak -> queit, we may end up sending a very small segment of speech
|
||||
# to end of turn model, which is not good
|
||||
use_only_last_vad_segment=False,
|
||||
),
|
||||
service_context=workflow_run_id,
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def create_twilio_transport(
|
||||
websocket_client: WebSocket,
|
||||
stream_sid: str,
|
||||
call_sid: str,
|
||||
workflow_run_id: int,
|
||||
audio_config: AudioConfig,
|
||||
vad_config: dict | None = None,
|
||||
ambient_noise_config: dict | None = None,
|
||||
):
|
||||
"""Create a transport for Twilio connections"""
|
||||
turn_analyzer = create_turn_analyzer(workflow_run_id, audio_config)
|
||||
|
||||
serializer = TwilioFrameSerializer(
|
||||
stream_sid=stream_sid,
|
||||
call_sid=call_sid,
|
||||
account_sid=os.environ["TWILIO_ACCOUNT_SID"],
|
||||
auth_token=os.environ["TWILIO_AUTH_TOKEN"],
|
||||
)
|
||||
|
||||
return FastAPIWebsocketTransport(
|
||||
websocket=websocket_client,
|
||||
params=FastAPIWebsocketParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
audio_in_sample_rate=audio_config.transport_in_sample_rate,
|
||||
audio_out_sample_rate=audio_config.transport_out_sample_rate,
|
||||
vad_analyzer=(
|
||||
SileroVADAnalyzer(
|
||||
params=VADParams(
|
||||
confidence=vad_config.get("confidence", 0.7),
|
||||
start_secs=vad_config.get("start_seconds", 0.4),
|
||||
stop_secs=vad_config.get("stop_seconds", 0.8),
|
||||
min_volume=vad_config.get("minimum_volume", 0.6),
|
||||
)
|
||||
)
|
||||
if vad_config
|
||||
else SileroVADAnalyzer()
|
||||
), # Sample rate will be set by transport
|
||||
audio_out_mixer=(
|
||||
SoundfileMixer(
|
||||
sound_files={
|
||||
"office": APP_ROOT_DIR
|
||||
/ "assets"
|
||||
/ f"office-ambience-{audio_config.transport_out_sample_rate}-mono.wav"
|
||||
},
|
||||
default_sound="office",
|
||||
volume=ambient_noise_config.get("volume", 0.3),
|
||||
)
|
||||
if ambient_noise_config and ambient_noise_config.get("enabled", False)
|
||||
else SilenceAudioMixer()
|
||||
),
|
||||
turn_analyzer=turn_analyzer,
|
||||
serializer=serializer,
|
||||
audio_in_filter=RNNoiseFilter(library_path=librnnoise_path)
|
||||
if ENABLE_RNNOISE
|
||||
else None,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def create_webrtc_transport(
|
||||
webrtc_connection: SmallWebRTCConnection,
|
||||
workflow_run_id: int,
|
||||
audio_config: AudioConfig,
|
||||
vad_config: dict | None = None,
|
||||
ambient_noise_config: dict | None = None,
|
||||
):
|
||||
"""Create a transport for WebRTC connections"""
|
||||
turn_analyzer = create_turn_analyzer(workflow_run_id, audio_config)
|
||||
|
||||
return SmallWebRTCTransport(
|
||||
webrtc_connection=webrtc_connection,
|
||||
params=TransportParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
audio_in_sample_rate=audio_config.transport_in_sample_rate,
|
||||
audio_out_sample_rate=audio_config.transport_out_sample_rate,
|
||||
vad_analyzer=(
|
||||
SileroVADAnalyzer(
|
||||
params=VADParams(
|
||||
confidence=vad_config.get("confidence", 0.7),
|
||||
start_secs=vad_config.get("start_seconds", 0.4),
|
||||
stop_secs=vad_config.get("stop_seconds", 0.8),
|
||||
min_volume=vad_config.get("minimum_volume", 0.6),
|
||||
)
|
||||
)
|
||||
if vad_config
|
||||
else SileroVADAnalyzer()
|
||||
), # Sample rate will be set by transport
|
||||
audio_out_mixer=(
|
||||
SoundfileMixer(
|
||||
sound_files={
|
||||
"office": APP_ROOT_DIR
|
||||
/ "assets"
|
||||
/ f"office-ambience-{audio_config.transport_out_sample_rate}-mono.wav"
|
||||
},
|
||||
default_sound="office",
|
||||
volume=ambient_noise_config.get("volume", 0.3),
|
||||
)
|
||||
if ambient_noise_config and ambient_noise_config.get("enabled", False)
|
||||
else SilenceAudioMixer()
|
||||
),
|
||||
turn_analyzer=turn_analyzer,
|
||||
audio_in_filter=RNNoiseFilter(library_path=librnnoise_path)
|
||||
if ENABLE_RNNOISE
|
||||
else None,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def create_stasis_transport(
|
||||
stasis_connection: StasisRTPConnection,
|
||||
workflow_run_id: int,
|
||||
audio_config: AudioConfig,
|
||||
vad_config: dict | None = None,
|
||||
ambient_noise_config: dict | None = None,
|
||||
):
|
||||
"""Create a transport for ARI connections"""
|
||||
turn_analyzer = create_turn_analyzer(workflow_run_id, audio_config)
|
||||
|
||||
serializer = StasisRTPFrameSerializer(
|
||||
StasisRTPFrameSerializer.InputParams(
|
||||
sample_rate=audio_config.transport_in_sample_rate
|
||||
)
|
||||
)
|
||||
|
||||
return StasisRTPTransport(
|
||||
stasis_connection,
|
||||
params=StasisRTPTransportParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
audio_out_sample_rate=audio_config.transport_out_sample_rate,
|
||||
audio_in_sample_rate=audio_config.transport_in_sample_rate,
|
||||
audio_out_10ms_chunks=2, # Send 20ms packets
|
||||
vad_analyzer=(
|
||||
SileroVADAnalyzer(
|
||||
params=VADParams(
|
||||
confidence=vad_config.get("confidence", 0.7),
|
||||
start_secs=vad_config.get("start_seconds", 0.4),
|
||||
stop_secs=vad_config.get("stop_seconds", 0.8),
|
||||
min_volume=vad_config.get("minimum_volume", 0.6),
|
||||
)
|
||||
)
|
||||
if vad_config
|
||||
else SileroVADAnalyzer()
|
||||
), # Sample rate will be set by transport
|
||||
audio_out_mixer=(
|
||||
SoundfileMixer(
|
||||
sound_files={
|
||||
"office": APP_ROOT_DIR
|
||||
/ "assets"
|
||||
/ f"office-ambience-{audio_config.transport_out_sample_rate}-mono.wav"
|
||||
},
|
||||
default_sound="office",
|
||||
volume=ambient_noise_config.get("volume", 0.3),
|
||||
)
|
||||
if ambient_noise_config and ambient_noise_config.get("enabled", False)
|
||||
else SilenceAudioMixer()
|
||||
),
|
||||
turn_analyzer=turn_analyzer,
|
||||
serializer=serializer,
|
||||
audio_in_filter=RNNoiseFilter(library_path=librnnoise_path)
|
||||
if ENABLE_RNNOISE
|
||||
else None,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def create_internal_transport(
|
||||
workflow_run_id: int,
|
||||
audio_config: AudioConfig,
|
||||
latency_seconds: float = 0.0,
|
||||
vad_config: dict | None = None,
|
||||
ambient_noise_config: dict | None = None,
|
||||
):
|
||||
"""Create an internal transport for agent-to-agent connections (LoopTalk).
|
||||
|
||||
Args:
|
||||
workflow_run_id: ID of the workflow run for turn analyzer context
|
||||
audio_config: Audio configuration for the transport
|
||||
latency_seconds: Network latency to simulate
|
||||
|
||||
Returns:
|
||||
InternalTransport instance configured with turn analyzer
|
||||
"""
|
||||
turn_analyzer = create_turn_analyzer(workflow_run_id, audio_config)
|
||||
|
||||
# Create and return the internal transport with latency
|
||||
return InternalTransport(
|
||||
params=TransportParams(
|
||||
audio_out_enabled=True,
|
||||
audio_out_sample_rate=audio_config.transport_out_sample_rate,
|
||||
audio_out_channels=1,
|
||||
audio_in_enabled=True,
|
||||
audio_in_sample_rate=audio_config.transport_in_sample_rate,
|
||||
audio_in_channels=1,
|
||||
vad_analyzer=(
|
||||
SileroVADAnalyzer(
|
||||
params=VADParams(
|
||||
confidence=vad_config.get("confidence", 0.7),
|
||||
start_secs=vad_config.get("start_seconds", 0.4),
|
||||
stop_secs=vad_config.get("stop_seconds", 0.8),
|
||||
min_volume=vad_config.get("minimum_volume", 0.6),
|
||||
)
|
||||
)
|
||||
if vad_config
|
||||
else SileroVADAnalyzer()
|
||||
),
|
||||
audio_out_mixer=(
|
||||
SoundfileMixer(
|
||||
sound_files={
|
||||
"office": APP_ROOT_DIR
|
||||
/ "assets"
|
||||
/ f"office-ambience-{audio_config.transport_out_sample_rate}-mono.wav"
|
||||
},
|
||||
default_sound="office",
|
||||
volume=ambient_noise_config.get("volume", 0.3),
|
||||
)
|
||||
if ambient_noise_config and ambient_noise_config.get("enabled", False)
|
||||
else SilenceAudioMixer()
|
||||
),
|
||||
turn_analyzer=turn_analyzer,
|
||||
audio_in_filter=RNNoiseFilter(library_path=librnnoise_path)
|
||||
if ENABLE_RNNOISE
|
||||
else None,
|
||||
),
|
||||
latency_seconds=latency_seconds,
|
||||
)
|
||||
76
api/services/pipecat/turn_context.py
Normal file
76
api/services/pipecat/turn_context.py
Normal file
|
|
@ -0,0 +1,76 @@
|
|||
"""Turn context management for logging across async boundaries.
|
||||
|
||||
This module provides a mechanism to track turn numbers across different
|
||||
async contexts, working around the limitation that contextvars don't
|
||||
propagate through asyncio.create_task() calls.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import Dict, Optional
|
||||
|
||||
from pipecat.utils.context import turn_var
|
||||
|
||||
|
||||
class TurnContextManager:
|
||||
"""Manages turn context across async task boundaries.
|
||||
|
||||
This class provides a workaround for the issue where contextvars
|
||||
don't propagate through asyncio.create_task() calls in the pipecat
|
||||
library's event system.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
# Map from task to turn number
|
||||
self._task_turns: Dict[asyncio.Task, int] = {}
|
||||
# Store the pipeline task reference
|
||||
self._pipeline_task: Optional[asyncio.Task] = None
|
||||
self._current_turn: int = 0
|
||||
|
||||
def set_pipeline_task(self, task: asyncio.Task):
|
||||
"""Set the main pipeline task reference."""
|
||||
self._pipeline_task = task
|
||||
|
||||
def set_turn(self, turn_number: int):
|
||||
"""Set the turn number for the current context."""
|
||||
self._current_turn = turn_number
|
||||
# Set in contextvar for direct access
|
||||
turn_var.set(turn_number)
|
||||
|
||||
# Also store for the current task
|
||||
try:
|
||||
current_task = asyncio.current_task()
|
||||
if current_task:
|
||||
self._task_turns[current_task] = turn_number
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
||||
def get_turn(self) -> int:
|
||||
"""Get the turn number, trying multiple sources."""
|
||||
# First try contextvar
|
||||
turn = turn_var.get()
|
||||
if turn > 0:
|
||||
return turn
|
||||
|
||||
# Try current task mapping
|
||||
try:
|
||||
current_task = asyncio.current_task()
|
||||
if current_task and current_task in self._task_turns:
|
||||
return self._task_turns[current_task]
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
||||
# Fall back to stored current turn
|
||||
return self._current_turn
|
||||
|
||||
def cleanup_task(self, task: asyncio.Task):
|
||||
"""Clean up turn mapping for completed tasks."""
|
||||
self._task_turns.pop(task, None)
|
||||
|
||||
|
||||
# Global instance
|
||||
_turn_context_manager = TurnContextManager()
|
||||
|
||||
|
||||
def get_turn_context_manager() -> TurnContextManager:
|
||||
"""Get the global turn context manager instance."""
|
||||
return _turn_context_manager
|
||||
76
api/services/pricing/README.md
Normal file
76
api/services/pricing/README.md
Normal file
|
|
@ -0,0 +1,76 @@
|
|||
# Pricing Module
|
||||
|
||||
This module contains pricing models and registries for different AI services used in workflow cost calculations.
|
||||
|
||||
## Structure
|
||||
|
||||
```
|
||||
pricing/
|
||||
├── __init__.py # Main module exports
|
||||
├── models.py # Base pricing model classes
|
||||
├── llm.py # LLM pricing configurations
|
||||
├── tts.py # TTS pricing configurations
|
||||
├── stt.py # STT pricing configurations
|
||||
├── registry.py # Combined pricing registry
|
||||
└── README.md # This file
|
||||
```
|
||||
|
||||
## Pricing Models
|
||||
|
||||
### TokenPricingModel
|
||||
Used for LLM services that charge based on tokens:
|
||||
- `prompt_token_price`: Cost per prompt token
|
||||
- `completion_token_price`: Cost per completion token
|
||||
- `cache_read_discount`: Discount for cache read tokens (default 50%)
|
||||
- `cache_creation_multiplier`: Premium for cache creation tokens (default 25%)
|
||||
|
||||
### CharacterPricingModel
|
||||
Used for TTS services that charge based on character count:
|
||||
- `character_price`: Cost per character
|
||||
|
||||
### TimePricingModel
|
||||
Used for STT services that charge based on time:
|
||||
- `second_price`: Cost per second
|
||||
|
||||
## Adding New Pricing
|
||||
|
||||
### Adding a New LLM Model
|
||||
Edit `llm.py` and add the model to the appropriate provider:
|
||||
|
||||
```python
|
||||
ServiceProviders.OPENAI: {
|
||||
"new-model": TokenPricingModel(
|
||||
prompt_token_price=Decimal("2.00") / 1000000,
|
||||
completion_token_price=Decimal("8.00") / 1000000,
|
||||
),
|
||||
# ... existing models
|
||||
}
|
||||
```
|
||||
|
||||
### Adding a New Provider
|
||||
1. Add pricing configurations to the appropriate service file (llm.py, tts.py, stt.py)
|
||||
2. The registry will automatically include them
|
||||
|
||||
### Adding a New Service Type
|
||||
1. Create a new pricing file (e.g., `image.py`)
|
||||
2. Define the pricing models
|
||||
3. Import and add to `registry.py`
|
||||
|
||||
## Usage
|
||||
|
||||
The pricing registry is automatically imported and used by the cost calculator:
|
||||
|
||||
```python
|
||||
from api.services.pricing import PRICING_REGISTRY
|
||||
from api.services.workflow.cost_calculator import cost_calculator
|
||||
|
||||
# The cost calculator uses the pricing registry automatically
|
||||
result = cost_calculator.calculate_total_cost(usage_info)
|
||||
```
|
||||
|
||||
## Maintenance
|
||||
|
||||
- Update pricing when providers change their rates
|
||||
- All prices should use `Decimal` for precision
|
||||
- Include comments with current pricing from provider documentation
|
||||
- Test changes with existing test suite
|
||||
9
api/services/pricing/__init__.py
Normal file
9
api/services/pricing/__init__.py
Normal file
|
|
@ -0,0 +1,9 @@
|
|||
"""
|
||||
Pricing module for workflow cost calculation.
|
||||
|
||||
This module contains pricing models and registries for different AI services.
|
||||
"""
|
||||
|
||||
from .registry import PRICING_REGISTRY
|
||||
|
||||
__all__ = ["PRICING_REGISTRY"]
|
||||
228
api/services/pricing/cost_calculator.py
Normal file
228
api/services/pricing/cost_calculator.py
Normal file
|
|
@ -0,0 +1,228 @@
|
|||
"""
|
||||
Cost Calculator for Workflow Runs
|
||||
|
||||
This module provides a comprehensive cost calculation system for workflow runs based on usage metrics
|
||||
from different AI service providers (OpenAI, Groq, Deepgram, etc.).
|
||||
|
||||
Features:
|
||||
- Token-based pricing for LLM services with cache optimization support
|
||||
- Character-based pricing for TTS services
|
||||
- Time-based pricing for STT services
|
||||
- Configurable pricing models that can be updated
|
||||
- Support for multiple providers and models
|
||||
- Automatic provider inference from model names
|
||||
- JSON serialization support for database storage
|
||||
|
||||
Usage:
|
||||
from api.tasks.cost_calculator import cost_calculator
|
||||
|
||||
usage_info = {
|
||||
"llm": {
|
||||
"processor_name|||gpt-4o": {
|
||||
"prompt_tokens": 1000,
|
||||
"completion_tokens": 500,
|
||||
"total_tokens": 1500,
|
||||
"cache_read_input_tokens": 0,
|
||||
"cache_creation_input_tokens": 0
|
||||
}
|
||||
},
|
||||
"tts": {
|
||||
"processor_name|||aura-2-helena-en": 2000 # character count
|
||||
}
|
||||
}
|
||||
|
||||
cost_breakdown = cost_calculator.calculate_total_cost(usage_info)
|
||||
print(f"Total cost: ${cost_breakdown['total']:.6f}")
|
||||
"""
|
||||
|
||||
from decimal import Decimal
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
from api.services.configuration.registry import ServiceProviders
|
||||
from api.services.pricing import PRICING_REGISTRY
|
||||
from api.services.pricing.models import (
|
||||
PricingModel,
|
||||
)
|
||||
|
||||
|
||||
class CostCalculator:
|
||||
"""Main cost calculator class"""
|
||||
|
||||
def __init__(self, pricing_registry: Dict = None):
|
||||
self.pricing_registry = pricing_registry or PRICING_REGISTRY
|
||||
|
||||
def get_pricing_model(
|
||||
self, service_type: str, provider: str, model: str
|
||||
) -> Optional[PricingModel]:
|
||||
"""Get pricing model for a specific service, provider, and model"""
|
||||
try:
|
||||
service_pricing = self.pricing_registry.get(service_type, {})
|
||||
|
||||
# Try to get pricing for the specific provider
|
||||
provider_pricing = service_pricing.get(provider, {})
|
||||
pricing_model = provider_pricing.get(model) or provider_pricing.get(
|
||||
"default"
|
||||
)
|
||||
|
||||
if pricing_model:
|
||||
return pricing_model
|
||||
|
||||
# If not found, try the "default" provider for this service type
|
||||
default_provider_pricing = service_pricing.get("default", {})
|
||||
return default_provider_pricing.get(model) or default_provider_pricing.get(
|
||||
"default"
|
||||
)
|
||||
|
||||
except (KeyError, AttributeError):
|
||||
return None
|
||||
|
||||
def calculate_llm_cost(
|
||||
self, provider: str, model: str, usage: Dict[str, int]
|
||||
) -> Decimal:
|
||||
"""Calculate cost for LLM usage"""
|
||||
pricing_model = self.get_pricing_model("llm", provider, model)
|
||||
if not pricing_model:
|
||||
return Decimal("0")
|
||||
return pricing_model.calculate_cost(usage)
|
||||
|
||||
def calculate_tts_cost(
|
||||
self, provider: str, model: str, character_count: int
|
||||
) -> Decimal:
|
||||
"""Calculate cost for TTS usage"""
|
||||
pricing_model = self.get_pricing_model("tts", provider, model)
|
||||
if not pricing_model:
|
||||
return Decimal("0")
|
||||
return pricing_model.calculate_cost(character_count)
|
||||
|
||||
def calculate_stt_cost(self, provider: str, model: str, seconds: float) -> Decimal:
|
||||
"""Calculate cost for STT usage"""
|
||||
pricing_model = self.get_pricing_model("stt", provider, model)
|
||||
if not pricing_model:
|
||||
return Decimal("0")
|
||||
return pricing_model.calculate_cost(seconds)
|
||||
|
||||
def calculate_total_cost(self, usage_info: Dict) -> Dict[str, Any]:
|
||||
llm_cost_total = Decimal("0")
|
||||
tts_cost_total = Decimal("0")
|
||||
stt_cost_total = Decimal("0")
|
||||
|
||||
# Calculate LLM costs
|
||||
llm_usage = usage_info.get("llm", {})
|
||||
for key, usage in llm_usage.items():
|
||||
processor, model = self._parse_key(key)
|
||||
# Try to determine provider from processor name or model
|
||||
provider = self._infer_provider_from_model(model, "llm")
|
||||
cost = self.calculate_llm_cost(provider, model, usage)
|
||||
llm_cost_total += cost
|
||||
|
||||
# Calculate TTS costs
|
||||
tts_usage = usage_info.get("tts", {})
|
||||
for key, character_count in tts_usage.items():
|
||||
processor, model = self._parse_key(key)
|
||||
# Handle the case where model is "None" - infer from processor
|
||||
if model.lower() in ["none", "null", ""]:
|
||||
provider = self._infer_provider_from_processor(processor, "tts")
|
||||
model = "default" # Use default model for the provider
|
||||
else:
|
||||
provider = self._infer_provider_from_model(model, "tts")
|
||||
cost = self.calculate_tts_cost(provider, model, character_count)
|
||||
tts_cost_total += cost
|
||||
|
||||
# Calculate STT costs from explicit stt usage
|
||||
stt_usage = usage_info.get("stt", {})
|
||||
for key, seconds in stt_usage.items():
|
||||
processor, model = self._parse_key(key)
|
||||
provider = self._infer_provider_from_model(model, "stt")
|
||||
cost = self.calculate_stt_cost(provider, model, seconds)
|
||||
stt_cost_total += cost
|
||||
|
||||
total_cost = llm_cost_total + tts_cost_total + stt_cost_total
|
||||
|
||||
return {
|
||||
"llm_cost": float(llm_cost_total),
|
||||
"tts_cost": float(tts_cost_total),
|
||||
"stt_cost": float(stt_cost_total),
|
||||
"total": float(total_cost),
|
||||
}
|
||||
|
||||
def _parse_key(self, key) -> Tuple[str, str]:
|
||||
"""Parse key which is in format 'processor|||model'"""
|
||||
if isinstance(key, str) and "|||" in key:
|
||||
parts = key.split("|||", 1)
|
||||
return parts[0], parts[1]
|
||||
else:
|
||||
# Fallback for backwards compatibility or malformed keys
|
||||
return str(key), "unknown"
|
||||
|
||||
def _infer_provider_from_model(self, model: str, service_type: str) -> str:
|
||||
"""Infer provider from model name"""
|
||||
if not model:
|
||||
return "unknown"
|
||||
|
||||
model_lower = model.lower()
|
||||
|
||||
# OpenAI models
|
||||
if any(keyword in model_lower for keyword in ["gpt", "whisper", "openai"]):
|
||||
return ServiceProviders.OPENAI
|
||||
|
||||
# Groq models
|
||||
if any(keyword in model_lower for keyword in ["groq"]):
|
||||
return ServiceProviders.GROQ
|
||||
|
||||
# Elevenlabs models
|
||||
if any(keyword in model_lower for keyword in ["eleven"]):
|
||||
return ServiceProviders.ELEVENLABS
|
||||
|
||||
# Deepgram models
|
||||
if any(
|
||||
keyword in model_lower
|
||||
for keyword in ["deepgram", "nova", "phonecall", "general"]
|
||||
):
|
||||
return ServiceProviders.DEEPGRAM
|
||||
|
||||
# Default to first available provider for the service type
|
||||
service_providers = self.pricing_registry.get(service_type, {})
|
||||
if service_providers:
|
||||
return list(service_providers.keys())[0]
|
||||
|
||||
return "unknown"
|
||||
|
||||
def _infer_provider_from_processor(self, processor: str, service_type: str) -> str:
|
||||
"""Infer provider from processor name"""
|
||||
if not processor:
|
||||
return "unknown"
|
||||
|
||||
processor_lower = processor.lower()
|
||||
|
||||
# OpenAI processors
|
||||
if any(keyword in processor_lower for keyword in ["openai", "gpt"]):
|
||||
return ServiceProviders.OPENAI
|
||||
|
||||
# Groq processors
|
||||
if any(keyword in processor_lower for keyword in ["groq"]):
|
||||
return ServiceProviders.GROQ
|
||||
|
||||
# Deepgram processors
|
||||
if any(keyword in processor_lower for keyword in ["deepgram"]):
|
||||
return ServiceProviders.DEEPGRAM
|
||||
|
||||
# Default to first available provider for the service type
|
||||
service_providers = self.pricing_registry.get(service_type, {})
|
||||
if service_providers:
|
||||
return list(service_providers.keys())[0]
|
||||
|
||||
return "unknown"
|
||||
|
||||
def update_pricing(
|
||||
self, service_type: str, provider: str, model: str, pricing_model: PricingModel
|
||||
):
|
||||
"""Update pricing for a specific service/provider/model combination"""
|
||||
if service_type not in self.pricing_registry:
|
||||
self.pricing_registry[service_type] = {}
|
||||
if provider not in self.pricing_registry[service_type]:
|
||||
self.pricing_registry[service_type][provider] = {}
|
||||
self.pricing_registry[service_type][provider][model] = pricing_model
|
||||
|
||||
|
||||
# Global cost calculator instance
|
||||
cost_calculator = CostCalculator()
|
||||
143
api/services/pricing/llm.py
Normal file
143
api/services/pricing/llm.py
Normal file
|
|
@ -0,0 +1,143 @@
|
|||
"""
|
||||
LLM pricing models for different providers.
|
||||
|
||||
Prices are per 1000 tokens for most models, with some newer models priced per million tokens.
|
||||
"""
|
||||
|
||||
from decimal import Decimal
|
||||
from typing import Dict
|
||||
|
||||
from api.services.configuration.registry import ServiceProviders
|
||||
|
||||
from .models import TokenPricingModel
|
||||
|
||||
# LLM pricing registry
|
||||
LLM_PRICING: Dict[str, Dict[str, TokenPricingModel]] = {
|
||||
ServiceProviders.OPENAI: {
|
||||
"gpt-3.5-turbo": TokenPricingModel(
|
||||
prompt_token_price=Decimal("0.0015") / 1000, # $0.0015 per 1K tokens
|
||||
completion_token_price=Decimal("0.002") / 1000, # $0.002 per 1K tokens
|
||||
),
|
||||
"gpt-4": TokenPricingModel(
|
||||
prompt_token_price=Decimal("0.03") / 1000, # $0.03 per 1K tokens
|
||||
completion_token_price=Decimal("0.06") / 1000, # $0.06 per 1K tokens
|
||||
),
|
||||
"gpt-4.1": TokenPricingModel(
|
||||
prompt_token_price=Decimal("2.00") / 1000000, # $2.00 per 1M tokens
|
||||
completion_token_price=Decimal("8.00") / 1000000, # $8.00 per 1M tokens
|
||||
),
|
||||
"gpt-4.1-mini": TokenPricingModel(
|
||||
prompt_token_price=Decimal("0.40") / 1000000, # $0.40 per 1M tokens
|
||||
completion_token_price=Decimal("1.60") / 1000000, # $1.60 per 1M tokens
|
||||
),
|
||||
"gpt-4.1-nano": TokenPricingModel(
|
||||
prompt_token_price=Decimal("0.10") / 1000000, # $0.10 per 1M tokens
|
||||
completion_token_price=Decimal("0.40") / 1000000, # $0.40 per 1M tokens
|
||||
),
|
||||
"gpt-4.5-preview": TokenPricingModel(
|
||||
prompt_token_price=Decimal("75.00") / 1000000, # $75.00 per 1M tokens
|
||||
completion_token_price=Decimal("150.00") / 1000000, # $150.00 per 1M tokens
|
||||
),
|
||||
"gpt-4o": TokenPricingModel(
|
||||
prompt_token_price=Decimal("2.50") / 1000000, # $2.50 per 1M tokens - FIXED
|
||||
completion_token_price=Decimal("10.00")
|
||||
/ 1000000, # $10.00 per 1M tokens - FIXED
|
||||
),
|
||||
"gpt-4o-audio-preview": TokenPricingModel(
|
||||
prompt_token_price=Decimal("2.50") / 1000000, # $2.50 per 1M tokens
|
||||
completion_token_price=Decimal("10.00") / 1000000, # $10.00 per 1M tokens
|
||||
),
|
||||
"gpt-4o-realtime-preview": TokenPricingModel(
|
||||
prompt_token_price=Decimal("5.00") / 1000000, # $5.00 per 1M tokens
|
||||
completion_token_price=Decimal("20.00") / 1000000, # $20.00 per 1M tokens
|
||||
),
|
||||
"gpt-4o-mini": TokenPricingModel(
|
||||
prompt_token_price=Decimal("0.15") / 1000000, # $0.15 per 1M tokens
|
||||
completion_token_price=Decimal("0.60") / 1000000, # $0.60 per 1M tokens
|
||||
),
|
||||
"gpt-4o-mini-audio-preview": TokenPricingModel(
|
||||
prompt_token_price=Decimal("0.15") / 1000000, # $0.15 per 1M tokens
|
||||
completion_token_price=Decimal("0.60") / 1000000, # $0.60 per 1M tokens
|
||||
),
|
||||
"gpt-4o-mini-realtime-preview": TokenPricingModel(
|
||||
prompt_token_price=Decimal("0.60") / 1000000, # $0.60 per 1M tokens
|
||||
completion_token_price=Decimal("2.40") / 1000000, # $2.40 per 1M tokens
|
||||
),
|
||||
"gpt-4o-search-preview": TokenPricingModel(
|
||||
prompt_token_price=Decimal("2.50") / 1000000, # $2.50 per 1M tokens
|
||||
completion_token_price=Decimal("10.00") / 1000000, # $10.00 per 1M tokens
|
||||
),
|
||||
"gpt-4o-mini-search-preview": TokenPricingModel(
|
||||
prompt_token_price=Decimal("0.15") / 1000000, # $0.15 per 1M tokens
|
||||
completion_token_price=Decimal("0.60") / 1000000, # $0.60 per 1M tokens
|
||||
),
|
||||
"o1": TokenPricingModel(
|
||||
prompt_token_price=Decimal("15.00") / 1000000, # $15.00 per 1M tokens
|
||||
completion_token_price=Decimal("60.00") / 1000000, # $60.00 per 1M tokens
|
||||
),
|
||||
"o1-pro": TokenPricingModel(
|
||||
prompt_token_price=Decimal("150.00") / 1000000, # $150.00 per 1M tokens
|
||||
completion_token_price=Decimal("600.00") / 1000000, # $600.00 per 1M tokens
|
||||
),
|
||||
"o1-mini": TokenPricingModel(
|
||||
prompt_token_price=Decimal("1.10") / 1000000, # $1.10 per 1M tokens
|
||||
completion_token_price=Decimal("4.40") / 1000000, # $4.40 per 1M tokens
|
||||
),
|
||||
"o3": TokenPricingModel(
|
||||
prompt_token_price=Decimal("10.00") / 1000000, # $10.00 per 1M tokens
|
||||
completion_token_price=Decimal("40.00") / 1000000, # $40.00 per 1M tokens
|
||||
),
|
||||
"o3-mini": TokenPricingModel(
|
||||
prompt_token_price=Decimal("1.10") / 1000000, # $1.10 per 1M tokens
|
||||
completion_token_price=Decimal("4.40") / 1000000, # $4.40 per 1M tokens
|
||||
),
|
||||
"o4-mini": TokenPricingModel(
|
||||
prompt_token_price=Decimal("1.10") / 1000000, # $1.10 per 1M tokens
|
||||
completion_token_price=Decimal("4.40") / 1000000, # $4.40 per 1M tokens
|
||||
),
|
||||
"computer-use-preview": TokenPricingModel(
|
||||
prompt_token_price=Decimal("3.00") / 1000000, # $3.00 per 1M tokens
|
||||
completion_token_price=Decimal("12.00") / 1000000, # $12.00 per 1M tokens
|
||||
),
|
||||
"gpt-image-1": TokenPricingModel(
|
||||
prompt_token_price=Decimal("5.00") / 1000000, # $5.00 per 1M tokens
|
||||
completion_token_price=Decimal("0") / 1000000, # No output pricing shown
|
||||
),
|
||||
"codex-mini-latest": TokenPricingModel(
|
||||
prompt_token_price=Decimal("1.50") / 1000000, # $1.50 per 1M tokens
|
||||
completion_token_price=Decimal("6.00") / 1000000, # $6.00 per 1M tokens
|
||||
),
|
||||
# Transcription models
|
||||
"gpt-4o-transcribe": TokenPricingModel(
|
||||
prompt_token_price=Decimal("2.50") / 1000000, # $2.50 per 1M tokens
|
||||
completion_token_price=Decimal("10.00") / 1000000, # $10.00 per 1M tokens
|
||||
),
|
||||
"gpt-4o-mini-transcribe": TokenPricingModel(
|
||||
prompt_token_price=Decimal("1.25") / 1000000, # $1.25 per 1M tokens
|
||||
completion_token_price=Decimal("5.00") / 1000000, # $5.00 per 1M tokens
|
||||
),
|
||||
# TTS models with token-based pricing
|
||||
"gpt-4o-mini-tts": TokenPricingModel(
|
||||
prompt_token_price=Decimal("0.60") / 1000000, # $0.60 per 1M tokens
|
||||
completion_token_price=Decimal("0")
|
||||
/ 1000000, # No completion tokens for TTS
|
||||
),
|
||||
},
|
||||
ServiceProviders.GROQ: {
|
||||
"llama-3.3-70b-versatile": TokenPricingModel(
|
||||
prompt_token_price=Decimal("0.00059") / 1000, # $0.00059 per 1K tokens
|
||||
completion_token_price=Decimal("0.00079") / 1000, # $0.00079 per 1K tokens
|
||||
),
|
||||
"deepseek-r1-distill-llama-70b": TokenPricingModel(
|
||||
prompt_token_price=Decimal("0.00059") / 1000, # Assuming similar pricing
|
||||
completion_token_price=Decimal("0.00079") / 1000,
|
||||
),
|
||||
},
|
||||
ServiceProviders.AZURE: {
|
||||
"gpt-4.1-mini": TokenPricingModel(
|
||||
prompt_token_price=Decimal("0.44") / 1000000, # $0.40 per 1M tokens
|
||||
completion_token_price=Decimal("8.80")
|
||||
/ 1000000, # $1.60 per 1M tokens if using data zone
|
||||
)
|
||||
},
|
||||
}
|
||||
89
api/services/pricing/models.py
Normal file
89
api/services/pricing/models.py
Normal file
|
|
@ -0,0 +1,89 @@
|
|||
"""
|
||||
Base pricing models for different service types.
|
||||
"""
|
||||
|
||||
from decimal import Decimal
|
||||
from enum import Enum
|
||||
from typing import Any, Dict
|
||||
|
||||
|
||||
class CostType(Enum):
|
||||
LLM_TOKENS = "llm_tokens"
|
||||
TTS_CHARACTERS = "tts_characters"
|
||||
STT_SECONDS = "stt_seconds"
|
||||
|
||||
|
||||
class PricingModel:
|
||||
"""Base class for pricing models"""
|
||||
|
||||
def calculate_cost(self, usage: Any) -> Decimal:
|
||||
"""Calculate cost based on usage"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class TokenPricingModel(PricingModel):
|
||||
"""Pricing model for token-based services (LLM)"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
prompt_token_price: Decimal,
|
||||
completion_token_price: Decimal,
|
||||
cache_read_discount: Decimal = Decimal("0.5"), # 50% discount for cache reads
|
||||
cache_creation_multiplier: Decimal = Decimal(
|
||||
"1.25"
|
||||
), # 25% premium for cache creation
|
||||
):
|
||||
self.prompt_token_price = prompt_token_price
|
||||
self.completion_token_price = completion_token_price
|
||||
self.cache_read_discount = cache_read_discount
|
||||
self.cache_creation_multiplier = cache_creation_multiplier
|
||||
|
||||
def calculate_cost(self, usage: Dict[str, int]) -> Decimal:
|
||||
"""Calculate cost for LLM token usage"""
|
||||
prompt_tokens = usage.get("prompt_tokens", 0)
|
||||
completion_tokens = usage.get("completion_tokens", 0)
|
||||
cache_read_tokens = usage.get("cache_read_input_tokens") or 0
|
||||
cache_creation_tokens = usage.get("cache_creation_input_tokens") or 0
|
||||
|
||||
# Base cost
|
||||
prompt_cost = Decimal(prompt_tokens) * self.prompt_token_price
|
||||
completion_cost = Decimal(completion_tokens) * self.completion_token_price
|
||||
|
||||
# Cache adjustments
|
||||
cache_read_savings = (
|
||||
Decimal(cache_read_tokens)
|
||||
* self.prompt_token_price
|
||||
* self.cache_read_discount
|
||||
)
|
||||
cache_creation_premium = (
|
||||
Decimal(cache_creation_tokens)
|
||||
* self.prompt_token_price
|
||||
* (self.cache_creation_multiplier - 1)
|
||||
)
|
||||
|
||||
total_cost = (
|
||||
prompt_cost + completion_cost - cache_read_savings + cache_creation_premium
|
||||
)
|
||||
return max(total_cost, Decimal("0")) # Ensure non-negative
|
||||
|
||||
|
||||
class CharacterPricingModel(PricingModel):
|
||||
"""Pricing model for character-based services (TTS)"""
|
||||
|
||||
def __init__(self, character_price: Decimal):
|
||||
self.character_price = character_price
|
||||
|
||||
def calculate_cost(self, character_count: int) -> Decimal:
|
||||
"""Calculate cost for TTS character usage"""
|
||||
return Decimal(character_count) * self.character_price
|
||||
|
||||
|
||||
class TimePricingModel(PricingModel):
|
||||
"""Pricing model for time-based services (STT)"""
|
||||
|
||||
def __init__(self, second_price: Decimal):
|
||||
self.second_price = second_price
|
||||
|
||||
def calculate_cost(self, seconds: float) -> Decimal:
|
||||
"""Calculate cost for STT time usage"""
|
||||
return Decimal(str(seconds)) * self.second_price
|
||||
16
api/services/pricing/registry.py
Normal file
16
api/services/pricing/registry.py
Normal file
|
|
@ -0,0 +1,16 @@
|
|||
"""
|
||||
Main pricing registry that combines all service type pricing models.
|
||||
"""
|
||||
|
||||
from typing import Dict
|
||||
|
||||
from .llm import LLM_PRICING
|
||||
from .stt import STT_PRICING
|
||||
from .tts import TTS_PRICING
|
||||
|
||||
# Combined pricing registry for all service types
|
||||
PRICING_REGISTRY: Dict = {
|
||||
"llm": LLM_PRICING,
|
||||
"tts": TTS_PRICING,
|
||||
"stt": STT_PRICING,
|
||||
}
|
||||
26
api/services/pricing/stt.py
Normal file
26
api/services/pricing/stt.py
Normal file
|
|
@ -0,0 +1,26 @@
|
|||
"""
|
||||
STT (Speech-to-Text) pricing models for different providers.
|
||||
|
||||
Prices are per second for STT services.
|
||||
"""
|
||||
|
||||
from decimal import Decimal
|
||||
from typing import Dict
|
||||
|
||||
from api.services.configuration.registry import ServiceProviders
|
||||
|
||||
from .models import TimePricingModel
|
||||
|
||||
# STT pricing registry
|
||||
STT_PRICING: Dict[str, Dict[str, TimePricingModel]] = {
|
||||
ServiceProviders.DEEPGRAM: {
|
||||
"nova-3-general": TimePricingModel(Decimal("0.0077") / 60),
|
||||
"nova-2": TimePricingModel(Decimal("0.0058") / 60),
|
||||
"default": TimePricingModel(Decimal("0.0077") / 60),
|
||||
},
|
||||
ServiceProviders.OPENAI: {
|
||||
"gpt-4o-transcribe": TimePricingModel(Decimal("0.015") / 60),
|
||||
"default": TimePricingModel(Decimal("0.015") / 60),
|
||||
},
|
||||
"default": {"default": TimePricingModel(Decimal("0.0077") / 60)},
|
||||
}
|
||||
30
api/services/pricing/tts.py
Normal file
30
api/services/pricing/tts.py
Normal file
|
|
@ -0,0 +1,30 @@
|
|||
"""
|
||||
TTS (Text-to-Speech) pricing models for different providers.
|
||||
|
||||
Prices are per character for TTS services.
|
||||
"""
|
||||
|
||||
from decimal import Decimal
|
||||
from typing import Dict
|
||||
|
||||
from api.services.configuration.registry import ServiceProviders
|
||||
|
||||
from .models import CharacterPricingModel
|
||||
|
||||
# TTS pricing registry
|
||||
TTS_PRICING: Dict[str, Dict[str, CharacterPricingModel]] = {
|
||||
ServiceProviders.OPENAI: {
|
||||
"gpt-4o-mini-tts": CharacterPricingModel(Decimal("0.6") / 1_00_00_000),
|
||||
"default": CharacterPricingModel(Decimal("0.6") / 1_00_00_000),
|
||||
},
|
||||
ServiceProviders.DEEPGRAM: {
|
||||
"aura-2": CharacterPricingModel(Decimal("0.030") / 1_000),
|
||||
"aura-1": CharacterPricingModel(Decimal("0.015") / 1_000),
|
||||
"default": CharacterPricingModel(Decimal("0.030") / 1_000),
|
||||
},
|
||||
ServiceProviders.ELEVENLABS: {
|
||||
# 6400 usd per 250*1e6 characters
|
||||
"default": CharacterPricingModel(Decimal("0.0256") / 1_000)
|
||||
},
|
||||
"default": {"default": CharacterPricingModel(Decimal("0.030") / 1_000)},
|
||||
}
|
||||
3
api/services/reports/__init__.py
Normal file
3
api/services/reports/__init__.py
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
from .daily_report import DailyReportService
|
||||
|
||||
__all__ = ["DailyReportService"]
|
||||
237
api/services/reports/daily_report.py
Normal file
237
api/services/reports/daily_report.py
Normal file
|
|
@ -0,0 +1,237 @@
|
|||
from datetime import datetime, time
|
||||
from typing import Any, Dict, List, Optional
|
||||
from zoneinfo import ZoneInfo
|
||||
|
||||
from api.db import db_client
|
||||
|
||||
|
||||
class DailyReportService:
|
||||
async def get_daily_report(
|
||||
self,
|
||||
organization_id: int,
|
||||
date: str,
|
||||
timezone: str,
|
||||
workflow_id: Optional[int] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Get daily report for a specific date and timezone.
|
||||
|
||||
Args:
|
||||
organization_id: The organization ID to filter by
|
||||
date: Date in YYYY-MM-DD format
|
||||
timezone: IANA timezone string (e.g., "America/New_York")
|
||||
workflow_id: Optional workflow ID to filter by (None means all workflows)
|
||||
"""
|
||||
# Parse date and timezone
|
||||
tz = ZoneInfo(timezone)
|
||||
date_obj = datetime.strptime(date, "%Y-%m-%d")
|
||||
|
||||
# Create start and end datetime in the specified timezone
|
||||
start_dt = datetime.combine(date_obj, time.min, tzinfo=tz)
|
||||
end_dt = datetime.combine(date_obj, time.max, tzinfo=tz)
|
||||
|
||||
# Convert to UTC for database queries
|
||||
start_utc = start_dt.astimezone(ZoneInfo("UTC"))
|
||||
end_utc = end_dt.astimezone(ZoneInfo("UTC"))
|
||||
|
||||
# Get workflow runs from database (optimized - only required fields)
|
||||
runs = await db_client.get_workflow_runs_for_daily_report(
|
||||
organization_id=organization_id,
|
||||
start_utc=start_utc,
|
||||
end_utc=end_utc,
|
||||
workflow_id=workflow_id,
|
||||
)
|
||||
|
||||
# Calculate metrics
|
||||
total_runs = len(runs)
|
||||
xfer_count = sum(
|
||||
1
|
||||
for run in runs
|
||||
if run["gathered_context"]
|
||||
and run["gathered_context"].get("mapped_call_disposition") == "XFER"
|
||||
)
|
||||
|
||||
# Calculate disposition distribution
|
||||
disposition_counts = {}
|
||||
for run in runs:
|
||||
if run["gathered_context"]:
|
||||
disposition = run["gathered_context"].get(
|
||||
"mapped_call_disposition", "UNKNOWN"
|
||||
)
|
||||
disposition_counts[disposition] = (
|
||||
disposition_counts.get(disposition, 0) + 1
|
||||
)
|
||||
|
||||
# Sort dispositions by count and get top 5
|
||||
sorted_dispositions = sorted(
|
||||
disposition_counts.items(), key=lambda x: x[1], reverse=True
|
||||
)
|
||||
|
||||
disposition_distribution = []
|
||||
other_count = 0
|
||||
|
||||
for i, (disposition, count) in enumerate(sorted_dispositions):
|
||||
if i < 5:
|
||||
disposition_distribution.append(
|
||||
{
|
||||
"disposition": disposition,
|
||||
"count": count,
|
||||
"percentage": round(
|
||||
(count / total_runs * 100) if total_runs > 0 else 0, 2
|
||||
),
|
||||
}
|
||||
)
|
||||
else:
|
||||
other_count += count
|
||||
|
||||
# Add "Other" category if there are more than 5 dispositions
|
||||
if other_count > 0:
|
||||
disposition_distribution.append(
|
||||
{
|
||||
"disposition": "Other",
|
||||
"count": other_count,
|
||||
"percentage": round(
|
||||
(other_count / total_runs * 100) if total_runs > 0 else 0, 2
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
# Calculate call duration distribution
|
||||
duration_buckets = {
|
||||
"0-10": {"range_start": 0, "range_end": 10, "count": 0},
|
||||
"10-30": {"range_start": 10, "range_end": 30, "count": 0},
|
||||
"30-60": {"range_start": 30, "range_end": 60, "count": 0},
|
||||
"60-120": {"range_start": 60, "range_end": 120, "count": 0},
|
||||
"120-180": {"range_start": 120, "range_end": 180, "count": 0},
|
||||
">180": {"range_start": 180, "range_end": None, "count": 0},
|
||||
}
|
||||
|
||||
for run in runs:
|
||||
if run["usage_info"]:
|
||||
duration_str = run["usage_info"].get("call_duration_seconds")
|
||||
if duration_str:
|
||||
try:
|
||||
duration = float(duration_str)
|
||||
if duration < 10:
|
||||
duration_buckets["0-10"]["count"] += 1
|
||||
elif duration < 30:
|
||||
duration_buckets["10-30"]["count"] += 1
|
||||
elif duration < 60:
|
||||
duration_buckets["30-60"]["count"] += 1
|
||||
elif duration < 120:
|
||||
duration_buckets["60-120"]["count"] += 1
|
||||
elif duration < 180:
|
||||
duration_buckets["120-180"]["count"] += 1
|
||||
else:
|
||||
duration_buckets[">180"]["count"] += 1
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
|
||||
# Format duration distribution
|
||||
call_duration_distribution = []
|
||||
total_calls_with_duration = sum(b["count"] for b in duration_buckets.values())
|
||||
|
||||
for bucket_name, bucket_data in duration_buckets.items():
|
||||
call_duration_distribution.append(
|
||||
{
|
||||
"bucket": bucket_name,
|
||||
"range_start": bucket_data["range_start"],
|
||||
"range_end": bucket_data["range_end"],
|
||||
"count": bucket_data["count"],
|
||||
"percentage": round(
|
||||
(bucket_data["count"] / total_calls_with_duration * 100)
|
||||
if total_calls_with_duration > 0
|
||||
else 0,
|
||||
2,
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
return {
|
||||
"date": date,
|
||||
"timezone": timezone,
|
||||
"workflow_id": workflow_id,
|
||||
"metrics": {"total_runs": total_runs, "xfer_count": xfer_count},
|
||||
"disposition_distribution": disposition_distribution,
|
||||
"call_duration_distribution": call_duration_distribution,
|
||||
}
|
||||
|
||||
async def get_workflows_for_organization(
|
||||
self, organization_id: int
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Get all workflows for an organization.
|
||||
"""
|
||||
workflows = await db_client.get_workflows_for_organization(organization_id)
|
||||
|
||||
return [{"id": workflow.id, "name": workflow.name} for workflow in workflows]
|
||||
|
||||
async def get_daily_runs_detail(
|
||||
self,
|
||||
organization_id: int,
|
||||
date: str,
|
||||
timezone: str,
|
||||
workflow_id: Optional[int] = None,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Get detailed workflow runs for CSV export.
|
||||
|
||||
Args:
|
||||
organization_id: The organization ID to filter by
|
||||
date: Date in YYYY-MM-DD format
|
||||
timezone: IANA timezone string (e.g., "America/New_York")
|
||||
workflow_id: Optional workflow ID to filter by
|
||||
"""
|
||||
# Parse date and timezone
|
||||
tz = ZoneInfo(timezone)
|
||||
date_obj = datetime.strptime(date, "%Y-%m-%d")
|
||||
|
||||
# Create start and end datetime in the specified timezone
|
||||
start_dt = datetime.combine(date_obj, time.min, tzinfo=tz)
|
||||
end_dt = datetime.combine(date_obj, time.max, tzinfo=tz)
|
||||
|
||||
# Convert to UTC for database queries
|
||||
start_utc = start_dt.astimezone(ZoneInfo("UTC"))
|
||||
end_utc = end_dt.astimezone(ZoneInfo("UTC"))
|
||||
|
||||
# Get workflow runs from database (optimized - only required fields)
|
||||
runs = await db_client.get_workflow_runs_for_daily_report(
|
||||
organization_id=organization_id,
|
||||
start_utc=start_utc,
|
||||
end_utc=end_utc,
|
||||
workflow_id=workflow_id,
|
||||
)
|
||||
|
||||
# Format runs for CSV export
|
||||
detailed_runs = []
|
||||
for run in runs:
|
||||
# Phone number is already extracted at the database level
|
||||
# Try customer_phone_number first, then fall back to initial_context
|
||||
phone_number = run["gathered_context"].get(
|
||||
"customer_phone_number", ""
|
||||
) or run["initial_context"].get("phone_number", "")
|
||||
|
||||
# Disposition is already extracted at the database level
|
||||
disposition = run["gathered_context"].get("mapped_call_disposition", "")
|
||||
|
||||
# Duration is already extracted at the database level
|
||||
duration_seconds = 0
|
||||
duration_str = run["usage_info"].get("call_duration_seconds", "0")
|
||||
try:
|
||||
duration_seconds = float(duration_str)
|
||||
except (ValueError, TypeError):
|
||||
duration_seconds = 0
|
||||
|
||||
detailed_runs.append(
|
||||
{
|
||||
"phone_number": phone_number,
|
||||
"disposition": disposition,
|
||||
"duration_seconds": duration_seconds,
|
||||
"workflow_id": run["workflow_id"],
|
||||
"run_id": run["id"],
|
||||
"workflow_name": run["workflow_name"],
|
||||
"created_at": run["created_at"].isoformat(),
|
||||
}
|
||||
)
|
||||
|
||||
return detailed_runs
|
||||
3
api/services/smart_turn/__init__.py
Normal file
3
api/services/smart_turn/__init__.py
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
from .websocket_smart_turn import WebSocketSmartTurnAnalyzer
|
||||
|
||||
__all__ = ["WebSocketSmartTurnAnalyzer"]
|
||||
478
api/services/smart_turn/app.py
Normal file
478
api/services/smart_turn/app.py
Normal file
|
|
@ -0,0 +1,478 @@
|
|||
import asyncio
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from contextlib import asynccontextmanager
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
from fastapi import (
|
||||
BackgroundTasks,
|
||||
FastAPI,
|
||||
HTTPException,
|
||||
Request,
|
||||
WebSocket,
|
||||
WebSocketDisconnect,
|
||||
WebSocketException,
|
||||
status,
|
||||
)
|
||||
from fastapi.websockets import WebSocketState
|
||||
from pipecat.audio.turn.smart_turn.local_smart_turn_v2 import LocalSmartTurnAnalyzerV2
|
||||
from scipy.io import wavfile
|
||||
|
||||
LOG_LEVEL = (
|
||||
logging.DEBUG
|
||||
if os.environ.get("LOG_LEVEL", "DEBUG").lower() == "debug"
|
||||
else logging.INFO
|
||||
)
|
||||
|
||||
logger = logging.getLogger("smart_turn")
|
||||
logger.setLevel(LOG_LEVEL)
|
||||
handler = logging.StreamHandler(sys.stdout)
|
||||
handler.setFormatter(
|
||||
logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
||||
)
|
||||
logger.addHandler(handler)
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------
|
||||
# Configuration
|
||||
# ----------------------------------------------------------------------------
|
||||
MODEL_PATH = os.getenv("LOCAL_SMART_TURN_MODEL_PATH", "pipecat-ai/smart-turn-v2")
|
||||
|
||||
# ----------------------------------------------------------------------------
|
||||
# Analyzer Pool
|
||||
# ----------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _AnalyzerWrapper:
|
||||
"""Wraps a LocalSmartTurnAnalyzer with a lock so only one request can use it at a time."""
|
||||
|
||||
def __init__(self, analyzer: LocalSmartTurnAnalyzerV2):
|
||||
self.analyzer = analyzer
|
||||
self.lock = asyncio.Lock()
|
||||
|
||||
|
||||
_analyzer_wrapper: _AnalyzerWrapper | None = None # Will be initialised in the lifespan
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""Manage the application lifespan - startup and shutdown logic."""
|
||||
# Startup logic
|
||||
global _analyzer_wrapper
|
||||
|
||||
if _analyzer_wrapper is None:
|
||||
logger.debug("Initializing LocalSmartTurnAnalyzer")
|
||||
analyzer = LocalSmartTurnAnalyzerV2(smart_turn_model_path=MODEL_PATH)
|
||||
_analyzer_wrapper = _AnalyzerWrapper(analyzer)
|
||||
logger.debug("LocalSmartTurnAnalyzer initialized")
|
||||
|
||||
yield # Application runs here
|
||||
|
||||
# Shutdown logic (if needed in the future)
|
||||
# Any cleanup code would go here
|
||||
|
||||
|
||||
app = FastAPI(
|
||||
title="Smart Turn API",
|
||||
description="A FastAPI application exposing LocalSmartTurnAnalyzer via HTTP",
|
||||
lifespan=lifespan,
|
||||
)
|
||||
|
||||
# ----------------------------------------------------------------------------
|
||||
# API Endpoints
|
||||
# ----------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def save_wav_file(
|
||||
audio_array: np.ndarray,
|
||||
prediction: int,
|
||||
probability: float,
|
||||
service_id: str | None = None,
|
||||
sample_rate: int = 16000,
|
||||
) -> None:
|
||||
"""Save audio data as a WAV file in the background.
|
||||
|
||||
Runs the blocking ``wavfile.write`` call in a thread so that the event loop
|
||||
is not blocked. This function is now ``async`` so it can be scheduled with
|
||||
``asyncio.create_task`` from the WebSocket endpoint, while still being
|
||||
compatible with ``BackgroundTasks`` (which will ``await`` coroutine
|
||||
functions).
|
||||
|
||||
Args:
|
||||
audio_array: The audio data as a numpy array
|
||||
prediction: The prediction result (0 or 1)
|
||||
probability: The probability of the prediction
|
||||
service_id: Optional service identifier
|
||||
sample_rate: The sample rate of the audio (default: 16000 Hz)
|
||||
"""
|
||||
|
||||
def _blocking_save() -> None:
|
||||
try:
|
||||
# Generate filename with current timestamp and prediction
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")[:-3] # Include ms
|
||||
|
||||
# Include service_id in filename if available
|
||||
service_prefix = f"{service_id}_" if service_id else ""
|
||||
|
||||
root_dir = (
|
||||
Path(__file__).resolve().parents[3]
|
||||
) # dograh/api/services/smart_turn/app.py
|
||||
filename = (
|
||||
root_dir
|
||||
/ f"smart_turn_pipeline/{service_prefix}{timestamp}_{prediction}_{probability}.wav"
|
||||
)
|
||||
|
||||
# Convert float32 [-1, 1] back to int16 PCM for WAV file
|
||||
audio_int16 = np.clip(audio_array * 32767, -32768, 32767).astype(np.int16)
|
||||
|
||||
# Use provided sample rate
|
||||
wavfile.write(filename, sample_rate, audio_int16)
|
||||
|
||||
length_seconds = len(audio_array) / sample_rate
|
||||
log_message = f"Saved audio to {filename} (length: {length_seconds:.2f}s, prediction: {prediction}"
|
||||
if service_id:
|
||||
log_message += f", service_id: {service_id}"
|
||||
log_message += ")"
|
||||
|
||||
logger.info(log_message)
|
||||
|
||||
except Exception as exc: # pragma: no cover – best-effort logging only
|
||||
log_message = f"Failed to save WAV file: {exc}"
|
||||
if service_id:
|
||||
log_message += f" (service_id: {service_id})"
|
||||
logger.error(log_message)
|
||||
|
||||
# Offload the blocking I/O to a thread to avoid blocking the event loop
|
||||
await asyncio.to_thread(_blocking_save)
|
||||
|
||||
|
||||
@app.post("/raw", status_code=status.HTTP_200_OK)
|
||||
async def handle_raw(request: Request, background_tasks: BackgroundTasks):
|
||||
"""
|
||||
Accept a NumPy-serialized float32 array (written via ``np.save``) in the body and
|
||||
return a JSON prediction compatible with ``HttpSmartTurnAnalyzer``.
|
||||
"""
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Secret key validation
|
||||
# ------------------------------------------------------------------
|
||||
expected_secret = os.getenv("SMART_TURN_HTTP_SERVICE_KEY")
|
||||
if expected_secret: # If a secret is configured, enforce validation
|
||||
provided_secret = request.headers.get("X-API-Key")
|
||||
if provided_secret != expected_secret:
|
||||
logger.warning(
|
||||
"Unauthorized access attempt to /raw endpoint with invalid or missing secret key"
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Unauthorized",
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Start total-time measurement as early as possible
|
||||
# ------------------------------------------------------------------
|
||||
request_start_time = time.perf_counter()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Log that we received a request (before doing any heavy work)
|
||||
# ------------------------------------------------------------------
|
||||
logger.debug("Received /raw request")
|
||||
|
||||
body = await request.body()
|
||||
if not body:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST, detail="Empty request body"
|
||||
)
|
||||
|
||||
# Extract service context and sample rate from headers
|
||||
service_id = request.headers.get("X-Service-Context")
|
||||
sample_rate_str = request.headers.get("X-Sample-Rate")
|
||||
sample_rate = int(sample_rate_str) if sample_rate_str else 16000
|
||||
|
||||
# Deserialize NumPy array
|
||||
try:
|
||||
audio_array = np.load(io.BytesIO(body))
|
||||
except Exception as exc:
|
||||
error_msg = f"Invalid NumPy payload: {exc}"
|
||||
if service_id:
|
||||
error_msg += f" (service_id: {service_id})"
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=error_msg,
|
||||
)
|
||||
|
||||
wrapper = _analyzer_wrapper
|
||||
if wrapper is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="Analyzer not initialized",
|
||||
)
|
||||
|
||||
# Run inference guarded by the wrapper lock so the model isn't used concurrently
|
||||
log_msg = "Going to acquire lock for model inference"
|
||||
if service_id:
|
||||
log_msg += f" (service_id: {service_id})"
|
||||
logger.debug(log_msg)
|
||||
|
||||
async with wrapper.lock:
|
||||
log_msg = "Acquired lock for model inference"
|
||||
if service_id:
|
||||
log_msg += f" (service_id: {service_id})"
|
||||
logger.debug(log_msg)
|
||||
|
||||
# Measure inference-only latency
|
||||
inference_start_time = time.perf_counter()
|
||||
result = await wrapper.analyzer._predict_endpoint(audio_array)
|
||||
inference_time = time.perf_counter() - inference_start_time
|
||||
|
||||
# Calculate total processing time (from request receipt to response preparation)
|
||||
total_time = time.perf_counter() - request_start_time
|
||||
|
||||
log_msg = (
|
||||
f"Inference done result: {result['prediction']} "
|
||||
f"probability: {result['probability']} time taken: {inference_time:.2f}s total: {total_time:.2f}s"
|
||||
)
|
||||
if service_id:
|
||||
log_msg += f" (service_id: {service_id})"
|
||||
logger.debug(log_msg)
|
||||
|
||||
# Ensure metrics section exists so client code can parse it consistently
|
||||
metrics = result.get("metrics", {})
|
||||
# Overwrite / set the timing metrics explicitly
|
||||
metrics["inference_time"] = inference_time
|
||||
metrics["total_time"] = total_time
|
||||
result["metrics"] = metrics
|
||||
|
||||
logger.debug(f"Result for service_id: {service_id} is: {result}")
|
||||
|
||||
# Add service_id to result for potential client use
|
||||
if service_id:
|
||||
result["service_id"] = service_id
|
||||
|
||||
# Persist audio in background so it doesn't block the response.
|
||||
background_tasks.add_task(
|
||||
save_wav_file,
|
||||
audio_array,
|
||||
result.get("prediction", 0),
|
||||
result.get("probability", 0),
|
||||
service_id,
|
||||
sample_rate,
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
@app.get("/")
|
||||
async def root():
|
||||
"""Health-check endpoint."""
|
||||
return {"message": "Smart Turn API is running"}
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------
|
||||
# WebSocket endpoint
|
||||
# ----------------------------------------------------------------------------
|
||||
|
||||
|
||||
@app.websocket("/ws")
|
||||
async def websocket_endpoint(ws: WebSocket):
|
||||
"""Handle streaming Smart Turn requests over WebSocket.
|
||||
|
||||
Each incoming binary message must be a NumPy-serialized float32 array (as
|
||||
produced by ``np.save``). A JSON-formatted prediction (identical to the
|
||||
``/raw`` HTTP endpoint) is sent back as a text message.
|
||||
"""
|
||||
|
||||
# Extract optional secret key from headers (during handshake)
|
||||
expected_secret = os.getenv("SMART_TURN_HTTP_SERVICE_KEY")
|
||||
if expected_secret:
|
||||
provided_secret = ws.headers.get("X-API-Key")
|
||||
if provided_secret != expected_secret:
|
||||
await ws.close(code=4401, reason="Unauthorized")
|
||||
return
|
||||
|
||||
# Accept the websocket connection and log it
|
||||
await ws.accept()
|
||||
|
||||
service_id = ws.headers.get("X-Service-Context")
|
||||
sample_rate_str = ws.headers.get("X-Sample-Rate")
|
||||
sample_rate = int(sample_rate_str) if sample_rate_str else 16000
|
||||
logger.debug(
|
||||
f"WebSocket connection accepted from service_id: {service_id}, sample_rate: {sample_rate}"
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Tunables – consider moving to env vars for ops control
|
||||
# ------------------------------------------------------------------
|
||||
connection_timeout = 120.0 # Seconds of inactivity before timing out
|
||||
MAX_BINARY_SIZE = int(
|
||||
os.getenv("SMART_TURN_MAX_PAYLOAD", 10 * 1024 * 1024) # 10MB max message size
|
||||
)
|
||||
|
||||
# Track background tasks so we can cancel them on disconnect
|
||||
background_tasks = set() # Track background tasks for cleanup
|
||||
|
||||
try:
|
||||
logger.debug("Entering WebSocket message loop")
|
||||
while True:
|
||||
data = None # Initialize data for each iteration
|
||||
try:
|
||||
logger.debug("Waiting for WebSocket message…")
|
||||
|
||||
# Create receive task to handle timeout properly
|
||||
receive_task = asyncio.create_task(ws.receive())
|
||||
try:
|
||||
msg = await asyncio.wait_for(
|
||||
receive_task, timeout=connection_timeout
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
# Cancel the receive task to prevent it from running in background
|
||||
receive_task.cancel()
|
||||
try:
|
||||
await receive_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
logger.warning(
|
||||
f"WebSocket connection timeout for service_id: {service_id}"
|
||||
)
|
||||
try:
|
||||
await ws.close(code=1001, reason="Connection timeout")
|
||||
except Exception as e:
|
||||
logger.debug(f"Error closing WebSocket after timeout: {e}")
|
||||
break
|
||||
except WebSocketDisconnect as e:
|
||||
logger.debug(f"WebSocket client disconnected: {e}")
|
||||
break
|
||||
|
||||
# Validate message structure
|
||||
if not isinstance(msg, dict):
|
||||
logger.error(f"Unexpected message type: {type(msg)}")
|
||||
break
|
||||
|
||||
# Handle disconnect message explicitly
|
||||
if msg.get("type") == "websocket.disconnect":
|
||||
logger.debug("Client sent disconnect frame")
|
||||
break
|
||||
|
||||
data = None
|
||||
# Binary frame
|
||||
if "bytes" in msg and msg["bytes"] is not None:
|
||||
data = msg["bytes"]
|
||||
logger.debug(
|
||||
"Received WebSocket audio payload (%d bytes)", len(data)
|
||||
)
|
||||
|
||||
except WebSocketDisconnect as e:
|
||||
logger.debug(f"WebSocket client disconnected: {e}")
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Error in WebSocket loop: {e}")
|
||||
break
|
||||
|
||||
if data is None:
|
||||
continue
|
||||
|
||||
request_start_time = time.perf_counter()
|
||||
|
||||
# --------------------------------------------------------------
|
||||
# Basic validation & secure deserialisation
|
||||
# --------------------------------------------------------------
|
||||
if len(data) > MAX_BINARY_SIZE:
|
||||
logger.warning("Received payload exceeding maximum allowed size")
|
||||
await ws.send_text('{"error": "Payload too large"}')
|
||||
continue
|
||||
|
||||
# Deserialize NumPy array (pickle disabled for security)
|
||||
try:
|
||||
audio_array = np.load(io.BytesIO(data), allow_pickle=False)
|
||||
except Exception as exc:
|
||||
error_msg = f"Invalid NumPy payload: {exc}"
|
||||
if service_id:
|
||||
error_msg += f" (service_id: {service_id})"
|
||||
# Send error response with proper error handling
|
||||
if ws.application_state == WebSocketState.CONNECTED:
|
||||
try:
|
||||
await ws.send_text(f'{{"error": "{error_msg}"}}')
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send error message: {e}")
|
||||
continue
|
||||
|
||||
wrapper = _analyzer_wrapper
|
||||
if wrapper is None:
|
||||
logger.error("Analyzer not initialized; closing connection")
|
||||
if ws.application_state == WebSocketState.CONNECTED:
|
||||
await ws.close(code=1011, reason="Analyzer not ready")
|
||||
break
|
||||
|
||||
async with wrapper.lock:
|
||||
inference_start_time = time.perf_counter()
|
||||
result = await wrapper.analyzer._predict_endpoint(audio_array)
|
||||
inference_time = time.perf_counter() - inference_start_time
|
||||
|
||||
# Timing metrics
|
||||
total_time = time.perf_counter() - request_start_time
|
||||
metrics = result.get("metrics", {})
|
||||
metrics["inference_time"] = inference_time
|
||||
metrics["total_time"] = total_time
|
||||
result["metrics"] = metrics
|
||||
|
||||
logger.debug(f"Result for service_id: {service_id} is: {result}")
|
||||
|
||||
if service_id:
|
||||
result["service_id"] = service_id
|
||||
|
||||
# Send result with proper error handling
|
||||
try:
|
||||
if ws.application_state == WebSocketState.CONNECTED:
|
||||
await ws.send_text(json.dumps(result))
|
||||
else:
|
||||
logger.warning(
|
||||
f"Cannot send result - WebSocket not connected for service_id: {service_id}"
|
||||
)
|
||||
break
|
||||
except WebSocketDisconnect:
|
||||
logger.debug(
|
||||
f"Client disconnected while sending result for service_id: {service_id}"
|
||||
)
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send result: {e}")
|
||||
break
|
||||
|
||||
# Save audio in the background so that it doesn't block streaming
|
||||
task = asyncio.create_task(
|
||||
save_wav_file(
|
||||
audio_array,
|
||||
result.get("prediction", 0),
|
||||
result.get("probability", 0),
|
||||
service_id,
|
||||
sample_rate,
|
||||
)
|
||||
)
|
||||
# Track task and remove when done
|
||||
background_tasks.add(task)
|
||||
task.add_done_callback(background_tasks.discard)
|
||||
|
||||
except WebSocketException as exc:
|
||||
logger.error(f"WebSocket error: {exc}")
|
||||
finally:
|
||||
# Cancel any remaining background tasks
|
||||
for task in background_tasks:
|
||||
if not task.done():
|
||||
task.cancel()
|
||||
# Wait for all background tasks to complete or be cancelled
|
||||
if background_tasks:
|
||||
await asyncio.gather(*background_tasks, return_exceptions=True)
|
||||
|
||||
# Attempt a graceful close if it's not already closed
|
||||
if ws.application_state == WebSocketState.CONNECTED:
|
||||
try:
|
||||
await ws.close()
|
||||
except Exception as exc:
|
||||
# Socket is probably already closed; log and ignore
|
||||
logger.debug(f"WebSocket already closed: {exc}")
|
||||
314
api/services/smart_turn/websocket_smart_turn.py
Normal file
314
api/services/smart_turn/websocket_smart_turn.py
Normal file
|
|
@ -0,0 +1,314 @@
|
|||
"""Smart-Turn analyzer that talks to a FastAPI WebSocket endpoint.
|
||||
|
||||
This analyzer keeps a persistent WebSocket connection alive so that the TCP/TLS
|
||||
handshake and HTTP upgrade happen only once per call session. Each speech
|
||||
segment is sent as a single binary message containing the NumPy-serialized
|
||||
float32 array, and a JSON reply is expected in return.
|
||||
|
||||
Rewritten to use the websockets library for simplified connection management.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import io
|
||||
import json
|
||||
import random
|
||||
import time
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import numpy as np
|
||||
import websockets
|
||||
from loguru import logger
|
||||
from pipecat.audio.turn.smart_turn.base_smart_turn import (
|
||||
BaseSmartTurn,
|
||||
SmartTurnTimeoutException,
|
||||
)
|
||||
|
||||
|
||||
class WebSocketSmartTurnAnalyzer(BaseSmartTurn):
|
||||
"""End-of-turn analyzer that sends audio via a persistent WebSocket."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
url: str,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
service_context: Optional[Any] = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self._url = url.rstrip("/")
|
||||
self._headers = headers or {}
|
||||
self._service_context = service_context
|
||||
|
||||
# WebSocket connection
|
||||
self._ws: Optional[websockets.WebSocketClientProtocol] = None
|
||||
self._ws_lock = asyncio.Lock()
|
||||
|
||||
# Connection management
|
||||
self._connection_task: Optional[asyncio.Task] = None
|
||||
self._reconnect_delay = 1.0
|
||||
self._max_reconnect_delay = 30.0
|
||||
self._closing = False
|
||||
self._connection_closed_event = asyncio.Event()
|
||||
|
||||
# Connection health monitoring
|
||||
self._last_successful_request = 0.0
|
||||
self._connection_attempts = 0
|
||||
|
||||
# Start connection manager in background
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_running():
|
||||
self._connection_task = loop.create_task(self._connection_manager())
|
||||
except RuntimeError:
|
||||
logger.debug(
|
||||
"No running loop at object creation time. Connection will be opened lazily on first use."
|
||||
)
|
||||
|
||||
def _serialize_array(self, audio_array: np.ndarray) -> bytes:
|
||||
"""Serialize numpy array to bytes."""
|
||||
buffer = io.BytesIO()
|
||||
np.save(buffer, audio_array)
|
||||
return buffer.getvalue()
|
||||
|
||||
async def _connection_manager(self) -> None:
|
||||
"""Manages WebSocket connection lifecycle with automatic reconnection."""
|
||||
while not self._closing:
|
||||
try:
|
||||
# Establish connection
|
||||
await self._establish_connection()
|
||||
|
||||
# Reset reconnect delay on successful connection
|
||||
self._reconnect_delay = 1.0
|
||||
self._connection_attempts = 0
|
||||
|
||||
# Wait for connection close event
|
||||
self._connection_closed_event.clear()
|
||||
await self._connection_closed_event.wait()
|
||||
|
||||
logger.debug("WebSocket connection closed")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Connection manager error: {e}")
|
||||
|
||||
finally:
|
||||
# Clean up connection
|
||||
if self._ws:
|
||||
try:
|
||||
await self._ws.close()
|
||||
except:
|
||||
pass
|
||||
self._ws = None
|
||||
|
||||
if not self._closing:
|
||||
# Exponential backoff for reconnection
|
||||
self._connection_attempts += 1
|
||||
delay = min(
|
||||
self._reconnect_delay
|
||||
* (2 ** min(self._connection_attempts - 1, 5)),
|
||||
self._max_reconnect_delay,
|
||||
)
|
||||
# Add jitter to avoid thundering herd
|
||||
delay += random.uniform(0, 0.5)
|
||||
logger.info(
|
||||
f"Reconnecting in {delay:.1f} seconds (attempt {self._connection_attempts})"
|
||||
)
|
||||
await asyncio.sleep(delay)
|
||||
|
||||
async def _establish_connection(self) -> None:
|
||||
"""Establish a new WebSocket connection with retry logic."""
|
||||
logger.debug("Establishing new WebSocket connection to Smart-Turn service...")
|
||||
|
||||
# Prepare headers
|
||||
extra_headers = dict(self._headers)
|
||||
if self._service_context is not None:
|
||||
extra_headers["X-Service-Context"] = str(self._service_context)
|
||||
|
||||
# _init_sample_rate is being set in the constructor, which we should
|
||||
# use in case self._sample_rate is not set yet. The actual _sample_rate
|
||||
# is being set in the set_sample_rate() method
|
||||
# but in case of WebSocketSmartTurnAnalyzer, we establish the websocket connection
|
||||
# during __init__() and won't see the set_sample_rate until later. So, lets
|
||||
# user the _init_sample_rate instead
|
||||
_sample_rate = self._sample_rate or self._init_sample_rate
|
||||
|
||||
if _sample_rate > 0:
|
||||
extra_headers["X-Sample-Rate"] = str(_sample_rate)
|
||||
|
||||
max_attempts = 3
|
||||
for attempt in range(max_attempts):
|
||||
try:
|
||||
# Add jitter to prevent thundering herd
|
||||
if attempt > 0:
|
||||
jitter = 0.1 * attempt
|
||||
await asyncio.sleep(jitter)
|
||||
|
||||
# Connect with websockets library
|
||||
self._ws = await websockets.connect(
|
||||
self._url,
|
||||
extra_headers=extra_headers,
|
||||
ping_interval=5.0, # let websockets send pings every 5s
|
||||
ping_timeout=3.0, # fail fast if no pong in 3s
|
||||
close_timeout=10,
|
||||
max_size=10 * 1024 * 1024, # 10MB max message size
|
||||
)
|
||||
|
||||
logger.info("WebSocket connection established successfully")
|
||||
return
|
||||
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
f"Failed to establish WebSocket (attempt {attempt + 1}/{max_attempts}): {exc}"
|
||||
)
|
||||
if attempt == max_attempts - 1:
|
||||
raise
|
||||
await asyncio.sleep(0.5 * (attempt + 1))
|
||||
|
||||
async def _ensure_ws(self) -> websockets.WebSocketClientProtocol:
|
||||
"""Return a connected WebSocket, waiting for connection if necessary."""
|
||||
async with self._ws_lock:
|
||||
# If connection manager isn't running, start it
|
||||
if not self._connection_task or self._connection_task.done():
|
||||
self._connection_task = asyncio.create_task(self._connection_manager())
|
||||
|
||||
# Wait for connection with timeout
|
||||
start_time = time.time()
|
||||
max_wait_time = 10.0
|
||||
|
||||
while not self._closing:
|
||||
if self._ws:
|
||||
return self._ws
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
if elapsed > max_wait_time:
|
||||
raise Exception(
|
||||
f"Timeout waiting for WebSocket connection after {max_wait_time}s"
|
||||
)
|
||||
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
if self._closing:
|
||||
raise Exception("Analyzer is closing")
|
||||
|
||||
raise Exception("Failed to establish WebSocket connection")
|
||||
|
||||
async def _predict_endpoint(self, audio_array: np.ndarray) -> Dict[str, Any]:
|
||||
"""Send audio and await JSON response via WebSocket."""
|
||||
data_bytes = self._serialize_array(audio_array)
|
||||
|
||||
try:
|
||||
# Ensure we have a connection
|
||||
ws = await self._ensure_ws()
|
||||
|
||||
# Send data
|
||||
try:
|
||||
await ws.send(data_bytes)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send data: {e}")
|
||||
self._connection_closed_event.set()
|
||||
return {
|
||||
"prediction": 0,
|
||||
"probability": 0.0,
|
||||
"metrics": {"inference_time": 0.0, "total_time": 0.0},
|
||||
}
|
||||
|
||||
# Wait for response
|
||||
start_time = time.time()
|
||||
while True:
|
||||
remaining_timeout = self._params.stop_secs - (time.time() - start_time)
|
||||
if remaining_timeout <= 0:
|
||||
raise SmartTurnTimeoutException(
|
||||
f"Request exceeded {self._params.stop_secs} seconds."
|
||||
)
|
||||
|
||||
try:
|
||||
# Receive message with timeout
|
||||
message = await asyncio.wait_for(
|
||||
ws.recv(), timeout=min(remaining_timeout, 0.5)
|
||||
)
|
||||
|
||||
# Handle text messages (JSON responses)
|
||||
if isinstance(message, str):
|
||||
try:
|
||||
result = json.loads(message)
|
||||
|
||||
# Skip ping/pong messages
|
||||
if result.get("type") in ["ping", "pong"]:
|
||||
continue
|
||||
|
||||
# Validate prediction response
|
||||
if "prediction" not in result:
|
||||
if "type" in result:
|
||||
continue
|
||||
else:
|
||||
logger.error(
|
||||
"Invalid response format from Smart-Turn service"
|
||||
)
|
||||
return {
|
||||
"prediction": 0,
|
||||
"probability": 0.0,
|
||||
"metrics": {
|
||||
"inference_time": 0.0,
|
||||
"total_time": 0.0,
|
||||
},
|
||||
}
|
||||
|
||||
self._last_successful_request = time.time()
|
||||
return result
|
||||
|
||||
except json.JSONDecodeError as exc:
|
||||
logger.error(
|
||||
f"Smart turn service returned invalid JSON: {exc}"
|
||||
)
|
||||
raise
|
||||
else:
|
||||
logger.error(f"Unexpected message type: {type(message)}")
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
except websockets.exceptions.ConnectionClosed:
|
||||
logger.warning("WebSocket connection closed during prediction")
|
||||
self._connection_closed_event.set()
|
||||
return {
|
||||
"prediction": 0,
|
||||
"probability": 0.0,
|
||||
"metrics": {"inference_time": 0.0, "total_time": 0.0},
|
||||
}
|
||||
|
||||
except SmartTurnTimeoutException:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.error(f"Smart turn prediction failed over WebSocket: {exc}")
|
||||
self._connection_closed_event.set()
|
||||
return {
|
||||
"prediction": 0,
|
||||
"probability": 0.0,
|
||||
"metrics": {"inference_time": 0.0, "total_time": 0.0},
|
||||
}
|
||||
|
||||
async def close(self):
|
||||
"""Asynchronously close the WebSocket."""
|
||||
self._closing = True
|
||||
self._connection_closed_event.set()
|
||||
|
||||
async with self._ws_lock:
|
||||
# Cancel tasks
|
||||
if self._connection_task and not self._connection_task.done():
|
||||
self._connection_task.cancel()
|
||||
try:
|
||||
await self._connection_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
# Close WebSocket
|
||||
if self._ws:
|
||||
try:
|
||||
await self._ws.close()
|
||||
except:
|
||||
pass
|
||||
finally:
|
||||
self._ws = None
|
||||
112
api/services/storage.py
Normal file
112
api/services/storage.py
Normal file
|
|
@ -0,0 +1,112 @@
|
|||
from loguru import logger
|
||||
|
||||
from api.constants import (
|
||||
ENABLE_AWS_S3,
|
||||
MINIO_ACCESS_KEY,
|
||||
MINIO_BUCKET,
|
||||
MINIO_ENDPOINT,
|
||||
MINIO_PUBLIC_ENDPOINT,
|
||||
MINIO_SECRET_KEY,
|
||||
MINIO_SECURE,
|
||||
S3_BUCKET,
|
||||
S3_REGION,
|
||||
)
|
||||
from api.enums import StorageBackend
|
||||
|
||||
from .filesystem import BaseFileSystem, MinioFileSystem, S3FileSystem
|
||||
|
||||
|
||||
def get_storage_for_backend(backend: str) -> BaseFileSystem:
|
||||
"""Get storage instance for a specific backend enum.
|
||||
|
||||
Maps StorageBackend enum codes to actual storage implementations:
|
||||
- Code 1 (S3): AWS S3 via S3FileSystem
|
||||
- Code 2 (MINIO): MinIO via MinioFileSystem
|
||||
"""
|
||||
# Code 2: MinIO implementation (local/OSS deployments)
|
||||
if backend == StorageBackend.MINIO.value:
|
||||
endpoint = MINIO_ENDPOINT
|
||||
# Auto-detect public endpoint:
|
||||
# - If MINIO_PUBLIC_ENDPOINT is set, use it (for custom domains/IPs)
|
||||
# - If running in Docker and endpoint is "minio:9000", use "localhost:9000" for local dev
|
||||
# - Otherwise, use the endpoint as-is (both containers and browser can reach it)
|
||||
public_endpoint = MINIO_PUBLIC_ENDPOINT
|
||||
if not public_endpoint:
|
||||
# Auto-detect based on endpoint
|
||||
if endpoint.startswith("minio:"):
|
||||
# Docker internal endpoint detected, assume local development
|
||||
public_endpoint = endpoint.replace("minio:", "localhost:")
|
||||
logger.info(
|
||||
f"Auto-detected local development: using {public_endpoint} for public access"
|
||||
)
|
||||
elif endpoint.startswith("host.docker.internal:"):
|
||||
# Docker Desktop special DNS detected, use localhost for browser access
|
||||
public_endpoint = endpoint.replace(
|
||||
"host.docker.internal:", "localhost:"
|
||||
)
|
||||
logger.info(
|
||||
f"Auto-detected Docker Desktop: using {public_endpoint} for public access"
|
||||
)
|
||||
else:
|
||||
# Already using a public endpoint (localhost:9000 or domain:9000)
|
||||
public_endpoint = endpoint
|
||||
|
||||
access_key = MINIO_ACCESS_KEY
|
||||
secret_key = MINIO_SECRET_KEY
|
||||
bucket = MINIO_BUCKET
|
||||
secure = MINIO_SECURE
|
||||
logger.info(
|
||||
f"Initializing {backend} storage at {endpoint} (public: {public_endpoint}) with bucket '{bucket}'"
|
||||
)
|
||||
return MinioFileSystem(
|
||||
endpoint=endpoint,
|
||||
access_key=access_key,
|
||||
secret_key=secret_key,
|
||||
bucket_name=bucket,
|
||||
secure=secure,
|
||||
public_endpoint=public_endpoint,
|
||||
)
|
||||
|
||||
# Code 1: AWS S3 implementation (cloud deployments)
|
||||
elif backend == StorageBackend.S3.value:
|
||||
if not S3_BUCKET:
|
||||
raise ValueError(
|
||||
"S3_BUCKET environment variable is required when using S3 storage"
|
||||
)
|
||||
bucket = S3_BUCKET
|
||||
region = S3_REGION
|
||||
logger.info(
|
||||
f"Initializing {backend} storage with bucket '{bucket}' in region '{region}'"
|
||||
)
|
||||
return S3FileSystem(bucket, region)
|
||||
|
||||
# Future backend implementations can be added here:
|
||||
# elif backend == StorageBackend.GCS: # Code 3
|
||||
# return GoogleCloudFileSystem(...)
|
||||
# elif backend == StorageBackend.AZURE: # Code 4
|
||||
# return AzureBlobFileSystem(...)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unknown storage backend: {backend}")
|
||||
|
||||
|
||||
def get_current_storage_backend() -> StorageBackend:
|
||||
"""Get the current storage backend enum."""
|
||||
return StorageBackend.get_current_backend()
|
||||
|
||||
|
||||
# Create a single storage instance at module load time
|
||||
_backend = StorageBackend.get_current_backend()
|
||||
logger.info(
|
||||
f"Initializing storage backend: {_backend.name} (value: {_backend.value}, ENABLE_AWS_S3={ENABLE_AWS_S3})"
|
||||
)
|
||||
storage_fs = get_storage_for_backend(_backend.value)
|
||||
|
||||
|
||||
# For backward compatibility, keep get_storage() function
|
||||
def get_storage() -> BaseFileSystem:
|
||||
"""Get the module-level storage instance.
|
||||
|
||||
Deprecated: Use 'from api.services.storage import storage_fs' instead.
|
||||
"""
|
||||
return storage_fs
|
||||
0
api/services/telephony/__init__.py
Normal file
0
api/services/telephony/__init__.py
Normal file
765
api/services/telephony/ari_client.py
Normal file
765
api/services/telephony/ari_client.py
Normal file
|
|
@ -0,0 +1,765 @@
|
|||
"""
|
||||
Dynamic ARI client that generates methods from Swagger/OpenAPI specification.
|
||||
Pure asyncio implementation without anyio dependencies.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Callable, Dict, List, Optional
|
||||
from urllib.parse import urljoin
|
||||
|
||||
import aiohttp
|
||||
from loguru import logger
|
||||
|
||||
|
||||
class SwaggerMethod:
|
||||
"""Represents a Swagger API method."""
|
||||
|
||||
def __init__(
|
||||
self, client: "AsyncARIClient", path: str, method: str, operation: dict
|
||||
):
|
||||
self.client = client
|
||||
self.path = path
|
||||
self.http_method = method.upper()
|
||||
self.operation = operation
|
||||
self.operation_id = operation.get("operationId", "")
|
||||
self.parameters = operation.get("parameters", [])
|
||||
self.description = operation.get("description", "")
|
||||
|
||||
def _build_path(self, **kwargs) -> str:
|
||||
"""Build the actual path by substituting path parameters."""
|
||||
path = self.path
|
||||
|
||||
# Replace path parameters like {channelId} with actual values
|
||||
for param in self.parameters:
|
||||
# Swagger spec uses 'paramType' not 'in'
|
||||
if param.get("paramType", param.get("in")) == "path":
|
||||
param_name = param["name"]
|
||||
if param_name in kwargs:
|
||||
path = path.replace(f"{{{param_name}}}", str(kwargs[param_name]))
|
||||
|
||||
return path
|
||||
|
||||
def _build_params(self, **kwargs) -> dict:
|
||||
"""Extract query parameters from kwargs."""
|
||||
params = {}
|
||||
|
||||
for param in self.parameters:
|
||||
# Swagger spec uses 'paramType' not 'in'
|
||||
if param.get("paramType", param.get("in")) == "query":
|
||||
param_name = param["name"]
|
||||
if param_name in kwargs:
|
||||
params[param_name] = kwargs[param_name]
|
||||
|
||||
return params
|
||||
|
||||
def _build_body(self, **kwargs) -> dict:
|
||||
"""Extract body parameters from kwargs."""
|
||||
body = {}
|
||||
|
||||
for param in self.parameters:
|
||||
# Swagger 1.2 uses 'paramType' = 'body' for body parameters
|
||||
if param.get("paramType", param.get("in")) == "body":
|
||||
param_name = param["name"]
|
||||
if param_name in kwargs:
|
||||
# In Swagger 1.2, body param is usually the whole body
|
||||
return (
|
||||
kwargs[param_name]
|
||||
if isinstance(kwargs[param_name], dict)
|
||||
else {param_name: kwargs[param_name]}
|
||||
)
|
||||
|
||||
return body
|
||||
|
||||
async def __call__(self, **kwargs):
|
||||
"""Execute the API method."""
|
||||
path = self._build_path(**kwargs)
|
||||
params = self._build_params(**kwargs)
|
||||
|
||||
# Check if there's a body parameter defined in the spec
|
||||
body_data = self._build_body(**kwargs)
|
||||
|
||||
# If no body param in spec, use remaining kwargs for body (backward compat)
|
||||
if not body_data:
|
||||
# Remove path and query parameters from kwargs (leaving body params)
|
||||
# Swagger spec uses 'paramType' not 'in'
|
||||
path_param_names = {
|
||||
p["name"]
|
||||
for p in self.parameters
|
||||
if p.get("paramType", p.get("in")) == "path"
|
||||
}
|
||||
query_param_names = {
|
||||
p["name"]
|
||||
for p in self.parameters
|
||||
if p.get("paramType", p.get("in")) == "query"
|
||||
}
|
||||
body_param_names = {
|
||||
p["name"]
|
||||
for p in self.parameters
|
||||
if p.get("paramType", p.get("in")) == "body"
|
||||
}
|
||||
body_data = {
|
||||
k: v
|
||||
for k, v in kwargs.items()
|
||||
if k not in path_param_names
|
||||
and k not in query_param_names
|
||||
and k not in body_param_names
|
||||
}
|
||||
|
||||
# Debug logging for externalMedia
|
||||
if "externalMedia" in path:
|
||||
logger.debug(
|
||||
f"externalMedia call - method: {self.http_method}, path: {path}, params: {params}"
|
||||
)
|
||||
|
||||
if self.http_method == "GET":
|
||||
return await self.client.api_get(path, **params)
|
||||
elif self.http_method == "POST":
|
||||
return await self.client.api_post(
|
||||
path, json_data=body_data if body_data else None, **params
|
||||
)
|
||||
elif self.http_method == "PUT":
|
||||
return await self.client.api_put(
|
||||
path, json_data=body_data if body_data else None, **params
|
||||
)
|
||||
elif self.http_method == "DELETE":
|
||||
return await self.client.api_delete(path, **params)
|
||||
else:
|
||||
raise ValueError(f"Unsupported HTTP method: {self.http_method}")
|
||||
|
||||
|
||||
class ResourceAPI:
|
||||
"""Represents a resource API (like channels, bridges, etc.)."""
|
||||
|
||||
def __init__(self, client: "AsyncARIClient", resource_name: str):
|
||||
self.client = client
|
||||
self.resource_name = resource_name
|
||||
self._methods = {}
|
||||
|
||||
def add_method(self, method_name: str, swagger_method: SwaggerMethod):
|
||||
"""Add a method to this resource."""
|
||||
self._methods[method_name] = swagger_method
|
||||
|
||||
def __getattr__(self, name):
|
||||
"""Dynamically return methods."""
|
||||
if name in self._methods:
|
||||
return self._methods[name]
|
||||
raise AttributeError(f"'{self.resource_name}' has no method '{name}'")
|
||||
|
||||
|
||||
@dataclass
|
||||
class Channel:
|
||||
"""Channel model with dynamic method support."""
|
||||
|
||||
id: str
|
||||
name: str = ""
|
||||
state: str = ""
|
||||
caller: Dict[str, str] = field(default_factory=dict)
|
||||
connected: Dict[str, str] = field(default_factory=dict)
|
||||
accountcode: str = ""
|
||||
dialplan: Dict[str, str] = field(default_factory=dict)
|
||||
creationtime: str = ""
|
||||
language: str = "en"
|
||||
|
||||
# Store reference to client for method calls
|
||||
_client: Optional["AsyncARIClient"] = field(default=None, repr=False)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict, client=None) -> "Channel":
|
||||
"""Create Channel from API response."""
|
||||
channel = cls(
|
||||
id=data.get("id", ""),
|
||||
name=data.get("name", ""),
|
||||
state=data.get("state", ""),
|
||||
caller=data.get("caller", {}),
|
||||
connected=data.get("connected", {}),
|
||||
accountcode=data.get("accountcode", ""),
|
||||
dialplan=data.get("dialplan", {}),
|
||||
creationtime=data.get("creationtime", ""),
|
||||
language=data.get("language", "en"),
|
||||
_client=client,
|
||||
)
|
||||
return channel
|
||||
|
||||
async def continueInDialplan(
|
||||
self,
|
||||
context: str = None,
|
||||
extension: str = None,
|
||||
priority: int = None,
|
||||
label: str = None,
|
||||
):
|
||||
"""Continue channel in dialplan."""
|
||||
if not self._client:
|
||||
raise RuntimeError("Channel not associated with a client")
|
||||
|
||||
params = {"channelId": self.id}
|
||||
if context:
|
||||
params["context"] = context
|
||||
if extension:
|
||||
params["extension"] = extension
|
||||
if priority is not None:
|
||||
params["priority"] = priority
|
||||
if label:
|
||||
params["label"] = label
|
||||
|
||||
# The ARI API method is named 'continueInDialplan'
|
||||
channels_api = self._client.channels
|
||||
if hasattr(channels_api, "continueInDialplan"):
|
||||
await channels_api.continueInDialplan(**params)
|
||||
else:
|
||||
# Fallback to direct API call
|
||||
await self._client.api_post(f"/channels/{self.id}/continue", **params)
|
||||
|
||||
async def hangup(self, reason: str = "normal"):
|
||||
"""Hangup the channel."""
|
||||
if not self._client:
|
||||
raise RuntimeError("Channel not associated with a client")
|
||||
await self._client.channels.hangup(channelId=self.id, reason=reason)
|
||||
|
||||
async def answer(self):
|
||||
"""Answer the channel."""
|
||||
if not self._client:
|
||||
raise RuntimeError("Channel not associated with a client")
|
||||
await self._client.channels.answer(channelId=self.id)
|
||||
|
||||
async def getChannelVar(self, variable: str):
|
||||
"""Get a channel variable."""
|
||||
if not self._client:
|
||||
raise RuntimeError("Channel not associated with a client")
|
||||
return await self._client.channels.getChannelVar(
|
||||
channelId=self.id, variable=variable
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Bridge:
|
||||
"""Bridge model with dynamic method support."""
|
||||
|
||||
id: str
|
||||
technology: str = ""
|
||||
bridge_type: str = ""
|
||||
bridge_class: str = ""
|
||||
creator: str = ""
|
||||
name: str = ""
|
||||
channels: List[str] = field(default_factory=list)
|
||||
|
||||
_client: Optional["AsyncARIClient"] = field(default=None, repr=False)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict, client=None) -> "Bridge":
|
||||
"""Create Bridge from API response."""
|
||||
return cls(
|
||||
id=data.get("id", ""),
|
||||
technology=data.get("technology", ""),
|
||||
bridge_type=data.get("bridge_type", ""),
|
||||
bridge_class=data.get("bridge_class", ""),
|
||||
creator=data.get("creator", ""),
|
||||
name=data.get("name", ""),
|
||||
channels=data.get("channels", []),
|
||||
_client=client,
|
||||
)
|
||||
|
||||
async def addChannel(self, channel: str):
|
||||
"""Add channel to bridge."""
|
||||
if not self._client:
|
||||
raise RuntimeError("Bridge not associated with a client")
|
||||
await self._client.bridges.addChannel(bridgeId=self.id, channel=channel)
|
||||
|
||||
async def removeChannel(self, channel: str):
|
||||
"""Remove channel from bridge."""
|
||||
if not self._client:
|
||||
raise RuntimeError("Bridge not associated with a client")
|
||||
await self._client.bridges.removeChannel(bridgeId=self.id, channel=channel)
|
||||
|
||||
async def destroy(self):
|
||||
"""Destroy the bridge."""
|
||||
if not self._client:
|
||||
raise RuntimeError("Bridge not associated with a client")
|
||||
await self._client.bridges.destroy(bridgeId=self.id)
|
||||
|
||||
|
||||
class AsyncARIClient:
|
||||
"""ARI client that dynamically generates methods from Swagger spec."""
|
||||
|
||||
def __init__(self, base_url: str, username: str, password: str, app: str):
|
||||
self.base_url = base_url.rstrip("/")
|
||||
self.username = username
|
||||
self.password = password
|
||||
self.app = app
|
||||
|
||||
# REST API URL
|
||||
self.api_url = self.base_url.replace("ws://", "http://").replace(
|
||||
"wss://", "https://"
|
||||
)
|
||||
|
||||
# WebSocket URL
|
||||
self.ws_url = (
|
||||
f"{self.base_url}/ari/events?app={app}&api_key={username}:{password}"
|
||||
)
|
||||
|
||||
# Session and WebSocket
|
||||
self._session: Optional[aiohttp.ClientSession] = None
|
||||
self._websocket: Optional[aiohttp.ClientWebSocketResponse] = None
|
||||
self._running = False
|
||||
|
||||
# Event handling
|
||||
self._event_handlers: Dict[str, List[Callable]] = {}
|
||||
self._event_queue: asyncio.Queue = asyncio.Queue(maxsize=1000)
|
||||
|
||||
# Resource APIs (will be populated from Swagger)
|
||||
self.channels: Optional[ResourceAPI] = None
|
||||
self.bridges: Optional[ResourceAPI] = None
|
||||
self.endpoints: Optional[ResourceAPI] = None
|
||||
self.recordings: Optional[ResourceAPI] = None
|
||||
self.sounds: Optional[ResourceAPI] = None
|
||||
self.playbacks: Optional[ResourceAPI] = None
|
||||
self.asterisk: Optional[ResourceAPI] = None
|
||||
self.applications: Optional[ResourceAPI] = None
|
||||
self.deviceStates: Optional[ResourceAPI] = None
|
||||
self.mailboxes: Optional[ResourceAPI] = None
|
||||
|
||||
# Swagger spec cache
|
||||
self._swagger_spec: Optional[dict] = None
|
||||
|
||||
async def connect(self):
|
||||
"""Connect to ARI and load Swagger spec."""
|
||||
# Create HTTP session
|
||||
auth = aiohttp.BasicAuth(self.username, self.password)
|
||||
self._session = aiohttp.ClientSession(auth=auth)
|
||||
|
||||
try:
|
||||
# Load Swagger spec and generate methods
|
||||
await self._load_swagger_spec()
|
||||
|
||||
# Connect WebSocket
|
||||
self._websocket = await self._session.ws_connect(
|
||||
self.ws_url, heartbeat=30, autoping=True
|
||||
)
|
||||
self._running = True
|
||||
logger.info(f"Connected to ARI at {self.ws_url}")
|
||||
|
||||
except Exception as e:
|
||||
await self._session.close()
|
||||
raise Exception(f"Failed to connect to ARI: {e}")
|
||||
|
||||
async def _load_swagger_spec(self):
|
||||
"""Load Swagger spec and generate dynamic methods."""
|
||||
spec_loaded = False
|
||||
try:
|
||||
# Get Swagger spec from ARI
|
||||
url = f"{self.api_url}/ari/api-docs/resources.json"
|
||||
async with self._session.get(url) as resp:
|
||||
resp.raise_for_status()
|
||||
resources = await resp.json()
|
||||
|
||||
# Store the spec
|
||||
self._swagger_spec = resources
|
||||
|
||||
# Create resource APIs
|
||||
for api_info in resources.get("apis", []):
|
||||
resource_path = api_info["path"]
|
||||
|
||||
# Fix the path - remove .{format} and add proper prefix
|
||||
resource_path = resource_path.replace(".{format}", ".json")
|
||||
|
||||
# Load detailed spec for this resource
|
||||
# The resource_path already contains /api-docs/, so we just need the base URL
|
||||
url = f"{self.api_url}/ari{resource_path}"
|
||||
try:
|
||||
async with self._session.get(url) as resp:
|
||||
resp.raise_for_status()
|
||||
spec = await resp.json()
|
||||
|
||||
self._process_swagger_spec(spec)
|
||||
spec_loaded = True
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load spec for {resource_path}: {e}")
|
||||
|
||||
if spec_loaded:
|
||||
logger.info("Loaded Swagger spec and generated dynamic methods")
|
||||
else:
|
||||
raise Exception("No individual specs could be loaded")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load Swagger spec, using fallback methods: {e}")
|
||||
self._create_fallback_methods()
|
||||
|
||||
def _process_swagger_spec(self, spec: dict):
|
||||
"""Process a Swagger spec and create dynamic methods."""
|
||||
# basePath is available in spec but not currently used
|
||||
|
||||
for api in spec.get("apis", []):
|
||||
path = api["path"]
|
||||
|
||||
for operation in api.get("operations", []):
|
||||
self._create_method_from_operation(path, operation)
|
||||
|
||||
def _create_method_from_operation(self, path: str, operation: dict):
|
||||
"""Create a method from a Swagger operation."""
|
||||
# Swagger spec uses 'httpMethod' not 'method'
|
||||
method = operation.get("httpMethod", operation.get("method", "GET"))
|
||||
operation_id = operation.get("nickname", "")
|
||||
|
||||
if not operation_id:
|
||||
return
|
||||
|
||||
# Determine resource from path (e.g., /channels/{channelId} -> channels)
|
||||
path_parts = path.strip("/").split("/")
|
||||
if path_parts:
|
||||
resource_name = path_parts[0]
|
||||
|
||||
# Create resource API if it doesn't exist
|
||||
if not hasattr(self, resource_name) or getattr(self, resource_name) is None:
|
||||
setattr(self, resource_name, ResourceAPI(self, resource_name))
|
||||
|
||||
resource_api = getattr(self, resource_name)
|
||||
|
||||
# Extract method name from operation ID
|
||||
# e.g., "channels_continue" -> "continue_"
|
||||
# or "channels_get" -> "get"
|
||||
method_name = operation_id
|
||||
if method_name.startswith(resource_name + "_"):
|
||||
method_name = method_name[len(resource_name) + 1 :]
|
||||
|
||||
# Handle special cases
|
||||
if method_name == "continue":
|
||||
method_name = "continue_" # Avoid Python keyword
|
||||
|
||||
# Create and add the method
|
||||
swagger_method = SwaggerMethod(self, path, method, operation)
|
||||
resource_api.add_method(method_name, swagger_method)
|
||||
|
||||
def _create_fallback_methods(self):
|
||||
"""Create fallback methods if Swagger spec is not available."""
|
||||
# Create basic resource APIs
|
||||
self.channels = ResourceAPI(self, "channels")
|
||||
self.bridges = ResourceAPI(self, "bridges")
|
||||
|
||||
# Add essential channel methods
|
||||
self.channels.add_method(
|
||||
"get",
|
||||
SwaggerMethod(
|
||||
self,
|
||||
"/channels/{channelId}",
|
||||
"GET",
|
||||
{
|
||||
"operationId": "get",
|
||||
"parameters": [{"name": "channelId", "in": "path"}],
|
||||
},
|
||||
),
|
||||
)
|
||||
self.channels.add_method(
|
||||
"hangup",
|
||||
SwaggerMethod(
|
||||
self,
|
||||
"/channels/{channelId}",
|
||||
"DELETE",
|
||||
{
|
||||
"operationId": "hangup",
|
||||
"parameters": [
|
||||
{"name": "channelId", "in": "path"},
|
||||
{"name": "reason", "in": "query"},
|
||||
],
|
||||
},
|
||||
),
|
||||
)
|
||||
self.channels.add_method(
|
||||
"answer",
|
||||
SwaggerMethod(
|
||||
self,
|
||||
"/channels/{channelId}/answer",
|
||||
"POST",
|
||||
{
|
||||
"operationId": "answer",
|
||||
"parameters": [{"name": "channelId", "in": "path"}],
|
||||
},
|
||||
),
|
||||
)
|
||||
self.channels.add_method(
|
||||
"continueInDialplan",
|
||||
SwaggerMethod(
|
||||
self,
|
||||
"/channels/{channelId}/continue",
|
||||
"POST",
|
||||
{
|
||||
"operationId": "continueInDialplan",
|
||||
"parameters": [
|
||||
{"name": "channelId", "in": "path"},
|
||||
{"name": "context", "in": "query"},
|
||||
{"name": "extension", "in": "query"},
|
||||
{"name": "priority", "in": "query"},
|
||||
{"name": "label", "in": "query"},
|
||||
],
|
||||
},
|
||||
),
|
||||
)
|
||||
self.channels.add_method(
|
||||
"externalMedia",
|
||||
SwaggerMethod(
|
||||
self,
|
||||
"/channels/externalMedia",
|
||||
"POST",
|
||||
{
|
||||
"operationId": "externalMedia",
|
||||
"parameters": [
|
||||
{"name": "channelId", "in": "query"}, # Add channelId parameter
|
||||
{"name": "app", "in": "query"},
|
||||
{"name": "external_host", "in": "query"},
|
||||
{"name": "format", "in": "query"},
|
||||
{"name": "encapsulation", "in": "query"},
|
||||
{"name": "transport", "in": "query"},
|
||||
{"name": "connection_type", "in": "query"},
|
||||
{"name": "direction", "in": "query"},
|
||||
],
|
||||
},
|
||||
),
|
||||
)
|
||||
self.channels.add_method(
|
||||
"getChannelVar",
|
||||
SwaggerMethod(
|
||||
self,
|
||||
"/channels/{channelId}/variable",
|
||||
"GET",
|
||||
{
|
||||
"operationId": "getChannelVar",
|
||||
"parameters": [
|
||||
{"name": "channelId", "in": "path"},
|
||||
{"name": "variable", "in": "query"},
|
||||
],
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
# Add essential bridge methods
|
||||
self.bridges.add_method(
|
||||
"get",
|
||||
SwaggerMethod(
|
||||
self,
|
||||
"/bridges/{bridgeId}",
|
||||
"GET",
|
||||
{
|
||||
"operationId": "get",
|
||||
"parameters": [{"name": "bridgeId", "in": "path"}],
|
||||
},
|
||||
),
|
||||
)
|
||||
self.bridges.add_method(
|
||||
"create",
|
||||
SwaggerMethod(
|
||||
self,
|
||||
"/bridges",
|
||||
"POST",
|
||||
{
|
||||
"operationId": "create",
|
||||
"parameters": [
|
||||
{"name": "type", "in": "query"},
|
||||
{"name": "name", "in": "query"},
|
||||
],
|
||||
},
|
||||
),
|
||||
)
|
||||
self.bridges.add_method(
|
||||
"addChannel",
|
||||
SwaggerMethod(
|
||||
self,
|
||||
"/bridges/{bridgeId}/addChannel",
|
||||
"POST",
|
||||
{
|
||||
"operationId": "addChannel",
|
||||
"parameters": [
|
||||
{"name": "bridgeId", "in": "path"},
|
||||
{"name": "channel", "in": "query"},
|
||||
],
|
||||
},
|
||||
),
|
||||
)
|
||||
self.bridges.add_method(
|
||||
"removeChannel",
|
||||
SwaggerMethod(
|
||||
self,
|
||||
"/bridges/{bridgeId}/removeChannel",
|
||||
"POST",
|
||||
{
|
||||
"operationId": "removeChannel",
|
||||
"parameters": [
|
||||
{"name": "bridgeId", "in": "path"},
|
||||
{"name": "channel", "in": "query"},
|
||||
],
|
||||
},
|
||||
),
|
||||
)
|
||||
self.bridges.add_method(
|
||||
"destroy",
|
||||
SwaggerMethod(
|
||||
self,
|
||||
"/bridges/{bridgeId}",
|
||||
"DELETE",
|
||||
{
|
||||
"operationId": "destroy",
|
||||
"parameters": [{"name": "bridgeId", "in": "path"}],
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
async def disconnect(self):
|
||||
"""Disconnect from ARI."""
|
||||
self._running = False
|
||||
|
||||
if self._websocket:
|
||||
await self._websocket.close()
|
||||
|
||||
if self._session:
|
||||
await self._session.close()
|
||||
|
||||
async def run(self):
|
||||
"""Main event loop."""
|
||||
if not self._websocket:
|
||||
raise RuntimeError("Not connected")
|
||||
|
||||
processor_task = asyncio.create_task(self._process_events())
|
||||
|
||||
try:
|
||||
async for msg in self._websocket:
|
||||
if msg.type == aiohttp.WSMsgType.TEXT:
|
||||
try:
|
||||
event = json.loads(msg.data)
|
||||
# Wrap channel/bridge objects
|
||||
if "channel" in event and isinstance(event["channel"], dict):
|
||||
event["channel"] = Channel.from_dict(event["channel"], self)
|
||||
if "bridge" in event and isinstance(event["bridge"], dict):
|
||||
event["bridge"] = Bridge.from_dict(event["bridge"], self)
|
||||
await self._event_queue.put(event)
|
||||
except json.JSONDecodeError:
|
||||
logger.error(f"Invalid JSON: {msg.data}")
|
||||
|
||||
elif msg.type == aiohttp.WSMsgType.ERROR:
|
||||
logger.error(f"WebSocket error: {self._websocket.exception()}")
|
||||
break
|
||||
|
||||
elif msg.type == aiohttp.WSMsgType.CLOSED:
|
||||
logger.info("WebSocket closed")
|
||||
break
|
||||
|
||||
finally:
|
||||
self._running = False
|
||||
processor_task.cancel()
|
||||
await asyncio.gather(processor_task, return_exceptions=True)
|
||||
|
||||
async def _process_events(self):
|
||||
"""Process events from queue."""
|
||||
while self._running:
|
||||
try:
|
||||
event = await asyncio.wait_for(self._event_queue.get(), timeout=1.0)
|
||||
event_type = event.get("type")
|
||||
if event_type:
|
||||
await self._dispatch_event(event_type, event)
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing event: {e}")
|
||||
|
||||
async def _dispatch_event(self, event_type: str, event: dict):
|
||||
"""Dispatch event to handlers."""
|
||||
handlers = self._event_handlers.get(event_type, [])
|
||||
if handlers:
|
||||
logger.debug(
|
||||
f"AsyncARIClient: Dispatching {event_type} to {len(handlers)} handlers"
|
||||
)
|
||||
for i, handler in enumerate(handlers):
|
||||
try:
|
||||
logger.debug(
|
||||
f" AsyncARIClient: Calling {event_type} handler {i + 1}/{len(handlers)}"
|
||||
)
|
||||
await handler(event)
|
||||
except Exception as e:
|
||||
logger.error(f"Handler {i + 1} error for {event_type}: {e}")
|
||||
|
||||
def on_event(self, event_type: str, handler: Callable):
|
||||
"""Register event handler."""
|
||||
if event_type not in self._event_handlers:
|
||||
self._event_handlers[event_type] = []
|
||||
logger.debug(
|
||||
f"AsyncARIClient: Registering handler for {event_type}. Current count: {len(self._event_handlers.get(event_type, []))}"
|
||||
)
|
||||
self._event_handlers[event_type].append(handler)
|
||||
logger.debug(
|
||||
f"AsyncARIClient: After registration, {event_type} handler count: {len(self._event_handlers[event_type])}"
|
||||
)
|
||||
|
||||
# REST API methods
|
||||
async def api_get(self, path: str, **params) -> dict:
|
||||
"""GET request."""
|
||||
# Ensure path starts with /ari if not already
|
||||
if not path.startswith("/ari"):
|
||||
path = f"/ari{path}" if path.startswith("/") else f"/ari/{path}"
|
||||
url = urljoin(self.api_url, path.lstrip("/"))
|
||||
async with self._session.get(url, params=params) as resp:
|
||||
resp.raise_for_status()
|
||||
data = await resp.json()
|
||||
# Wrap known objects
|
||||
if isinstance(data, list):
|
||||
# Handle lists of channels/bridges
|
||||
if "/channels" in path:
|
||||
return [
|
||||
Channel.from_dict(item, self)
|
||||
if isinstance(item, dict)
|
||||
else item
|
||||
for item in data
|
||||
]
|
||||
elif "/bridges" in path:
|
||||
return [
|
||||
Bridge.from_dict(item, self) if isinstance(item, dict) else item
|
||||
for item in data
|
||||
]
|
||||
return data
|
||||
elif isinstance(data, dict):
|
||||
if "/channels/" in path and "id" in data:
|
||||
return Channel.from_dict(data, self)
|
||||
elif "/bridges/" in path and "id" in data:
|
||||
return Bridge.from_dict(data, self)
|
||||
return data
|
||||
|
||||
async def api_post(self, path: str, json_data: dict = None, **params) -> dict:
|
||||
"""POST request."""
|
||||
# Ensure path starts with /ari if not already
|
||||
if not path.startswith("/ari"):
|
||||
path = f"/ari{path}" if path.startswith("/") else f"/ari/{path}"
|
||||
url = urljoin(self.api_url, path.lstrip("/"))
|
||||
async with self._session.post(url, json=json_data, params=params) as resp:
|
||||
resp.raise_for_status()
|
||||
if resp.content_length and resp.content_length > 0:
|
||||
data = await resp.json()
|
||||
# Wrap known objects
|
||||
if "id" in data and "state" in data:
|
||||
return Channel.from_dict(data, self)
|
||||
elif "id" in data and "bridge_type" in data:
|
||||
return Bridge.from_dict(data, self)
|
||||
return data
|
||||
return {}
|
||||
|
||||
async def api_put(self, path: str, json_data: dict = None, **params) -> dict:
|
||||
"""PUT request."""
|
||||
# Ensure path starts with /ari if not already
|
||||
if not path.startswith("/ari"):
|
||||
path = f"/ari{path}" if path.startswith("/") else f"/ari/{path}"
|
||||
url = urljoin(self.api_url, path.lstrip("/"))
|
||||
async with self._session.put(url, json=json_data, params=params) as resp:
|
||||
resp.raise_for_status()
|
||||
if resp.content_length and resp.content_length > 0:
|
||||
return await resp.json()
|
||||
return {}
|
||||
|
||||
async def api_delete(self, path: str, **params) -> dict:
|
||||
"""DELETE request."""
|
||||
# Ensure path starts with /ari if not already
|
||||
if not path.startswith("/ari"):
|
||||
path = f"/ari{path}" if path.startswith("/") else f"/ari/{path}"
|
||||
url = urljoin(self.api_url, path.lstrip("/"))
|
||||
async with self._session.delete(url, params=params) as resp:
|
||||
resp.raise_for_status()
|
||||
if resp.content_length and resp.content_length > 0:
|
||||
return await resp.json()
|
||||
return {}
|
||||
446
api/services/telephony/ari_client_manager.py
Normal file
446
api/services/telephony/ari_client_manager.py
Normal file
|
|
@ -0,0 +1,446 @@
|
|||
"""
|
||||
ARI Client Manager using the new Async ARI Client.
|
||||
Drop-in replacement for the existing ari_client_manager.py.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
from typing import Awaitable, Callable, Optional
|
||||
|
||||
import httpx
|
||||
from loguru import logger
|
||||
|
||||
from api.services.telephony.ari_client import AsyncARIClient, Channel
|
||||
from api.services.telephony.ari_client_singleton import ari_client_singleton
|
||||
|
||||
|
||||
class ARIClientManager:
|
||||
"""Manages ARI client connection and event handling.
|
||||
|
||||
This is a compatibility wrapper around AsyncARIClient.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ari_client: AsyncARIClient,
|
||||
app_endpoint: str,
|
||||
_conn_ctx=None, # Not used with AsyncARIClient
|
||||
):
|
||||
"""Initialize the ARI client manager.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
ari_client: AsyncARIClient
|
||||
The connected ARI client.
|
||||
app_endpoint: str
|
||||
The app endpoint for external media.
|
||||
_conn_ctx:
|
||||
Not used, kept for compatibility.
|
||||
"""
|
||||
self._ari_client = ari_client
|
||||
self._app_endpoint = app_endpoint
|
||||
self._conn_ctx = _conn_ctx # Not used but kept for compatibility
|
||||
self._start_handlers = []
|
||||
self._end_handlers = []
|
||||
self._running = False
|
||||
self._handlers_registered = False # Track if handlers are registered
|
||||
|
||||
def register_start_handler(
|
||||
self, handler: Callable[[Channel, dict], Awaitable[None]]
|
||||
):
|
||||
"""Register a handler for StasisStart events."""
|
||||
logger.debug(
|
||||
f"Registering start handler. Current count: {len(self._start_handlers)}"
|
||||
)
|
||||
self._start_handlers.append(handler)
|
||||
logger.debug(f"After registration, handler count: {len(self._start_handlers)}")
|
||||
|
||||
def register_end_handler(self, handler: Callable[[str], Awaitable[None]]):
|
||||
"""Register a handler for StasisEnd events."""
|
||||
self._end_handlers.append(handler)
|
||||
|
||||
async def update_client(self, new_client: AsyncARIClient, new_conn_ctx=None):
|
||||
"""Update to a new client (for reconnection)."""
|
||||
logger.info("Updating ARI client for reconnection")
|
||||
self._ari_client = new_client
|
||||
self._conn_ctx = new_conn_ctx
|
||||
# Clear old event handlers from the client before re-registering
|
||||
# to prevent duplicate handler registrations
|
||||
if hasattr(new_client, "_event_handlers"):
|
||||
new_client._event_handlers.clear()
|
||||
# Re-register event handlers
|
||||
self._register_handlers()
|
||||
|
||||
def _register_handlers(self):
|
||||
"""Register event handlers with the client."""
|
||||
logger.debug(
|
||||
f"_register_handlers called. Start handlers count: {len(self._start_handlers)}, End handlers count: {len(self._end_handlers)}"
|
||||
)
|
||||
|
||||
async def on_stasis_start(event):
|
||||
"""Handle StasisStart events."""
|
||||
channel = event.get("channel")
|
||||
|
||||
# Only handle PJSIP and SIP channels
|
||||
if channel and hasattr(channel, "name"):
|
||||
if not (
|
||||
channel.name.startswith("PJSIP") or channel.name.startswith("SIP")
|
||||
):
|
||||
logger.debug(
|
||||
f"Ignoring StasisStart for non-SIP channel: {channel.name}"
|
||||
)
|
||||
return
|
||||
|
||||
# Log the event
|
||||
logger.info(
|
||||
f"StasisStart event for channel: {channel.id if channel else 'unknown'}"
|
||||
)
|
||||
|
||||
# Extract call context variables
|
||||
call_context_vars = {}
|
||||
try:
|
||||
# Get channel variables
|
||||
var_result = await channel.getChannelVar(
|
||||
variable="LOCAL_ARI_CALL_VARIABLES"
|
||||
)
|
||||
call_context_vars = json.loads(var_result.get("value", "{}"))
|
||||
|
||||
# Try to get phone number and fetch additional data
|
||||
phone_number = call_context_vars.get("phone")
|
||||
ari_data_uri = os.getenv("ARI_DATA_FETCHING_URI")
|
||||
|
||||
if phone_number and ari_data_uri:
|
||||
try:
|
||||
start_time = time.time()
|
||||
fetch_url = f"{ari_data_uri}{phone_number}"
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(fetch_url, timeout=10.0)
|
||||
response.raise_for_status()
|
||||
|
||||
# Parse the response - get the latest line if multiple lines
|
||||
response_text = response.text.strip()
|
||||
if response_text:
|
||||
lines = response_text.split("\n")
|
||||
latest_line = lines[-1].strip()
|
||||
|
||||
if latest_line:
|
||||
# Parse the pipe-delimited data
|
||||
fields = latest_line.split("|")
|
||||
field_names = [
|
||||
"status",
|
||||
"user",
|
||||
"vendor_lead_code",
|
||||
"source_id",
|
||||
"list_id",
|
||||
"gmt_offset_now",
|
||||
"phone_code",
|
||||
"phone_number",
|
||||
"title",
|
||||
"first_name",
|
||||
"middle_initial",
|
||||
"last_name",
|
||||
"address1",
|
||||
"address2",
|
||||
"address3",
|
||||
"city",
|
||||
"state",
|
||||
"province",
|
||||
"postal_code",
|
||||
"country_code",
|
||||
"gender",
|
||||
"date_of_birth",
|
||||
"alt_phone",
|
||||
"email",
|
||||
"security_phrase",
|
||||
"comments",
|
||||
"called_count",
|
||||
"last_local_call_time",
|
||||
"rank",
|
||||
"owner",
|
||||
"entry_list_id",
|
||||
"lead_id",
|
||||
]
|
||||
|
||||
# Map fields to call_context_vars
|
||||
for i, field_name in enumerate(field_names):
|
||||
try:
|
||||
call_context_vars[field_name] = fields[i]
|
||||
except IndexError:
|
||||
logger.error(
|
||||
f"channelID: {channel.id} IndexError while accessing fields {i}"
|
||||
)
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
logger.info(
|
||||
f"channelID: {channel.id} Successfully fetched user details for phone: {phone_number} in {elapsed_time:.3f} seconds"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
elapsed_time = time.time() - start_time
|
||||
logger.error(
|
||||
f"channelID: {channel.id} Failed to fetch user details from ARI_DATA_FETCHING_URI after {elapsed_time:.3f} seconds: {e}"
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"channelID: {channel.id} call context variables: {call_context_vars}"
|
||||
)
|
||||
|
||||
except (
|
||||
KeyError,
|
||||
AttributeError,
|
||||
httpx.HTTPStatusError,
|
||||
json.JSONDecodeError,
|
||||
) as e:
|
||||
logger.debug(f"could not find variable LOCAL_ARI_CALL_VARIABLES: {e}")
|
||||
|
||||
# Call all registered handlers with call_context_vars
|
||||
logger.debug(
|
||||
f"Calling {len(self._start_handlers)} start handlers for channel {channel.id}"
|
||||
)
|
||||
for i, handler in enumerate(self._start_handlers):
|
||||
try:
|
||||
logger.debug(
|
||||
f" Calling start handler {i + 1}/{len(self._start_handlers)}"
|
||||
)
|
||||
await handler(channel, call_context_vars)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in StasisStart handler {i + 1}: {e}")
|
||||
|
||||
async def on_stasis_end(event):
|
||||
"""Handle StasisEnd events."""
|
||||
channel = event.get("channel", {})
|
||||
channel_id = channel.id if hasattr(channel, "id") else channel.get("id", "")
|
||||
|
||||
# # Only handle PJSIP and SIP channels
|
||||
# if channel:
|
||||
# channel_name = channel.name if hasattr(channel, 'name') else channel.get("name", "")
|
||||
# if channel_name and not (channel_name.startswith("PJSIP") or channel_name.startswith("SIP")):
|
||||
# logger.debug(f"Ignoring StasisEnd for non-SIP channel: {channel_name}")
|
||||
# return
|
||||
|
||||
logger.info(f"StasisEnd event for channel: {channel_id}")
|
||||
|
||||
# Call all registered handlers
|
||||
for handler in self._end_handlers:
|
||||
try:
|
||||
await handler(channel_id)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in StasisEnd handler: {e}")
|
||||
|
||||
# Register with the AsyncARIClient
|
||||
logger.debug(f"Registering StasisStart and StasisEnd with AsyncARIClient")
|
||||
self._ari_client.on_event("StasisStart", on_stasis_start)
|
||||
self._ari_client.on_event("StasisEnd", on_stasis_end)
|
||||
logger.debug(f"Event handlers registered with client")
|
||||
|
||||
async def run(self):
|
||||
"""Run the event loop.
|
||||
|
||||
The actual WebSocket handling is done by AsyncARIClient.
|
||||
This just registers handlers and waits.
|
||||
"""
|
||||
logger.debug("Running ARIClientManager")
|
||||
self._running = True
|
||||
# Register handlers only once, on first run
|
||||
if not self._handlers_registered:
|
||||
self._register_handlers()
|
||||
self._handlers_registered = True
|
||||
|
||||
try:
|
||||
# The AsyncARIClient.run() method handles WebSocket
|
||||
# We don't call it here as it's called by the supervisor
|
||||
while self._running:
|
||||
await asyncio.sleep(1)
|
||||
except asyncio.CancelledError:
|
||||
logger.debug(f"ARIClientManager run cancelled")
|
||||
self._running = False
|
||||
raise
|
||||
finally:
|
||||
self._running = False
|
||||
|
||||
|
||||
class _ARIClientManagerSupervisor:
|
||||
"""Supervisor that maintains ARI connection with automatic reconnection.
|
||||
|
||||
This replaces the asyncari-based supervisor with AsyncARIClient.
|
||||
"""
|
||||
|
||||
# Reconnection parameters
|
||||
_INITIAL_BACKOFF = 1 # Start with 1 second
|
||||
_MAX_BACKOFF = 60 # Max 60 seconds between retries
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
on_channel_start: Callable[[Channel, dict], Awaitable[None]],
|
||||
on_channel_end: Optional[Callable[[str], Awaitable[None]]] = None,
|
||||
):
|
||||
self._on_channel_start = on_channel_start
|
||||
self._on_channel_end = on_channel_end
|
||||
self._shutting_down = False
|
||||
|
||||
async def start(self):
|
||||
"""Start the supervisor and maintain connection."""
|
||||
await self._runner()
|
||||
|
||||
async def stop(self):
|
||||
"""Stop the supervisor."""
|
||||
logger.info("Stopping ARI Client Manager Supervisor")
|
||||
self._shutting_down = True
|
||||
|
||||
async def __aenter__(self):
|
||||
"""Async context manager entry."""
|
||||
asyncio.create_task(self.start())
|
||||
return self
|
||||
|
||||
async def __aexit__(self, *args):
|
||||
"""Async context manager exit."""
|
||||
await self.stop()
|
||||
|
||||
async def _runner(self):
|
||||
"""Main reconnection loop using AsyncARIClient."""
|
||||
backoff = self._INITIAL_BACKOFF
|
||||
ari_client_manager: Optional[ARIClientManager] = None
|
||||
|
||||
while not self._shutting_down:
|
||||
client = None
|
||||
|
||||
try:
|
||||
logger.debug("Going to connect with ARI")
|
||||
|
||||
# Get configuration from environment
|
||||
base_url = os.getenv("ARI_STASIS_ENDPOINT")
|
||||
username = os.getenv("ARI_STASIS_USER")
|
||||
password = os.getenv("ARI_STASIS_USER_PASSWORD")
|
||||
app = os.getenv("ARI_STASIS_APP_NAME")
|
||||
|
||||
# Convert HTTP to WebSocket URL
|
||||
ws_url = base_url.replace("http://", "ws://").replace(
|
||||
"https://", "wss://"
|
||||
)
|
||||
|
||||
# Create and connect the AsyncARIClient
|
||||
client = AsyncARIClient(ws_url, username, password, app)
|
||||
await client.connect()
|
||||
|
||||
# Update the singleton with the new client
|
||||
ari_client_singleton.set_client(client)
|
||||
|
||||
if ari_client_manager is None:
|
||||
# First connection - create new manager
|
||||
logger.debug("Creating new ARIClientManager (first connection)")
|
||||
ari_client_manager = ARIClientManager(
|
||||
client,
|
||||
os.getenv("ARI_STASIS_APP_ENDPOINT"),
|
||||
_conn_ctx=None, # Not needed with AsyncARIClient
|
||||
)
|
||||
logger.debug(f"Registering handlers with new manager")
|
||||
ari_client_manager.register_start_handler(self._on_channel_start)
|
||||
if self._on_channel_end:
|
||||
ari_client_manager.register_end_handler(self._on_channel_end)
|
||||
else:
|
||||
# Reconnection - update existing manager
|
||||
logger.debug("Updating existing ARIClientManager (reconnection)")
|
||||
# Don't re-register start and end handlers as they're already registered
|
||||
await ari_client_manager.update_client(client, None)
|
||||
|
||||
logger.info("Connected to ARI — supervisor entering event loop")
|
||||
|
||||
# Reset backoff after successful connection
|
||||
backoff = self._INITIAL_BACKOFF
|
||||
|
||||
# Create tasks for both the client and manager
|
||||
client_task = asyncio.create_task(client.run())
|
||||
manager_task = asyncio.create_task(ari_client_manager.run())
|
||||
|
||||
# Wait for either to complete (likely due to disconnection)
|
||||
done, pending = await asyncio.wait(
|
||||
{client_task, manager_task}, return_when=asyncio.FIRST_COMPLETED
|
||||
)
|
||||
|
||||
# Cancel the other task
|
||||
for task in pending:
|
||||
task.cancel()
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
except asyncio.CancelledError:
|
||||
# Check if we're shutting down
|
||||
if self._shutting_down or asyncio.current_task().cancelled():
|
||||
logger.debug("ARI supervisor task cancelled — shutting down")
|
||||
break
|
||||
|
||||
# Otherwise it's a transient connection error
|
||||
logger.warning("ARI connection lost due to CancelledError — will retry")
|
||||
|
||||
# Force a context switch to reset event loop state
|
||||
await asyncio.sleep(0)
|
||||
|
||||
except Exception as exc:
|
||||
# Check if we're shutting down
|
||||
if self._shutting_down or asyncio.current_task().cancelled():
|
||||
logger.warning("Exiting due to shutdown during exception handling")
|
||||
break
|
||||
|
||||
# Log and retry
|
||||
logger.warning(f"ARI connection failed or lost: {exc!r} - will retry")
|
||||
|
||||
finally:
|
||||
# Disconnect client if connected
|
||||
if client:
|
||||
try:
|
||||
await client.disconnect()
|
||||
except Exception as e:
|
||||
logger.warning(f"Error disconnecting client: {e}")
|
||||
# Clear the singleton when disconnecting
|
||||
ari_client_singleton.clear()
|
||||
|
||||
# Check if we're shutting down before sleeping
|
||||
if self._shutting_down:
|
||||
logger.debug("Exiting reconnection loop due to shutdown")
|
||||
break
|
||||
|
||||
# Exponential back-off with jitter before the next attempt
|
||||
jitter = random.uniform(0.1, backoff)
|
||||
logger.debug(f"Waiting {jitter:.1f} seconds before reconnecting...")
|
||||
|
||||
# Sleep with proper event loop handling
|
||||
await asyncio.sleep(0) # Yield control first
|
||||
await asyncio.sleep(jitter)
|
||||
|
||||
logger.debug(f"Finished sleeping for {jitter} seconds")
|
||||
backoff = min(backoff * 2, self._MAX_BACKOFF)
|
||||
logger.debug(f"New backoff value: {backoff}, continuing loop...")
|
||||
|
||||
|
||||
async def setup_ari_client_supervisor(
|
||||
on_channel_start: Callable[[Channel, dict], Awaitable[None]],
|
||||
on_channel_end: Callable[[str], Awaitable[None]] | None = None,
|
||||
) -> "_ARIClientManagerSupervisor | None":
|
||||
"""Start a background supervisor that keeps the ARI connection alive.
|
||||
|
||||
This is a drop-in replacement for the asyncari-based function.
|
||||
Uses AsyncARIClient instead of asyncari.
|
||||
|
||||
If the *ENABLE_ARI_STASIS* environment variable is not set to ``"true"``
|
||||
(case-insensitive) the function returns ``None`` and no supervisor is
|
||||
launched.
|
||||
"""
|
||||
|
||||
if os.getenv("ENABLE_ARI_STASIS", "false").lower() != "true":
|
||||
logger.info("ARI Stasis integration disabled via environment variable")
|
||||
return None
|
||||
|
||||
logger.info("Starting ARI Client Supervisor with AsyncARIClient")
|
||||
|
||||
supervisor = _ARIClientManagerSupervisor(on_channel_start, on_channel_end)
|
||||
|
||||
# Start the supervisor in the background
|
||||
asyncio.create_task(supervisor.start())
|
||||
|
||||
return supervisor
|
||||
50
api/services/telephony/ari_client_singleton.py
Normal file
50
api/services/telephony/ari_client_singleton.py
Normal file
|
|
@ -0,0 +1,50 @@
|
|||
"""Singleton holder for the current ARI client instance.
|
||||
|
||||
This module provides a thread-safe singleton that holds the current
|
||||
ARI client instance, which can be updated during reconnections.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from api.services.telephony.ari_client import AsyncARIClient
|
||||
|
||||
|
||||
class ARIClientSingleton:
|
||||
"""Singleton holder for the current ARI client instance."""
|
||||
|
||||
_instance: Optional["ARIClientSingleton"] = None
|
||||
_client: Optional[AsyncARIClient] = None
|
||||
|
||||
def __new__(cls):
|
||||
"""Ensure only one instance exists."""
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
def set_client(self, client: AsyncARIClient) -> None:
|
||||
"""Update the ARI client instance.
|
||||
|
||||
Args:
|
||||
client: The new ARI client instance.
|
||||
"""
|
||||
self._client = client
|
||||
logger.info("ARI client singleton updated with new client instance")
|
||||
|
||||
def get_client(self) -> Optional[AsyncARIClient]:
|
||||
"""Get the current ARI client instance.
|
||||
|
||||
Returns:
|
||||
The current ARI client, or None if not set.
|
||||
"""
|
||||
return self._client
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Clear the current client instance."""
|
||||
self._client = None
|
||||
logger.info("ARI client singleton cleared")
|
||||
|
||||
|
||||
# Global singleton instance
|
||||
ari_client_singleton = ARIClientSingleton()
|
||||
748
api/services/telephony/ari_manager.py
Normal file
748
api/services/telephony/ari_manager.py
Normal file
|
|
@ -0,0 +1,748 @@
|
|||
"""Standalone ARI Manager Service for distributed architecture.
|
||||
|
||||
This service maintains the single WebSocket connection to Asterisk ARI
|
||||
and distributes events to multiple FastAPI workers via Redis pub/sub.
|
||||
|
||||
ARIManager creates an instance of ARIClientSupervisor and registers the callbacks
|
||||
on_channel_start and on_channel_end. It is responsible to take in caller_channel
|
||||
and setup ARIManagerConnection, i.e create bridge for externalMedia.
|
||||
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import signal
|
||||
import time
|
||||
from typing import Dict, Optional
|
||||
|
||||
from api.constants import REDIS_URL
|
||||
|
||||
# --- Add logging setup before importing loguru ---
|
||||
from api.logging_config import setup_logging
|
||||
from api.services.telephony.stasis_event_protocol import (
|
||||
BaseWorkerToARIManagerCommand,
|
||||
DisconnectCommand,
|
||||
RedisChannels,
|
||||
RedisKeys,
|
||||
SocketClosedCommand,
|
||||
TransferCommand,
|
||||
parse_command,
|
||||
)
|
||||
|
||||
logging_queue_listener = setup_logging()
|
||||
|
||||
import redis.asyncio as aioredis
|
||||
import redis.exceptions
|
||||
from loguru import logger
|
||||
from pipecat.utils.enums import EndTaskReason
|
||||
|
||||
from api.services.telephony.ari_client import Channel
|
||||
from api.services.telephony.ari_client_manager import (
|
||||
ARIClientManager,
|
||||
setup_ari_client_supervisor,
|
||||
)
|
||||
from api.services.telephony.ari_manager_connection import ARIManagerConnection
|
||||
|
||||
|
||||
class ARIManager:
|
||||
"""Manages ARI connection and distributes events to workers via Redis."""
|
||||
|
||||
def __init__(self, redis_client: aioredis.Redis):
|
||||
self.redis = redis_client
|
||||
self.stasis_manager: Optional[ARIClientManager] = None
|
||||
self._running = False
|
||||
self._ari_client_supervisor = None
|
||||
self._tasks: Dict[str, asyncio.Task] = {}
|
||||
self._pubsubs: Dict[
|
||||
str, aioredis.client.PubSub
|
||||
] = {} # Track pubsub connections
|
||||
self._active_channels: set[str] = (
|
||||
set()
|
||||
) # Track channels managed by this instance
|
||||
self._port_range = range(4000, 5000, 2) # Even ports only
|
||||
self._channel_connections: Dict[
|
||||
str, ARIManagerConnection
|
||||
] = {} # Track connections by channel ID
|
||||
self._channel_disposed: Dict[str, bool] = {} # Track channel disposed state
|
||||
self._socket_closed: Dict[str, bool] = {} # Track socket closed state
|
||||
self._active_workers: list[str] = [] # Cached list of active workers
|
||||
self._worker_discovery_task: Optional[asyncio.Task] = None
|
||||
self._channel_to_worker: Dict[str, str] = {} # Map channel to worker
|
||||
|
||||
async def on_channel_start(self, caller_channel: Channel, call_context_vars: dict):
|
||||
"""Handle new channel from ARIClientManager with atomically allocated port."""
|
||||
try:
|
||||
# Atomically allocate port for this channel (prevents race conditions)
|
||||
port = await self._get_and_allocate_port_atomic(caller_channel.id)
|
||||
|
||||
# Create connection with allocated port
|
||||
connection = ARIManagerConnection(
|
||||
caller_channel=caller_channel,
|
||||
host=os.getenv("ARI_STASIS_APP_ENDPOINT"),
|
||||
port=port,
|
||||
)
|
||||
|
||||
# Track the connection
|
||||
self._channel_connections[caller_channel.id] = connection
|
||||
# Initialize channel state flags
|
||||
self._channel_disposed[caller_channel.id] = False
|
||||
self._socket_closed[caller_channel.id] = False
|
||||
|
||||
# Handle the connection
|
||||
await self._on_stasis_call(connection, call_context_vars)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Error handling new channel {caller_channel.id}: {e}")
|
||||
# Release port if allocation failed
|
||||
await self._release_port_for_channel(caller_channel.id)
|
||||
|
||||
async def on_channel_end(self, channel_id: str):
|
||||
"""Handle channel end notification from ARIClientManager."""
|
||||
logger.info(f"channelID: {channel_id} Received channel end notification")
|
||||
|
||||
# Find the connection for this channel
|
||||
connection = None
|
||||
caller_channel_id = None
|
||||
|
||||
# Check if it's a caller channel
|
||||
if channel_id in self._channel_connections:
|
||||
connection = self._channel_connections[channel_id]
|
||||
caller_channel_id = channel_id
|
||||
else:
|
||||
# TODO: We are currently not handling StasisEnd on ExternalMedia
|
||||
for conn_channel_id, conn in self._channel_connections.items():
|
||||
if conn.em_channel_id and conn.em_channel_id == channel_id:
|
||||
logger.debug(
|
||||
f"channelID: {channel_id} ExternalMedia StasisEnd - Ignoring"
|
||||
)
|
||||
# connection = conn
|
||||
# caller_channel_id = conn_channel_id
|
||||
break
|
||||
|
||||
# Publish StasisEnd event to worker immediately
|
||||
if connection and caller_channel_id:
|
||||
worker_id = self._get_worker_for_channel(caller_channel_id)
|
||||
event = {
|
||||
"type": "stasis_end",
|
||||
"channel_id": caller_channel_id,
|
||||
"reason": EndTaskReason.USER_HANGUP.value,
|
||||
}
|
||||
await self.redis.publish(
|
||||
RedisChannels.worker_events(worker_id), json.dumps(event)
|
||||
)
|
||||
logger.info(f"channelID: {channel_id} Published StasisEnd event")
|
||||
|
||||
# Notify the connection about channel end
|
||||
await connection.notify_channel_end()
|
||||
|
||||
# Mark channel as disposed
|
||||
if caller_channel_id in self._channel_disposed:
|
||||
self._channel_disposed[caller_channel_id] = True
|
||||
# Check if both flags are set to cleanup
|
||||
await self._check_and_cleanup_channel(caller_channel_id)
|
||||
|
||||
async def _on_stasis_call(
|
||||
self, connection: ARIManagerConnection, call_context_vars: dict
|
||||
):
|
||||
"""Handle new Stasis call by setting up the connection and publishing to Redis."""
|
||||
try:
|
||||
# Setup the connection (create bridge and external media)
|
||||
await connection.setup_call()
|
||||
|
||||
if not connection.is_connected():
|
||||
logger.warning("Connection is not connected, skipping")
|
||||
return
|
||||
|
||||
# Extract all necessary information after bridge is created
|
||||
channel_id = connection.caller_channel_id
|
||||
em_channel_id = connection.em_channel_id
|
||||
bridge_id = connection.bridge_id
|
||||
|
||||
# Track this channel as active
|
||||
self._active_channels.add(channel_id)
|
||||
|
||||
# Create event with all connection details
|
||||
event = {
|
||||
"type": "stasis_start",
|
||||
"channel_id": channel_id,
|
||||
"caller_channel_id": channel_id,
|
||||
"em_channel_id": em_channel_id,
|
||||
"bridge_id": bridge_id,
|
||||
"local_addr": list(connection.local_addr),
|
||||
"remote_addr": list(connection.remote_addr),
|
||||
"call_context_vars": call_context_vars,
|
||||
}
|
||||
|
||||
# Select worker using round-robin
|
||||
worker_id = await self._select_worker()
|
||||
if worker_id is None:
|
||||
logger.error(f"channelID: {channel_id} No active workers available")
|
||||
await connection.disconnect()
|
||||
return
|
||||
|
||||
# Track channel to worker mapping
|
||||
self._channel_to_worker[channel_id] = worker_id
|
||||
channel = RedisChannels.worker_events(worker_id)
|
||||
|
||||
# Publish event to specific worker
|
||||
await self.redis.publish(channel, json.dumps(event))
|
||||
logger.info(
|
||||
f"channelID: {channel_id} Published stasis_start event to worker {worker_id}"
|
||||
)
|
||||
|
||||
# Start monitoring for commands from workers
|
||||
self._tasks[channel_id] = asyncio.create_task(
|
||||
self._monitor_channel_commands(channel_id, connection)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Error handling stasis call: {e}")
|
||||
|
||||
async def _get_and_allocate_port_atomic(self, channel_id: str) -> int:
|
||||
"""Atomically find and allocate an available port using Redis Lua script.
|
||||
|
||||
This method prevents race conditions by using a Lua script that executes
|
||||
atomically in Redis, ensuring that two concurrent calls cannot allocate
|
||||
the same port.
|
||||
"""
|
||||
# Lua script for atomic port allocation
|
||||
lua_script = """
|
||||
local port_range_start = tonumber(ARGV[1])
|
||||
local port_range_end = tonumber(ARGV[2])
|
||||
local port_range_step = tonumber(ARGV[3])
|
||||
local channel_id = KEYS[1]
|
||||
local timestamp = ARGV[4]
|
||||
|
||||
-- Check if channel already has a port allocated
|
||||
local existing_port = redis.call('HGET', 'channel_ports', channel_id)
|
||||
if existing_port then
|
||||
return tonumber(existing_port)
|
||||
end
|
||||
|
||||
-- Find first available port
|
||||
for port = port_range_start, port_range_end, port_range_step do
|
||||
local port_str = tostring(port)
|
||||
local exists = redis.call('HEXISTS', 'port_channels', port_str)
|
||||
if exists == 0 then
|
||||
-- Atomically allocate the port
|
||||
redis.call('HSET', 'channel_ports', channel_id, port)
|
||||
redis.call('HSET', 'port_channels', port_str, channel_id)
|
||||
redis.call('HSET', 'channel_allocation_time', channel_id, timestamp)
|
||||
return port
|
||||
end
|
||||
end
|
||||
|
||||
return -1 -- No ports available
|
||||
"""
|
||||
|
||||
# Execute the Lua script with port range parameters
|
||||
port_start = min(self._port_range)
|
||||
port_end = max(self._port_range)
|
||||
port_step = self._port_range.step
|
||||
timestamp = int(time.time())
|
||||
|
||||
port = await self.redis.eval(
|
||||
lua_script,
|
||||
1, # Number of keys
|
||||
channel_id, # KEYS[1]
|
||||
port_start, # ARGV[1]
|
||||
port_end, # ARGV[2]
|
||||
port_step, # ARGV[3]
|
||||
timestamp, # ARGV[4]
|
||||
)
|
||||
|
||||
if port == -1:
|
||||
# If all ports exhausted, clean up orphaned ports and retry
|
||||
await self._cleanup_orphaned_ports()
|
||||
|
||||
# Retry after cleanup
|
||||
port = await self.redis.eval(
|
||||
lua_script, 1, channel_id, port_start, port_end, port_step, timestamp
|
||||
)
|
||||
|
||||
if port == -1:
|
||||
raise RuntimeError(
|
||||
"No available ports in configured range after cleanup"
|
||||
)
|
||||
|
||||
logger.debug(f"Atomically allocated port {port} for channel {channel_id}")
|
||||
return port
|
||||
|
||||
async def _release_port_for_channel(self, channel_id: str):
|
||||
"""Atomically release port when channel ends.
|
||||
|
||||
Uses a Lua script to ensure all cleanup operations happen atomically,
|
||||
preventing partial cleanup or race conditions during release.
|
||||
"""
|
||||
lua_script = """
|
||||
local channel_id = KEYS[1]
|
||||
|
||||
-- Get the port allocated to this channel
|
||||
local port = redis.call('HGET', 'channel_ports', channel_id)
|
||||
|
||||
if port then
|
||||
-- Atomically clean up all related entries
|
||||
redis.call('HDEL', 'channel_ports', channel_id)
|
||||
redis.call('HDEL', 'port_channels', port)
|
||||
redis.call('HDEL', 'channel_allocation_time', channel_id)
|
||||
return port
|
||||
end
|
||||
|
||||
return nil
|
||||
"""
|
||||
|
||||
port = await self.redis.eval(lua_script, 1, channel_id)
|
||||
|
||||
if port:
|
||||
logger.debug(f"Atomically released port {port} for channel {channel_id}")
|
||||
else:
|
||||
logger.debug(f"No port was allocated for channel {channel_id}")
|
||||
|
||||
async def _discover_workers(self):
|
||||
"""Periodically discover active workers from Redis."""
|
||||
try:
|
||||
while self._running:
|
||||
try:
|
||||
# Get all worker IDs from the set
|
||||
worker_ids = await self.redis.smembers(RedisKeys.workers_set())
|
||||
|
||||
# Filter to only active workers
|
||||
active_workers = []
|
||||
for worker_id in worker_ids:
|
||||
worker_id = (
|
||||
worker_id.decode()
|
||||
if isinstance(worker_id, bytes)
|
||||
else worker_id
|
||||
)
|
||||
worker_key = RedisKeys.worker_active(worker_id)
|
||||
worker_data = await self.redis.get(worker_key)
|
||||
|
||||
if worker_data:
|
||||
try:
|
||||
data = json.loads(worker_data)
|
||||
# Only include workers that are ready (not draining)
|
||||
if data.get("status") == "ready":
|
||||
active_workers.append(worker_id)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"Invalid worker data for {worker_id}")
|
||||
|
||||
# Update the cached list atomically
|
||||
self._active_workers = active_workers
|
||||
logger.info(f"Discovered {len(active_workers)} active workers")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error discovering workers: {e}")
|
||||
|
||||
# Check every 5 seconds
|
||||
await asyncio.sleep(5)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.debug("Worker discovery task cancelled")
|
||||
|
||||
async def _select_worker(self) -> Optional[str]:
|
||||
"""Select a worker using round-robin."""
|
||||
if not self._active_workers:
|
||||
return None
|
||||
|
||||
# Use Redis to maintain round-robin index across restarts
|
||||
try:
|
||||
index = await self.redis.incr(RedisKeys.round_robin_index())
|
||||
worker_index = (index - 1) % len(self._active_workers)
|
||||
return self._active_workers[worker_index]
|
||||
except Exception as e:
|
||||
logger.error(f"Error selecting worker: {e}")
|
||||
# Fallback to first worker if Redis operation fails
|
||||
return self._active_workers[0] if self._active_workers else None
|
||||
|
||||
def _get_worker_for_channel(self, channel_id: str) -> str:
|
||||
"""Get the assigned worker for a channel (for sending commands)."""
|
||||
# Return the worker ID that was assigned to this channel
|
||||
return self._channel_to_worker.get(channel_id, "")
|
||||
|
||||
async def _monitor_channel_commands(
|
||||
self, channel_id: str, connection: ARIManagerConnection
|
||||
):
|
||||
"""Listen for commands from workers for this channel."""
|
||||
# TODO: Not sure if its a good idea to monitor command for every channel
|
||||
# using pubsub. What happens if there are more number of calls than number
|
||||
# of tcp connections redis can support? We can do something similar to
|
||||
# Campaign Orchestrator, where we can subscribe to one channel and have
|
||||
# commands for every channel there.
|
||||
command_channel = RedisChannels.channel_commands(channel_id)
|
||||
pubsub = None
|
||||
|
||||
try:
|
||||
pubsub = self.redis.pubsub()
|
||||
await pubsub.subscribe(command_channel)
|
||||
|
||||
# Store the pubsub connection for cleanup
|
||||
self._pubsubs[channel_id] = pubsub
|
||||
|
||||
logger.debug(f"channelID: {channel_id} Monitoring commands for channel")
|
||||
|
||||
async for message in pubsub.listen():
|
||||
if message["type"] == "message":
|
||||
try:
|
||||
command = parse_command(message["data"])
|
||||
if command:
|
||||
await self._handle_worker_command(
|
||||
channel_id, command, connection
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"Failed to parse command for {channel_id}: {message['data']}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Error handling command for {channel_id}: {e}"
|
||||
)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.debug(f"channelID: {channel_id} Command monitor cancelled")
|
||||
raise # Re-raise to maintain proper cancellation semantics
|
||||
except (ConnectionError, redis.exceptions.ConnectionError) as e:
|
||||
# We close the pubsub before cancelling the task. So, the code
|
||||
# flow will arrive here
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.exception(f"Error in command monitor for {channel_id}: {e}")
|
||||
|
||||
async def _handle_worker_command(
|
||||
self,
|
||||
channel_id: str,
|
||||
command: BaseWorkerToARIManagerCommand,
|
||||
connection: ARIManagerConnection,
|
||||
):
|
||||
"""Execute commands from workers."""
|
||||
if isinstance(command, DisconnectCommand):
|
||||
logger.info(
|
||||
f"channelID: {channel_id} Worker requested disconnect: {command.reason}"
|
||||
)
|
||||
await connection.disconnect(command.reason)
|
||||
|
||||
elif isinstance(command, TransferCommand):
|
||||
logger.info(f"channelID: {channel_id} Worker requested transfer")
|
||||
await connection.transfer(command.context)
|
||||
|
||||
elif isinstance(command, SocketClosedCommand):
|
||||
logger.info(f"channelID: {channel_id} Worker notified socket closed")
|
||||
|
||||
# Mark socket as closed
|
||||
if channel_id in self._socket_closed:
|
||||
self._socket_closed[channel_id] = True
|
||||
|
||||
# Release port immediately
|
||||
await self._release_port_for_channel(channel_id)
|
||||
|
||||
# Check if both flags are set to cleanup
|
||||
await self._check_and_cleanup_channel(channel_id)
|
||||
else:
|
||||
logger.warning(
|
||||
f"channelID: {channel_id} Received unknown command: {command}"
|
||||
)
|
||||
|
||||
async def _check_and_cleanup_channel(self, channel_id: str):
|
||||
"""Check if both flags are set and cleanup channel if so."""
|
||||
channel_disposed = self._channel_disposed.get(channel_id, False)
|
||||
socket_closed = self._socket_closed.get(channel_id, False)
|
||||
|
||||
logger.debug(
|
||||
f"channelID: {channel_id} Check cleanup - disposed: {channel_disposed}, socket_closed: {socket_closed}"
|
||||
)
|
||||
|
||||
if channel_disposed and socket_closed:
|
||||
# Remove from active channels and connections
|
||||
self._active_channels.discard(channel_id)
|
||||
self._channel_connections.pop(channel_id, None)
|
||||
|
||||
# Close pubsub connection first (before cancelling task)
|
||||
if channel_id in self._pubsubs:
|
||||
pubsub = self._pubsubs[channel_id]
|
||||
try:
|
||||
command_channel = RedisChannels.channel_commands(channel_id)
|
||||
await pubsub.unsubscribe(command_channel)
|
||||
await pubsub.aclose()
|
||||
logger.debug(
|
||||
f"channelID: {channel_id} Closed pubsub connection in cleanup"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Error closing pubsub for {channel_id}: {e}")
|
||||
finally:
|
||||
del self._pubsubs[channel_id]
|
||||
|
||||
# Cancel command monitor task
|
||||
if channel_id in self._tasks:
|
||||
task = self._tasks[channel_id]
|
||||
if not task.done():
|
||||
# Task is still running, cancel it
|
||||
task.cancel()
|
||||
try:
|
||||
# Wait for task to complete
|
||||
await task
|
||||
logger.debug(
|
||||
f"channelID: {channel_id} Task completed after cancel"
|
||||
)
|
||||
except asyncio.CancelledError:
|
||||
logger.debug(
|
||||
f"channelID: {channel_id} Task cancelled successfully"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"channelID: {channel_id} Task raised exception: {e}"
|
||||
)
|
||||
else:
|
||||
# Task already completed
|
||||
logger.debug(
|
||||
f"channelID: {channel_id} Monitor task already completed"
|
||||
)
|
||||
try:
|
||||
# Still await to get any exception that might have occurred
|
||||
await task
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"channelID: {channel_id} Completed task had exception: {e}"
|
||||
)
|
||||
|
||||
del self._tasks[channel_id]
|
||||
|
||||
# Clean up the flag tracking
|
||||
self._channel_disposed.pop(channel_id, None)
|
||||
self._socket_closed.pop(channel_id, None)
|
||||
|
||||
logger.info(f"channelID: {channel_id} Completed cleanup of all resources")
|
||||
|
||||
async def _cleanup_orphaned_ports(self):
|
||||
"""Clean up ports from previous ungraceful shutdowns."""
|
||||
try:
|
||||
# Get all channel-port mappings
|
||||
channel_ports = await self.redis.hgetall("channel_ports")
|
||||
if not channel_ports:
|
||||
return
|
||||
|
||||
logger.info(
|
||||
f"Found {len(channel_ports)} existing port allocations, checking for orphans..."
|
||||
)
|
||||
|
||||
cleaned = 0
|
||||
current_time = int(time.time())
|
||||
max_age_seconds = 3600 # 1 hour
|
||||
|
||||
# On startup, we can safely assume any existing allocations are orphaned
|
||||
# since this is a fresh instance with no active channels yet
|
||||
if not self._active_channels:
|
||||
# Clean up all existing allocations on startup
|
||||
for channel_id, port in channel_ports.items():
|
||||
allocation_time = await self.redis.hget(
|
||||
"channel_allocation_time", channel_id
|
||||
)
|
||||
age_str = ""
|
||||
if allocation_time:
|
||||
age = current_time - int(allocation_time)
|
||||
age_str = f" (aged {age}s)"
|
||||
|
||||
await self._release_port_for_channel(channel_id)
|
||||
logger.info(
|
||||
f"Cleaned up orphaned port {port} for channel {channel_id}{age_str}"
|
||||
)
|
||||
cleaned += 1
|
||||
else:
|
||||
# During runtime, only clean up channels not being tracked
|
||||
for channel_id, port in channel_ports.items():
|
||||
if channel_id not in self._active_channels:
|
||||
# Check allocation age
|
||||
allocation_time = await self.redis.hget(
|
||||
"channel_allocation_time", channel_id
|
||||
)
|
||||
if allocation_time:
|
||||
age = current_time - int(allocation_time)
|
||||
if age > max_age_seconds:
|
||||
# Too old, clean up regardless
|
||||
await self._release_port_for_channel(channel_id)
|
||||
logger.info(
|
||||
f"Cleaned up stale port {port} for channel {channel_id} (aged {age}s)"
|
||||
)
|
||||
cleaned += 1
|
||||
continue
|
||||
|
||||
# Not tracked by this instance, might be orphaned
|
||||
# For safety, only clean up if reasonably old (5 minutes)
|
||||
if (
|
||||
allocation_time
|
||||
and (current_time - int(allocation_time)) > 300
|
||||
):
|
||||
await self._release_port_for_channel(channel_id)
|
||||
logger.info(
|
||||
f"Cleaned up orphaned port {port} for untracked channel {channel_id}"
|
||||
)
|
||||
cleaned += 1
|
||||
|
||||
if cleaned > 0:
|
||||
logger.info(f"Cleaned up {cleaned} orphaned port allocations")
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Error during orphaned port cleanup: {e}")
|
||||
|
||||
async def _periodic_cleanup(self):
|
||||
"""Periodically clean up orphaned ports."""
|
||||
cleanup_interval = 1800 # 30 minutes
|
||||
|
||||
while self._running:
|
||||
try:
|
||||
await asyncio.sleep(cleanup_interval)
|
||||
if self._running: # Check again after sleep
|
||||
logger.info("Running periodic orphaned port cleanup...")
|
||||
await self._cleanup_orphaned_ports()
|
||||
except asyncio.CancelledError:
|
||||
logger.debug("Periodic cleanup task cancelled")
|
||||
break
|
||||
except Exception as e:
|
||||
logger.exception(f"Error in periodic cleanup: {e}")
|
||||
|
||||
async def run(self):
|
||||
"""Main run loop for ARI Manager."""
|
||||
self._running = True
|
||||
|
||||
# Setup ARI connection with supervisor
|
||||
try:
|
||||
self._ari_client_supervisor = await setup_ari_client_supervisor(
|
||||
self.on_channel_start, self.on_channel_end
|
||||
)
|
||||
if not self._ari_client_supervisor:
|
||||
logger.error("Failed to setup ARI connection")
|
||||
return
|
||||
|
||||
# Start worker discovery task
|
||||
self._worker_discovery_task = asyncio.create_task(self._discover_workers())
|
||||
|
||||
# Wait a moment for initial worker discovery
|
||||
await asyncio.sleep(1)
|
||||
|
||||
logger.info(
|
||||
f"ARI Manager started with {len(self._active_workers)} active workers"
|
||||
)
|
||||
|
||||
# Clean up any orphaned ports from previous runs
|
||||
await self._cleanup_orphaned_ports()
|
||||
|
||||
# Start periodic cleanup task
|
||||
cleanup_task = asyncio.create_task(self._periodic_cleanup())
|
||||
|
||||
# Keep running until shutdown
|
||||
while self._running:
|
||||
await asyncio.sleep(1)
|
||||
|
||||
logger.debug("ARIManager._running is false. Will cleanup and shutdown")
|
||||
|
||||
# Cancel cleanup task
|
||||
cleanup_task.cancel()
|
||||
try:
|
||||
await cleanup_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"ARI Manager error: {e}")
|
||||
finally:
|
||||
if self._ari_client_supervisor:
|
||||
await self._ari_client_supervisor.close()
|
||||
logger.info("ARI Manager stopped")
|
||||
|
||||
async def shutdown(self):
|
||||
"""Graceful shutdown."""
|
||||
logger.info("Shutting down ARI Manager...")
|
||||
|
||||
# Close supervisor first to prevent reconnection attempts
|
||||
if self._ari_client_supervisor:
|
||||
await self._ari_client_supervisor.close()
|
||||
|
||||
# Cancel worker discovery task
|
||||
if self._worker_discovery_task:
|
||||
self._worker_discovery_task.cancel()
|
||||
try:
|
||||
await self._worker_discovery_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
self._worker_discovery_task = None
|
||||
|
||||
# Now set running to False
|
||||
self._running = False
|
||||
|
||||
# Clean up all active channel ports before shutting down
|
||||
if self._active_channels:
|
||||
logger.info(f"Cleaning up {len(self._active_channels)} active channels...")
|
||||
for channel_id in list(
|
||||
self._active_channels
|
||||
): # Copy to avoid modification during iteration
|
||||
await self._release_port_for_channel(channel_id)
|
||||
logger.info(
|
||||
f"Released port for active channel {channel_id} during shutdown"
|
||||
)
|
||||
self._active_channels.clear()
|
||||
|
||||
# Clear flag tracking
|
||||
self._channel_disposed.clear()
|
||||
self._socket_closed.clear()
|
||||
|
||||
# Cancel all monitoring tasks
|
||||
for task in self._tasks.values():
|
||||
task.cancel()
|
||||
|
||||
# Wait for tasks to complete
|
||||
if self._tasks:
|
||||
await asyncio.gather(*self._tasks.values(), return_exceptions=True)
|
||||
|
||||
|
||||
async def main():
|
||||
"""Main entry point for ARI Manager service."""
|
||||
# Setup Redis connection
|
||||
redis = await aioredis.from_url(REDIS_URL, decode_responses=True)
|
||||
|
||||
# Create and run manager
|
||||
manager = ARIManager(redis)
|
||||
|
||||
# Create a shutdown event for clean coordination
|
||||
shutdown_event = asyncio.Event()
|
||||
|
||||
# Setup signal handlers
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
def signal_handler(signum):
|
||||
logger.info(f"Received shutdown signal {signum}")
|
||||
# Set the shutdown event which will trigger shutdown
|
||||
shutdown_event.set()
|
||||
|
||||
for sig in (signal.SIGTERM, signal.SIGINT):
|
||||
loop.add_signal_handler(sig, lambda s=sig: signal_handler(s))
|
||||
|
||||
# Run manager with shutdown monitoring
|
||||
manager_task = asyncio.create_task(manager.run())
|
||||
shutdown_task = asyncio.create_task(shutdown_event.wait())
|
||||
|
||||
try:
|
||||
# Wait for either normal completion or shutdown signal
|
||||
done, pending = await asyncio.wait(
|
||||
[manager_task, shutdown_task], return_when=asyncio.FIRST_COMPLETED
|
||||
)
|
||||
|
||||
# If shutdown was triggered, perform graceful shutdown
|
||||
if shutdown_task in done:
|
||||
await manager.shutdown()
|
||||
# Cancel the manager task if still running
|
||||
if manager_task in pending:
|
||||
manager_task.cancel()
|
||||
try:
|
||||
await manager_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
finally:
|
||||
await redis.aclose()
|
||||
# --- Ensure Axiom logging listener is stopped gracefully ---
|
||||
if logging_queue_listener is not None:
|
||||
logging_queue_listener.stop()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Configure logging
|
||||
logger.add("logs/ari_manager.log", rotation="10 MB")
|
||||
asyncio.run(main())
|
||||
323
api/services/telephony/ari_manager_connection.py
Normal file
323
api/services/telephony/ari_manager_connection.py
Normal file
|
|
@ -0,0 +1,323 @@
|
|||
"""ARI-specific Stasis connection for use by ARI Manager.
|
||||
|
||||
This connection has direct access to the ARI client and manages
|
||||
the actual Asterisk channels, bridges, and RTP setup.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
import httpx
|
||||
from loguru import logger
|
||||
from pipecat.utils.base_object import BaseObject
|
||||
|
||||
from api.services.telephony.ari_client import AsyncARIClient, Bridge, Channel
|
||||
from api.services.telephony.ari_client_singleton import ari_client_singleton
|
||||
|
||||
|
||||
class ARIManagerConnection(BaseObject):
|
||||
"""ARI Manager's connection that directly controls Asterisk resources.
|
||||
|
||||
This class is used only by the ARI Manager process and has full
|
||||
access to the ARI client for creating bridges, channels, etc.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
caller_channel: Channel,
|
||||
host: str,
|
||||
port: int,
|
||||
) -> None:
|
||||
"""Initialize ARI Stasis connection.
|
||||
|
||||
Args:
|
||||
caller_channel: The caller's channel object.
|
||||
host: Host address for RTP transport.
|
||||
port: Port number for RTP transport.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
# External dependencies.
|
||||
self._host: str = host
|
||||
self._port: int = port
|
||||
|
||||
# Store channel IDs instead of Channel objects to avoid stale references
|
||||
self.caller_channel_id: str = caller_channel.id
|
||||
self.em_channel_id: Optional[str] = None # externalMedia channel ID
|
||||
|
||||
# Store bridge ID to avoid stale references after reconnection
|
||||
self.bridge_id: Optional[str] = None
|
||||
|
||||
# RTP addressing information
|
||||
self.local_addr = ("0.0.0.0", port)
|
||||
self.remote_addr = None
|
||||
|
||||
# Internal state.
|
||||
self._closed: bool = False
|
||||
self._is_connected: bool = False
|
||||
|
||||
def is_connected(self) -> bool:
|
||||
"""Check if the connection is established."""
|
||||
return self._is_connected and not self._closed
|
||||
|
||||
@property
|
||||
def _ari(self) -> Optional[AsyncARIClient]:
|
||||
"""Get the current ARI client from singleton."""
|
||||
return ari_client_singleton.get_client()
|
||||
|
||||
async def _get_channel(self, channel_id: str) -> Optional[Channel]:
|
||||
"""Safely get a channel object by ID.
|
||||
|
||||
Returns None if the channel doesn't exist or can't be fetched.
|
||||
"""
|
||||
if not channel_id:
|
||||
return None
|
||||
try:
|
||||
# Get current client from singleton
|
||||
client = self._ari
|
||||
if not client:
|
||||
logger.warning(
|
||||
f"Cannot get channel {channel_id} - No ARI client available"
|
||||
)
|
||||
return None
|
||||
# Check if the session is still active
|
||||
if not client._session or client._session.closed:
|
||||
logger.warning(
|
||||
f"Cannot get channel {channel_id} - ARI session is closed"
|
||||
)
|
||||
return None
|
||||
return await client.channels.get(channelId=channel_id)
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not get channel {channel_id} - {e}")
|
||||
return None
|
||||
|
||||
async def _get_bridge(self, bridge_id: str) -> Optional[Bridge]:
|
||||
"""Safely get a bridge object by ID.
|
||||
|
||||
Returns None if the bridge doesn't exist or can't be fetched.
|
||||
"""
|
||||
if not bridge_id:
|
||||
return None
|
||||
try:
|
||||
# Get current client from singleton
|
||||
client = self._ari
|
||||
if not client:
|
||||
logger.warning(
|
||||
f"Cannot get bridge {bridge_id} - No ARI client available"
|
||||
)
|
||||
return None
|
||||
# Check if the session is still active
|
||||
if not client._session or client._session.closed:
|
||||
logger.warning(f"Cannot get bridge {bridge_id} - ARI session is closed")
|
||||
return None
|
||||
return await client.bridges.get(bridgeId=bridge_id)
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not get bridge {bridge_id}: {e}")
|
||||
return None
|
||||
|
||||
async def _cleanup_resources(self):
|
||||
"""Clean up external media channel and bridge."""
|
||||
# Cleanup external media channel
|
||||
try:
|
||||
if self.em_channel_id:
|
||||
em_channel = await self._get_channel(self.em_channel_id)
|
||||
if em_channel:
|
||||
await em_channel.hangup()
|
||||
logger.debug(
|
||||
f"channelID: {self.em_channel_id} Hung up external media"
|
||||
)
|
||||
self.em_channel_id = None
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
f"Failed to hang-up externalMedia channel: {self.em_channel_id}"
|
||||
f"Error: {exc}"
|
||||
)
|
||||
|
||||
# Cleanup bridge
|
||||
try:
|
||||
if self.bridge_id:
|
||||
bridge = await self._get_bridge(self.bridge_id)
|
||||
if bridge:
|
||||
await bridge.destroy()
|
||||
logger.debug(f"bridgeID: {self.bridge_id} Destroyed bridge")
|
||||
self.bridge_id = None
|
||||
except Exception as exc:
|
||||
logger.warning(f"Failed to destroy bridge: {self.bridge_id}Error: {exc}")
|
||||
|
||||
async def _sync_call_data(self, call_transfer_context: dict):
|
||||
"""Sync call data to ARI_DATA_SYNCING_URI."""
|
||||
if not os.getenv("ARI_DATA_SYNCING_URI"):
|
||||
return
|
||||
|
||||
lead_id = call_transfer_context.get("lead_id")
|
||||
status = call_transfer_context.get("disposition")
|
||||
|
||||
# {'lead_id': '299154', 'disposition': 'VM', 'agent_name': 'Alex', 'decision_maker': 'False', 'employment': 'N/A', 'debts': 'N/A', 'number_of_credit_cards': 'N/A', 'time': '2025-08-07T13:16:02-04:00'}
|
||||
|
||||
full_name = call_transfer_context.get("full_name", "")
|
||||
phone = call_transfer_context.get("phone", "")
|
||||
debts = call_transfer_context.get("debts", "")
|
||||
employment = call_transfer_context.get("employment", "")
|
||||
time = call_transfer_context.get("time", "")
|
||||
|
||||
comment = f"Type:Qualified!NName:{full_name}!NPhone:{phone}!NDebts:{debts}!NCC:N/A!NDM:Yes!NEmployment:{employment}!NTime:{time}!NVendor Id:!NStatus:{status}"
|
||||
|
||||
try:
|
||||
if lead_id and status:
|
||||
ari_data_uri = os.getenv("ARI_DATA_SYNCING_URI")
|
||||
# Add URL params to the base URL
|
||||
sync_url = f"{ari_data_uri}&lead_id={lead_id}&status={status}&comments={comment}"
|
||||
|
||||
logger.debug(
|
||||
f"channelID: {self.caller_channel_id} Syncing data to ARI_DATA_SYNCING_URI: {sync_url}"
|
||||
)
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(sync_url, timeout=10.0)
|
||||
response.raise_for_status()
|
||||
logger.info(
|
||||
f"channelID: {self.caller_channel_id} Successfully synced data for lead_id: {lead_id} with status: {status}"
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"channelID: {self.caller_channel_id} Missing lead_id or status for syncing"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"channelID: {self.caller_channel_id} Failed to sync data to ARI_DATA_SYNCING_URI: {e}"
|
||||
)
|
||||
|
||||
async def disconnect(self, reason: str):
|
||||
"""Instruct Asterisk to hang-up the call and perform cleanup."""
|
||||
if self._closed:
|
||||
return
|
||||
|
||||
# Lets mark it as closed so that when we receive StasisEnd, we don't
|
||||
# try to cleanup resource again
|
||||
self._closed = True
|
||||
|
||||
# Clean up resources first
|
||||
await self._cleanup_resources()
|
||||
|
||||
try:
|
||||
if self.caller_channel_id:
|
||||
caller_channel = await self._get_channel(self.caller_channel_id)
|
||||
if caller_channel:
|
||||
logger.debug(
|
||||
f"channelID: {self.caller_channel_id} Hanging up caller channel due to reason: {reason}"
|
||||
)
|
||||
await caller_channel.hangup()
|
||||
except Exception:
|
||||
logger.exception("Failed to hangup caller channel")
|
||||
|
||||
async def transfer(self, call_transfer_context: dict):
|
||||
"""Transfer the call by continuing in dialplan with extracted variables."""
|
||||
if self._closed:
|
||||
return
|
||||
|
||||
# Lets mark it as closed so that when we receive StasisEnd, we don't
|
||||
# try to cleanup resource again
|
||||
self._closed = True
|
||||
|
||||
try:
|
||||
# Clean up resources before transferring
|
||||
await self._cleanup_resources()
|
||||
|
||||
if self.caller_channel_id:
|
||||
caller_channel = await self._get_channel(self.caller_channel_id)
|
||||
if caller_channel:
|
||||
logger.debug(
|
||||
f"channelID: {self.caller_channel_id} User qualified, continuing in dialplan "
|
||||
f"REMOTE_DISPO_CALL_VARIABLES: {json.dumps(call_transfer_context)}"
|
||||
)
|
||||
|
||||
# Sync data to ARI_DATA_SYNCING_URI
|
||||
await self._sync_call_data(
|
||||
call_transfer_context=call_transfer_context
|
||||
)
|
||||
|
||||
await caller_channel.continueInDialplan()
|
||||
except Exception:
|
||||
logger.exception("Failed to transfer caller channel")
|
||||
|
||||
async def setup_call(self):
|
||||
"""Setup the bridge and external media channel.
|
||||
|
||||
This must be called after initialization to establish the connection.
|
||||
"""
|
||||
await self._setup_call(self._host, self._port)
|
||||
|
||||
async def _setup_call(self, host: str, port: int):
|
||||
"""Create externalMedia + bridge and notify that the call is connected."""
|
||||
try:
|
||||
em_channel_id = str(uuid.uuid4())
|
||||
logger.debug(
|
||||
f"channelID: {em_channel_id} Creating externalMedia channel on {host}:{port}"
|
||||
)
|
||||
|
||||
client = self._ari
|
||||
if not client:
|
||||
raise RuntimeError("No ARI client available")
|
||||
|
||||
em_channel = await client.channels.externalMedia(
|
||||
app=client.app,
|
||||
channelId=em_channel_id,
|
||||
external_host=f"{host}:{port}",
|
||||
format="ulaw",
|
||||
direction="both",
|
||||
)
|
||||
|
||||
# Store the channel ID
|
||||
self.em_channel_id = em_channel.id
|
||||
|
||||
# Create a mixing bridge and add both legs.
|
||||
bridge = await client.bridges.create(type="mixing")
|
||||
self.bridge_id = bridge.id
|
||||
# Add channels individually as AsyncARIClient expects single channel per call
|
||||
await bridge.addChannel(channel=self.caller_channel_id)
|
||||
await bridge.addChannel(channel=self.em_channel_id)
|
||||
|
||||
# TODO: Figure out how can we get the remote public IP. Till then
|
||||
# just pick it from the environment variable
|
||||
# Get RTP addressing information
|
||||
# ip = await em_channel.getChannelVar(
|
||||
# variable="UNICASTRTP_LOCAL_ADDRESS"
|
||||
# )
|
||||
port = await em_channel.getChannelVar(variable="UNICASTRTP_LOCAL_PORT")
|
||||
|
||||
self.remote_addr = (
|
||||
os.environ.get("ASTERISK_REMOTE_IP"),
|
||||
int(port["value"]),
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"channelID: {self.caller_channel_id} ARIManagerConnection connection resources ready "
|
||||
f"(bridgeID: {self.bridge_id}), (emChannelID: {self.em_channel_id})"
|
||||
f"remote address: {self.remote_addr}, local address: {self.local_addr}"
|
||||
)
|
||||
|
||||
self._is_connected = True
|
||||
|
||||
except Exception as exc:
|
||||
logger.exception(f"Error setting up ARIManagerConnection: {exc}")
|
||||
await self._cleanup_resources()
|
||||
|
||||
async def notify_channel_end(self):
|
||||
"""Notify that a channel has ended. Received after we get StasisEnd on the caller channel"""
|
||||
if self._closed:
|
||||
return
|
||||
|
||||
self._closed = True
|
||||
self._is_connected = False
|
||||
|
||||
# Cleanup resources using the shared method
|
||||
await self._cleanup_resources()
|
||||
|
||||
def __repr__(self):
|
||||
"""Return string representation of connection."""
|
||||
return (
|
||||
f"<ARIManagerConnection id={self.id} caller={self.caller_channel_id} "
|
||||
f"em={self.em_channel_id} bridge={self.bridge_id} state={'closed' if self._closed else 'open'}>"
|
||||
)
|
||||
184
api/services/telephony/stasis_event_protocol.py
Normal file
184
api/services/telephony/stasis_event_protocol.py
Normal file
|
|
@ -0,0 +1,184 @@
|
|||
"""Redis communication protocol for distributed ARI architecture.
|
||||
|
||||
Defines message formats and helpers for ARI Manager <-> Worker communication.
|
||||
"""
|
||||
|
||||
import json
|
||||
from dataclasses import asdict, dataclass
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
|
||||
class EventType(str, Enum):
|
||||
"""Types of events sent from ARI Manager to Workers."""
|
||||
|
||||
STASIS_START = "stasis_start"
|
||||
STASIS_END = "stasis_end"
|
||||
CHANNEL_UPDATE = "channel_update"
|
||||
ERROR = "error"
|
||||
|
||||
|
||||
class CommandType(str, Enum):
|
||||
"""Types of commands sent from Workers to ARI Manager."""
|
||||
|
||||
DISCONNECT = "disconnect"
|
||||
TRANSFER = "transfer"
|
||||
UPDATE_STATE = "update_state"
|
||||
SOCKET_CLOSED = "socket_closed"
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseWorkerToARIManagerCommand:
|
||||
"""Base class for all commands sent from Workers to ARI Manager."""
|
||||
|
||||
type: str
|
||||
channel_id: str = ""
|
||||
|
||||
def to_json(self) -> str:
|
||||
return json.dumps(asdict(self))
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, data: str):
|
||||
return cls(**json.loads(data))
|
||||
|
||||
|
||||
@dataclass
|
||||
class StasisStartEvent:
|
||||
"""Event sent when a new call is bridged and ready."""
|
||||
|
||||
type: str = EventType.STASIS_START
|
||||
channel_id: str = ""
|
||||
caller_channel_id: str = ""
|
||||
em_channel_id: Optional[str] = None
|
||||
bridge_id: Optional[str] = None
|
||||
local_addr: List[Any] = None # [host, port]
|
||||
remote_addr: Optional[List[Any]] = None # [host, port] with UNICASTRTP_LOCAL_PORT
|
||||
call_context_vars: Dict[str, Any] = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.local_addr is None:
|
||||
self.local_addr = []
|
||||
if self.call_context_vars is None:
|
||||
self.call_context_vars = {}
|
||||
|
||||
def to_json(self) -> str:
|
||||
return json.dumps(asdict(self))
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, data: str) -> "StasisStartEvent":
|
||||
return cls(**json.loads(data))
|
||||
|
||||
|
||||
@dataclass
|
||||
class StasisEndEvent:
|
||||
"""Event sent when a call ends."""
|
||||
|
||||
type: str = EventType.STASIS_END
|
||||
channel_id: str = ""
|
||||
reason: Optional[str] = None
|
||||
|
||||
def to_json(self) -> str:
|
||||
return json.dumps(asdict(self))
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, data: str) -> "StasisEndEvent":
|
||||
return cls(**json.loads(data))
|
||||
|
||||
|
||||
@dataclass
|
||||
class DisconnectCommand(BaseWorkerToARIManagerCommand):
|
||||
"""Command to disconnect a call."""
|
||||
|
||||
type: str = CommandType.DISCONNECT
|
||||
reason: str = "worker_requested"
|
||||
|
||||
|
||||
@dataclass
|
||||
class TransferCommand(BaseWorkerToARIManagerCommand):
|
||||
"""Command to transfer a call."""
|
||||
|
||||
type: str = CommandType.TRANSFER
|
||||
context: Dict[str, Any] = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.context is None:
|
||||
self.context = {}
|
||||
|
||||
|
||||
@dataclass
|
||||
class SocketClosedCommand(BaseWorkerToARIManagerCommand):
|
||||
"""Command to notify that RTP sockets have been closed."""
|
||||
|
||||
type: str = CommandType.SOCKET_CLOSED
|
||||
|
||||
|
||||
class RedisChannels:
|
||||
"""Redis channel naming conventions."""
|
||||
|
||||
@staticmethod
|
||||
def worker_events(worker_id: str) -> str:
|
||||
"""Channel for events sent to a specific worker."""
|
||||
return f"ari:events:worker:{worker_id}"
|
||||
|
||||
@staticmethod
|
||||
def channel_commands(channel_id: str) -> str:
|
||||
"""Channel for commands related to a specific call channel."""
|
||||
return f"ari:commands:{channel_id}"
|
||||
|
||||
@staticmethod
|
||||
def channel_updates(channel_id: str) -> str:
|
||||
"""Channel for state updates about a specific call."""
|
||||
return f"ari:updates:{channel_id}"
|
||||
|
||||
|
||||
class RedisKeys:
|
||||
"""Redis key naming conventions for worker registration and discovery."""
|
||||
|
||||
@staticmethod
|
||||
def worker_active(worker_id: str) -> str:
|
||||
"""Key for active worker status and metadata."""
|
||||
return f"workers:active:{worker_id}"
|
||||
|
||||
@staticmethod
|
||||
def workers_set() -> str:
|
||||
"""Set containing all registered worker IDs."""
|
||||
return "workers:set"
|
||||
|
||||
@staticmethod
|
||||
def round_robin_index() -> str:
|
||||
"""Counter for round-robin worker selection."""
|
||||
return "workers:round_robin:index"
|
||||
|
||||
|
||||
def parse_event(data: str) -> Any:
|
||||
"""Parse a Redis event message."""
|
||||
try:
|
||||
parsed = json.loads(data)
|
||||
event_type = parsed.get("type")
|
||||
|
||||
if event_type == EventType.STASIS_START:
|
||||
return StasisStartEvent(**parsed)
|
||||
elif event_type == EventType.STASIS_END:
|
||||
return StasisEndEvent(**parsed)
|
||||
else:
|
||||
return parsed
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def parse_command(data: str) -> Any:
|
||||
"""Parse a Redis command message."""
|
||||
try:
|
||||
parsed = json.loads(data)
|
||||
cmd_type = parsed.get("type")
|
||||
|
||||
if cmd_type == CommandType.DISCONNECT:
|
||||
return DisconnectCommand(**parsed)
|
||||
elif cmd_type == CommandType.TRANSFER:
|
||||
return TransferCommand(**parsed)
|
||||
elif cmd_type == CommandType.SOCKET_CLOSED:
|
||||
return SocketClosedCommand(**parsed)
|
||||
else:
|
||||
return parsed
|
||||
except Exception:
|
||||
return None
|
||||
361
api/services/telephony/stasis_rtp_client.py
Normal file
361
api/services/telephony/stasis_rtp_client.py
Normal file
|
|
@ -0,0 +1,361 @@
|
|||
"""Low-level RTP transport for Asterisk externalMedia sessions.
|
||||
|
||||
stasis_rtp_client.py
|
||||
~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
* Sends and receives **proper RTP/UDP** (PT 0 PCMU/μ-law).
|
||||
* Uses 20 ms frames (160 bytes payload) by default; automatically
|
||||
chunks or concatenates data so timestamps stay correct.
|
||||
* Verifies the RTP header on the receive path (SSRC and PT).
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import secrets
|
||||
import socket
|
||||
import struct
|
||||
from typing import TYPE_CHECKING, AsyncIterator, Optional
|
||||
|
||||
from loguru import logger
|
||||
from pipecat.utils.enums import EndTaskReason
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from api.services.telephony.stasis_rtp_connection import StasisRTPConnection
|
||||
from api.services.telephony.stasis_rtp_transport import StasisRTPCallbacks
|
||||
|
||||
# ─────────────────────────────────────────────────────────────────── helpers
|
||||
|
||||
|
||||
_RTP_HDR = struct.Struct("!BBHII") # v/p/x/cc, m/pt, seq, ts, ssrc
|
||||
_PT_PCMU = 0 # static payload type for μ-law
|
||||
|
||||
|
||||
class _RTPEncoder:
|
||||
"""Builds PCMU RTP headers for the packets we SEND to Asterisk."""
|
||||
|
||||
def __init__(self):
|
||||
self.ssrc = secrets.randbits(32)
|
||||
self.seq = secrets.randbits(16)
|
||||
self.ts = 0 # incremented by #payload bytes
|
||||
|
||||
def pack(self, payload: bytes, mark=False) -> bytes:
|
||||
b0 = 0x80 # V=2
|
||||
b1 = (0x80 if mark else 0x00) | _PT_PCMU
|
||||
hdr = _RTP_HDR.pack(b0, b1, self.seq, self.ts, self.ssrc)
|
||||
self.seq = (self.seq + 1) & 0xFFFF
|
||||
self.ts += len(payload) # 1 sample/byte @ 8 kHz
|
||||
return hdr + payload
|
||||
|
||||
|
||||
class _RTPDecoder:
|
||||
"""Very forgiving RTP decoder.
|
||||
|
||||
Latches on the first valid packet and then insists
|
||||
that SSRC & PT match afterwards. Returns *None* if the packet
|
||||
should be ignored.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.peer_ssrc: int | None = None # learned from first packet
|
||||
|
||||
def unpack(self, packet: bytes) -> bytes | None:
|
||||
if len(packet) < _RTP_HDR.size:
|
||||
return None
|
||||
b0, b1, seq, ts, ssrc = _RTP_HDR.unpack_from(packet)
|
||||
if (b0 & 0xC0) != 0x80: # RTP v2?
|
||||
return None
|
||||
if (b1 & 0x7F) != _PT_PCMU: # payload-type 0?
|
||||
return None
|
||||
if self.peer_ssrc is None:
|
||||
self.peer_ssrc = ssrc # latch on first good packet
|
||||
elif ssrc != self.peer_ssrc:
|
||||
return None # stray stream – drop
|
||||
return packet[_RTP_HDR.size :]
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────────────────────── client
|
||||
|
||||
|
||||
class StasisRTPClient:
|
||||
"""Low-level wrapper around StasisRTPConnection.
|
||||
|
||||
Public API
|
||||
──────────
|
||||
• await setup(start_frame) kept for parity (does nothing)
|
||||
• await connect()
|
||||
• async for payload in receive(): # μ-law bytes (20 ms each)
|
||||
…
|
||||
• await send(data) # any length; will be chunked
|
||||
• await disconnect()
|
||||
"""
|
||||
|
||||
_FRAME_SIZE = 160 # 20 ms @ 8 kHz PCMU
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
connection: "StasisRTPConnection",
|
||||
callbacks: "StasisRTPCallbacks",
|
||||
):
|
||||
"""Initialize Stasis RTP client.
|
||||
|
||||
Args:
|
||||
connection: RTP connection parameters.
|
||||
callbacks: Callback handlers for transport events.
|
||||
"""
|
||||
from typing import Any
|
||||
|
||||
self._connection = connection
|
||||
self._callbacks = callbacks
|
||||
self._encoder = _RTPEncoder()
|
||||
self._decoder = _RTPDecoder()
|
||||
|
||||
self._recv_sock: Optional[socket.socket] = None
|
||||
self._send_sock: Optional[socket.socket] = None
|
||||
self._closing = False
|
||||
self._recv_sock_ready = asyncio.Event() # Signal when recv socket is ready
|
||||
self._leave_counter = 0 # Track input/output transport usage
|
||||
self._fallback_disconnect_timer: Optional[asyncio.Task] = (
|
||||
None # Safety timer for disconnect
|
||||
)
|
||||
|
||||
# ── wire event handlers to the connection ────────────────
|
||||
@self._connection.event_handler("connected")
|
||||
async def _on_connected(_: Any):
|
||||
await self._setup_sockets()
|
||||
await self._callbacks.on_client_connected(
|
||||
self._connection.caller_channel_id
|
||||
)
|
||||
|
||||
@self._connection.event_handler("disconnected")
|
||||
async def _on_disconnected(_: Any, reason: str):
|
||||
# Cancel the safety timer if it exists. We start the safety timer when
|
||||
# sending disconnect or transfer from the engine, i.e when the disconnect()
|
||||
# method of the StasisRTPClient is called during wind down of the pipeline.
|
||||
# We start the timer so that if we don't get the remote hangup in a given
|
||||
# duration, we will call client disconnected handler.
|
||||
if (
|
||||
self._fallback_disconnect_timer
|
||||
and not self._fallback_disconnect_timer.done()
|
||||
):
|
||||
self._fallback_disconnect_timer.cancel()
|
||||
self._fallback_disconnect_timer = None
|
||||
|
||||
if not self._closing:
|
||||
# Mark the client as closing, so that when the pipeline is
|
||||
# cancelled or getting closed, we don't try start the fallback
|
||||
# disconnect timer and return safely from disconnect
|
||||
self._closing = True
|
||||
|
||||
await self._callbacks.on_client_disconnected(
|
||||
self._connection.caller_channel_id, reason
|
||||
)
|
||||
|
||||
# ─── public helpers ──────────────────────────────────────────
|
||||
|
||||
async def setup(self, _):
|
||||
"""Setup method for compatibility."""
|
||||
self._leave_counter += 1
|
||||
|
||||
async def connect(self):
|
||||
"""Connect to the RTP socket."""
|
||||
if self._connection.is_connected():
|
||||
return
|
||||
await self._connection.connect()
|
||||
|
||||
async def disconnect(
|
||||
self,
|
||||
reason: str = EndTaskReason.UNKNOWN.value,
|
||||
call_transfer_context: dict = {}, # Keep parameter for backward compatibility
|
||||
):
|
||||
"""Disconnect from the RTP socket."""
|
||||
# Decrement leave counter when disconnect is called
|
||||
logger.debug(f"StasisRTPClient.disconnect leave_counter: {self._leave_counter}")
|
||||
self._leave_counter -= 1
|
||||
if self._leave_counter > 0:
|
||||
# Early return - InputTransport called first, OutputTransport will call later
|
||||
return
|
||||
|
||||
# Only proceed when counter reaches 0 (OutputTransport's call)
|
||||
# Close sockets
|
||||
logger.debug("Going to close sockets")
|
||||
await self._close_sockets()
|
||||
|
||||
if self._closing:
|
||||
# We might have received the disconnected callback from the StasisRTPConnection
|
||||
# due to user hangup. We will just return. We have already closed the sockets
|
||||
# in disconnected callback handler.
|
||||
return
|
||||
self._closing = True
|
||||
|
||||
# Create a safety timer that will call on_client_disconnected if we don't
|
||||
# get StasisEnd from the dialer within 5 seconds. StasisEnd is needed to
|
||||
# trigger on_client_disconnected handler in the event_handlers
|
||||
async def _fallback_disconnect_timeout():
|
||||
await asyncio.sleep(5.0)
|
||||
logger.warning(
|
||||
"Disconnect event not received within 5 seconds, calling on_client_disconnected as fallback"
|
||||
)
|
||||
await self._callbacks.on_client_disconnected(
|
||||
self._connection.caller_channel_id
|
||||
)
|
||||
|
||||
self._fallback_disconnect_timer = asyncio.create_task(
|
||||
_fallback_disconnect_timeout()
|
||||
)
|
||||
|
||||
# Only call disconnect if not a transfer (transfer already handled in PipecatEngine)
|
||||
# NOTE: Transfer now happens immediately in PipecatEngine.send_end_task_frame()
|
||||
if reason != EndTaskReason.USER_QUALIFIED.value:
|
||||
try:
|
||||
await self._connection.disconnect(reason)
|
||||
except Exception as exc:
|
||||
logger.error(f"Failed to disconnect RTP connection: {exc}")
|
||||
else:
|
||||
logger.debug(
|
||||
"Skipping disconnect call for USER_QUALIFIED - transfer already initiated by engine"
|
||||
)
|
||||
|
||||
# ─── socket management ──────────────────────────────────────
|
||||
|
||||
async def _setup_sockets(self):
|
||||
if self._recv_sock and self._send_sock:
|
||||
return
|
||||
|
||||
logger.debug(
|
||||
f"Setting up Sockets - local {self._connection.local_addr}, remote: {self._connection.remote_addr}"
|
||||
)
|
||||
|
||||
# receive socket – bind to local address provided by connection
|
||||
if not self._recv_sock:
|
||||
rs = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||
rs.setblocking(False)
|
||||
rs.bind(self._connection.local_addr)
|
||||
self._recv_sock = rs
|
||||
self._recv_sock_ready.set() # Signal that recv socket is ready
|
||||
|
||||
# send socket – connect to remote (Asterisk) address
|
||||
if not self._send_sock:
|
||||
ss = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||
ss.setblocking(False)
|
||||
ss.connect(self._connection.remote_addr)
|
||||
self._send_sock = ss
|
||||
|
||||
logger.debug(
|
||||
f"Socket setup complete - recv_fd: {self._recv_sock.fileno()}, send_fd: {self._send_sock.fileno()}"
|
||||
)
|
||||
|
||||
async def _close_sockets(self):
|
||||
"""Safely close sockets with proper error handling."""
|
||||
for sock_name, sock in [("recv", self._recv_sock), ("send", self._send_sock)]:
|
||||
if sock:
|
||||
try:
|
||||
# Shutdown the socket first to break any pending operations
|
||||
sock.shutdown(socket.SHUT_RDWR)
|
||||
except OSError:
|
||||
# Socket might already be closed or in a bad state
|
||||
pass
|
||||
try:
|
||||
sock.close()
|
||||
except Exception as exc:
|
||||
logger.debug(f"Error closing {sock_name} socket: {exc}")
|
||||
|
||||
self._recv_sock = None
|
||||
self._send_sock = None
|
||||
self._recv_sock_ready.clear() # Reset the event for potential reconnection
|
||||
|
||||
# Notify the connection that sockets are closed so ARI Manager can clean up ports
|
||||
await self._connection.notify_sockets_closed()
|
||||
|
||||
logger.debug("Closed sockets in StasisRTPClient")
|
||||
|
||||
# ─── receive path ────────────────────────────────────────────
|
||||
|
||||
async def receive(self) -> AsyncIterator[bytes]:
|
||||
"""Async generator yielding μ-law frames (exactly 160 bytes each).
|
||||
|
||||
Silently drops any packet whose RTP header does not match our SSRC/PT.
|
||||
"""
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
# Wait for recv socket to be created
|
||||
try:
|
||||
await self._recv_sock_ready.wait()
|
||||
except asyncio.CancelledError:
|
||||
return
|
||||
|
||||
logger.debug("Going to receive from the socket now")
|
||||
|
||||
while not self._closing:
|
||||
try:
|
||||
# each loop gets 172 bytes UDP packet, which is 160 bytes of
|
||||
# audio data (Asterisk sends 20ms audio chunks with 8k sample rate)
|
||||
# and 12 bytes of RTP header
|
||||
data = await loop.sock_recv(self._recv_sock, 2048)
|
||||
except asyncio.CancelledError:
|
||||
logger.debug("RTP receive task cancelled")
|
||||
break
|
||||
except (OSError, socket.error) as exc:
|
||||
logger.warning(f"RTP receive failed (socket closed): {exc}")
|
||||
break
|
||||
except Exception as exc:
|
||||
logger.debug(f"Unexpected error in receive: {exc}")
|
||||
break
|
||||
|
||||
payload = self._decoder.unpack(data)
|
||||
if payload is None:
|
||||
continue # header failed validation
|
||||
|
||||
# In practice Asterisk sends 20 ms frames – assert just in case.
|
||||
if len(payload) != self._FRAME_SIZE:
|
||||
logger.warning(f"Dropping non-20 ms packet len={len(payload)}")
|
||||
continue
|
||||
yield payload
|
||||
|
||||
# ─── send path ───────────────────────────────────────────────
|
||||
|
||||
async def send(self, data: bytes):
|
||||
"""Send μ-law data of arbitrary length.
|
||||
|
||||
Splits/aggregates into 160-byte chunks before RTP-wrapping.
|
||||
"""
|
||||
if self._closing or not self._send_sock:
|
||||
return
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
# chunk/concat to 160-byte frames
|
||||
chunks = self._chunk_ulaw(data, self._FRAME_SIZE)
|
||||
for i, chunk in enumerate(chunks):
|
||||
mark = i == 0 # set marker on the first packet of talk-spurt
|
||||
packet = self._encoder.pack(chunk, mark=mark)
|
||||
try:
|
||||
await loop.sock_sendall(self._send_sock, packet)
|
||||
except (OSError, socket.error) as exc:
|
||||
logger.warning(f"RTP send failed (socket closed): {exc}")
|
||||
break
|
||||
except Exception as exc:
|
||||
logger.error(f"RTP send failed: {exc}")
|
||||
break
|
||||
|
||||
def _chunk_ulaw(self, buf: bytes, size: int) -> list[bytes]:
|
||||
"""Split / aggregate μ-law bytes to exact *size* multiples.
|
||||
|
||||
• If buf length is not a multiple of *size*, pad the last chunk with 0xFF
|
||||
(silence). That keeps timestamps monotonic.
|
||||
"""
|
||||
if not buf:
|
||||
return []
|
||||
if len(buf) % size:
|
||||
pad = size - (len(buf) % size)
|
||||
buf += b"\xff" * pad
|
||||
return [buf[i : i + size] for i in range(0, len(buf), size)]
|
||||
|
||||
# ─── properties ──────────────────────────────────────────────
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
"""Check if client is connected."""
|
||||
return self._connection.is_connected() and not self._closing
|
||||
|
||||
@property
|
||||
def is_closing(self) -> bool:
|
||||
"""Check if client is closing."""
|
||||
return self._closing
|
||||
182
api/services/telephony/stasis_rtp_connection.py
Normal file
182
api/services/telephony/stasis_rtp_connection.py
Normal file
|
|
@ -0,0 +1,182 @@
|
|||
"""Stasis RTP connection for worker processes.
|
||||
|
||||
This connection works without direct ARI access and communicates with
|
||||
the ARI Manager via Redis for all control operations.
|
||||
"""
|
||||
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import redis.asyncio as aioredis
|
||||
from loguru import logger
|
||||
from pipecat.utils.base_object import BaseObject
|
||||
from pipecat.utils.enums import EndTaskReason
|
||||
|
||||
from api.services.telephony.stasis_event_protocol import (
|
||||
DisconnectCommand,
|
||||
RedisChannels,
|
||||
SocketClosedCommand,
|
||||
TransferCommand,
|
||||
)
|
||||
|
||||
|
||||
class StasisRTPConnection(BaseObject):
|
||||
"""Worker-side connection that communicates with ARI Manager via Redis.
|
||||
|
||||
This class provides the same API as the original StasisRTPConnection but
|
||||
without direct ARI client access. All channel operations are delegated
|
||||
to the ARI Manager process via Redis.
|
||||
"""
|
||||
|
||||
_SUPPORTED_EVENTS = [
|
||||
"connecting",
|
||||
"connected",
|
||||
"disconnected",
|
||||
"closed",
|
||||
"failed",
|
||||
"new",
|
||||
]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
redis_client: aioredis.Redis,
|
||||
channel_id: str,
|
||||
caller_channel_id: str,
|
||||
em_channel_id: Optional[str],
|
||||
bridge_id: Optional[str],
|
||||
local_addr: Optional[Tuple[str, int]],
|
||||
remote_addr: Optional[Tuple[str, int]],
|
||||
workflow_run_id: Optional[int] = None,
|
||||
):
|
||||
"""Initialize distributed connection with pre-established details.
|
||||
|
||||
Args:
|
||||
redis_client: Redis client for communication
|
||||
channel_id: Primary channel ID for this connection
|
||||
caller_channel_id: Caller's channel ID
|
||||
em_channel_id: External media channel ID
|
||||
bridge_id: Bridge ID (already created by ARI Manager)
|
||||
local_addr: Local RTP address (host, port)
|
||||
remote_addr: Remote RTP address with UNICASTRTP_LOCAL_PORT
|
||||
workflow_run_id: Workflow run ID for logging context
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.redis = redis_client
|
||||
self.channel_id = channel_id
|
||||
self.caller_channel_id = caller_channel_id
|
||||
self.em_channel_id = em_channel_id
|
||||
self.bridge_id = bridge_id
|
||||
self.workflow_run_id = workflow_run_id
|
||||
|
||||
# RTP addressing (same as StasisRTPConnection)
|
||||
self.local_addr = local_addr
|
||||
self.remote_addr = remote_addr
|
||||
|
||||
# State tracking
|
||||
# self._closed_by_stasis_end should only be set True after we get
|
||||
# StasisEnd from the transport
|
||||
self._closed_by_stasis_end = False
|
||||
|
||||
self._connect_invoked = False
|
||||
|
||||
# Register event handlers
|
||||
for evt in self._SUPPORTED_EVENTS:
|
||||
self._register_event_handler(evt)
|
||||
|
||||
logger.debug(
|
||||
f"channelID: {channel_id} StasisRTPConnection created: "
|
||||
f"bridgeID: {bridge_id}, local_addr={local_addr}, remote_addr={remote_addr}"
|
||||
)
|
||||
|
||||
async def connect(self):
|
||||
"""Signal readiness to start the call.
|
||||
|
||||
Since the bridge is already established by ARI Manager,
|
||||
we can immediately trigger the connected event.
|
||||
"""
|
||||
self._connect_invoked = True
|
||||
if self.is_connected():
|
||||
await self._call_event_handler("connected")
|
||||
else:
|
||||
logger.warning(
|
||||
"StasisRTPConnection is not connected - did not call connected handler"
|
||||
)
|
||||
|
||||
async def disconnect(self, reason: str):
|
||||
"""Request disconnection via Redis command to ARI Manager. Usually called
|
||||
when there is a disconnect triggered by workflow"""
|
||||
# If we have already received user hangup via StasisEnd, lets
|
||||
# return
|
||||
if self._closed_by_stasis_end:
|
||||
return
|
||||
|
||||
logger.info(f"channelID: {self.channel_id} Requesting disconnect: {reason}")
|
||||
|
||||
# Send disconnect command to ARI Manager
|
||||
command = DisconnectCommand(channel_id=self.channel_id, reason=reason)
|
||||
channel = RedisChannels.channel_commands(self.channel_id)
|
||||
await self.redis.publish(channel, command.to_json())
|
||||
|
||||
async def transfer(self, call_transfer_context: dict):
|
||||
"""Request call transfer via Redis command to ARI Manager."""
|
||||
# If we have already received user hangup via StasisEnd, lets
|
||||
# return
|
||||
if self._closed_by_stasis_end:
|
||||
return
|
||||
|
||||
logger.info(f"channelID: {self.channel_id} Requesting transfer")
|
||||
|
||||
# Send transfer command to ARI Manager
|
||||
command = TransferCommand(
|
||||
channel_id=self.channel_id, context=call_transfer_context
|
||||
)
|
||||
channel = RedisChannels.channel_commands(self.channel_id)
|
||||
await self.redis.publish(channel, command.to_json())
|
||||
|
||||
async def notify_sockets_closed(self):
|
||||
"""Notify ARI Manager that RTP sockets have been closed."""
|
||||
logger.info(
|
||||
f"channelID: {self.channel_id} Notifying ARI Manager that sockets are closed"
|
||||
)
|
||||
|
||||
# Send socket_closed command to ARI Manager
|
||||
command = SocketClosedCommand(channel_id=self.channel_id)
|
||||
channel = RedisChannels.channel_commands(self.channel_id)
|
||||
await self.redis.publish(channel, command.to_json())
|
||||
|
||||
def is_connected(self) -> bool:
|
||||
"""Check if connection is established.
|
||||
|
||||
Returns True once connect() has been called and connection is not closed.
|
||||
"""
|
||||
return self._connect_invoked and not self._closed_by_stasis_end
|
||||
|
||||
async def handle_remote_disconnect(self, reason: str = EndTaskReason.UNKNOWN.value):
|
||||
"""Handle disconnection initiated by ARI Manager. Is called when the user hangs up."""
|
||||
if self._closed_by_stasis_end:
|
||||
return
|
||||
|
||||
self._closed_by_stasis_end = True
|
||||
|
||||
if self._connect_invoked:
|
||||
# Unless self._connect_invoked is True, the event handlers won't be registered. We only
|
||||
# register the event handler of client when the transports are initiated during pipeline
|
||||
# initialisation. Any caller must check and wait for _connect_invoked before
|
||||
# calling the method
|
||||
await self._call_event_handler("disconnected", reason)
|
||||
else:
|
||||
logger.warning(
|
||||
f"ChannelID: {self.channel_id} Got remote disconnect before connection was invoked"
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"channelID: {self.channel_id} StasisRTPConnection disconnected: {reason}"
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
"""String representation of connection."""
|
||||
return (
|
||||
f"<StasisRTPConnection id={self.id} channel={self.channel_id} "
|
||||
f"caller={self.caller_channel_id} em={self.em_channel_id} "
|
||||
f"state={'closed' if self._closed_by_stasis_end else 'open'}>"
|
||||
)
|
||||
120
api/services/telephony/stasis_rtp_serializer.py
Normal file
120
api/services/telephony/stasis_rtp_serializer.py
Normal file
|
|
@ -0,0 +1,120 @@
|
|||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
"""Stasis RTP frame serializer.
|
||||
|
||||
This serializer converts between Pipecat frames and the raw μ-law RTP payload
|
||||
stream expected by an Stasis *External Media* channel.
|
||||
|
||||
The serializer:
|
||||
|
||||
* Down-samples PCM to 8-kHz μ-law for **outgoing** audio (:class:`AudioRawFrame`).
|
||||
* Up-samples μ-law to the pipeline's native rate for **incoming** audio.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from loguru import logger
|
||||
from pipecat.audio.utils import create_default_resampler, pcm_to_ulaw, ulaw_to_pcm
|
||||
from pipecat.frames.frames import (
|
||||
AudioRawFrame,
|
||||
Frame,
|
||||
InputAudioRawFrame,
|
||||
StartFrame,
|
||||
)
|
||||
from pipecat.serializers.base_serializer import FrameSerializer, FrameSerializerType
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class StasisRTPFrameSerializer(FrameSerializer):
|
||||
"""Serializer for Asterisk External Media streams (raw μ-law)."""
|
||||
|
||||
class InputParams(BaseModel):
|
||||
"""Configuration parameters.
|
||||
|
||||
Attributes:
|
||||
----------
|
||||
stasis_sample_rate : int, default 8000
|
||||
The sample-rate used by Stasis when sending μ-law (PCMU).
|
||||
sample_rate : Optional[int]
|
||||
Override for the pipeline's *input* sample-rate. When omitted the
|
||||
value from the :class:`StartFrame` is used.
|
||||
"""
|
||||
|
||||
stasis_sample_rate: int = 8000
|
||||
sample_rate: Optional[int] = None
|
||||
|
||||
def __init__(self, params: Optional[InputParams] = None):
|
||||
"""Initialize Stasis RTP frame serializer.
|
||||
|
||||
Args:
|
||||
params: Optional configuration parameters for the serializer.
|
||||
"""
|
||||
self._params = params or self.InputParams()
|
||||
|
||||
# Wire / pipeline rates
|
||||
self._stasis_sample_rate = self._params.stasis_sample_rate
|
||||
self._sample_rate = 0 # pipeline rate, filled in *setup*
|
||||
|
||||
# Resampler shared between encode / decode paths
|
||||
self._resampler = create_default_resampler()
|
||||
|
||||
@property
|
||||
def type(self) -> FrameSerializerType:
|
||||
"""Stasis uses raw bytes → BINARY."""
|
||||
return FrameSerializerType.BINARY
|
||||
|
||||
async def setup(self, frame: StartFrame):
|
||||
"""Remember pipeline configuration."""
|
||||
self._sample_rate = self._params.sample_rate or frame.audio_in_sample_rate
|
||||
|
||||
async def serialize(self, frame: Frame) -> bytes | str | None:
|
||||
"""Convert a Pipecat frame to a wire payload.
|
||||
|
||||
Only :class:`AudioRawFrame` instances are translated all other frame
|
||||
types are silently ignored, allowing higher-level transports to deal
|
||||
with them as needed.
|
||||
"""
|
||||
if isinstance(frame, AudioRawFrame):
|
||||
try:
|
||||
# Pipeline PCM → 8-kHz μ-law
|
||||
encoded = await pcm_to_ulaw(
|
||||
frame.audio,
|
||||
frame.sample_rate,
|
||||
self._stasis_sample_rate,
|
||||
self._resampler,
|
||||
)
|
||||
return encoded # raw bytes
|
||||
except Exception as exc: # pragma: no cover – robustness
|
||||
logger.error(
|
||||
f"StasisRTPFrameSerializer.serialize: encode failed: {exc}"
|
||||
)
|
||||
return None
|
||||
|
||||
# Non-audio frames are not transmitted on the media path
|
||||
return None
|
||||
|
||||
async def deserialize(self, data: bytes | str) -> Frame | None:
|
||||
"""Convert wire payloads to Pipecat frames.
|
||||
|
||||
The Stasis media socket delivers bare μ-law bytes, therefore *data*
|
||||
must be *bytes*. Any *str* is ignored.
|
||||
"""
|
||||
if not isinstance(data, (bytes, bytearray)):
|
||||
return None
|
||||
|
||||
try:
|
||||
pcm = await ulaw_to_pcm(
|
||||
bytes(data),
|
||||
self._stasis_sample_rate,
|
||||
self._sample_rate,
|
||||
self._resampler,
|
||||
)
|
||||
return InputAudioRawFrame(
|
||||
audio=pcm,
|
||||
sample_rate=self._sample_rate,
|
||||
num_channels=1,
|
||||
)
|
||||
except Exception as exc: # pragma: no cover
|
||||
logger.error(f"StasisRTPFrameSerializer.deserialize: decode failed: {exc}")
|
||||
return None
|
||||
324
api/services/telephony/stasis_rtp_transport.py
Normal file
324
api/services/telephony/stasis_rtp_transport.py
Normal file
|
|
@ -0,0 +1,324 @@
|
|||
# transports/ari_external_media.py (new file)
|
||||
|
||||
"""Stasis RTP transport for Asterisk External Media integration."""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Awaitable, Callable, Optional
|
||||
|
||||
from loguru import logger
|
||||
from pipecat.frames.frames import (
|
||||
CancelFrame,
|
||||
EndFrame,
|
||||
InputAudioRawFrame,
|
||||
OutputAudioRawFrame,
|
||||
StartFrame,
|
||||
TransportMessageFrame,
|
||||
TransportMessageUrgentFrame,
|
||||
)
|
||||
from pipecat.serializers.base_serializer import FrameSerializer
|
||||
from pipecat.transports.base_input import BaseInputTransport
|
||||
from pipecat.transports.base_output import (
|
||||
BaseOutputTransport,
|
||||
TransportClientNotConnectedException,
|
||||
)
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
from pipecat.utils.enums import EndTaskReason
|
||||
from pydantic import BaseModel
|
||||
|
||||
from api.services.telephony.stasis_rtp_client import StasisRTPClient
|
||||
from api.services.telephony.stasis_rtp_connection import StasisRTPConnection
|
||||
|
||||
|
||||
class StasisRTPTransportParams(TransportParams):
|
||||
"""Transport parameters for Stasis RTP transport."""
|
||||
|
||||
serializer: FrameSerializer
|
||||
|
||||
|
||||
class StasisRTPCallbacks(BaseModel):
|
||||
"""Callbacks for Stasis RTP transport events."""
|
||||
|
||||
on_client_connected: Callable[[str], Awaitable[None]]
|
||||
on_client_disconnected: Callable[
|
||||
[str, Optional[str]], Awaitable[None]
|
||||
] # Added optional disconnect reason
|
||||
on_client_closed: Callable[[str], Awaitable[None]]
|
||||
|
||||
|
||||
# ------------------------------------------------ Input Transport -------------------------
|
||||
|
||||
"""
|
||||
Transport calls client receive to receive the audio from the socket. This happens in the self._receive_audio task.
|
||||
Then the audio frames are pushed to _audio_in_queue using push_audio_frame method. Then the _audio_task_handler processes
|
||||
the frames from the _audio_in_queue and pushes them to the VAD analyzer, turn analyzer and pushes the audio
|
||||
further downstream to tts.
|
||||
|
||||
The BaseInputTransport pipeline is responsible for:
|
||||
- Resampling the audio to the correct sample rate
|
||||
- Applying the audio filter
|
||||
- Pushing the audio frames to the VAD analyzer
|
||||
- Pushing the audio frames to the turn analyzer
|
||||
- Pushing the audio frames to the bot interruption analyzer
|
||||
- Pushing the audio frames down the pipeline to the tts
|
||||
|
||||
stop method is called from process_frame of the BaseInputTransport. super.stop() stops _audio_task_handler. It then
|
||||
calls _client.disconnect. Transport's callbacks are sent to the client using StasisRTPCallbacks.
|
||||
"""
|
||||
|
||||
|
||||
class StasisRTPInputTransport(BaseInputTransport):
|
||||
"""Input transport for receiving audio over Stasis RTP."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
transport: BaseTransport,
|
||||
client: StasisRTPClient,
|
||||
params: StasisRTPTransportParams,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize Stasis RTP input transport.
|
||||
|
||||
Args:
|
||||
transport: Parent transport instance.
|
||||
client: Stasis RTP client for socket communication.
|
||||
params: Transport parameters including serializer.
|
||||
**kwargs: Additional keyword arguments for BaseInputTransport.
|
||||
"""
|
||||
super().__init__(params, **kwargs)
|
||||
self._transport = transport
|
||||
self._client = client
|
||||
self._params = params
|
||||
|
||||
self._receive_task: Optional[asyncio.Task] = None
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
"""Start the input transport."""
|
||||
await super().start(frame)
|
||||
|
||||
await self._client.setup(frame)
|
||||
await self._params.serializer.setup(frame)
|
||||
|
||||
# Ensure underlying connection is established and socket ready.
|
||||
await self._client.connect()
|
||||
|
||||
if not self._receive_task:
|
||||
self._receive_task = self.create_task(self._receive_audio())
|
||||
|
||||
await self.set_transport_ready(frame)
|
||||
|
||||
async def _stop_tasks(self):
|
||||
if self._receive_task:
|
||||
await self.cancel_task(self._receive_task)
|
||||
self._receive_task = None
|
||||
|
||||
async def stop(self, frame: EndFrame):
|
||||
"""Stop the input transport."""
|
||||
await super().stop(frame)
|
||||
await self._stop_tasks()
|
||||
# Call disconnect on the client when EndFrame is encountered
|
||||
await self._client.disconnect(
|
||||
frame.metadata.get("reason", EndTaskReason.UNKNOWN.value),
|
||||
frame.metadata.get("call_transfer_context", {}),
|
||||
)
|
||||
logger.debug("Successfully disconnected from StasisRTPClient")
|
||||
|
||||
async def cancel(self, frame: CancelFrame):
|
||||
"""Cancel the input transport."""
|
||||
await super().cancel(frame)
|
||||
await self._stop_tasks()
|
||||
# Call disconnect on the client when CancelFrame is encountered
|
||||
await self._client.disconnect(
|
||||
frame.metadata.get("reason", EndTaskReason.SYSTEM_CANCELLED.value),
|
||||
frame.metadata.get("call_transfer_context", {}),
|
||||
)
|
||||
|
||||
async def _receive_audio(self):
|
||||
try:
|
||||
async for payload in self._client.receive():
|
||||
frame = await self._params.serializer.deserialize(payload)
|
||||
if not frame:
|
||||
continue
|
||||
|
||||
if isinstance(frame, InputAudioRawFrame):
|
||||
await self.push_audio_frame(frame)
|
||||
else:
|
||||
await self.push_frame(frame)
|
||||
except Exception as exc:
|
||||
logger.error(f"StasisRTPInputTransport receive error: {exc}")
|
||||
|
||||
# No app-messages in RTP path, but keep compatibility
|
||||
async def push_app_message(self, message):
|
||||
"""Push app message (not supported in RTP transport)."""
|
||||
logger.debug("StasisRTPInputTransport received app message ignored (RTP only)")
|
||||
|
||||
|
||||
# ------------------------------------------------ Output Transport ------------------------
|
||||
|
||||
|
||||
class StasisRTPOutputTransport(BaseOutputTransport):
|
||||
"""Output transport for sending audio over Stasis RTP."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
transport: BaseTransport,
|
||||
client: StasisRTPClient,
|
||||
params: StasisRTPTransportParams,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize Stasis RTP output transport.
|
||||
|
||||
Args:
|
||||
transport: Parent transport instance.
|
||||
client: Stasis RTP client for socket communication.
|
||||
params: Transport parameters including serializer.
|
||||
**kwargs: Additional keyword arguments for BaseOutputTransport.
|
||||
"""
|
||||
super().__init__(params, **kwargs)
|
||||
|
||||
self._transport = transport
|
||||
self._client = client
|
||||
self._params = params
|
||||
|
||||
# Pace outgoing audio so we don't dump buffers instantly (simulate 10-ms chunks)
|
||||
self._send_interval: float = 0
|
||||
self._next_send_time: float = 0
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
"""Start the output transport."""
|
||||
await super().start(frame)
|
||||
|
||||
await self._client.setup(frame)
|
||||
await self._params.serializer.setup(frame)
|
||||
|
||||
self._send_interval = self._params.audio_out_10ms_chunks * 10 / 1000 # ms
|
||||
|
||||
await self.set_transport_ready(frame)
|
||||
|
||||
async def stop(self, frame: EndFrame):
|
||||
"""Stop the output transport."""
|
||||
await super().stop(frame)
|
||||
|
||||
# Call disconnect on the client when EndFrame is encountered
|
||||
# The client will check its _leave_counter and decide whether to close sockets
|
||||
await self._client.disconnect(
|
||||
frame.metadata.get("reason", EndTaskReason.UNKNOWN.value),
|
||||
frame.metadata.get("call_transfer_context", {}),
|
||||
)
|
||||
|
||||
async def cancel(self, frame: CancelFrame):
|
||||
"""Cancel the output transport."""
|
||||
await super().cancel(frame)
|
||||
# Call disconnect on the client when CancelFrame is encountered
|
||||
await self._client.disconnect(
|
||||
frame.metadata.get("reason", EndTaskReason.SYSTEM_CANCELLED.value),
|
||||
frame.metadata.get("call_transfer_context", {}),
|
||||
)
|
||||
|
||||
async def send_message(
|
||||
self, frame: TransportMessageFrame | TransportMessageUrgentFrame
|
||||
):
|
||||
"""Send message frame (not supported in RTP transport)."""
|
||||
# RTP path has no generic message channel; ignore.
|
||||
pass
|
||||
|
||||
async def write_audio_frame(self, frame: OutputAudioRawFrame):
|
||||
"""Write audio frame to RTP stream."""
|
||||
if self._client.is_closing:
|
||||
raise TransportClientNotConnectedException()
|
||||
|
||||
if not self._client.is_connected:
|
||||
# If not connected yet, just simulate playback delay.
|
||||
await self._write_audio_sleep()
|
||||
return
|
||||
|
||||
payload = await self._params.serializer.serialize(frame)
|
||||
if payload:
|
||||
await self._client.send(payload)
|
||||
|
||||
await self._write_audio_sleep()
|
||||
|
||||
async def _write_audio_sleep(self):
|
||||
"""Simulates real-time audio playback timing by introducing controlled delays.
|
||||
|
||||
This method implements a clock simulation to pace audio transmission at realistic
|
||||
intervals. Without this pacing, audio frames would be sent as fast as possible,
|
||||
which could overwhelm receivers or cause buffering issues.
|
||||
|
||||
The method:
|
||||
1. Calculates how long to sleep based on when the next frame should be sent
|
||||
2. Sleeps for the calculated duration (or 0 if we're already behind schedule)
|
||||
3. Updates _next_send_time for the next audio chunk
|
||||
|
||||
The _send_interval is computed as: (audio_chunk_size / sample_rate) / 2
|
||||
This creates timing that simulates how an actual audio device would output
|
||||
audio at the proper rate (e.g., every 10ms for 10ms audio chunks).
|
||||
"""
|
||||
current_time = time.monotonic()
|
||||
sleep_duration = max(0, self._next_send_time - current_time)
|
||||
await asyncio.sleep(sleep_duration)
|
||||
if sleep_duration == 0:
|
||||
self._next_send_time = time.monotonic() + self._send_interval
|
||||
else:
|
||||
self._next_send_time += self._send_interval
|
||||
|
||||
|
||||
class StasisRTPTransport(BaseTransport):
|
||||
"""Main transport class for Stasis RTP communication."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
stasis_connection: StasisRTPConnection,
|
||||
params: StasisRTPTransportParams,
|
||||
input_name: Optional[str] = None,
|
||||
output_name: Optional[str] = None,
|
||||
):
|
||||
"""Initialize Stasis RTP transport.
|
||||
|
||||
Args:
|
||||
stasis_connection: Connection parameters for Stasis RTP.
|
||||
params: Transport parameters including serializer.
|
||||
input_name: Optional name for input transport.
|
||||
output_name: Optional name for output transport.
|
||||
"""
|
||||
super().__init__(input_name=input_name, output_name=output_name)
|
||||
|
||||
self._params = params
|
||||
|
||||
client_callbacks = StasisRTPCallbacks(
|
||||
on_client_connected=self._on_client_connected,
|
||||
on_client_disconnected=self._on_client_disconnected,
|
||||
on_client_closed=self._on_client_closed,
|
||||
)
|
||||
self._client = StasisRTPClient(stasis_connection, client_callbacks)
|
||||
|
||||
self._input = StasisRTPInputTransport(
|
||||
self, self._client, self._params, name=self._input_name
|
||||
)
|
||||
|
||||
self._output = StasisRTPOutputTransport(
|
||||
self, self._client, self._params, name=self._output_name
|
||||
)
|
||||
|
||||
# expose handlers
|
||||
self._register_event_handler("on_client_connected")
|
||||
self._register_event_handler("on_client_disconnected")
|
||||
self._register_event_handler("on_client_closed")
|
||||
|
||||
def input(self) -> StasisRTPInputTransport:
|
||||
"""Get the input transport."""
|
||||
return self._input
|
||||
|
||||
def output(self) -> StasisRTPOutputTransport:
|
||||
"""Get the output transport."""
|
||||
return self._output
|
||||
|
||||
# ------------------------------------------------ event adapters ----------
|
||||
async def _on_client_connected(self, chan_id: str):
|
||||
await self._call_event_handler("on_client_connected", chan_id)
|
||||
|
||||
async def _on_client_disconnected(self, chan_id: str, reason: Optional[str] = None):
|
||||
await self._call_event_handler("on_client_disconnected", chan_id, reason)
|
||||
|
||||
async def _on_client_closed(self, chan_id: str):
|
||||
await self._call_event_handler("on_client_closed", chan_id)
|
||||
105
api/services/telephony/test_asyncari_ping.py
Normal file
105
api/services/telephony/test_asyncari_ping.py
Normal file
|
|
@ -0,0 +1,105 @@
|
|||
#!/usr/bin/env python3
|
||||
"""Test script to verify asyncari ping functionality."""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Add the asyncari src to Python path for testing
|
||||
asyncari_path = Path(__file__).parent.parent.parent.parent.parent / "asyncari" / "src"
|
||||
sys.path.insert(0, str(asyncari_path))
|
||||
|
||||
import asyncari
|
||||
from loguru import logger
|
||||
|
||||
|
||||
async def test_ping():
|
||||
"""Test the ping functionality with asyncari."""
|
||||
|
||||
# Configure from environment or use defaults
|
||||
base_url = os.getenv("ARI_STASIS_ENDPOINT", "http://localhost:8088")
|
||||
username = os.getenv("ARI_STASIS_USER", "asterisk")
|
||||
password = os.getenv("ARI_STASIS_USER_PASSWORD", "asterisk")
|
||||
apps = os.getenv("ARI_STASIS_APP_NAME", "test-app")
|
||||
|
||||
logger.info(f"Connecting to ARI at {base_url}")
|
||||
|
||||
try:
|
||||
async with asyncari.connect(
|
||||
base_url=base_url, apps=apps, username=username, password=password
|
||||
) as client:
|
||||
logger.info("Connected to ARI")
|
||||
|
||||
# Test REST API ping
|
||||
logger.info("Testing REST API ping...")
|
||||
result = await client.asterisk.ping()
|
||||
logger.info(f"REST API ping successful: {result}")
|
||||
|
||||
# Test WebSocket ping (should work with our wrapper)
|
||||
logger.info("Testing WebSocket ping...")
|
||||
for ws in client.websockets:
|
||||
try:
|
||||
await ws.ping()
|
||||
logger.info("WebSocket ping() called successfully (no-op)")
|
||||
except AttributeError:
|
||||
logger.error("WebSocket doesn't have ping() method")
|
||||
except Exception as e:
|
||||
logger.error(f"WebSocket ping failed: {e}")
|
||||
|
||||
# Test the keep_alive function
|
||||
from ari_client_manager import keep_alive
|
||||
|
||||
logger.info("Starting keep_alive task...")
|
||||
keep_alive_task = asyncio.create_task(keep_alive(client, interval=5.0))
|
||||
|
||||
# Run for 20 seconds to see several pings
|
||||
await asyncio.sleep(20)
|
||||
|
||||
# Cancel keep_alive
|
||||
keep_alive_task.cancel()
|
||||
try:
|
||||
await keep_alive_task
|
||||
except asyncio.CancelledError:
|
||||
logger.info("keep_alive task cancelled")
|
||||
|
||||
logger.info("Test completed successfully!")
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Test failed: {e}")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
async def test_with_manager():
|
||||
"""Test using the ARI client manager."""
|
||||
from ari_client_manager import setup_ari_client_supervisor
|
||||
|
||||
async def on_stasis_call(client, channel, context_vars):
|
||||
logger.info(f"Received call: {channel.id}")
|
||||
|
||||
# Enable ARI Stasis for testing
|
||||
os.environ["ENABLE_ARI_STASIS"] = "true"
|
||||
|
||||
supervisor = await setup_ari_client_supervisor(on_stasis_call)
|
||||
|
||||
if supervisor:
|
||||
logger.info("ARI Stasis supervisor started with ping support")
|
||||
|
||||
# Run for 30 seconds
|
||||
await asyncio.sleep(30)
|
||||
|
||||
await supervisor.close()
|
||||
logger.info("Supervisor closed")
|
||||
else:
|
||||
logger.error("Failed to start supervisor")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
|
||||
if len(sys.argv) > 1 and sys.argv[1] == "manager":
|
||||
asyncio.run(test_with_manager())
|
||||
else:
|
||||
asyncio.run(test_ping())
|
||||
83
api/services/telephony/test_real_ping.py
Normal file
83
api/services/telephony/test_real_ping.py
Normal file
|
|
@ -0,0 +1,83 @@
|
|||
#!/usr/bin/env python3
|
||||
"""Test script to verify real WebSocket ping frames are being sent."""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Add the asyncari src to Python path
|
||||
asyncari_path = Path(__file__).parent.parent.parent.parent.parent / "asyncari" / "src"
|
||||
sys.path.insert(0, str(asyncari_path))
|
||||
|
||||
import asyncari
|
||||
from loguru import logger
|
||||
|
||||
# Enable debug logging to see ping frames
|
||||
logger.add(sys.stderr, level="DEBUG")
|
||||
|
||||
|
||||
async def test_real_ping():
|
||||
"""Test that real WebSocket ping frames are sent."""
|
||||
|
||||
# Configure from environment or use defaults
|
||||
base_url = os.getenv("ARI_STASIS_ENDPOINT", "http://localhost:8088")
|
||||
username = os.getenv("ARI_STASIS_USER", "asterisk")
|
||||
password = os.getenv("ARI_STASIS_USER_PASSWORD", "asterisk")
|
||||
apps = os.getenv("ARI_STASIS_APP_NAME", "test-app")
|
||||
|
||||
logger.info(f"Connecting to ARI at {base_url}")
|
||||
|
||||
try:
|
||||
async with asyncari.connect(
|
||||
base_url=base_url, apps=apps, username=username, password=password
|
||||
) as client:
|
||||
logger.info("Connected to ARI")
|
||||
|
||||
# Get the WebSocket
|
||||
for ws in client.websockets:
|
||||
logger.info(f"WebSocket type: {type(ws)}")
|
||||
logger.info(
|
||||
f"WebSocket wrapper active: {'WebSocketWrapper' in str(type(ws))}"
|
||||
)
|
||||
|
||||
# Check internal structure
|
||||
if hasattr(ws, "_websocket"):
|
||||
inner_ws = ws._websocket
|
||||
logger.info(f"Inner WebSocket type: {type(inner_ws)}")
|
||||
logger.info(f"Has _connection: {hasattr(inner_ws, '_connection')}")
|
||||
logger.info(f"Has _sock: {hasattr(inner_ws, '_sock')}")
|
||||
|
||||
# Send a test ping
|
||||
logger.info("Sending test ping...")
|
||||
try:
|
||||
await ws.ping(b"test-ping-123")
|
||||
logger.info("Ping sent successfully!")
|
||||
except Exception as e:
|
||||
logger.error(f"Ping failed: {e}")
|
||||
|
||||
# Test the keep_alive function
|
||||
logger.info("\nTesting keep_alive function...")
|
||||
from ari_client_manager import keep_alive
|
||||
|
||||
# Run keep_alive for a short time
|
||||
keep_alive_task = asyncio.create_task(keep_alive(client, interval=3.0))
|
||||
|
||||
# Let it run for 10 seconds to see multiple pings
|
||||
await asyncio.sleep(10)
|
||||
|
||||
# Cancel and cleanup
|
||||
keep_alive_task.cancel()
|
||||
try:
|
||||
await keep_alive_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
logger.info("Test completed!")
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Test failed: {e}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(test_real_ping())
|
||||
207
api/services/telephony/twilio.py
Normal file
207
api/services/telephony/twilio.py
Normal file
|
|
@ -0,0 +1,207 @@
|
|||
import random
|
||||
from typing import Any, Dict, List, Optional
|
||||
from urllib.parse import urlencode
|
||||
|
||||
import aiohttp
|
||||
from loguru import logger
|
||||
from pydantic import ValidationError
|
||||
from twilio.request_validator import RequestValidator
|
||||
|
||||
from api.constants import (
|
||||
BACKEND_API_ENDPOINT,
|
||||
TWILIO_ACCOUNT_SID,
|
||||
TWILIO_AUTH_TOKEN,
|
||||
TWILIO_DEFAULT_FROM_NUMBER,
|
||||
)
|
||||
from api.db import db_client
|
||||
|
||||
|
||||
class TwilioService:
|
||||
"""Service for interacting with Twilio API."""
|
||||
|
||||
def __init__(self):
|
||||
if (
|
||||
not TWILIO_DEFAULT_FROM_NUMBER
|
||||
or not TWILIO_ACCOUNT_SID
|
||||
or not TWILIO_AUTH_TOKEN
|
||||
):
|
||||
raise ValidationError(
|
||||
"Please set TWILIO_DEFAULT_FROM_NUMBER, TWILIO_ACCOUNT_SID, and TWILIO_AUTH_TOKEN environment"
|
||||
"variables to use TwilioService"
|
||||
)
|
||||
|
||||
self.account_sid = TWILIO_ACCOUNT_SID
|
||||
self.auth_token = TWILIO_AUTH_TOKEN
|
||||
self.default_from_number = TWILIO_DEFAULT_FROM_NUMBER
|
||||
|
||||
self.base_url = f"https://api.twilio.com/2010-04-01/Accounts/{self.account_sid}"
|
||||
|
||||
async def get_organization_phone_numbers(self, organization_id: int) -> List[str]:
|
||||
"""
|
||||
Get the list of Twilio phone numbers configured for an organization.
|
||||
|
||||
Args:
|
||||
organization_id: The organization ID
|
||||
|
||||
Returns:
|
||||
List of phone numbers, or default if none configured
|
||||
"""
|
||||
try:
|
||||
from api.enums import OrganizationConfigurationKey
|
||||
|
||||
config = await db_client.get_configuration(
|
||||
organization_id,
|
||||
OrganizationConfigurationKey.TWILIO_PHONE_NUMBERS.value,
|
||||
)
|
||||
|
||||
if config and config.value:
|
||||
# Expect the value to be a list of phone numbers
|
||||
phone_numbers = config.value.get("value", [])
|
||||
if isinstance(phone_numbers, list) and phone_numbers:
|
||||
return phone_numbers
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Error getting phone numbers for org {organization_id}: {e}"
|
||||
)
|
||||
|
||||
# Fall back to default from environment
|
||||
return [self.default_from_number]
|
||||
|
||||
async def initiate_call(
|
||||
self,
|
||||
to_number: str,
|
||||
url_args: Dict[str, Any] = {},
|
||||
workflow_run_id: Optional[int] = None,
|
||||
organization_id: Optional[int] = None,
|
||||
**kwargs: Any,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Initiates a Twilio call using the Calls API.
|
||||
|
||||
Args:
|
||||
to_number: The destination phone number
|
||||
url_args: Dictionary of URL parameters to append to the base URL
|
||||
workflow_run_id: The workflow run ID for tracking callbacks
|
||||
organization_id: The organization ID for selecting phone numbers
|
||||
**kwargs: Additional parameters to pass to the Twilio API
|
||||
|
||||
Returns:
|
||||
Dict containing the Twilio API response
|
||||
"""
|
||||
endpoint = f"{self.base_url}/Calls.json"
|
||||
|
||||
if not BACKEND_API_ENDPOINT:
|
||||
raise ValidationError(
|
||||
"Please set BACKEND_API_ENDPOINT environment variable to a tunnel or persistant URL"
|
||||
)
|
||||
|
||||
# Construct the URL with parameters if any
|
||||
url: str = f"https://{BACKEND_API_ENDPOINT}/api/v1/twilio/twiml"
|
||||
if url_args:
|
||||
query_string = urlencode(url_args)
|
||||
url = f"{url}?{query_string}"
|
||||
|
||||
logger.debug(f"Initiating call with URL: {url}")
|
||||
|
||||
# Get phone numbers for organization and select one randomly
|
||||
if organization_id:
|
||||
phone_numbers = await self.get_organization_phone_numbers(organization_id)
|
||||
from_number = random.choice(phone_numbers)
|
||||
logger.info(
|
||||
f"Selected phone number {from_number} from {len(phone_numbers)} "
|
||||
f"available numbers for org {organization_id}"
|
||||
)
|
||||
else:
|
||||
from_number = self.default_from_number
|
||||
|
||||
# Prepare call data
|
||||
data = {"To": to_number, "From": from_number, "Url": url}
|
||||
|
||||
if not BACKEND_API_ENDPOINT:
|
||||
raise ValidationError(
|
||||
"Please set BACKEND_API_ENDPOINT environment variable to a tunnel or persistant URL"
|
||||
)
|
||||
|
||||
# Add status callback configuration if workflow_run_id is provided
|
||||
if workflow_run_id:
|
||||
callback_url = f"https://{BACKEND_API_ENDPOINT}/api/v1/twilio/status-callback/{workflow_run_id}"
|
||||
data.update(
|
||||
{
|
||||
"StatusCallback": callback_url,
|
||||
"StatusCallbackEvent": [
|
||||
"initiated",
|
||||
"ringing",
|
||||
"answered",
|
||||
"completed",
|
||||
],
|
||||
"StatusCallbackMethod": "POST",
|
||||
}
|
||||
)
|
||||
|
||||
# Add any additional kwargs
|
||||
data.update(kwargs)
|
||||
|
||||
# Make the API request
|
||||
async with aiohttp.ClientSession() as session:
|
||||
auth = aiohttp.BasicAuth(self.account_sid, self.auth_token)
|
||||
async with session.post(endpoint, data=data, auth=auth) as response:
|
||||
if response.status != 201:
|
||||
error_data = await response.json()
|
||||
raise Exception(f"Failed to initiate call: {error_data}")
|
||||
|
||||
return await response.json()
|
||||
|
||||
async def get_start_call_twiml(
|
||||
self, workflow_id: int, user_id: int, workflow_run_id: int
|
||||
) -> str:
|
||||
if not BACKEND_API_ENDPOINT:
|
||||
raise ValidationError(
|
||||
"Please set BACKEND_API_ENDPOINT environment variable to a tunnel or persistant URL"
|
||||
)
|
||||
|
||||
twiml_content = f"""<?xml version="1.0" encoding="UTF-8"?>
|
||||
<Response>
|
||||
<Connect>
|
||||
<Stream url="wss://{BACKEND_API_ENDPOINT}/api/v1/twilio/ws/{workflow_id}/{user_id}/{workflow_run_id}"></Stream>
|
||||
</Connect>
|
||||
<Pause length="40"/>
|
||||
</Response>"""
|
||||
return twiml_content
|
||||
|
||||
async def get_call(self, call_sid: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Retrieves information about a specific call.
|
||||
|
||||
Args:
|
||||
call_sid: The SID of the call to retrieve
|
||||
|
||||
Returns:
|
||||
Dict containing the call information
|
||||
"""
|
||||
endpoint = f"{self.base_url}/Calls/{call_sid}.json"
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
auth = aiohttp.BasicAuth(self.account_sid, self.auth_token)
|
||||
async with session.get(endpoint, auth=auth) as response:
|
||||
if response.status != 200:
|
||||
error_data = await response.json()
|
||||
raise Exception(f"Failed to get call: {error_data}")
|
||||
|
||||
return await response.json()
|
||||
|
||||
def verify_signature(
|
||||
self, url: str, params: Dict[str, Any], signature: str
|
||||
) -> bool:
|
||||
"""
|
||||
Verify Twilio request signature using official Twilio SDK.
|
||||
|
||||
Args:
|
||||
url: The full URL of the webhook
|
||||
params: The POST parameters (form data) as a dictionary
|
||||
signature: The X-Twilio-Signature header value
|
||||
|
||||
Returns:
|
||||
bool: True if signature is valid, False otherwise
|
||||
"""
|
||||
validator = RequestValidator(self.auth_token)
|
||||
return validator.validate(url, params, signature)
|
||||
371
api/services/telephony/worker_event_subscriber.py
Normal file
371
api/services/telephony/worker_event_subscriber.py
Normal file
|
|
@ -0,0 +1,371 @@
|
|||
"""Worker Event Subscriber for distributed ARI architecture.
|
||||
|
||||
This component runs in each FastAPI worker process and subscribes to
|
||||
Redis events from the ARI Manager. It creates pipelines for assigned calls
|
||||
without any direct ARI connection.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import uuid
|
||||
from typing import Awaitable, Callable, Optional
|
||||
|
||||
import redis.asyncio as aioredis
|
||||
from loguru import logger
|
||||
from pipecat.utils.context import set_current_run_id
|
||||
|
||||
from api.routes.stasis_rtp import on_stasis_call
|
||||
from api.services.telephony.stasis_event_protocol import (
|
||||
DisconnectCommand,
|
||||
RedisChannels,
|
||||
RedisKeys,
|
||||
StasisEndEvent,
|
||||
StasisStartEvent,
|
||||
parse_event,
|
||||
)
|
||||
from api.services.telephony.stasis_rtp_connection import StasisRTPConnection
|
||||
|
||||
|
||||
class WorkerEventSubscriber:
|
||||
"""Subscribes to ARI events from Redis and processes them in the worker."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
redis_client: aioredis.Redis,
|
||||
on_stasis_call: Callable[[StasisRTPConnection, dict], Awaitable[None]],
|
||||
):
|
||||
self.redis = redis_client
|
||||
self.worker_id = str(uuid.uuid4()) # Generate unique worker ID
|
||||
self.on_stasis_call = on_stasis_call
|
||||
self._running = False
|
||||
self._task: Optional[asyncio.Task] = None
|
||||
self._heartbeat_task: Optional[asyncio.Task] = None
|
||||
self._active_connections: dict[str, StasisRTPConnection] = {}
|
||||
self._active_tasks: dict[str, asyncio.Task] = {}
|
||||
self._cleanup_tasks: dict[str, asyncio.Task] = {}
|
||||
self._shutting_down = False
|
||||
self._shutdown_event = asyncio.Event()
|
||||
|
||||
async def start(self):
|
||||
"""Start the event subscriber."""
|
||||
if self._task is None:
|
||||
self._running = True
|
||||
|
||||
# Register worker in Redis
|
||||
await self._register_worker()
|
||||
|
||||
# Start main event loop
|
||||
self._task = asyncio.create_task(
|
||||
self._run(), name=f"worker_subscriber_{self.worker_id}"
|
||||
)
|
||||
|
||||
# Start heartbeat task
|
||||
self._heartbeat_task = asyncio.create_task(
|
||||
self._heartbeat_loop(), name=f"worker_heartbeat_{self.worker_id}"
|
||||
)
|
||||
|
||||
logger.info(f"Worker {self.worker_id} event subscriber started")
|
||||
|
||||
async def _register_worker(self):
|
||||
"""Register this worker in Redis."""
|
||||
worker_key = RedisKeys.worker_active(self.worker_id)
|
||||
worker_data = json.dumps({"status": "ready", "active_calls": 0})
|
||||
|
||||
# Set with TTL of 30 seconds (will be refreshed by heartbeat)
|
||||
await self.redis.setex(worker_key, 30, worker_data)
|
||||
|
||||
# Add to workers set
|
||||
await self.redis.sadd(RedisKeys.workers_set(), self.worker_id)
|
||||
|
||||
logger.info(f"Worker {self.worker_id} registered in Redis")
|
||||
|
||||
async def _heartbeat_loop(self):
|
||||
"""Send periodic heartbeats to Redis."""
|
||||
try:
|
||||
while self._running:
|
||||
# Update worker status with current active call count
|
||||
worker_key = RedisKeys.worker_active(self.worker_id)
|
||||
worker_data = json.dumps(
|
||||
{
|
||||
"status": "draining" if self._shutting_down else "ready",
|
||||
"active_calls": len(self._active_tasks),
|
||||
}
|
||||
)
|
||||
|
||||
# Refresh TTL to 30 seconds
|
||||
await self.redis.setex(worker_key, 30, worker_data)
|
||||
|
||||
# Wait 10 seconds before next heartbeat
|
||||
await asyncio.sleep(10)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.debug(f"Worker {self.worker_id} heartbeat cancelled")
|
||||
except Exception as e:
|
||||
logger.exception(f"Worker {self.worker_id} heartbeat error: {e}")
|
||||
|
||||
async def graceful_shutdown(self, max_wait_seconds: int = 300):
|
||||
"""Gracefully shutdown the worker, waiting for calls to complete.
|
||||
|
||||
Args:
|
||||
max_wait_seconds: Maximum time to wait for calls to complete (default 5 minutes)
|
||||
"""
|
||||
logger.info(f"Worker {self.worker_id} starting graceful shutdown")
|
||||
|
||||
# Mark as shutting down to prevent new calls
|
||||
self._shutting_down = True
|
||||
|
||||
# Update status in Redis to 'draining'
|
||||
worker_key = RedisKeys.worker_active(self.worker_id)
|
||||
worker_data = json.dumps(
|
||||
{"status": "draining", "active_calls": len(self._active_tasks)}
|
||||
)
|
||||
await self.redis.setex(worker_key, 30, worker_data)
|
||||
|
||||
# Wait for active tasks to complete (with timeout)
|
||||
start_time = asyncio.get_event_loop().time()
|
||||
while (
|
||||
self._active_tasks
|
||||
and (asyncio.get_event_loop().time() - start_time) < max_wait_seconds
|
||||
):
|
||||
active_count = len(self._active_tasks)
|
||||
logger.info(
|
||||
f"Worker {self.worker_id} waiting for {active_count} active calls to complete"
|
||||
)
|
||||
|
||||
# Update Redis with current status
|
||||
worker_data = json.dumps(
|
||||
{"status": "draining", "active_calls": active_count}
|
||||
)
|
||||
await self.redis.setex(worker_key, 30, worker_data)
|
||||
|
||||
# Wait a bit before checking again
|
||||
await asyncio.sleep(5)
|
||||
|
||||
# Force stop if timeout reached
|
||||
if self._active_tasks:
|
||||
logger.warning(
|
||||
f"Worker {self.worker_id} forcefully stopping {len(self._active_tasks)} active calls after timeout channel_ids: {list(self._active_tasks.keys())}"
|
||||
)
|
||||
|
||||
await self.stop()
|
||||
|
||||
async def stop(self):
|
||||
"""Stop the event subscriber and deregister from Redis."""
|
||||
self._running = False
|
||||
|
||||
# Deregister from Redis
|
||||
await self._deregister_worker()
|
||||
|
||||
# Cancel all active call processing tasks
|
||||
for channel_id, task in list(self._active_tasks.items()):
|
||||
if not task.done():
|
||||
logger.info(f"Cancelling active call task for channel {channel_id}")
|
||||
task.cancel()
|
||||
|
||||
# Cancel all cleanup tasks
|
||||
for channel_id, task in list(self._cleanup_tasks.items()):
|
||||
if not task.done():
|
||||
logger.info(f"Cancelling cleanup task for channel {channel_id}")
|
||||
task.cancel()
|
||||
|
||||
# Wait for all tasks to complete
|
||||
all_tasks = list(self._active_tasks.values()) + list(
|
||||
self._cleanup_tasks.values()
|
||||
)
|
||||
if all_tasks:
|
||||
await asyncio.gather(*all_tasks, return_exceptions=True)
|
||||
|
||||
# Cancel heartbeat task
|
||||
if self._heartbeat_task:
|
||||
self._heartbeat_task.cancel()
|
||||
try:
|
||||
await self._heartbeat_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
if self._task:
|
||||
self._task.cancel()
|
||||
try:
|
||||
await self._task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
logger.info(f"Worker {self.worker_id} event subscriber stopped")
|
||||
|
||||
async def _deregister_worker(self):
|
||||
"""Remove this worker from Redis."""
|
||||
try:
|
||||
# Remove from active workers
|
||||
await self.redis.delete(RedisKeys.worker_active(self.worker_id))
|
||||
|
||||
# Remove from workers set
|
||||
await self.redis.srem(RedisKeys.workers_set(), self.worker_id)
|
||||
|
||||
logger.info(f"Worker {self.worker_id} deregistered from Redis")
|
||||
except Exception as e:
|
||||
logger.error(f"Error deregistering worker {self.worker_id}: {e}")
|
||||
|
||||
async def _run(self):
|
||||
"""Main subscriber loop."""
|
||||
self._running = True
|
||||
channel = RedisChannels.worker_events(self.worker_id)
|
||||
pubsub = self.redis.pubsub()
|
||||
|
||||
try:
|
||||
await pubsub.subscribe(channel)
|
||||
logger.info(f"Worker {self.worker_id} subscribed to {channel}")
|
||||
|
||||
async for message in pubsub.listen():
|
||||
if not self._running:
|
||||
break
|
||||
|
||||
if message["type"] == "message":
|
||||
try:
|
||||
await self._handle_event(message["data"])
|
||||
except Exception as e:
|
||||
logger.exception(f"Error handling event: {e}")
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.debug(f"Worker {self.worker_id} subscriber cancelled")
|
||||
except Exception as e:
|
||||
logger.exception(f"Worker {self.worker_id} subscriber error: {e}")
|
||||
finally:
|
||||
await pubsub.unsubscribe(channel)
|
||||
await pubsub.aclose()
|
||||
|
||||
async def _handle_event(self, data: str):
|
||||
"""Handle an event from the ARI Manager."""
|
||||
event = parse_event(data)
|
||||
if not event:
|
||||
logger.warning(f"Failed to parse event: {data}")
|
||||
return
|
||||
|
||||
if isinstance(event, StasisStartEvent):
|
||||
await self._handle_stasis_start(event)
|
||||
elif isinstance(event, StasisEndEvent):
|
||||
await self._handle_stasis_end(event)
|
||||
else:
|
||||
logger.warning(
|
||||
f"channelID: {event.channel_id} Unhandled event type: {type(event)}"
|
||||
)
|
||||
|
||||
async def _handle_stasis_start(self, event: StasisStartEvent):
|
||||
"""Handle a new call assignment."""
|
||||
|
||||
channel_id = event.channel_id
|
||||
logger.info(
|
||||
f"channelID: {channel_id} Worker {self.worker_id} handling StasisStart"
|
||||
)
|
||||
|
||||
try:
|
||||
# Create StasisRTPConnection without ARI client
|
||||
connection = StasisRTPConnection(
|
||||
redis_client=self.redis,
|
||||
channel_id=channel_id,
|
||||
caller_channel_id=event.caller_channel_id,
|
||||
em_channel_id=event.em_channel_id,
|
||||
bridge_id=event.bridge_id,
|
||||
local_addr=tuple(event.local_addr) if event.local_addr else None,
|
||||
remote_addr=tuple(event.remote_addr) if event.remote_addr else None,
|
||||
)
|
||||
|
||||
# Store connection for cleanup
|
||||
self._active_connections[channel_id] = connection
|
||||
|
||||
# Create a background task to handle the call
|
||||
task = asyncio.create_task(
|
||||
self._process_call(connection, event.call_context_vars, channel_id),
|
||||
name=f"call_handler_{channel_id}",
|
||||
)
|
||||
self._active_tasks[channel_id] = task
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Error handling StasisStart for {channel_id}: {e}")
|
||||
# Send disconnect command if setup fails
|
||||
await self._send_disconnect(channel_id, "setup_failed")
|
||||
|
||||
async def _process_call(
|
||||
self, connection: StasisRTPConnection, call_context_vars: dict, channel_id: str
|
||||
):
|
||||
"""Process a call in the background."""
|
||||
try:
|
||||
await self.on_stasis_call(connection, call_context_vars)
|
||||
except Exception as e:
|
||||
logger.exception(f"Error processing call for {channel_id}: {e}")
|
||||
# Send disconnect command if call processing fails
|
||||
await self._send_disconnect(channel_id, "processing_failed")
|
||||
finally:
|
||||
# Clean up task reference
|
||||
if channel_id in self._active_tasks:
|
||||
del self._active_tasks[channel_id]
|
||||
|
||||
async def _process_cleanup(self, channel_id: str, reason: str):
|
||||
"""Process call cleanup in the background."""
|
||||
try:
|
||||
if channel_id in self._active_connections:
|
||||
connection: StasisRTPConnection = self._active_connections[channel_id]
|
||||
|
||||
# We must wait for the connection's invocation
|
||||
# before sending in remote disconnect. Otherwise,
|
||||
# the event handlers won't be registered and we won't
|
||||
# be able to call on_client_disconnected to cancel the
|
||||
# pipeline
|
||||
while not connection._connect_invoked:
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# Set the run_id context so that we can have it in logs
|
||||
if connection.workflow_run_id:
|
||||
set_current_run_id(connection.workflow_run_id)
|
||||
|
||||
await connection.handle_remote_disconnect(reason)
|
||||
del self._active_connections[channel_id]
|
||||
except Exception as e:
|
||||
logger.exception(f"Error during cleanup for {channel_id}: {e}")
|
||||
finally:
|
||||
# Clean up task reference from cleanup tasks dictionary
|
||||
if channel_id in self._cleanup_tasks:
|
||||
del self._cleanup_tasks[channel_id]
|
||||
|
||||
async def _handle_stasis_end(self, event: StasisEndEvent):
|
||||
"""Handle call termination."""
|
||||
channel_id = event.channel_id
|
||||
logger.info(
|
||||
f"channelID: {channel_id} Worker {self.worker_id} handling StasisEnd, Reason: {event.reason}"
|
||||
)
|
||||
|
||||
# Create a background task to handle the cleanup
|
||||
if channel_id in self._active_connections:
|
||||
# Check if there's already a cleanup task for this channel
|
||||
if (
|
||||
channel_id not in self._cleanup_tasks
|
||||
or self._cleanup_tasks[channel_id].done()
|
||||
):
|
||||
# Lets start a new task, since we need to poll for
|
||||
# connection to be invoked from the pipeline before
|
||||
# caling remote disconnect
|
||||
task = asyncio.create_task(
|
||||
self._process_cleanup(channel_id, event.reason),
|
||||
name=f"cleanup_handler_{channel_id}",
|
||||
)
|
||||
self._cleanup_tasks[channel_id] = task
|
||||
else:
|
||||
logger.warning(
|
||||
f"channelID: {channel_id} Cleanup skipped - cleanup task still running"
|
||||
)
|
||||
|
||||
async def _send_disconnect(self, channel_id: str, reason: str):
|
||||
"""Send disconnect command to ARI Manager."""
|
||||
|
||||
command = DisconnectCommand(channel_id=channel_id, reason=reason)
|
||||
channel = RedisChannels.channel_commands(channel_id)
|
||||
await self.redis.publish(channel, command.to_json())
|
||||
|
||||
|
||||
async def setup_worker_subscriber(
|
||||
redis_client: aioredis.Redis,
|
||||
) -> WorkerEventSubscriber:
|
||||
"""Setup the worker event subscriber with dynamic registration."""
|
||||
subscriber = WorkerEventSubscriber(redis_client, on_stasis_call)
|
||||
logger.info(f"Setting up worker event subscriber with ID {subscriber.worker_id}")
|
||||
await subscriber.start()
|
||||
return subscriber
|
||||
0
api/services/workflow/__init__.py
Normal file
0
api/services/workflow/__init__.py
Normal file
77
api/services/workflow/disposition_mapper.py
Normal file
77
api/services/workflow/disposition_mapper.py
Normal file
|
|
@ -0,0 +1,77 @@
|
|||
"""Utility module for applying disposition code mapping."""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from api.db import db_client
|
||||
from api.enums import OrganizationConfigurationKey
|
||||
|
||||
|
||||
async def apply_disposition_mapping(value: str, organization_id: Optional[int]) -> str:
|
||||
"""Apply disposition code mapping if configured.
|
||||
|
||||
Args:
|
||||
value: The original disposition value to map
|
||||
organization_id: The organization ID
|
||||
|
||||
Returns:
|
||||
The mapped value if found in configuration, otherwise the original value
|
||||
"""
|
||||
if not organization_id or not value:
|
||||
return value
|
||||
|
||||
try:
|
||||
disposition_mapping = await db_client.get_configuration_value(
|
||||
organization_id,
|
||||
OrganizationConfigurationKey.DISPOSITION_CODE_MAPPING.value,
|
||||
default={},
|
||||
)
|
||||
|
||||
if not disposition_mapping:
|
||||
return value
|
||||
|
||||
# Return mapped value if exists, otherwise original
|
||||
# DISPOSITION_CODE_MAPPING looks like {"user_idle_max_duration_exceeded": "DAIR"} etc.
|
||||
mapped_value = disposition_mapping.get(value, value)
|
||||
|
||||
if mapped_value != value:
|
||||
logger.debug(
|
||||
f"Mapped disposition code from '{value}' to '{mapped_value}' "
|
||||
f"for organization {organization_id}"
|
||||
)
|
||||
|
||||
return mapped_value
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error applying disposition mapping: {e}")
|
||||
return value
|
||||
|
||||
|
||||
async def get_organization_id_from_workflow_run(
|
||||
workflow_run_id: Optional[int],
|
||||
) -> Optional[int]:
|
||||
"""Get organization_id from workflow_run_id through the model relationships.
|
||||
|
||||
Args:
|
||||
workflow_run_id: The workflow run ID
|
||||
|
||||
Returns:
|
||||
The organization ID if found, otherwise None
|
||||
"""
|
||||
if not workflow_run_id:
|
||||
return None
|
||||
|
||||
try:
|
||||
workflow_run = await db_client.get_workflow_run_by_id(workflow_run_id)
|
||||
if not workflow_run or not workflow_run.workflow:
|
||||
return None
|
||||
|
||||
workflow = workflow_run.workflow
|
||||
if not workflow.user:
|
||||
return None
|
||||
|
||||
return workflow.user.selected_organization_id
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting organization_id from workflow_run: {e}")
|
||||
return None
|
||||
96
api/services/workflow/dto.py
Normal file
96
api/services/workflow/dto.py
Normal file
|
|
@ -0,0 +1,96 @@
|
|||
from enum import Enum
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field, ValidationError, model_validator
|
||||
|
||||
|
||||
class NodeType(str, Enum):
|
||||
startNode = "startCall"
|
||||
endNode = "endCall"
|
||||
agentNode = "agentNode"
|
||||
globalNode = "globalNode"
|
||||
|
||||
|
||||
class Position(BaseModel):
|
||||
x: float
|
||||
y: float
|
||||
|
||||
|
||||
class VariableType(str, Enum):
|
||||
string = "string"
|
||||
number = "number"
|
||||
boolean = "boolean"
|
||||
|
||||
|
||||
class ExtractionVariableDTO(BaseModel):
|
||||
name: str = Field(..., min_length=1)
|
||||
type: VariableType
|
||||
prompt: Optional[str] = None
|
||||
|
||||
|
||||
class NodeDataDTO(BaseModel):
|
||||
name: str = Field(..., min_length=1)
|
||||
prompt: str = Field(..., min_length=1)
|
||||
is_static: bool = False
|
||||
is_start: bool = False
|
||||
is_end: bool = False
|
||||
allow_interrupt: bool = False
|
||||
extraction_enabled: bool = False
|
||||
extraction_prompt: Optional[str] = None
|
||||
extraction_variables: Optional[list[ExtractionVariableDTO]] = None
|
||||
add_global_prompt: bool = True
|
||||
wait_for_user_response: bool = False
|
||||
wait_for_user_response_timeout: Optional[float] = None
|
||||
detect_voicemail: bool = True
|
||||
delayed_start: bool = False
|
||||
delayed_start_duration: Optional[float] = None
|
||||
|
||||
|
||||
class RFNodeDTO(BaseModel):
|
||||
id: str
|
||||
type: NodeType = Field(default=NodeType.agentNode)
|
||||
position: Position
|
||||
data: NodeDataDTO
|
||||
|
||||
|
||||
class EdgeDataDTO(BaseModel):
|
||||
label: str = Field(..., min_length=1)
|
||||
condition: str = Field(..., min_length=1)
|
||||
|
||||
|
||||
class RFEdgeDTO(BaseModel):
|
||||
id: str
|
||||
source: str
|
||||
target: str
|
||||
data: EdgeDataDTO
|
||||
|
||||
|
||||
class ReactFlowDTO(BaseModel):
|
||||
nodes: List[RFNodeDTO]
|
||||
edges: List[RFEdgeDTO]
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _referential_integrity(self):
|
||||
node_ids = {n.id for n in self.nodes}
|
||||
line_errors: list[dict[str, str]] = []
|
||||
|
||||
for idx, edge in enumerate(self.edges):
|
||||
for endpoint in (edge.source, edge.target):
|
||||
if endpoint not in node_ids:
|
||||
line_errors.append(
|
||||
dict(
|
||||
loc=("edges", idx),
|
||||
type="missing_node",
|
||||
msg="Edge references missing node",
|
||||
input=edge.model_dump(mode="python"),
|
||||
ctx={"edge_id": edge.id, "endpoint": endpoint},
|
||||
)
|
||||
)
|
||||
|
||||
if line_errors:
|
||||
raise ValidationError.from_exception_data(
|
||||
title="ReactFlowDTO validation failed",
|
||||
line_errors=line_errors,
|
||||
)
|
||||
|
||||
return self
|
||||
16
api/services/workflow/errors.py
Normal file
16
api/services/workflow/errors.py
Normal file
|
|
@ -0,0 +1,16 @@
|
|||
# api/services/workflow/errors.py
|
||||
from enum import Enum
|
||||
from typing import TypedDict
|
||||
|
||||
|
||||
class ItemKind(str, Enum):
|
||||
node = "node"
|
||||
edge = "edge"
|
||||
workflow = "workflow"
|
||||
|
||||
|
||||
class WorkflowError(TypedDict):
|
||||
kind: ItemKind # "node" | "edge"
|
||||
id: str | None # nodeId or edgeId
|
||||
field: str | None # “data.prompt”, “position.x”, … (optional)
|
||||
message: str # human-readable text
|
||||
939
api/services/workflow/pipecat_engine.py
Normal file
939
api/services/workflow/pipecat_engine.py
Normal file
|
|
@ -0,0 +1,939 @@
|
|||
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Optional, Union
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
CancelFrame,
|
||||
EndFrame,
|
||||
FunctionCallResultProperties,
|
||||
LLMFullResponseEndFrame,
|
||||
LLMFullResponseStartFrame,
|
||||
TTSSpeakFrame,
|
||||
)
|
||||
from pipecat.pipeline.task import PipelineTask
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContextFrame
|
||||
from pipecat.services.llm_service import FunctionCallParams
|
||||
from pipecat.services.openai.llm import OpenAILLMContext
|
||||
from pipecat.transports.base_transport import BaseTransport
|
||||
from pipecat.utils.enums import EndTaskReason
|
||||
|
||||
from api.constants import VOICEMAIL_RECORDING_DURATION
|
||||
from api.services.gender.gender_service import GenderService
|
||||
from api.services.workflow.disposition_mapper import (
|
||||
apply_disposition_mapping,
|
||||
get_organization_id_from_workflow_run,
|
||||
)
|
||||
from api.services.workflow.pipecat_engine_voicemail_detector import (
|
||||
VoicemailDetector,
|
||||
)
|
||||
from api.services.workflow.workflow import Node, WorkflowGraph
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pipecat.processors.audio.audio_buffer_processor import AudioBuffer
|
||||
from pipecat.services.anthropic.llm import AnthropicLLMService
|
||||
from pipecat.services.google.llm import GoogleLLMService
|
||||
from pipecat.services.openai.llm import OpenAILLMService
|
||||
|
||||
from api.services.telephony.stasis_rtp_connection import StasisRTPConnection
|
||||
|
||||
LLMService = Union[OpenAILLMService, AnthropicLLMService, GoogleLLMService]
|
||||
|
||||
import asyncio
|
||||
|
||||
from loguru import logger
|
||||
from pipecat.processors.filters.stt_mute_filter import STTMuteFilter
|
||||
from pipecat.utils.tracing.context_registry import get_current_turn_context
|
||||
|
||||
from api.services.workflow import pipecat_engine_callbacks as engine_callbacks
|
||||
from api.services.workflow.pipecat_engine_utils import (
|
||||
get_function_schema,
|
||||
render_template,
|
||||
update_llm_context,
|
||||
)
|
||||
from api.services.workflow.pipecat_engine_variable_extractor import (
|
||||
VariableExtractionManager,
|
||||
)
|
||||
from api.services.workflow.tools.calculator import get_calculator_tools, safe_calculator
|
||||
from api.services.workflow.tools.timezone import (
|
||||
convert_time,
|
||||
get_current_time,
|
||||
get_time_tools,
|
||||
)
|
||||
|
||||
|
||||
class PipecatEngine:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
task: Optional[PipelineTask] = None,
|
||||
llm: Optional["LLMService"] = None,
|
||||
context: Optional[OpenAILLMContext] = None,
|
||||
tts: Optional[Any] = None,
|
||||
transport: Optional[BaseTransport] = None,
|
||||
workflow: WorkflowGraph,
|
||||
call_context_vars: dict,
|
||||
audio_buffer: Optional["AudioBuffer"] = None,
|
||||
workflow_run_id: Optional[int] = None,
|
||||
):
|
||||
self.task = task
|
||||
self.llm = llm
|
||||
self.context = context
|
||||
self.tts = tts
|
||||
self.transport = transport
|
||||
self.workflow = workflow
|
||||
self._call_context_vars = call_context_vars
|
||||
self._audio_buffer = audio_buffer
|
||||
self._workflow_run_id = workflow_run_id
|
||||
self._initialized = False
|
||||
self._pending_function_calls = 0
|
||||
self._current_node: Optional[Node] = None
|
||||
self._gathered_context: dict = {}
|
||||
self._user_response_timeout_task: Optional[asyncio.Task] = None
|
||||
self._call_disposition: Optional[str] = None
|
||||
|
||||
# Stasis connection for immediate transfers
|
||||
self._stasis_connection: Optional["StasisRTPConnection"] = None
|
||||
|
||||
# Will be set later in initialize() when we have
|
||||
# access to _context
|
||||
self._variable_extraction_manager = None
|
||||
|
||||
self._gender_service = GenderService(confidence_threshold=0.5)
|
||||
|
||||
# Voicemail detection state
|
||||
self._detect_voicemail = False
|
||||
self._voicemail_detector = None
|
||||
self._voicemail_detection_task: Optional[asyncio.Task] = None
|
||||
|
||||
# This transition is generated by the llm as part of tool call. This can
|
||||
# also be accompanied with some content which can be played using TTS. If the
|
||||
# bot is interrupted, we would cancel this transition (we do cancel this currently when
|
||||
# the next generation starts in handle_generation_started callback handler.)
|
||||
self._pending_generated_transition_after_context_push: Optional[
|
||||
Callable[[], Awaitable[None]]
|
||||
] = None
|
||||
|
||||
# This is the transtion which is typically programmatic transition, and not goes as
|
||||
# tool call to LLM. This is not interrupted by the user and is done on context push
|
||||
self._pending_control_transition_after_context_push: Optional[
|
||||
Callable[[], Awaitable[None]]
|
||||
] = None
|
||||
|
||||
# Flag to determine if the current llm generation has a text completion
|
||||
self._defer_context_push: bool = False
|
||||
|
||||
# Lazy loaded built-in function schemas
|
||||
self._builtin_function_schemas: Optional[list[dict]] = None
|
||||
|
||||
# Flag to control whether to queue context frame
|
||||
self._queue_context_frame: bool = True
|
||||
|
||||
# Track current LLM reference text for TTS aggregation correction
|
||||
self._current_llm_reference_text: str = ""
|
||||
|
||||
@property
|
||||
def builtin_function_schemas(self) -> list[dict]:
|
||||
"""Get built-in function schemas (calculator and timezone tools)."""
|
||||
if self._builtin_function_schemas is None:
|
||||
self._builtin_function_schemas = []
|
||||
|
||||
# Transform calculator tools to get_function_schema format
|
||||
for tool in get_calculator_tools():
|
||||
func = tool["function"]
|
||||
schema = get_function_schema(
|
||||
func["name"],
|
||||
func["description"],
|
||||
properties=func["parameters"]["properties"],
|
||||
required=func["parameters"]["required"],
|
||||
)
|
||||
self._builtin_function_schemas.append(schema)
|
||||
|
||||
# Transform timezone tools to get_function_schema format
|
||||
for tool in get_time_tools():
|
||||
func = tool["function"]
|
||||
schema = get_function_schema(
|
||||
func["name"],
|
||||
func["description"],
|
||||
properties=func["parameters"]["properties"],
|
||||
required=func["parameters"]["required"],
|
||||
)
|
||||
self._builtin_function_schemas.append(schema)
|
||||
|
||||
return self._builtin_function_schemas
|
||||
|
||||
async def initialize(self):
|
||||
# TODO: May be set_node in a separate task so that we return from initialize immediately
|
||||
if self._initialized:
|
||||
logger.warning(f"{self.__class__.__name__} already initialized")
|
||||
return
|
||||
try:
|
||||
self._initialized = True
|
||||
|
||||
# Helper that encapsulates variable extraction logic
|
||||
self._variable_extraction_manager = VariableExtractionManager(self)
|
||||
|
||||
# Add current time in EST (America/New_York) to gathered context
|
||||
try:
|
||||
est_time_result = get_current_time("America/New_York")
|
||||
# The get_current_time utility returns a dict with 'datetime' field
|
||||
# Store the ISO formatted datetime string under the key 'time'
|
||||
self._gathered_context["time"] = est_time_result.get("datetime")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to fetch current EST time: {e}")
|
||||
|
||||
# Register built-in functions with the LLM
|
||||
await self._register_builtin_functions()
|
||||
|
||||
# Set gender in initial context predicted from first name
|
||||
if "first_name" in self._call_context_vars:
|
||||
salutation = await self._gender_service.get_salutation(
|
||||
self._call_context_vars["first_name"]
|
||||
)
|
||||
self._call_context_vars["salutation"] = salutation
|
||||
|
||||
await self.set_node(self.workflow.start_node_id)
|
||||
logger.debug(f"{self.__class__.__name__} initialized")
|
||||
except Exception as e:
|
||||
logger.error(f"Error initializing {self.__class__.__name__}: {e}")
|
||||
raise
|
||||
|
||||
def _get_function_schema(self, function_name: str, description: str):
|
||||
"""Thin wrapper around utils.get_function_schema for backwards compatibility."""
|
||||
|
||||
return get_function_schema(function_name, description)
|
||||
|
||||
async def _update_llm_context(self, system_message: dict, functions: list[dict]):
|
||||
"""Delegate context update to the shared workflow.utils implementation."""
|
||||
|
||||
update_llm_context(self.context, system_message, functions)
|
||||
|
||||
def _format_prompt(self, prompt: str) -> str:
|
||||
"""Delegate prompt formatting to the shared workflow.utils implementation."""
|
||||
|
||||
return render_template(prompt, self._call_context_vars)
|
||||
|
||||
async def _create_transition_func(self, name: str, transition_to_node: str):
|
||||
async def transition_func(function_call_params: FunctionCallParams) -> None:
|
||||
"""Inner function that handles the actual tool invocation."""
|
||||
try:
|
||||
# Track pending function call
|
||||
self._pending_function_calls += 1
|
||||
logger.debug(
|
||||
f"Function call pending: {function_call_params.function_name} (total: {self._pending_function_calls})"
|
||||
)
|
||||
|
||||
# For edge functions, prevent LLM completion until transition (run_llm=False)
|
||||
# For node functions, allow immediate completion (run_llm=True)
|
||||
async def on_context_updated() -> None:
|
||||
"""
|
||||
Framework will run this function after the function call result has been updated in the context.
|
||||
This way, when we do set_node from within this function, and go for LLM completion with updated
|
||||
system prompts, the context is updated with function call result.
|
||||
"""
|
||||
self._pending_function_calls -= 1
|
||||
# Perform variable extraction before transitioning to new node
|
||||
await self._perform_variable_extraction_if_needed(
|
||||
self._current_node
|
||||
)
|
||||
await self.set_node(transition_to_node)
|
||||
|
||||
result = {"status": "done"}
|
||||
|
||||
properties = FunctionCallResultProperties(
|
||||
run_llm=False,
|
||||
on_context_updated=on_context_updated,
|
||||
)
|
||||
|
||||
async def _invoke_result_callback():
|
||||
"""
|
||||
Functions are executed immediately when they come from LLM as part of text completion.
|
||||
But, if the LLM completion also has some text, we would want to not call the function if the user interrupts the speech.
|
||||
We would also not want the function to be added to context, so that the LLM can call the function again. Hence, we
|
||||
defer the function invocation until we receive on_context_updated callback, i.e the bot has finished speaking
|
||||
the text that was generated.
|
||||
"""
|
||||
await function_call_params.result_callback(
|
||||
result, properties=properties
|
||||
)
|
||||
|
||||
if self._defer_context_push:
|
||||
"""
|
||||
We set the flag to _defer_context_push when we receive text in the current generation from LLM.
|
||||
This is set in the handle_llm_generated_text callback handler.
|
||||
"""
|
||||
logger.debug(
|
||||
"Deferring transition function result until context push"
|
||||
)
|
||||
# Only one deferred transition should exist at any time.
|
||||
# Overwrite if one is somehow already set (unexpected).
|
||||
self._pending_generated_transition_after_context_push = (
|
||||
_invoke_result_callback
|
||||
)
|
||||
else:
|
||||
"""
|
||||
If there was no text in the current generation, and we only had function call,
|
||||
lets invoke the result callback, so that framework can call on_context_updated and
|
||||
we can do switch node.
|
||||
"""
|
||||
await _invoke_result_callback()
|
||||
except Exception as e:
|
||||
logger.error(f"Error in transition function {name}: {str(e)}")
|
||||
self._pending_function_calls = 0
|
||||
error_result = {"status": "error", "error": str(e)}
|
||||
await function_call_params.result_callback(error_result)
|
||||
|
||||
return transition_func
|
||||
|
||||
async def _register_transition_function_with_llm(
|
||||
self, name: str, transition_to_node: str
|
||||
):
|
||||
logger.debug(
|
||||
f"Registering function {name} to transition to node {transition_to_node} with LLM"
|
||||
)
|
||||
|
||||
# Create transition function
|
||||
transition_func = await self._create_transition_func(name, transition_to_node)
|
||||
|
||||
# Register function with LLM
|
||||
self.llm.register_function(
|
||||
name,
|
||||
transition_func,
|
||||
cancel_on_interruption=True,
|
||||
)
|
||||
|
||||
async def _register_builtin_functions(self):
|
||||
"""Register built-in functions (calculator and timezone) with the LLM."""
|
||||
logger.debug("Registering built-in functions with LLM")
|
||||
|
||||
properties = FunctionCallResultProperties(run_llm=True)
|
||||
|
||||
# Register calculator function
|
||||
async def calculate_func(function_call_params: FunctionCallParams) -> None:
|
||||
try:
|
||||
expr = function_call_params.arguments.get("expression", "")
|
||||
result = safe_calculator(expr)
|
||||
await function_call_params.result_callback(
|
||||
{"expression": expr, "result": result}, properties=properties
|
||||
)
|
||||
except Exception as e:
|
||||
await function_call_params.result_callback(
|
||||
{"error": str(e)}, properties=properties
|
||||
)
|
||||
|
||||
# Register timezone functions
|
||||
async def get_current_time_func(
|
||||
function_call_params: FunctionCallParams,
|
||||
) -> None:
|
||||
try:
|
||||
timezone = function_call_params.arguments.get("timezone", "UTC")
|
||||
result = get_current_time(timezone)
|
||||
await function_call_params.result_callback(
|
||||
result, properties=properties
|
||||
)
|
||||
except Exception as e:
|
||||
await function_call_params.result_callback(
|
||||
{"error": str(e)}, properties=properties
|
||||
)
|
||||
|
||||
async def convert_time_func(function_call_params: FunctionCallParams) -> None:
|
||||
try:
|
||||
result = convert_time(
|
||||
function_call_params.arguments.get("source_timezone"),
|
||||
function_call_params.arguments.get("time"),
|
||||
function_call_params.arguments.get("target_timezone"),
|
||||
)
|
||||
await function_call_params.result_callback(
|
||||
result, properties=properties
|
||||
)
|
||||
except Exception as e:
|
||||
await function_call_params.result_callback(
|
||||
{"error": str(e)}, properties=properties
|
||||
)
|
||||
|
||||
# Register all built-in functions
|
||||
self.llm.register_function("safe_calculator", calculate_func)
|
||||
self.llm.register_function("get_current_time", get_current_time_func)
|
||||
self.llm.register_function("convert_time", convert_time_func)
|
||||
|
||||
async def _queue_tts_response(self, text: str) -> None:
|
||||
"""Queue TTS frames for static text response."""
|
||||
await self.task.queue_frames(
|
||||
[
|
||||
LLMFullResponseStartFrame(),
|
||||
TTSSpeakFrame(text=text),
|
||||
LLMFullResponseEndFrame(),
|
||||
]
|
||||
)
|
||||
|
||||
async def _setup_static_start_node_transition(self, node: Node) -> None:
|
||||
"""Set up the deferred transition for static start nodes."""
|
||||
if not node.out_edges:
|
||||
return
|
||||
|
||||
next_node_id = node.out_edges[0].target
|
||||
|
||||
if not node.wait_for_user_response:
|
||||
# Normal static start node - transition immediately after context push
|
||||
async def _deferred_static_transition():
|
||||
try:
|
||||
await self.set_node(next_node_id)
|
||||
except Exception as exc:
|
||||
logger.error(
|
||||
f"Error executing deferred static node transition to {next_node_id}: {exc}"
|
||||
)
|
||||
|
||||
self._pending_control_transition_after_context_push = (
|
||||
_deferred_static_transition
|
||||
)
|
||||
|
||||
async def _perform_variable_extraction_if_needed(
|
||||
self, previous_node: Optional[Node]
|
||||
) -> None:
|
||||
"""Perform variable extraction if the previous node had extraction enabled."""
|
||||
if (
|
||||
previous_node
|
||||
and previous_node.extraction_enabled
|
||||
and previous_node.extraction_variables
|
||||
):
|
||||
logger.debug(
|
||||
f"Scheduling background variable extraction for node: {previous_node.name}"
|
||||
)
|
||||
|
||||
# Capture the current turn context before creating the background task
|
||||
parent_context = get_current_turn_context()
|
||||
extraction_prompt = self._format_prompt(previous_node.extraction_prompt)
|
||||
extraction_variables = previous_node.extraction_variables
|
||||
|
||||
async def _background_extraction():
|
||||
try:
|
||||
extracted_data = (
|
||||
await self._variable_extraction_manager._perform_extraction(
|
||||
extraction_variables, parent_context, extraction_prompt
|
||||
)
|
||||
)
|
||||
self._gathered_context.update(extracted_data)
|
||||
logger.debug(
|
||||
f"Background variable extraction completed. Extracted: {extracted_data}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error during background variable extraction: {str(e)}"
|
||||
)
|
||||
|
||||
# Fire and forget - extraction happens in background without blocking
|
||||
asyncio.create_task(_background_extraction())
|
||||
|
||||
async def _setup_llm_context_and_start_generation(self, node: Node) -> None:
|
||||
"""Common method to set up LLM context and queue context frame for non-static nodes."""
|
||||
# Set node name for tracing
|
||||
try:
|
||||
self.context.set_node_name(node.name)
|
||||
except AttributeError:
|
||||
logger.warning(f"context has no set_node_name method")
|
||||
|
||||
# Register transition functions if not an end node
|
||||
if not node.is_end:
|
||||
for outgoing_edge in node.out_edges:
|
||||
await self._register_transition_function_with_llm(
|
||||
outgoing_edge.get_function_name(), outgoing_edge.target
|
||||
)
|
||||
|
||||
# Set up system message and functions
|
||||
(
|
||||
system_message,
|
||||
functions,
|
||||
) = await self._compose_system_message_functions_for_node(node)
|
||||
await self._update_llm_context(system_message, functions)
|
||||
|
||||
# Queue context frame if needed
|
||||
if self._queue_context_frame:
|
||||
await self.task.queue_frame(OpenAILLMContextFrame(self.context))
|
||||
else:
|
||||
logger.debug(
|
||||
f"Not queueing context frame for node: {node.name} as _queue_context_frame is False"
|
||||
)
|
||||
|
||||
# Reset _queue_context_frame as default behavior
|
||||
self._queue_context_frame = True
|
||||
|
||||
async def set_node(self, node_id: str):
|
||||
"""
|
||||
Simplified set_node implementation according to v2 PRD.
|
||||
"""
|
||||
node = self.workflow.nodes[node_id]
|
||||
|
||||
logger.debug(
|
||||
f"Executing node: name: {node.name} is_static: {node.is_static} allow_interrupt: {node.allow_interrupt} is_end: {node.is_end}"
|
||||
)
|
||||
|
||||
# Set current node for all nodes (including static ones) so STT mute filter works
|
||||
self._current_node = node
|
||||
|
||||
# Handle start nodes
|
||||
if node.is_start:
|
||||
await self._handle_start_node(node)
|
||||
# Handle end nodes
|
||||
elif node.is_end:
|
||||
await self._handle_end_node(node)
|
||||
# Handle normal agent nodes
|
||||
else:
|
||||
await self._handle_agent_node(node)
|
||||
|
||||
async def _handle_start_node(self, node: Node) -> None:
|
||||
"""Handle start node execution."""
|
||||
# Handle voicemail detection setup (before any returns)
|
||||
if node.detect_voicemail:
|
||||
if not self._audio_buffer:
|
||||
logger.warning(
|
||||
"Voicemail detection enabled but no audio buffer available - skipping detection"
|
||||
)
|
||||
else:
|
||||
logger.debug(
|
||||
"Start node has detect_voicemail enabled - setting up audio-based detector"
|
||||
)
|
||||
self._detect_voicemail = True
|
||||
|
||||
self._voicemail_detector = VoicemailDetector(
|
||||
detection_duration=VOICEMAIL_RECORDING_DURATION,
|
||||
workflow_run_id=self._workflow_run_id,
|
||||
)
|
||||
|
||||
# Register audio handler on the audio buffer input processor
|
||||
audio_input = self._audio_buffer.input()
|
||||
|
||||
@audio_input.event_handler("on_input_audio_data")
|
||||
async def handle_voicemail_audio(
|
||||
processor, pcm, sample_rate, num_channels
|
||||
):
|
||||
if (
|
||||
self._voicemail_detector
|
||||
and self._voicemail_detector.is_detecting
|
||||
):
|
||||
await self._voicemail_detector.handle_audio_data(
|
||||
processor, pcm, sample_rate, num_channels
|
||||
)
|
||||
|
||||
# Start detection
|
||||
await self._voicemail_detector.start_detection(self)
|
||||
|
||||
# Check if delayed start is enabled
|
||||
if node.delayed_start:
|
||||
# Use configured duration or default to 3 seconds
|
||||
delay_duration = node.delayed_start_duration or 2.0
|
||||
logger.debug(
|
||||
f"Delayed start enabled - waiting {delay_duration} seconds before speaking"
|
||||
)
|
||||
await asyncio.sleep(delay_duration)
|
||||
|
||||
if node.is_static:
|
||||
# Queue TTS for static start node
|
||||
formatted_prompt = self._format_prompt(node.prompt)
|
||||
await self._queue_tts_response(formatted_prompt)
|
||||
|
||||
# Set up deferred transition for static start nodes
|
||||
await self._setup_static_start_node_transition(node)
|
||||
else:
|
||||
# Start generation for non-static start node
|
||||
await self._setup_llm_context_and_start_generation(node)
|
||||
|
||||
async def _handle_end_node(self, node: Node) -> None:
|
||||
"""Handle end node execution."""
|
||||
if node.is_static:
|
||||
# Queue TTS for static end node
|
||||
formatted_prompt = self._format_prompt(node.prompt)
|
||||
await self._queue_tts_response(formatted_prompt)
|
||||
else:
|
||||
# Start generation for non-static end node
|
||||
await self._setup_llm_context_and_start_generation(node)
|
||||
|
||||
# If this end node has extraction enabled, perform extraction immediately
|
||||
if node.extraction_enabled and node.extraction_variables:
|
||||
await self._perform_variable_extraction_if_needed(node)
|
||||
|
||||
# TODO: Extract disposition code from extracted variables
|
||||
# Defer send_end_task_frame using _pending_control_transition_after_context_push
|
||||
|
||||
# Decide the end-task reason dynamically depending on call_disposition.
|
||||
async def _deferred_end_task():
|
||||
# call_disposition is the disposition which is generated from
|
||||
# llm call based on the conversation so far.
|
||||
# TODO: Make this more generic based on configuration or llm prompting
|
||||
disposition = self._gathered_context.get("call_disposition")
|
||||
if disposition == "XFER":
|
||||
reason = EndTaskReason.USER_QUALIFIED.value
|
||||
else:
|
||||
reason = EndTaskReason.USER_DISQUALIFIED.value
|
||||
await self.send_end_task_frame(reason)
|
||||
|
||||
self._pending_control_transition_after_context_push = _deferred_end_task
|
||||
|
||||
async def _handle_agent_node(self, node: Node) -> None:
|
||||
"""Handle agent node execution."""
|
||||
if node.is_static:
|
||||
# Queue TTS for static agent node
|
||||
formatted_prompt = self._format_prompt(node.prompt)
|
||||
await self._queue_tts_response(formatted_prompt)
|
||||
|
||||
# Set up deferred transition for static agent nodes
|
||||
await self._setup_agent_node_transition(node)
|
||||
else:
|
||||
# Set context and functions for non-static agent node
|
||||
await self._setup_llm_context_and_start_generation(node)
|
||||
|
||||
async def _setup_agent_node_transition(self, node: Node) -> None:
|
||||
"""Set up the deferred transition for static agent nodes."""
|
||||
if not node.out_edges:
|
||||
return
|
||||
|
||||
next_node_id = node.out_edges[0].target
|
||||
|
||||
async def _deferred_static_transition():
|
||||
try:
|
||||
await self.set_node(next_node_id)
|
||||
except Exception as exc:
|
||||
logger.error(
|
||||
f"Error executing deferred static node transition to {next_node_id}: {exc}"
|
||||
)
|
||||
|
||||
self._pending_control_transition_after_context_push = (
|
||||
_deferred_static_transition
|
||||
)
|
||||
|
||||
async def send_end_task_frame(
|
||||
self,
|
||||
reason: str,
|
||||
additional_metadata: dict = None,
|
||||
abort_immediately: bool = False,
|
||||
):
|
||||
"""
|
||||
Centralized method to send EndTaskFrame with metadata including
|
||||
call_transfer_context and call_context_vars
|
||||
"""
|
||||
frame_to_push = CancelFrame() if abort_immediately else EndFrame()
|
||||
|
||||
# Customer disposition code using their mapping
|
||||
mapped_disposition = ""
|
||||
|
||||
# Apply disposition mapping - first try call_disposition if it is,
|
||||
# extracted from the call conversation then fall back to reason
|
||||
call_disposition = self._gathered_context.get("call_disposition", "")
|
||||
organization_id = await get_organization_id_from_workflow_run(
|
||||
self._workflow_run_id
|
||||
)
|
||||
|
||||
if call_disposition:
|
||||
# If call_disposition exists, map it
|
||||
mapped_disposition = await apply_disposition_mapping(
|
||||
call_disposition, organization_id
|
||||
)
|
||||
# Store the original and mapped values
|
||||
self._gathered_context["extracted_call_disposition"] = call_disposition
|
||||
self._gathered_context["call_disposition"] = mapped_disposition
|
||||
else:
|
||||
# Otherwise, map the disconnect reason
|
||||
mapped_disposition = await apply_disposition_mapping(
|
||||
reason, organization_id
|
||||
)
|
||||
# Store the mapped disconnect reason
|
||||
self._gathered_context["call_disposition"] = mapped_disposition
|
||||
|
||||
# TODO: Generalise this, currently tailored to Kapil's use case
|
||||
self._gathered_context["address"] = ", ".join(
|
||||
[
|
||||
self._call_context_vars.get("address1", ""),
|
||||
self._call_context_vars.get("address2", ""),
|
||||
self._call_context_vars.get("address3", ""),
|
||||
self._call_context_vars.get("city", ""),
|
||||
self._call_context_vars.get("state", ""),
|
||||
self._call_context_vars.get("province", ""),
|
||||
self._call_context_vars.get("postal_code", ""),
|
||||
]
|
||||
)
|
||||
self._gathered_context["full_name"] = " ".join(
|
||||
[
|
||||
self._call_context_vars.get("first_name", ""),
|
||||
self._call_context_vars.get("middle_initial", ""),
|
||||
self._call_context_vars.get("last_name", ""),
|
||||
]
|
||||
)
|
||||
self._gathered_context["agent_name"] = "Alex"
|
||||
self._gathered_context["customer_phone_number"] = self._call_context_vars.get(
|
||||
"phone", ""
|
||||
)
|
||||
self._gathered_context["timezone"] = self._call_context_vars.get("province", "")
|
||||
self._gathered_context["vendor_id"] = self._call_context_vars.get(
|
||||
"vendor_lead_code", ""
|
||||
)
|
||||
|
||||
decision_maker = self._gathered_context.get("primary_cardholder", False)
|
||||
employment_status = self._gathered_context.get("employment_status", "N/A")
|
||||
call_transfer_context = {
|
||||
"first_name": self._call_context_vars.get("first_name", ""),
|
||||
"full_name": self._gathered_context.get("full_name", ""),
|
||||
"phone": self._call_context_vars.get("phone", ""),
|
||||
"lead_id": self._call_context_vars.get("lead_id"),
|
||||
"disposition": mapped_disposition,
|
||||
"agent_name": self._gathered_context.get("agent_name", "Alex"),
|
||||
"decision_maker": str(decision_maker),
|
||||
"employment": employment_status.title() if employment_status else "N/A",
|
||||
"debts": self._gathered_context.get("total_debt", "N/A"),
|
||||
"number_of_credit_cards": self._gathered_context.get(
|
||||
"number_of_credit_cards", "N/A"
|
||||
),
|
||||
"time": self._gathered_context.get("time"),
|
||||
}
|
||||
|
||||
logger.debug(
|
||||
f"gathered_context: {self._gathered_context} call_transfer_context: {call_transfer_context}"
|
||||
)
|
||||
|
||||
# Initiate immediate transfer for Stasis connections when user is qualified
|
||||
if (
|
||||
reason == EndTaskReason.USER_QUALIFIED.value
|
||||
and self._stasis_connection is not None
|
||||
and not abort_immediately
|
||||
):
|
||||
try:
|
||||
logger.info(
|
||||
f"Initiating immediate Stasis transfer for channel {self._stasis_connection.channel_id}"
|
||||
)
|
||||
await self._stasis_connection.transfer(call_transfer_context)
|
||||
logger.info("Immediate transfer initiated successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initiate immediate transfer: {e}")
|
||||
# Continue with normal flow even if immediate transfer fails
|
||||
|
||||
if reason == EndTaskReason.CALL_DURATION_EXCEEDED.value:
|
||||
await self.task.queue_frame(
|
||||
TTSSpeakFrame(
|
||||
"Sorry! It seems like our time has exceeded. Someone from our team will reach out to you soon. Thank you!"
|
||||
)
|
||||
)
|
||||
|
||||
metadata = {
|
||||
# Keep original reason in metadata, which would be used to decide
|
||||
# whether to disconnect or to transfer the call in the transport
|
||||
"reason": reason,
|
||||
"call_transfer_context": call_transfer_context,
|
||||
}
|
||||
|
||||
# Add any additional metadata
|
||||
if additional_metadata:
|
||||
metadata.update(additional_metadata)
|
||||
|
||||
frame_to_push.metadata = metadata
|
||||
|
||||
# Store the original reason for later retrieval in event handler
|
||||
self._call_disposition = mapped_disposition
|
||||
|
||||
logger.debug(
|
||||
f"Finishing run with reason: {reason}, disposition: {mapped_disposition} queueing frame {frame_to_push}"
|
||||
)
|
||||
await self.task.queue_frame(frame_to_push)
|
||||
|
||||
async def _compose_system_message_functions_for_node(
|
||||
self, node: "Node"
|
||||
) -> tuple[list[dict], list[dict]]:
|
||||
"""Generate the system messages and function schemas for the given node.
|
||||
|
||||
This performs the same formatting logic used when entering a node but
|
||||
does **not** register the functions with the LLM; callers are
|
||||
responsible for that.
|
||||
"""
|
||||
|
||||
global_prompt = ""
|
||||
if self.workflow.global_node_id and node.add_global_prompt:
|
||||
global_node = self.workflow.nodes[self.workflow.global_node_id]
|
||||
global_prompt = self._format_prompt(global_node.prompt)
|
||||
|
||||
functions: list[dict] = []
|
||||
|
||||
# Add built-in function schemas (calculator and timezone tools)
|
||||
functions.extend(self.builtin_function_schemas)
|
||||
|
||||
# Transition functions (schema only; registration handled elsewhere)
|
||||
for outgoing_edge in node.out_edges:
|
||||
function_schema = self._get_function_schema(
|
||||
outgoing_edge.get_function_name(), outgoing_edge.condition
|
||||
)
|
||||
functions.append(function_schema)
|
||||
|
||||
formatted_node_prompt = self._format_prompt(node.prompt)
|
||||
|
||||
system_message = {
|
||||
"role": "system",
|
||||
"content": "\n\n".join(
|
||||
p for p in (global_prompt, formatted_node_prompt) if p
|
||||
),
|
||||
}
|
||||
|
||||
return system_message, functions
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Pending transition handling
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def flush_pending_transitions(self, *, source: str = "context_push"):
|
||||
"""Execute and clear any pending transitions.
|
||||
|
||||
Args:
|
||||
source: Indicates the trigger that caused this flush:
|
||||
- "context_push": the assistant context aggregator completed a push.
|
||||
"""
|
||||
|
||||
if source != "context_push":
|
||||
raise ValueError("Invalid flush source – expected 'context_push'")
|
||||
|
||||
len_pending_functions = 0
|
||||
|
||||
if self._pending_generated_transition_after_context_push is not None:
|
||||
len_pending_functions += 1
|
||||
if self._pending_control_transition_after_context_push is not None:
|
||||
len_pending_functions += 1
|
||||
|
||||
# Nothing to do
|
||||
if len_pending_functions == 0:
|
||||
return
|
||||
|
||||
logger.debug(
|
||||
f"Flushing {len_pending_functions} pending transition(s) after {source.replace('_', ' ')}"
|
||||
)
|
||||
|
||||
# Generated transition
|
||||
if self._pending_generated_transition_after_context_push is not None:
|
||||
pending_cb = self._pending_generated_transition_after_context_push
|
||||
self._pending_generated_transition_after_context_push = None
|
||||
try:
|
||||
await pending_cb()
|
||||
except Exception as exc: # pragma: no cover
|
||||
logger.error(f"Error executing deferred transition: {exc}")
|
||||
|
||||
# Control transition (context push)
|
||||
if self._pending_control_transition_after_context_push is not None:
|
||||
logger.debug("Executing control transition after context push")
|
||||
static_cb = self._pending_control_transition_after_context_push
|
||||
self._pending_control_transition_after_context_push = None
|
||||
try:
|
||||
await static_cb()
|
||||
except Exception as exc: # pragma: no cover
|
||||
logger.error(f"Error executing deferred static node transition: {exc}")
|
||||
|
||||
def create_should_mute_callback(self) -> Callable[[STTMuteFilter], Awaitable[bool]]:
|
||||
"""
|
||||
This callback is called by STTMuteFilter to determine if the STT should be muted.
|
||||
"""
|
||||
return engine_callbacks.create_should_mute_callback(self)
|
||||
|
||||
def create_user_idle_callback(self):
|
||||
"""
|
||||
This callback is called when the user is idle for a certain duration.
|
||||
We use this to either play the static text or end the call
|
||||
"""
|
||||
return engine_callbacks.create_user_idle_callback(self)
|
||||
|
||||
def create_max_duration_callback(self):
|
||||
"""
|
||||
This callback is called when the call duration exceeds the max duration.
|
||||
We use this to send the EndTaskFrame.
|
||||
"""
|
||||
return engine_callbacks.create_max_duration_callback(self)
|
||||
|
||||
def create_llm_generated_text_callback(self):
|
||||
"""
|
||||
This callback is called when some text is generated by the LLM.
|
||||
We use this to defer the result_callback of the node transition functions if
|
||||
there is set_node called along with some text generated. This way, we will
|
||||
have the context sent in the next generation from new node.
|
||||
"""
|
||||
return engine_callbacks.create_llm_generated_text_callback(self)
|
||||
|
||||
def create_generation_started_callback(self):
|
||||
"""
|
||||
This callback is called when a new generation starts.
|
||||
This is used to reset the flags that control the flow of the engine.
|
||||
"""
|
||||
return engine_callbacks.create_generation_started_callback(self)
|
||||
|
||||
def create_user_stopped_speaking_callback(self):
|
||||
"""
|
||||
This callback is called when the user stops speaking.
|
||||
We use this to handle transitions when wait_for_user_response is enabled.
|
||||
"""
|
||||
return engine_callbacks.create_user_stopped_speaking_callback(self)
|
||||
|
||||
def create_user_started_speaking_callback(self):
|
||||
"""
|
||||
This callback is called when the user starts speaking.
|
||||
We use this to handle wait_for_user_greeting functionality.
|
||||
"""
|
||||
return engine_callbacks.create_user_started_speaking_callback(self)
|
||||
|
||||
def create_aggregation_correction_callback(self) -> Callable[[str], str]:
|
||||
"""Create a callback that corrects corrupted aggregation using reference text."""
|
||||
return engine_callbacks.create_aggregation_correction_callback(self)
|
||||
|
||||
def get_call_disposition(self) -> Optional[str]:
|
||||
"""Get the disconnect reason set by the engine."""
|
||||
return self._call_disposition
|
||||
|
||||
def get_gathered_context(self) -> dict:
|
||||
"""Get the gathered context including extracted variables."""
|
||||
return self._gathered_context.copy()
|
||||
|
||||
def set_context(self, context: OpenAILLMContext) -> None:
|
||||
"""Set the OpenAI LLM context.
|
||||
|
||||
This allows setting the context after the engine has been created,
|
||||
which is useful when the context needs to be created after the engine.
|
||||
"""
|
||||
self.context = context
|
||||
|
||||
def set_task(self, task: PipelineTask) -> None:
|
||||
"""Set the pipeline task.
|
||||
|
||||
This allows setting the task after the engine has been created,
|
||||
which is useful when the task needs to be created after the engine.
|
||||
"""
|
||||
self.task = task
|
||||
|
||||
def set_audio_buffer(self, audio_buffer: "AudioBuffer") -> None:
|
||||
"""Set the audio buffer.
|
||||
|
||||
This allows setting the audio buffer after the engine has been created,
|
||||
which is useful when the audio buffer needs to be created after the engine.
|
||||
"""
|
||||
self._audio_buffer = audio_buffer
|
||||
|
||||
def set_stasis_connection(
|
||||
self, connection: Optional["StasisRTPConnection"]
|
||||
) -> None:
|
||||
"""Set the Stasis RTP connection for immediate transfers.
|
||||
|
||||
This allows the engine to initiate transfers immediately when XFER
|
||||
disposition is detected, without waiting for pipeline shutdown.
|
||||
|
||||
Args:
|
||||
connection: The StasisRTPConnection instance, or None for non-Stasis transports
|
||||
"""
|
||||
self._stasis_connection = connection
|
||||
if connection:
|
||||
logger.debug(
|
||||
f"Stasis connection set for immediate transfers: {connection.channel_id}"
|
||||
)
|
||||
|
||||
async def handle_llm_text_frame(self, text: str):
|
||||
"""Accumulate LLM text frames to build reference text."""
|
||||
self._current_llm_reference_text += text
|
||||
|
||||
async def cleanup(self):
|
||||
"""Clean up engine resources on disconnect."""
|
||||
# Cancel any pending timeout tasks
|
||||
if (
|
||||
self._user_response_timeout_task
|
||||
and not self._user_response_timeout_task.done()
|
||||
):
|
||||
self._user_response_timeout_task.cancel()
|
||||
|
||||
# Stop voicemail detection if active
|
||||
if self._voicemail_detector and hasattr(
|
||||
self._voicemail_detector, "stop_detection"
|
||||
):
|
||||
await self._voicemail_detector.stop_detection()
|
||||
305
api/services/workflow/pipecat_engine_callbacks.py
Normal file
305
api/services/workflow/pipecat_engine_callbacks.py
Normal file
|
|
@ -0,0 +1,305 @@
|
|||
from __future__ import annotations
|
||||
|
||||
"""Callback factory helpers for :pyclass:`~api.services.workflow.pipecat_engine.PipecatEngine`.
|
||||
|
||||
Each helper takes a :class:`PipecatEngine` instance and returns an async
|
||||
callback function suitable for passing to the various pipeline processors.
|
||||
Separating these helpers into their own module keeps
|
||||
``pipecat_engine.py`` focused on high-level engine orchestration logic while
|
||||
encapsulating the callback implementations here for easier maintenance and
|
||||
unit-testing.
|
||||
"""
|
||||
|
||||
import re
|
||||
from typing import TYPE_CHECKING, Awaitable, Callable
|
||||
|
||||
from loguru import logger
|
||||
from pipecat.frames.frames import (
|
||||
LLMFullResponseEndFrame,
|
||||
LLMFullResponseStartFrame,
|
||||
TTSSpeakFrame,
|
||||
)
|
||||
from pipecat.processors.filters.stt_mute_filter import STTMuteFilter
|
||||
from pipecat.utils.enums import EndTaskReason
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pipecat.processors.user_idle_processor import UserIdleProcessor
|
||||
|
||||
from api.services.workflow.pipecat_engine import PipecatEngine
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# STT mute handling
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def create_should_mute_callback(
|
||||
engine: "PipecatEngine",
|
||||
) -> Callable[[STTMuteFilter], Awaitable[bool]]:
|
||||
"""Return a callback indicating whether STT should be muted.
|
||||
|
||||
STT is muted when *interruptions are **not*** allowed on the current node.
|
||||
"""
|
||||
|
||||
async def callback(_: STTMuteFilter) -> bool: # noqa: D401
|
||||
if engine._current_node is None:
|
||||
# Default to not muting if we have no active node yet.
|
||||
return False
|
||||
|
||||
logger.debug(
|
||||
f"STT mute callback: allow_interrupt={engine._current_node.allow_interrupt}"
|
||||
)
|
||||
return not engine._current_node.allow_interrupt
|
||||
|
||||
return callback
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# User-idle handling
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def create_user_idle_callback(engine: "PipecatEngine"):
|
||||
"""Return a callback that handles user-idle timeouts."""
|
||||
|
||||
async def handle_user_idle(
|
||||
user_idle: "UserIdleProcessor", retry_count: int
|
||||
) -> bool:
|
||||
logger.debug(f"Handling user_idle, attempt: {retry_count}")
|
||||
|
||||
# Check if we're on a StartNode - if yes, directly disconnect
|
||||
if engine._current_node and engine._current_node.is_start:
|
||||
logger.debug("User idle on StartNode - disconnecting immediately")
|
||||
await engine.send_end_task_frame(
|
||||
EndTaskReason.USER_IDLE_MAX_DURATION_EXCEEDED.value
|
||||
)
|
||||
return False
|
||||
|
||||
if retry_count == 1:
|
||||
# Simulate an LLM generation, so that we can have the LLM context
|
||||
# updated with the new message
|
||||
await engine.task.queue_frames(
|
||||
[
|
||||
LLMFullResponseStartFrame(),
|
||||
TTSSpeakFrame("Just checking in to see if you're still there."),
|
||||
LLMFullResponseEndFrame(),
|
||||
]
|
||||
)
|
||||
return True
|
||||
|
||||
# Second attempt: terminate the call due to inactivity.
|
||||
await user_idle.push_frame(
|
||||
TTSSpeakFrame("It seems like you're busy right now. Have a nice day!")
|
||||
)
|
||||
await engine.send_end_task_frame(
|
||||
EndTaskReason.USER_IDLE_MAX_DURATION_EXCEEDED.value
|
||||
)
|
||||
return False
|
||||
|
||||
return handle_user_idle
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Max-duration handling
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def create_max_duration_callback(engine: "PipecatEngine"):
|
||||
"""Return a callback that ends the task when the max call duration is exceeded."""
|
||||
|
||||
async def handle_max_duration():
|
||||
logger.debug("Max call duration exceeded. Terminating call")
|
||||
await engine.send_end_task_frame(EndTaskReason.CALL_DURATION_EXCEEDED.value)
|
||||
|
||||
return handle_max_duration
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# LLM-generated-text handling
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def create_llm_generated_text_callback(engine: "PipecatEngine"):
|
||||
"""Return a callback invoked when the LLM emits text (not only tool calls)."""
|
||||
|
||||
async def handle_llm_generated_text(): # noqa: D401
|
||||
logger.debug(
|
||||
"Generation has text content in current response - deferring context push from set_node"
|
||||
)
|
||||
engine._defer_context_push = True
|
||||
|
||||
return handle_llm_generated_text
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Generation-started handling
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def create_generation_started_callback(engine: "PipecatEngine"):
|
||||
"""Return a callback that resets flags at the start of each LLM generation."""
|
||||
|
||||
async def handle_generation_started(): # noqa: D401
|
||||
logger.debug("LLM generation started - resetting defer flags and tool counters")
|
||||
engine._defer_context_push = False
|
||||
engine._pending_function_calls = 0
|
||||
engine._pending_generated_transition_after_context_push = None
|
||||
# Clear reference text from previous generation
|
||||
engine._current_llm_reference_text = ""
|
||||
|
||||
return handle_generation_started
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# User-stopped-speaking handling
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def create_user_stopped_speaking_callback(engine: "PipecatEngine"):
|
||||
"""Return a callback that handles when the user stops speaking.
|
||||
|
||||
According to simplified flow:
|
||||
- For start nodes with wait_for_user_response=True:
|
||||
- Cancel timeout task if still active
|
||||
- Transition to next node with _queue_context_frame=False
|
||||
"""
|
||||
|
||||
async def handle_user_stopped_speaking():
|
||||
# Only handle if current node is a start node with wait_for_user_response
|
||||
if (
|
||||
engine._current_node
|
||||
and engine._current_node.is_start
|
||||
and engine._current_node.wait_for_user_response
|
||||
and engine._current_node.out_edges
|
||||
):
|
||||
# Cancel timeout task if it's still active
|
||||
if (
|
||||
engine._user_response_timeout_task
|
||||
and not engine._user_response_timeout_task.done()
|
||||
):
|
||||
logger.debug("Cancelling user response timeout - user responded")
|
||||
engine._user_response_timeout_task.cancel()
|
||||
engine._user_response_timeout_task = None
|
||||
|
||||
# Transition to next node
|
||||
next_node_id = engine._current_node.out_edges[0].target
|
||||
logger.debug(
|
||||
f"User stopped speaking after wait_for_user_response - transitioning to: {next_node_id}"
|
||||
)
|
||||
|
||||
# Set flag to not queue context frame since
|
||||
# it will be pushed by user context aggregator
|
||||
# we are just setting the context with next node's
|
||||
# functions and prompts
|
||||
engine._queue_context_frame = False
|
||||
|
||||
# Transition to next node
|
||||
await engine.set_node(next_node_id)
|
||||
|
||||
return handle_user_stopped_speaking
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# User-started-speaking handling
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def create_user_started_speaking_callback(engine: "PipecatEngine"):
|
||||
"""Return a callback that handles when the user starts speaking.
|
||||
|
||||
According to simplified flow:
|
||||
- For start nodes with wait_for_user_response=True:
|
||||
- Cancel the timeout timer if it exists (but don't set to None)
|
||||
"""
|
||||
|
||||
async def handle_user_started_speaking():
|
||||
# Only handle if current node is a start node with wait_for_user_response
|
||||
if (
|
||||
engine._current_node
|
||||
and engine._current_node.is_start
|
||||
and engine._current_node.wait_for_user_response
|
||||
and engine._user_response_timeout_task
|
||||
and not engine._user_response_timeout_task.done()
|
||||
):
|
||||
logger.debug(
|
||||
"User started speaking during wait_for_user_response - cancelling timeout timer"
|
||||
)
|
||||
engine._user_response_timeout_task.cancel()
|
||||
# Don't set to None here - let user_stopped_speaking handle the transition
|
||||
|
||||
return handle_user_started_speaking
|
||||
|
||||
|
||||
def create_aggregation_correction_callback(engine: "PipecatEngine"):
|
||||
"""Create a callback that uses engine's reference text to correct corrupted aggregation."""
|
||||
|
||||
def correct_corrupted_aggregation(ref: str, corrupted: str) -> str:
|
||||
"""Correct corrupted text by aligning it with reference text.
|
||||
|
||||
This is a pure function that doesn't depend on engine instance.
|
||||
"""
|
||||
# 1) Safety check: if ref (minus spaces) is shorter than corrupted, bail out
|
||||
# also if corrupted is less than 10 characters, lets also return that since most likely
|
||||
# Elevenlabs returned the right alignment
|
||||
alnum_corr = "".join(ch for ch in corrupted if ch.isalnum())
|
||||
alnum_ref = "".join(ch for ch in ref if ch.isalnum())
|
||||
|
||||
if corrupted in ref or len(alnum_ref) < len(alnum_corr) or len(alnum_corr) < 10:
|
||||
return corrupted
|
||||
|
||||
# 2) Find where in `ref` we should start aligning.
|
||||
# We take the first N (N=10) characters of `corrupted`
|
||||
# and look for all their occurrences in `ref`.
|
||||
# We pick the *last* one
|
||||
prefix = corrupted[:10]
|
||||
|
||||
# find all start‐indices of that prefix in ref
|
||||
starts = [m.start() for m in re.finditer(re.escape(prefix), ref)]
|
||||
start_idx = starts[-1] if starts else 0
|
||||
|
||||
# 3) Now run the same two‑pointer scan from start_idx
|
||||
i, j = start_idx, 0
|
||||
out_chars = []
|
||||
while i < len(ref) and j < len(corrupted):
|
||||
r_ch, c_ch = ref[i], corrupted[j]
|
||||
if r_ch == c_ch:
|
||||
out_chars.append(r_ch)
|
||||
i += 1
|
||||
j += 1
|
||||
|
||||
elif c_ch == " ":
|
||||
# extra space in corrupted → skip it
|
||||
j += 1
|
||||
|
||||
elif r_ch == " " or r_ch in ".,;:!?":
|
||||
# missing structural char in corrupted → emit from ref
|
||||
out_chars.append(r_ch)
|
||||
i += 1
|
||||
|
||||
else:
|
||||
# letter mismatch → best‑effort copy from ref
|
||||
out_chars.append(r_ch)
|
||||
i += 1
|
||||
j += 1
|
||||
|
||||
# 4) A final check - the final created output should be exactly
|
||||
# as corrupted sentence sans whitespace.
|
||||
alnum_out = "".join([ch for ch in out_chars if ch.isalnum()])
|
||||
if alnum_out != alnum_corr:
|
||||
return corrupted
|
||||
|
||||
# 5) Join and return exactly what we built
|
||||
return "".join(out_chars)
|
||||
|
||||
def correct_aggregation(corrupted: str) -> str:
|
||||
reference = engine._current_llm_reference_text
|
||||
|
||||
if not reference:
|
||||
logger.warning("No reference text available for aggregation correction")
|
||||
return corrupted
|
||||
|
||||
# Apply the correction algorithm
|
||||
corrected = correct_corrupted_aggregation(reference, corrupted)
|
||||
return corrected
|
||||
|
||||
return correct_aggregation
|
||||
90
api/services/workflow/pipecat_engine_utils.py
Normal file
90
api/services/workflow/pipecat_engine_utils.py
Normal file
|
|
@ -0,0 +1,90 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from google.genai.types import (
|
||||
Content,
|
||||
Part,
|
||||
)
|
||||
from pipecat.adapters.schemas.function_schema import FunctionSchema
|
||||
from pipecat.adapters.schemas.tools_schema import ToolsSchema
|
||||
from pipecat.services.google.llm import GoogleLLMContext
|
||||
from pipecat.services.openai.llm import OpenAILLMContext
|
||||
|
||||
from api.utils.template_renderer import render_template
|
||||
|
||||
__all__ = [
|
||||
"get_function_schema",
|
||||
"update_llm_context",
|
||||
"render_template",
|
||||
]
|
||||
|
||||
|
||||
def get_function_schema(
|
||||
function_name: str,
|
||||
description: str,
|
||||
*,
|
||||
properties: Dict[str, Any] | None = None,
|
||||
required: List[str] | None = None,
|
||||
) -> FunctionSchema:
|
||||
"""Create a FunctionSchema definition that can later be transformed into
|
||||
the provider-specific format (OpenAI, Gemini, etc.).
|
||||
|
||||
The helper keeps the public signature backward-compatible – callers that
|
||||
only pass ``function_name`` and ``description`` continue to work and will
|
||||
define a parameter-less function.
|
||||
"""
|
||||
|
||||
return FunctionSchema(
|
||||
name=function_name,
|
||||
description=description,
|
||||
properties=properties or {},
|
||||
required=required or [],
|
||||
)
|
||||
|
||||
|
||||
def update_llm_context(
|
||||
context: OpenAILLMContext,
|
||||
system_message: Dict[str, Any],
|
||||
functions: List[FunctionSchema],
|
||||
) -> None:
|
||||
"""Update *context* with an up-to-date system message and tool list.
|
||||
|
||||
This helper removes any previous system messages before inserting the new
|
||||
*system_message* at the top of the conversation history and then instructs
|
||||
the LLM which *functions* (a.k.a. tools) are currently available.
|
||||
"""
|
||||
|
||||
# Wrap the provided function schemas in a ToolsSchema so that the adapter
|
||||
# associated with the current LLM service can convert them to the correct
|
||||
# provider-specific representation when required.
|
||||
tools_schema = ToolsSchema(standard_tools=functions)
|
||||
|
||||
if isinstance(context, GoogleLLMContext):
|
||||
context.system_message = system_message["content"]
|
||||
|
||||
if functions:
|
||||
# Lets only call set_tools if we have functions, else Gemini will
|
||||
# throw an exception
|
||||
context.set_tools(tools_schema)
|
||||
|
||||
if context.messages[-1].role != "user":
|
||||
# Google expects the last message should end with user message
|
||||
context.add_message(Content(role="user", parts=[Part(text="...")]))
|
||||
return
|
||||
|
||||
# In case of OpenAILLMContext, replace the system message with incoming system message
|
||||
previous_interactions = context.messages
|
||||
|
||||
# Filter out old system messages but keep user/assistant/function content.
|
||||
messages: List[Dict[str, Any]] = [system_message]
|
||||
messages.extend(
|
||||
interaction
|
||||
for interaction in previous_interactions
|
||||
if interaction["role"] != "system"
|
||||
)
|
||||
|
||||
context.set_messages(messages)
|
||||
|
||||
if functions:
|
||||
context.set_tools(tools_schema)
|
||||
192
api/services/workflow/pipecat_engine_variable_extractor.py
Normal file
192
api/services/workflow/pipecat_engine_variable_extractor.py
Normal file
|
|
@ -0,0 +1,192 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any, List
|
||||
|
||||
from loguru import logger
|
||||
from openai import AsyncOpenAI
|
||||
from opentelemetry import trace
|
||||
from pipecat.services.openai.llm import OpenAILLMContext
|
||||
from pipecat.utils.tracing.service_attributes import add_llm_span_attributes
|
||||
|
||||
from api.services.pipecat.tracing_config import is_tracing_enabled
|
||||
from api.services.workflow.dto import ExtractionVariableDTO
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from api.services.workflow.pipecat_engine import PipecatEngine
|
||||
|
||||
|
||||
class VariableExtractionManager:
|
||||
"""Helper that registers and executes the \"extract_variables\" tool.
|
||||
|
||||
The manager is responsible for two things:
|
||||
1. Registering a callable with the LLM service so that the tool can be
|
||||
invoked from within the model.
|
||||
2. Executing the extraction in a background task while maintaining
|
||||
correct bookkeeping and optional OpenTelemetry tracing.
|
||||
"""
|
||||
|
||||
def __init__(self, engine: "PipecatEngine") -> None: # noqa: F821
|
||||
# We keep a reference to the engine so we can reuse its context
|
||||
# and update internal counters / extracted variable state.
|
||||
self._engine = engine
|
||||
self._context = engine.context
|
||||
self._model = "gpt-4o"
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Internal helpers
|
||||
# ------------------------------------------------------------------
|
||||
async def _perform_extraction(
|
||||
self,
|
||||
extraction_variables: List[ExtractionVariableDTO],
|
||||
parent_ctx: Any,
|
||||
extraction_prompt: str = "",
|
||||
) -> dict:
|
||||
"""Run the actual extraction chat completion and post-process the result."""
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Build the prompt that instructs the model to extract the variables.
|
||||
# ------------------------------------------------------------------
|
||||
vars_description = "\n".join(
|
||||
f"- {v.name} ({v.type}): {v.prompt}" for v in extraction_variables
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Build a normalized representation of the existing conversation so the
|
||||
# extractor works with both OpenAI-style (dict) messages and Google
|
||||
# Gemini `Content` objects.
|
||||
# ------------------------------------------------------------------
|
||||
def _get_role_and_content(msg: Any) -> tuple[str | None, str | None]:
|
||||
"""Return a pair of (role, content) for the given message.
|
||||
|
||||
The logic supports both OpenAI-style dict messages and Google
|
||||
`Content` objects that expose ``role`` and ``parts`` attributes.
|
||||
Only plain textual content is extracted – image parts, tool call
|
||||
placeholders, etc. are ignored for the purpose of variable
|
||||
extraction.
|
||||
"""
|
||||
|
||||
# --------------------------------------------------------------
|
||||
# OpenAI format → simple dict with ``role`` and ``content`` keys
|
||||
# --------------------------------------------------------------
|
||||
if isinstance(msg, dict):
|
||||
role = msg.get("role")
|
||||
content_field = msg.get("content")
|
||||
|
||||
# Content can be a str, list of segments, or None.
|
||||
if isinstance(content_field, str):
|
||||
content = content_field
|
||||
elif isinstance(content_field, list):
|
||||
# Collapse all text parts into a single string.
|
||||
texts = [
|
||||
segment.get("text", "")
|
||||
for segment in content_field
|
||||
if isinstance(segment, dict) and segment.get("type") == "text"
|
||||
]
|
||||
content = " ".join(texts) if texts else None
|
||||
else:
|
||||
content = None
|
||||
|
||||
return role, content
|
||||
|
||||
# --------------------------------------------------------------
|
||||
# Google Gemini format → ``Content`` object with ``parts`` list
|
||||
# --------------------------------------------------------------
|
||||
role_attr = getattr(msg, "role", None)
|
||||
parts_attr = getattr(msg, "parts", None)
|
||||
|
||||
if role_attr is None or parts_attr is None:
|
||||
return None, None # Unrecognised message format
|
||||
|
||||
role = (
|
||||
"assistant" if role_attr == "model" else role_attr
|
||||
) # Normalise role name
|
||||
|
||||
# Collect textual parts only (ignore images, function calls, etc.)
|
||||
texts: list[str] = []
|
||||
for part in parts_attr:
|
||||
text_val = getattr(part, "text", None)
|
||||
if text_val:
|
||||
texts.append(text_val)
|
||||
|
||||
content = " ".join(texts) if texts else None
|
||||
return role, content
|
||||
|
||||
conversation_lines: list[str] = []
|
||||
for msg in self._context.messages:
|
||||
role, content = _get_role_and_content(msg)
|
||||
if role in ("assistant", "user") and content:
|
||||
conversation_lines.append(f"{role}: {content}")
|
||||
|
||||
conversation_history = "\n".join(conversation_lines)
|
||||
|
||||
system_prompt = (
|
||||
"You are an assistant tasked with extracting structured data from the conversation. "
|
||||
"Return ONLY a valid JSON object with the requested variables as top-level keys. Do not wrap the JSON in markdown." # noqa: E501
|
||||
)
|
||||
# Use provided extraction_prompt as system prompt, or default
|
||||
system_prompt = (
|
||||
system_prompt + "\n\n" + extraction_prompt
|
||||
if extraction_prompt
|
||||
else system_prompt
|
||||
)
|
||||
|
||||
user_prompt = (
|
||||
"\n\nVariables to extract:\n"
|
||||
f"{vars_description}"
|
||||
"\n\nConversation history:\n"
|
||||
f"{conversation_history}"
|
||||
)
|
||||
|
||||
extraction_context = OpenAILLMContext()
|
||||
extraction_messages = [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt},
|
||||
]
|
||||
extraction_context.set_messages(extraction_messages)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Use independent OpenAI client for LLM call
|
||||
# ------------------------------------------------------------------
|
||||
client = AsyncOpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
|
||||
|
||||
# Direct API call - no pipeline involvement
|
||||
response = await client.chat.completions.create(
|
||||
model=self._model,
|
||||
messages=extraction_messages,
|
||||
temperature=0.0,
|
||||
response_format={"type": "json_object"},
|
||||
)
|
||||
|
||||
llm_response = response.choices[0].message.content
|
||||
|
||||
if is_tracing_enabled():
|
||||
tracer = trace.get_tracer("pipecat")
|
||||
with tracer.start_as_current_span(
|
||||
"variable_extraction", context=parent_ctx
|
||||
) as span:
|
||||
add_llm_span_attributes(
|
||||
span,
|
||||
service_name="OpenAILLMService",
|
||||
model=self._model,
|
||||
operation_name="variable_extraction",
|
||||
messages=json.dumps(extraction_messages),
|
||||
output=llm_response,
|
||||
stream=False,
|
||||
parameters={"temperature": 0.0, "response_format": "json_object"},
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Parse the assistant output – fall back to raw text if it is not valid JSON.
|
||||
# ------------------------------------------------------------------
|
||||
try:
|
||||
extracted = json.loads(llm_response)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(
|
||||
"Extractor returned invalid JSON; storing raw content instead."
|
||||
)
|
||||
extracted = {"raw": llm_response}
|
||||
|
||||
logger.debug(f"Extracted variables: {extracted}")
|
||||
return extracted
|
||||
448
api/services/workflow/pipecat_engine_voicemail_detector.py
Normal file
448
api/services/workflow/pipecat_engine_voicemail_detector.py
Normal file
|
|
@ -0,0 +1,448 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
import wave
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from langfuse import get_client
|
||||
from loguru import logger
|
||||
from openai import AsyncOpenAI
|
||||
from opentelemetry import context as otel_context
|
||||
from pipecat.utils.enums import EndTaskReason
|
||||
from pipecat.utils.tracing.context_registry import get_current_turn_context
|
||||
|
||||
from api.db import db_client
|
||||
from api.services.pipecat.tracing_config import is_tracing_enabled
|
||||
from api.tasks.arq import enqueue_job
|
||||
from api.tasks.function_names import FunctionNames
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from api.services.workflow.pipecat_engine import PipecatEngine
|
||||
|
||||
|
||||
DEFAULT_VOICEMAIL_PROMPT = """
|
||||
You are analyzing the beginning of a phone call to determine if it's a voicemail greeting.
|
||||
|
||||
Common voicemail indicators:
|
||||
- "You've reached the voicemail of..."
|
||||
- "Please leave a message after the beep"
|
||||
- "I'm not available right now"
|
||||
- "Press 1 to leave a message"
|
||||
- Robotic or pre-recorded voice quality mentioned
|
||||
- Background music or hold music references
|
||||
|
||||
Transcript: {transcript}
|
||||
|
||||
Respond with a JSON object:
|
||||
{
|
||||
"is_voicemail": true/false,
|
||||
"confidence": 0.0-1.0,
|
||||
"reasoning": "Brief explanation"
|
||||
}
|
||||
"""
|
||||
|
||||
|
||||
class VoicemailDetector:
|
||||
"""
|
||||
Autonomous voicemail detection system that operates independently of the main pipeline.
|
||||
"""
|
||||
|
||||
def __init__(self, detection_duration: float = 15.0, workflow_run_id: int = None):
|
||||
self.detection_duration = detection_duration
|
||||
self.audio_buffer = bytearray()
|
||||
self.is_detecting = False
|
||||
self.workflow_run_id = workflow_run_id
|
||||
self._langfuse_client = get_client()
|
||||
|
||||
# We will set the sample rate when we receive the audio packet
|
||||
self._sample_rate = None
|
||||
|
||||
# Task management
|
||||
self._detection_task: Optional[asyncio.Task] = None
|
||||
self._is_cancelled = False
|
||||
self._engine: Optional[PipecatEngine] = None
|
||||
|
||||
# Event for audio collection completion
|
||||
self._audio_collected_event = asyncio.Event()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Utility helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _current_duration_seconds(self) -> float:
|
||||
"""Return the duration (in seconds) of the audio currently in the buffer."""
|
||||
if self._sample_rate:
|
||||
return len(self.audio_buffer) / (self._sample_rate * 2)
|
||||
return 0.0
|
||||
|
||||
async def handle_audio_data(
|
||||
self, processor, pcm: bytes, sample_rate: int, num_channels: int
|
||||
):
|
||||
"""Handle incoming audio data without affecting pipeline."""
|
||||
if not self.is_detecting or self._is_cancelled:
|
||||
return
|
||||
|
||||
# Store the actual sample rate from the first audio packet
|
||||
if self._sample_rate is None:
|
||||
self._sample_rate = sample_rate
|
||||
logger.debug(f"Voicemail detector using sample rate: {sample_rate}")
|
||||
|
||||
# Add to buffer without resampling
|
||||
self.audio_buffer.extend(pcm)
|
||||
|
||||
# Check if we've collected enough audio
|
||||
current_duration = self._current_duration_seconds()
|
||||
if current_duration >= self.detection_duration:
|
||||
self._audio_collected_event.set()
|
||||
|
||||
async def start_detection(self, engine: PipecatEngine):
|
||||
"""Start voicemail detection process."""
|
||||
logger.info("Starting voicemail detection")
|
||||
self.is_detecting = True
|
||||
self._is_cancelled = False
|
||||
self._engine = engine
|
||||
self._audio_collected_event.clear()
|
||||
|
||||
# Start detection in background
|
||||
self._detection_task = asyncio.create_task(self._run_detection_with_timeout())
|
||||
|
||||
async def stop_detection(self):
|
||||
"""Stop detection immediately (called on disconnect)."""
|
||||
logger.info("Stopping voicemail detection due to disconnect")
|
||||
self._is_cancelled = True
|
||||
self.is_detecting = False
|
||||
|
||||
# Set the event to unblock any waiting tasks
|
||||
self._audio_collected_event.set()
|
||||
|
||||
# Cancel ongoing detection task
|
||||
if self._detection_task and not self._detection_task.done():
|
||||
self._detection_task.cancel()
|
||||
|
||||
# Clear audio buffer
|
||||
self.audio_buffer.clear()
|
||||
|
||||
# Wait for tasks to complete cancellation
|
||||
if self._detection_task:
|
||||
try:
|
||||
await self._detection_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
async def _run_detection_with_timeout(self):
|
||||
"""Run detection with proper timeout and cancellation handling."""
|
||||
try:
|
||||
# Wait for audio collection or cancellation directly
|
||||
await self._wait_for_audio_collection()
|
||||
|
||||
# Check if cancelled during collection
|
||||
if self._is_cancelled:
|
||||
logger.info("Detection cancelled during audio collection")
|
||||
return
|
||||
|
||||
# Process detection
|
||||
await self._process_detection()
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.info("Voicemail detection task cancelled")
|
||||
except Exception as e:
|
||||
logger.error(f"Error in voicemail detection: {e}")
|
||||
finally:
|
||||
self.is_detecting = False
|
||||
|
||||
async def _wait_for_audio_collection(self):
|
||||
"""Wait for audio buffer to fill or timeout."""
|
||||
try:
|
||||
# Wait for either audio collection completion or timeout
|
||||
await asyncio.wait_for(
|
||||
self._audio_collected_event.wait(),
|
||||
timeout=self.detection_duration + 2.0,
|
||||
)
|
||||
|
||||
if not self._is_cancelled:
|
||||
current_duration = self._current_duration_seconds()
|
||||
logger.info(
|
||||
f"Collected {current_duration:.1f}s of audio for voicemail detection (sample rate: {self._sample_rate}Hz)"
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
if not self._is_cancelled:
|
||||
current_duration = self._current_duration_seconds()
|
||||
logger.warning("Audio collection timeout exceeded")
|
||||
logger.info(
|
||||
f"Proceeding with {current_duration:.1f}s of audio (sample rate: {self._sample_rate}Hz)"
|
||||
)
|
||||
|
||||
async def _process_detection(self):
|
||||
"""Process the collected audio to detect voicemail."""
|
||||
if not self.audio_buffer or not self._engine:
|
||||
logger.warning("No audio buffer or engine available for detection")
|
||||
return
|
||||
|
||||
try:
|
||||
# Convert PCM to WAV once for both transcription and storage
|
||||
wav_data = self._create_wav_from_pcm(bytes(self.audio_buffer))
|
||||
|
||||
# Transcribe audio
|
||||
logger.info("Transcribing audio for voicemail detection")
|
||||
transcript = await self._transcribe_audio(wav_data)
|
||||
|
||||
if not transcript:
|
||||
logger.warning("No transcript obtained from audio")
|
||||
|
||||
# Still upload the raw recording so data pipeline has it
|
||||
if self.workflow_run_id:
|
||||
await self._save_voicemail_audio(wav_data, 0.0, False)
|
||||
|
||||
return
|
||||
|
||||
logger.info(
|
||||
f"Voicemail detection transcript obtained: {transcript[:100]}..."
|
||||
)
|
||||
|
||||
# Analyze transcript
|
||||
result = await self._analyze_transcript(transcript)
|
||||
|
||||
# Extract common fields
|
||||
confidence = result.get("confidence", 0.0)
|
||||
reasoning = result.get("reasoning", "No reasoning provided")
|
||||
|
||||
# Save voicemail audio to S3 once for data pipeline (include duration in filename)
|
||||
s3_path = None
|
||||
if self.workflow_run_id:
|
||||
s3_path = await self._save_voicemail_audio(
|
||||
wav_data, confidence, result.get("is_voicemail")
|
||||
)
|
||||
|
||||
# Take action based on result
|
||||
if result.get("is_voicemail", False):
|
||||
logger.info(
|
||||
f"Voicemail detected with confidence {confidence}: {reasoning}"
|
||||
)
|
||||
|
||||
# Update workflow run with voicemail tags
|
||||
if self.workflow_run_id:
|
||||
# Fetch the workflow run from database
|
||||
workflow_run = await db_client.get_workflow_run_by_id(
|
||||
self.workflow_run_id
|
||||
)
|
||||
if workflow_run:
|
||||
call_tags = workflow_run.gathered_context.get("call_tags", [])
|
||||
call_tags.extend(["voicemail_detected", "not_connected"])
|
||||
|
||||
await db_client.update_workflow_run(
|
||||
run_id=workflow_run.id,
|
||||
gathered_context={
|
||||
"call_tags": call_tags,
|
||||
"voicemail_transcript": transcript,
|
||||
"voicemail_confidence": confidence,
|
||||
},
|
||||
)
|
||||
|
||||
# Send end task frame with metadata (including optional S3 path)
|
||||
await self._engine.send_end_task_frame(
|
||||
reason=EndTaskReason.VOICEMAIL_DETECTED.value,
|
||||
additional_metadata={
|
||||
"voicemail_transcript": transcript,
|
||||
"voicemail_confidence": confidence,
|
||||
"voicemail_reasoning": reasoning,
|
||||
"voicemail_detection_duration": self.detection_duration,
|
||||
"voicemail_audio_s3_path": s3_path,
|
||||
},
|
||||
abort_immediately=True,
|
||||
)
|
||||
else:
|
||||
logger.info("No voicemail detected, continuing normal conversation")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing voicemail detection: {e}")
|
||||
|
||||
async def _transcribe_audio(self, wav_data: bytes) -> str:
|
||||
"""Transcribe audio using OpenAI API directly.
|
||||
|
||||
Args:
|
||||
wav_data: WAV formatted audio data
|
||||
"""
|
||||
client = AsyncOpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
|
||||
|
||||
# Direct API call - no pipeline involvement
|
||||
response = await client.audio.transcriptions.create(
|
||||
file=("audio.wav", wav_data, "audio/wav"),
|
||||
model="whisper-1", # Using whisper-1 as it's more stable for transcription
|
||||
language="en",
|
||||
temperature=0.0,
|
||||
)
|
||||
|
||||
return response.text.strip()
|
||||
|
||||
def _create_wav_from_pcm(self, pcm_data: bytes) -> bytes:
|
||||
"""Convert raw PCM data to WAV format."""
|
||||
wav_buffer = io.BytesIO()
|
||||
with wave.open(wav_buffer, "wb") as wav_file:
|
||||
wav_file.setnchannels(1) # Mono
|
||||
wav_file.setsampwidth(2) # 16-bit
|
||||
wav_file.setframerate(self._sample_rate)
|
||||
wav_file.writeframes(pcm_data)
|
||||
|
||||
wav_buffer.seek(0)
|
||||
return wav_buffer.read()
|
||||
|
||||
async def _analyze_transcript(self, transcript: str) -> dict:
|
||||
"""Analyze transcript using independent OpenAI client."""
|
||||
# Capture the current turn context for proper span nesting
|
||||
parent_context = get_current_turn_context()
|
||||
|
||||
client = AsyncOpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
|
||||
|
||||
langfuse_prompt = None
|
||||
try:
|
||||
langfuse_prompt = self._langfuse_client.get_prompt(
|
||||
"production/voicemail_detection"
|
||||
)
|
||||
prompt = langfuse_prompt.compile(transcript=transcript)
|
||||
except Exception as e:
|
||||
logger.warning(f"Error getting prompt from Langfuse: {e}")
|
||||
prompt = DEFAULT_VOICEMAIL_PROMPT.replace("{transcript}", transcript)
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": prompt,
|
||||
}
|
||||
]
|
||||
|
||||
# When we have a parent OpenTelemetry context, we need to activate it
|
||||
# so that Langfuse's OTEL tracer will automatically pick it up
|
||||
if parent_context and is_tracing_enabled():
|
||||
# Activate the parent context for this scope
|
||||
token = otel_context.attach(parent_context)
|
||||
try:
|
||||
# Start Langfuse generation - it will automatically use the active OTEL context
|
||||
langfuse_generation = None
|
||||
try:
|
||||
langfuse_generation = self._langfuse_client.start_generation(
|
||||
name="voicemail_detection",
|
||||
model="gpt-4o",
|
||||
input=messages,
|
||||
metadata={
|
||||
"temperature": 0.0,
|
||||
"detection_duration": self.detection_duration,
|
||||
"transcript_length": len(transcript),
|
||||
},
|
||||
prompt=langfuse_prompt,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Error starting Langfuse generation: {e}")
|
||||
|
||||
# Direct API call
|
||||
response = await client.chat.completions.create(
|
||||
model="gpt-4o",
|
||||
messages=messages,
|
||||
temperature=0.0,
|
||||
response_format={"type": "json_object"},
|
||||
)
|
||||
|
||||
llm_response = response.choices[0].message.content
|
||||
|
||||
# Update and end Langfuse generation
|
||||
if langfuse_generation:
|
||||
try:
|
||||
langfuse_generation.update(
|
||||
output=llm_response,
|
||||
usage_details={
|
||||
"prompt_tokens": response.usage.prompt_tokens
|
||||
if response.usage
|
||||
else 0,
|
||||
"completion_tokens": response.usage.completion_tokens
|
||||
if response.usage
|
||||
else 0,
|
||||
"total_tokens": response.usage.total_tokens
|
||||
if response.usage
|
||||
else 0,
|
||||
},
|
||||
)
|
||||
langfuse_generation.end()
|
||||
except Exception as e:
|
||||
logger.warning(f"Error updating Langfuse generation: {e}")
|
||||
finally:
|
||||
# Detach the context
|
||||
otel_context.detach(token)
|
||||
else:
|
||||
# No parent context or tracing disabled - just make the API call
|
||||
response = await client.chat.completions.create(
|
||||
model="gpt-4o",
|
||||
messages=messages,
|
||||
temperature=0.0,
|
||||
response_format={"type": "json_object"},
|
||||
)
|
||||
llm_response = response.choices[0].message.content
|
||||
|
||||
# Parse response
|
||||
try:
|
||||
return json.loads(llm_response)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning("Invalid JSON response from voicemail detection")
|
||||
return {
|
||||
"is_voicemail": False,
|
||||
"confidence": 0.0,
|
||||
"reasoning": "Invalid response",
|
||||
}
|
||||
|
||||
async def _save_voicemail_audio(
|
||||
self, wav_data: bytes, confidence: float, is_voicemail: bool
|
||||
) -> Optional[str]:
|
||||
"""Save voicemail audio to temp file and enqueue task to upload to S3.
|
||||
|
||||
Args:
|
||||
wav_data: WAV formatted audio data
|
||||
confidence: Detection confidence score
|
||||
is_voicemail: Whether it was detected as voicemail
|
||||
|
||||
Returns:
|
||||
The expected S3 object key (bucket path). The actual upload happens asynchronously.
|
||||
"""
|
||||
try:
|
||||
# Create filename with prediction, confidence and duration
|
||||
duration_seconds = self._current_duration_seconds()
|
||||
prediction = "voicemail" if is_voicemail else "not_voicemail"
|
||||
confidence_int = int(confidence * 100)
|
||||
duration_int = int(duration_seconds)
|
||||
s3_key = f"voicemail_detections/{self.workflow_run_id}_{prediction}_{confidence_int}_{duration_int}.wav"
|
||||
|
||||
# Write WAV data to temp file - DO NOT delete it here, the async task will handle cleanup
|
||||
with tempfile.NamedTemporaryFile(
|
||||
suffix=".wav",
|
||||
delete=False, # Important: don't delete immediately
|
||||
prefix=f"voicemail_{self.workflow_run_id}_",
|
||||
) as tmp_file:
|
||||
tmp_file.write(wav_data)
|
||||
tmp_file.flush()
|
||||
temp_file_path = tmp_file.name
|
||||
|
||||
logger.info(f"Saved voicemail audio to temp file: {temp_file_path}")
|
||||
|
||||
# Enqueue async task to upload to S3
|
||||
await enqueue_job(
|
||||
FunctionNames.UPLOAD_VOICEMAIL_AUDIO_TO_S3,
|
||||
self.workflow_run_id,
|
||||
temp_file_path,
|
||||
s3_key,
|
||||
)
|
||||
|
||||
logger.info(f"Enqueued voicemail audio upload task for: {s3_key}")
|
||||
return s3_key
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save voicemail audio: {e}")
|
||||
# Clean up temp file if task enqueue failed
|
||||
if "temp_file_path" in locals() and os.path.exists(temp_file_path):
|
||||
try:
|
||||
os.remove(temp_file_path)
|
||||
except Exception as cleanup_error:
|
||||
logger.warning(
|
||||
f"Failed to cleanup temp file after error: {cleanup_error}"
|
||||
)
|
||||
return None
|
||||
0
api/services/workflow/test/__init__.py
Normal file
0
api/services/workflow/test/__init__.py
Normal file
164
api/services/workflow/test/definitions/rf-1.json
Normal file
164
api/services/workflow/test/definitions/rf-1.json
Normal file
|
|
@ -0,0 +1,164 @@
|
|||
{
|
||||
"nodes": [
|
||||
{
|
||||
"id": "915",
|
||||
"type": "agentNode",
|
||||
"position": {
|
||||
"x": 633,
|
||||
"y": 324
|
||||
},
|
||||
"data": {
|
||||
"prompt": "You are a voice agent whose mode of speaking is voice. Ask the user whether they want to talk to a sales guy or a customer service agent",
|
||||
"name": "Agent"
|
||||
},
|
||||
"measured": {
|
||||
"width": 300,
|
||||
"height": 100
|
||||
},
|
||||
"selected": false,
|
||||
"dragging": false
|
||||
},
|
||||
{
|
||||
"id": "7598",
|
||||
"type": "agentNode",
|
||||
"position": {
|
||||
"x": 460.1247806640531,
|
||||
"y": 610.3714977079578
|
||||
},
|
||||
"data": {
|
||||
"prompt": "You are a customer service agent whose mode of communication with the user is voice. Tell them that someone from our team will reach out to them soon",
|
||||
"name": "Agent"
|
||||
},
|
||||
"measured": {
|
||||
"width": 300,
|
||||
"height": 100
|
||||
},
|
||||
"selected": false,
|
||||
"dragging": false
|
||||
},
|
||||
{
|
||||
"id": "6919",
|
||||
"type": "agentNode",
|
||||
"position": {
|
||||
"x": 914.666735413607,
|
||||
"y": 642.9800281289787
|
||||
},
|
||||
"data": {
|
||||
"prompt": "You are a sales representative whose mode of communication with the user is voice. Tell the user that someone from our team will reach out to you soon",
|
||||
"name": "Agent"
|
||||
},
|
||||
"measured": {
|
||||
"width": 300,
|
||||
"height": 100
|
||||
},
|
||||
"selected": false,
|
||||
"dragging": false
|
||||
},
|
||||
{
|
||||
"id": "6581",
|
||||
"type": "startCall",
|
||||
"position": {
|
||||
"x": 648,
|
||||
"y": 35
|
||||
},
|
||||
"data": {
|
||||
"prompt": "Hello, I am Abhishek from Dograh. ",
|
||||
"is_static": true,
|
||||
"name": "Start Call",
|
||||
"is_start": true
|
||||
},
|
||||
"measured": {
|
||||
"width": 300,
|
||||
"height": 100
|
||||
},
|
||||
"selected": false,
|
||||
"dragging": false
|
||||
},
|
||||
{
|
||||
"id": "1802",
|
||||
"type": "endCall",
|
||||
"position": {
|
||||
"x": 666.7733431033548,
|
||||
"y": 987.4345801025363
|
||||
},
|
||||
"data": {
|
||||
"prompt": "Thank you for calling Dograh. Have a great day!",
|
||||
"is_static": true,
|
||||
"name": "End Call"
|
||||
},
|
||||
"measured": {
|
||||
"width": 300,
|
||||
"height": 100
|
||||
},
|
||||
"selected": false,
|
||||
"dragging": false
|
||||
}
|
||||
],
|
||||
"edges": [
|
||||
{
|
||||
"animated": true,
|
||||
"type": "custom",
|
||||
"source": "915",
|
||||
"target": "7598",
|
||||
"id": "xy-edge__915-7598",
|
||||
"selected": false,
|
||||
"data": {
|
||||
"condition": "The customer wants to talk to a customer service agent",
|
||||
"label": "customer service agent"
|
||||
}
|
||||
},
|
||||
{
|
||||
"animated": true,
|
||||
"type": "custom",
|
||||
"source": "915",
|
||||
"target": "6919",
|
||||
"id": "xy-edge__915-6919",
|
||||
"selected": false,
|
||||
"data": {
|
||||
"condition": "customer wants to talk to a sales representative",
|
||||
"label": "sales representative"
|
||||
}
|
||||
},
|
||||
{
|
||||
"animated": true,
|
||||
"type": "custom",
|
||||
"source": "6581",
|
||||
"target": "915",
|
||||
"id": "xy-edge__6581-915",
|
||||
"selected": false,
|
||||
"data": {
|
||||
"condition": "Always take this route",
|
||||
"label": "Always take this route"
|
||||
}
|
||||
},
|
||||
{
|
||||
"animated": true,
|
||||
"type": "custom",
|
||||
"source": "7598",
|
||||
"target": "1802",
|
||||
"id": "xy-edge__7598-1802",
|
||||
"selected": false,
|
||||
"data": {
|
||||
"condition": "end call",
|
||||
"label": "end call"
|
||||
}
|
||||
},
|
||||
{
|
||||
"animated": true,
|
||||
"type": "custom",
|
||||
"source": "6919",
|
||||
"target": "1802",
|
||||
"id": "xy-edge__6919-1802",
|
||||
"selected": false,
|
||||
"data": {
|
||||
"condition": "end call",
|
||||
"label": "end call"
|
||||
}
|
||||
}
|
||||
],
|
||||
"viewport": {
|
||||
"x": 0,
|
||||
"y": 0,
|
||||
"zoom": 1
|
||||
}
|
||||
}
|
||||
192
api/services/workflow/test/test_aggregation_fix.py
Normal file
192
api/services/workflow/test/test_aggregation_fix.py
Normal file
|
|
@ -0,0 +1,192 @@
|
|||
from unittest.mock import Mock
|
||||
|
||||
from api.services.workflow.pipecat_engine_callbacks import (
|
||||
create_aggregation_correction_callback,
|
||||
)
|
||||
|
||||
|
||||
def test_aggregation_fixer():
|
||||
"""Validate the aggregation correction algorithm using a helper that
|
||||
creates a fresh callback for every (reference, corrupted) pair.
|
||||
|
||||
The production callback now needs a PipecatEngine instance with the
|
||||
`_current_llm_reference_text` set. For test-friendliness we mock a bare
|
||||
object providing just that attribute for each assertion so the original
|
||||
two-argument test cases remain unchanged.
|
||||
"""
|
||||
|
||||
def fixer(reference: str, corrupted: str) -> str: # noqa: D401
|
||||
mock_engine = Mock()
|
||||
mock_engine._current_llm_reference_text = reference
|
||||
return create_aggregation_correction_callback(mock_engine)(corrupted)
|
||||
|
||||
##### Trailing extra Chars #####
|
||||
|
||||
assert (
|
||||
fixer(
|
||||
"Good Morning Mr NARGES , My name is Alex and I am calling you from Consumer Services. The reason of my call today, as I can see in our records that you are making your monthly credit card payments on time, but you STILL carry a balance of over 7 thousand dollars, right?",
|
||||
"My name is Alex and I am calling you from Cons umer Services. The reason of my call today, as I can see in our records that you are making your monthly credit card payments on time, but you STILL carry a balance of over 7 thousand dollars, right?",
|
||||
)
|
||||
== "My name is Alex and I am calling you from Consumer Services. The reason of my call today, as I can see in our records that you are making your monthly credit card payments on time, but you STILL carry a balance of over 7 thousand dollars, right?"
|
||||
), "leading_whole_sentence"
|
||||
|
||||
# Whole sentences
|
||||
assert (
|
||||
fixer(
|
||||
"Good Morning Mr NARGES , My name is Alex and I am calling you from Consumer Services. The reason of my call today, as I can see in our records that you are making your monthly credit card payments on time, but you STILL carry a balance of over 7 thousand dollars, right?",
|
||||
"Good Morning Mr NAR GES , My name is Alex and I am calling you from Cons umer Services. The reason of my call today, as I can see in our records that you are making your monthly credit card payments on time, but you STILL carry a balance of over 7 thousand dollars, right?",
|
||||
)
|
||||
== "Good Morning Mr NARGES , My name is Alex and I am calling you from Consumer Services. The reason of my call today, as I can see in our records that you are making your monthly credit card payments on time, but you STILL carry a balance of over 7 thousand dollars, right?"
|
||||
), "whole_sentences"
|
||||
|
||||
# With a period in the end
|
||||
assert (
|
||||
fixer(
|
||||
"Good Morning Mr NARGES , My name is Alex and I am calling you from Consumer Services. The reason of my call today, as I can see in our records that you are making your monthly credit card payments on time, but you STILL carry a balance of over 7 thousand dollars, right?",
|
||||
"Good Morning Mr NAR GES , My name is Alex and I am calling you from Cons umer Services.",
|
||||
)
|
||||
== "Good Morning Mr NARGES , My name is Alex and I am calling you from Consumer Services."
|
||||
), "period_end"
|
||||
|
||||
# without a period in the end
|
||||
assert (
|
||||
fixer(
|
||||
"Good Morning Mr NARGES , My name is Alex and I am calling you from Consumer Services. The reason of my call today, as I can see in our records that you are making your monthly credit card payments on time, but you STILL carry a balance of over 7 thousand dollars, right?",
|
||||
"Good Morning Mr NAR GES , My name is Alex and I am calling you from Cons umer Services",
|
||||
)
|
||||
== "Good Morning Mr NARGES , My name is Alex and I am calling you from Consumer Services"
|
||||
), "without_period_end"
|
||||
|
||||
# Extra space in the end
|
||||
assert (
|
||||
fixer(
|
||||
"Good Morning Mr NARGES , My name is Alex and I am calling you from Consumer Services. The reason of my call today, as I can see in our records that you are making your monthly credit card payments on time, but you STILL carry a balance of over 7 thousand dollars, right?",
|
||||
"Good Morning Mr NAR GES , My name is Alex and I am calling you from Cons umer Services ",
|
||||
)
|
||||
== "Good Morning Mr NARGES , My name is Alex and I am calling you from Consumer Services"
|
||||
), "extra_space"
|
||||
|
||||
# Multiple spaces in corruption
|
||||
assert (
|
||||
fixer(
|
||||
"Good Morning Mr NARGES , My name is Alex and I am calling you from Consumer Services. The reason of my call today, as I can see in our records that you are making your monthly credit card payments on time, but you STILL carry a balance of over 7 thousand dollars, right?",
|
||||
"Good Morning Mr NAR GES , My name is Alex and I am calling you from Cons umer Servi ces ",
|
||||
)
|
||||
== "Good Morning Mr NARGES , My name is Alex and I am calling you from Consumer Services"
|
||||
), "multiple_space"
|
||||
|
||||
# Multiple spaces in corruption ending in a whitespace
|
||||
assert (
|
||||
fixer(
|
||||
"Good Morning Mr NARGES , My name is Alex and I am calling you from Consumer Services. The reason of my call today, as I can see in our records that you are making your monthly credit card payments on time, but you STILL carry a balance of over 7 thousand dollars, right?",
|
||||
"Good Morning Mr NAR GES , My name is Alex and I am calling you from Cons umer Servi ces. ",
|
||||
)
|
||||
== "Good Morning Mr NARGES , My name is Alex and I am calling you from Consumer Services. "
|
||||
), "multiple_space_end_ws"
|
||||
|
||||
##### Leading extra Chars #####
|
||||
|
||||
# Whole sentences
|
||||
assert (
|
||||
fixer(
|
||||
"Good Morning Mr NARGES , My name is Alex and I am calling you from Consumer Services. The reason of my call today, as I can see in our records that you are making your monthly credit card payments on time, but you STILL carry a balance of over 7 thousand dollars, right?",
|
||||
"My name is Alex and I am calling you from Cons umer Services. The reason of my call today, as I can see in our records that you are making your monthly credit card payments on time, but you STILL carry a balance of over 7 thousand dollars, right?",
|
||||
)
|
||||
== "My name is Alex and I am calling you from Consumer Services. The reason of my call today, as I can see in our records that you are making your monthly credit card payments on time, but you STILL carry a balance of over 7 thousand dollars, right?"
|
||||
), "leading_whole_sentence"
|
||||
|
||||
# With a period in the end
|
||||
assert (
|
||||
fixer(
|
||||
"Good Morning Mr NARGES , My name is Alex and I am calling you from Consumer Services. The reason of my call today, as I can see in our records that you are making your monthly credit card payments on time, but you STILL carry a balance of over 7 thousand dollars, right?",
|
||||
"My name is Alex and I am calling you from Cons umer Services.",
|
||||
)
|
||||
== "My name is Alex and I am calling you from Consumer Services."
|
||||
), "leading_period_end"
|
||||
|
||||
# without a period in the end
|
||||
assert (
|
||||
fixer(
|
||||
"Good Morning Mr NARGES , My name is Alex and I am calling you from Consumer Services. The reason of my call today, as I can see in our records that you are making your monthly credit card payments on time, but you STILL carry a balance of over 7 thousand dollars, right?",
|
||||
"My name is Alex and I am calling you from Cons umer Services",
|
||||
)
|
||||
== "My name is Alex and I am calling you from Consumer Services"
|
||||
), "leading_without_period_end"
|
||||
|
||||
# Extra space in the end
|
||||
assert (
|
||||
fixer(
|
||||
"Good Morning Mr NARGES , My name is Alex and I am calling you from Consumer Services. The reason of my call today, as I can see in our records that you are making your monthly credit card payments on time, but you STILL carry a balance of over 7 thousand dollars, right?",
|
||||
"My name is Alex and I am calling you from Cons umer Services ",
|
||||
)
|
||||
== "My name is Alex and I am calling you from Consumer Services"
|
||||
), "leading_extra_space"
|
||||
|
||||
# Multiple spaces in corruption
|
||||
assert (
|
||||
fixer(
|
||||
"Good Morning Mr NARGES , My name is Alex and I am calling you from Consumer Services. The reason of my call today, as I can see in our records that you are making your monthly credit card payments on time, but you STILL carry a balance of over 7 thousand dollars, right?",
|
||||
"My name is Alex and I am calling you from Cons umer Servi ces ",
|
||||
)
|
||||
== "My name is Alex and I am calling you from Consumer Services"
|
||||
), "leading_multiple_space"
|
||||
|
||||
# Multiple spaces in corruption ending in a whitespace
|
||||
assert (
|
||||
fixer(
|
||||
"Good Morning Mr NARGES , My name is Alex and I am calling you from Consumer Services. The reason of my call today, as I can see in our records that you are making your monthly credit card payments on time, but you STILL carry a balance of over 7 thousand dollars, right?",
|
||||
"My name is Alex and I am calling you from Cons umer Servi ces. ",
|
||||
)
|
||||
== "My name is Alex and I am calling you from Consumer Services. "
|
||||
), "leading_multiple_space_end_ws"
|
||||
|
||||
# Whitespace
|
||||
assert fixer("", "") == ""
|
||||
|
||||
# Missing reference
|
||||
assert (
|
||||
fixer("", "My name is Alex and I am calling you from Cons umer Servi ces.")
|
||||
== "My name is Alex and I am calling you from Cons umer Servi ces."
|
||||
), "missing_reference"
|
||||
|
||||
# Smaller reference
|
||||
assert (
|
||||
fixer(
|
||||
"My name is Alex",
|
||||
"My name is Alex and I am calling you from Cons umer Servi ces.",
|
||||
)
|
||||
== "My name is Alex and I am calling you from Cons umer Servi ces."
|
||||
), "smaller_reference"
|
||||
|
||||
# Unrelated reference
|
||||
assert (
|
||||
fixer(
|
||||
"Hello Hello",
|
||||
"My name is Alex and I am calling you from Cons umer Servi ces.",
|
||||
)
|
||||
== "My name is Alex and I am calling you from Cons umer Servi ces."
|
||||
), "unrelated_reference"
|
||||
|
||||
|
||||
def test_create_aggregation_correction_callback():
|
||||
"""Test the new aggregation correction callback creator."""
|
||||
# Mock engine with reference text
|
||||
mock_engine = Mock()
|
||||
mock_engine._current_llm_reference_text = "Good Morning Mr NARGES, My name is Alex and I am calling you from Consumer Services."
|
||||
|
||||
# Create callback
|
||||
callback = create_aggregation_correction_callback(mock_engine)
|
||||
|
||||
# Test correction
|
||||
corrected = callback(
|
||||
"Good Morning Mr NAR GES, My name is Alex and I am calling you from Cons umer Services."
|
||||
)
|
||||
assert (
|
||||
corrected
|
||||
== "Good Morning Mr NARGES, My name is Alex and I am calling you from Consumer Services."
|
||||
)
|
||||
|
||||
# Test with no reference text
|
||||
mock_engine._current_llm_reference_text = ""
|
||||
corrected = callback("Some corrupted text")
|
||||
assert corrected == "Some corrupted text" # Should return as-is when no reference
|
||||
128
api/services/workflow/test/test_aggregation_integration.py
Normal file
128
api/services/workflow/test/test_aggregation_integration.py
Normal file
|
|
@ -0,0 +1,128 @@
|
|||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
from pipecat.processors.aggregators.llm_response import LLMAssistantAggregatorParams
|
||||
from pipecat.services.openai.llm import OpenAILLMContext
|
||||
|
||||
from api.services.workflow.pipecat_engine import PipecatEngine
|
||||
from api.services.workflow.pipecat_engine_callbacks import (
|
||||
create_generation_started_callback,
|
||||
)
|
||||
|
||||
|
||||
class TestAggregationIntegration:
|
||||
"""Integration tests for the TTS aggregation correction flow."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_engine_reference_text_tracking(self):
|
||||
"""Test that the engine properly tracks LLM reference text."""
|
||||
# Create mock dependencies
|
||||
mock_task = Mock()
|
||||
mock_llm = Mock()
|
||||
mock_context = Mock(spec=OpenAILLMContext)
|
||||
mock_tts = Mock()
|
||||
mock_workflow = Mock()
|
||||
mock_workflow.start_node_id = "start"
|
||||
mock_workflow.nodes = {
|
||||
"start": Mock(is_start=True, is_static=True, is_end=False, out_edges=[])
|
||||
}
|
||||
|
||||
# Create engine
|
||||
engine = PipecatEngine(
|
||||
task=mock_task,
|
||||
llm=mock_llm,
|
||||
context=mock_context,
|
||||
tts=mock_tts,
|
||||
workflow=mock_workflow,
|
||||
call_context_vars={},
|
||||
workflow_run_id=1,
|
||||
)
|
||||
|
||||
# Test initial state
|
||||
assert engine._current_llm_reference_text == ""
|
||||
|
||||
# Test accumulating LLM text
|
||||
await engine.handle_llm_text_frame("Hello ")
|
||||
assert engine._current_llm_reference_text == "Hello "
|
||||
|
||||
await engine.handle_llm_text_frame("world!")
|
||||
assert engine._current_llm_reference_text == "Hello world!"
|
||||
|
||||
# Test generation started callback clears reference text
|
||||
callback = create_generation_started_callback(engine)
|
||||
await callback()
|
||||
assert engine._current_llm_reference_text == ""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aggregation_correction_callback_creation(self):
|
||||
"""Test creating the aggregation correction callback."""
|
||||
# Create mock engine
|
||||
mock_task = Mock()
|
||||
mock_llm = Mock()
|
||||
mock_context = Mock(spec=OpenAILLMContext)
|
||||
mock_workflow = Mock()
|
||||
|
||||
engine = PipecatEngine(
|
||||
task=mock_task,
|
||||
llm=mock_llm,
|
||||
context=mock_context,
|
||||
workflow=mock_workflow,
|
||||
call_context_vars={},
|
||||
workflow_run_id=1,
|
||||
)
|
||||
|
||||
# Set reference text
|
||||
engine._current_llm_reference_text = "Hello, world! How are you?"
|
||||
|
||||
# Create correction callback
|
||||
callback = engine.create_aggregation_correction_callback()
|
||||
|
||||
# Test correction - note that trailing punctuation might be stripped if not in corrupted text
|
||||
corrected = callback("Hello world How are you")
|
||||
assert corrected == "Hello, world! How are you"
|
||||
|
||||
def test_llm_assistant_aggregator_params_with_callback(self):
|
||||
"""Test that LLMAssistantAggregatorParams accepts correction callback."""
|
||||
|
||||
def mock_callback(text: str) -> str:
|
||||
return text.upper()
|
||||
|
||||
params = LLMAssistantAggregatorParams(
|
||||
expect_stripped_words=True, correct_aggregation_callback=mock_callback
|
||||
)
|
||||
|
||||
assert params.expect_stripped_words is True
|
||||
assert params.correct_aggregation_callback is not None
|
||||
assert params.correct_aggregation_callback("hello") == "HELLO"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pipeline_callbacks_processor_llm_text_frame(self):
|
||||
"""Test that PipelineEngineCallbacksProcessor handles LLMTextFrame."""
|
||||
from pipecat.frames.frames import LLMTextFrame
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
|
||||
from api.services.pipecat.pipeline_engine_callbacks_processor import (
|
||||
PipelineEngineCallbacksProcessor,
|
||||
)
|
||||
|
||||
# Track callback invocations
|
||||
callback_invoked = False
|
||||
callback_text = None
|
||||
|
||||
async def mock_llm_text_callback(text: str):
|
||||
nonlocal callback_invoked, callback_text
|
||||
callback_invoked = True
|
||||
callback_text = text
|
||||
|
||||
# Create processor with callback
|
||||
processor = PipelineEngineCallbacksProcessor(
|
||||
llm_text_frame_callback=mock_llm_text_callback
|
||||
)
|
||||
|
||||
# Process LLMTextFrame
|
||||
frame = LLMTextFrame(text="Hello world")
|
||||
await processor.process_frame(frame, FrameDirection.DOWNSTREAM)
|
||||
|
||||
# Verify callback was invoked
|
||||
assert callback_invoked is True
|
||||
assert callback_text == "Hello world"
|
||||
31
api/services/workflow/test/test_cost_calculator.py
Normal file
31
api/services/workflow/test/test_cost_calculator.py
Normal file
|
|
@ -0,0 +1,31 @@
|
|||
from api.services.pricing.cost_calculator import cost_calculator
|
||||
|
||||
|
||||
def test_cost_calculator():
|
||||
"""Test function to verify cost calculation works"""
|
||||
sample_usage = {
|
||||
"llm": {
|
||||
"OpenAILLMService#0|||gpt-4.1-mini": {
|
||||
"prompt_tokens": 45380,
|
||||
"completion_tokens": 496,
|
||||
"total_tokens": 45876,
|
||||
"cache_read_input_tokens": 0,
|
||||
"cache_creation_input_tokens": 0,
|
||||
}
|
||||
},
|
||||
"tts": {"ElevenLabsTTSService#0|||eleven_flash_v2_5": 2399},
|
||||
"stt": {"DeepgramSTTService#0|||nova-3-general": 177.21536946296692},
|
||||
"call_duration_seconds": 179,
|
||||
}
|
||||
|
||||
result = cost_calculator.calculate_total_cost(sample_usage)
|
||||
assert result["llm_cost"] == 45380 * 0.40 / 1_000_000 + 496 * 1.60 / 1_000_000
|
||||
assert result["tts_cost"] == 2399 * 0.0256 / 1_000
|
||||
assert result["stt_cost"] == 177.21536946296692 / 60 * 0.0077
|
||||
assert (
|
||||
abs(
|
||||
result["total"]
|
||||
- (result["llm_cost"] + result["tts_cost"] + result["stt_cost"])
|
||||
)
|
||||
< 1e-10
|
||||
)
|
||||
11
api/services/workflow/test/test_dto.py
Normal file
11
api/services/workflow/test/test_dto.py
Normal file
|
|
@ -0,0 +1,11 @@
|
|||
import pytest
|
||||
|
||||
from api.services.workflow.dto import ReactFlowDTO
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dto():
|
||||
# assert no exceptions are raised
|
||||
with open("services/workflow/test/definitions/rf-1.json", "r") as f:
|
||||
dto = ReactFlowDTO.model_validate_json(f.read())
|
||||
assert dto is not None
|
||||
159
api/services/workflow/test/test_interruption_correction.py
Normal file
159
api/services/workflow/test/test_interruption_correction.py
Normal file
|
|
@ -0,0 +1,159 @@
|
|||
from unittest.mock import AsyncMock, Mock
|
||||
|
||||
import pytest
|
||||
from pipecat.frames.frames import StartInterruptionFrame
|
||||
from pipecat.processors.aggregators.llm_response import LLMAssistantAggregatorParams
|
||||
from pipecat.services.openai.llm import (
|
||||
OpenAIAssistantContextAggregator,
|
||||
OpenAILLMContext,
|
||||
)
|
||||
|
||||
|
||||
class TestInterruptionCorrection:
|
||||
"""Test that TTS aggregation correction works during interruptions."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openai_interruption_with_correction(self):
|
||||
"""Test OpenAI assistant context aggregator applies correction during interruption."""
|
||||
# Create mock context
|
||||
mock_context = Mock(spec=OpenAILLMContext)
|
||||
mock_context.get_messages.return_value = []
|
||||
mock_context.add_message = Mock()
|
||||
|
||||
# Create correction callback
|
||||
def correction_callback(text: str) -> str:
|
||||
# Simulate fixing corrupted text
|
||||
if text == "Hello world how are you":
|
||||
return "Hello world, how are you"
|
||||
return text
|
||||
|
||||
# Create aggregator with correction callback
|
||||
params = LLMAssistantAggregatorParams(
|
||||
expect_stripped_words=True, correct_aggregation_callback=correction_callback
|
||||
)
|
||||
|
||||
aggregator = OpenAIAssistantContextAggregator(
|
||||
context=mock_context, params=params
|
||||
)
|
||||
|
||||
# Set up aggregation state
|
||||
aggregator._aggregation = "Hello world how are you"
|
||||
aggregator._current_llm_response_id = "test-id"
|
||||
aggregator._response_function_messages = {}
|
||||
aggregator._function_calls_in_progress = {}
|
||||
aggregator._started = 1
|
||||
|
||||
# Mock push_context_frame and reset methods
|
||||
aggregator.push_context_frame = AsyncMock()
|
||||
aggregator.reset = AsyncMock()
|
||||
|
||||
# Process interruption
|
||||
interruption_frame = StartInterruptionFrame()
|
||||
await aggregator._handle_interruptions(interruption_frame)
|
||||
|
||||
# Verify the corrected text was added to context
|
||||
mock_context.add_message.assert_called_once()
|
||||
added_message = mock_context.add_message.call_args[0][0]
|
||||
assert added_message["role"] == "assistant"
|
||||
assert (
|
||||
added_message["content"]
|
||||
== "Hello world, how are you <<interrupted_by_user>>"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_google_interruption_with_correction(self):
|
||||
"""Test Google assistant context aggregator applies correction during interruption."""
|
||||
from pipecat.services.google.llm import (
|
||||
Content,
|
||||
GoogleAssistantContextAggregator,
|
||||
)
|
||||
|
||||
# Create mock context
|
||||
mock_context = Mock(spec=OpenAILLMContext)
|
||||
mock_context.get_messages.return_value = []
|
||||
mock_context.add_message = Mock()
|
||||
|
||||
# Create correction callback
|
||||
def correction_callback(text: str) -> str:
|
||||
# Simulate fixing corrupted text
|
||||
if text == "I am here to help":
|
||||
return "I am here to help"
|
||||
return text
|
||||
|
||||
# Create aggregator with correction callback
|
||||
params = LLMAssistantAggregatorParams(
|
||||
expect_stripped_words=True, correct_aggregation_callback=correction_callback
|
||||
)
|
||||
|
||||
aggregator = GoogleAssistantContextAggregator(
|
||||
context=mock_context, params=params
|
||||
)
|
||||
|
||||
# Set up aggregation state
|
||||
aggregator._aggregation = "I am here to help"
|
||||
aggregator._current_llm_response_id = "test-id"
|
||||
aggregator._response_function_messages = {}
|
||||
aggregator._function_calls_in_progress = {}
|
||||
aggregator._started = 1
|
||||
|
||||
# Mock push_context_frame and reset methods
|
||||
aggregator.push_context_frame = AsyncMock()
|
||||
aggregator.reset = AsyncMock()
|
||||
|
||||
# Process interruption
|
||||
interruption_frame = StartInterruptionFrame()
|
||||
await aggregator._handle_interruptions(interruption_frame)
|
||||
|
||||
# Verify the corrected text was added to context
|
||||
mock_context.add_message.assert_called_once()
|
||||
added_content = mock_context.add_message.call_args[0][0]
|
||||
|
||||
# Google uses Content objects
|
||||
assert isinstance(added_content, Content)
|
||||
assert added_content.role == "model"
|
||||
assert len(added_content.parts) == 1
|
||||
assert (
|
||||
added_content.parts[0].text == "I am here to help <<interrupted_by_user>>"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_interruption_correction_error_handling(self):
|
||||
"""Test that interruption handling continues even if correction callback fails."""
|
||||
# Create mock context
|
||||
mock_context = Mock(spec=OpenAILLMContext)
|
||||
mock_context.get_messages.return_value = []
|
||||
mock_context.add_message = Mock()
|
||||
|
||||
# Create correction callback that raises error
|
||||
def failing_callback(text: str) -> str:
|
||||
raise ValueError("Correction failed")
|
||||
|
||||
# Create aggregator with failing callback
|
||||
params = LLMAssistantAggregatorParams(
|
||||
expect_stripped_words=True, correct_aggregation_callback=failing_callback
|
||||
)
|
||||
|
||||
aggregator = OpenAIAssistantContextAggregator(
|
||||
context=mock_context, params=params
|
||||
)
|
||||
|
||||
# Set up aggregation state
|
||||
aggregator._aggregation = "Some text"
|
||||
aggregator._current_llm_response_id = "test-id"
|
||||
aggregator._response_function_messages = {}
|
||||
aggregator._function_calls_in_progress = {}
|
||||
aggregator._started = 1
|
||||
|
||||
# Mock push_context_frame and reset methods
|
||||
aggregator.push_context_frame = AsyncMock()
|
||||
aggregator.reset = AsyncMock()
|
||||
|
||||
# Process interruption - should not raise
|
||||
interruption_frame = StartInterruptionFrame()
|
||||
await aggregator._handle_interruptions(interruption_frame)
|
||||
|
||||
# Verify the original text was still added (fallback behavior)
|
||||
mock_context.add_message.assert_called_once()
|
||||
added_message = mock_context.add_message.call_args[0][0]
|
||||
assert added_message["role"] == "assistant"
|
||||
assert added_message["content"] == "Some text <<interrupted_by_user>>"
|
||||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue