mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-07 07:55:16 +02:00
115 lines
4.5 KiB
Python
115 lines
4.5 KiB
Python
"""Audio configuration for pipeline components.
|
|
|
|
This module provides centralized audio configuration to ensure consistent
|
|
sample rates across all pipeline components and proper coordination between
|
|
transport serializers, VAD, and audio buffers.
|
|
"""
|
|
|
|
from dataclasses import dataclass
|
|
from typing import Optional
|
|
|
|
from loguru import logger
|
|
|
|
|
|
@dataclass
|
|
class AudioConfig:
|
|
"""Centralized audio configuration for the pipeline.
|
|
|
|
Note: Pipeline is limited to 16kHz maximum to support VAD.
|
|
Transports handle resampling from/to higher rates (24kHz, 48kHz).
|
|
|
|
Attributes:
|
|
transport_in_sample_rate: Sample rate of incoming audio from transport (after resampling)
|
|
transport_out_sample_rate: Sample rate of outgoing audio to transport (before resampling)
|
|
vad_sample_rate: Sample rate for VAD processing (8000 or 16000)
|
|
pipeline_sample_rate: Internal pipeline processing sample rate (max 16000)
|
|
buffer_size_seconds: Audio buffer size in seconds
|
|
"""
|
|
|
|
transport_in_sample_rate: int
|
|
transport_out_sample_rate: int
|
|
vad_sample_rate: int = 16000 # VAD typically resamples internally
|
|
pipeline_sample_rate: Optional[int] = None # If None, uses transport rates
|
|
buffer_size_seconds: float = 5.0 # This is how frequenly we will call merge_auido
|
|
max_recording_duration_seconds: float = 300.0 # 5 minutes max recording duration
|
|
|
|
def __post_init__(self):
|
|
# Validate VAD sample rate
|
|
if self.vad_sample_rate not in [8000, 16000]:
|
|
raise ValueError(
|
|
f"VAD sample rate must be 8000 or 16000, got {self.vad_sample_rate}"
|
|
)
|
|
|
|
# Set pipeline sample rate to transport out rate if not specified
|
|
if self.pipeline_sample_rate is None:
|
|
self.pipeline_sample_rate = min(self.transport_out_sample_rate, 16000)
|
|
|
|
# Ensure pipeline sample rate doesn't exceed 16kHz (VAD limitation)
|
|
if self.pipeline_sample_rate > 16000:
|
|
logger.warning(
|
|
f"Pipeline sample rate {self.pipeline_sample_rate} exceeds 16kHz limit, "
|
|
f"capping at 16kHz. Transport will handle resampling."
|
|
)
|
|
self.pipeline_sample_rate = 16000
|
|
|
|
# Log configuration for auditing
|
|
logger.info(
|
|
f"AudioConfig initialized: "
|
|
f"transport_in={self.transport_in_sample_rate}Hz, "
|
|
f"transport_out={self.transport_out_sample_rate}Hz, "
|
|
f"vad={self.vad_sample_rate}Hz, "
|
|
f"pipeline={self.pipeline_sample_rate}Hz, "
|
|
f"buffer={self.buffer_size_seconds}s"
|
|
)
|
|
|
|
@property
|
|
def buffer_size_bytes(self) -> int:
|
|
"""Calculate buffer size in bytes based on pipeline sample rate."""
|
|
# 2 bytes per sample (16-bit PCM)
|
|
return int(self.pipeline_sample_rate * 2 * self.buffer_size_seconds)
|
|
|
|
@property
|
|
def buffer_size_samples(self) -> int:
|
|
"""Calculate buffer size in samples based on pipeline sample rate."""
|
|
return int(self.pipeline_sample_rate * self.buffer_size_seconds)
|
|
|
|
@property
|
|
def max_recording_bytes(self) -> int:
|
|
"""Calculate max recording size in bytes based on pipeline sample rate and duration."""
|
|
# 2 bytes per sample (16-bit PCM)
|
|
return int(self.pipeline_sample_rate * 2 * self.max_recording_duration_seconds)
|
|
|
|
|
|
def create_audio_config(transport_type: str) -> AudioConfig:
|
|
"""Create audio configuration for a given transport.
|
|
|
|
Telephony providers contribute their wire-format sample rate through the
|
|
provider registry (``ProviderSpec.transport_sample_rate``); WebRTC modes
|
|
use 16 kHz (transports handle resampling from/to 24 kHz). The remaining
|
|
AudioConfig fields are derived from the chosen rate.
|
|
"""
|
|
# Defer registry import to avoid an import cycle: the registry is imported
|
|
# by every telephony provider package at startup.
|
|
from api.enums import WorkflowRunMode
|
|
from api.services.telephony import registry
|
|
|
|
telephony_spec = registry.get_optional(transport_type)
|
|
if telephony_spec is not None:
|
|
rate = telephony_spec.transport_sample_rate
|
|
elif transport_type in (
|
|
WorkflowRunMode.WEBRTC.value,
|
|
WorkflowRunMode.SMALLWEBRTC.value,
|
|
):
|
|
rate = 16000
|
|
else:
|
|
logger.warning(
|
|
f"Unknown transport type: {transport_type}, using default config"
|
|
)
|
|
rate = 16000
|
|
|
|
return AudioConfig(
|
|
transport_in_sample_rate=rate,
|
|
transport_out_sample_rate=rate,
|
|
vad_sample_rate=rate,
|
|
pipeline_sample_rate=rate,
|
|
)
|