mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-22 08:38:13 +02:00
fix: migrate from custom audio recorder to native AudioBuffer (#115)
* fix: update to pipecat VM Detector * fix: refactor to remove audio synchronizer * feat: add speechmatics as STT
This commit is contained in:
parent
31521008cf
commit
edf0fa4fbc
12 changed files with 193 additions and 591 deletions
|
|
@ -16,10 +16,9 @@ from api.services.workflow.disposition_mapper import (
|
|||
from api.services.workflow.pipecat_engine import PipecatEngine
|
||||
from api.tasks.arq import enqueue_job
|
||||
from api.tasks.function_names import FunctionNames
|
||||
from pipecat.frames.frames import Frame
|
||||
from pipecat.frames.frames import Frame, LLMContextFrame
|
||||
from pipecat.pipeline.task import PipelineTask
|
||||
from pipecat.processors.audio.audio_buffer_processor import AudioBuffer
|
||||
from pipecat.processors.audio.audio_synchronizer import AudioSynchronizer
|
||||
from pipecat.processors.audio.audio_buffer_processor import AudioBufferProcessor
|
||||
|
||||
|
||||
def register_transport_event_handlers(
|
||||
|
|
@ -27,8 +26,7 @@ def register_transport_event_handlers(
|
|||
transport,
|
||||
workflow_run_id,
|
||||
engine: PipecatEngine,
|
||||
audio_buffer: AudioBuffer,
|
||||
audio_synchronizer: AudioSynchronizer,
|
||||
audio_buffer: AudioBufferProcessor,
|
||||
audio_config=AudioConfig,
|
||||
):
|
||||
"""Register event handlers for transport events"""
|
||||
|
|
@ -53,8 +51,6 @@ def register_transport_event_handlers(
|
|||
async def on_client_connected(transport, participant):
|
||||
logger.debug("In on_client_connected callback handler - initializing workflow")
|
||||
await audio_buffer.start_recording()
|
||||
if audio_synchronizer:
|
||||
await audio_synchronizer.start_recording()
|
||||
await engine.initialize()
|
||||
|
||||
@transport.event_handler("on_client_disconnected")
|
||||
|
|
@ -68,8 +64,6 @@ def register_transport_event_handlers(
|
|||
|
||||
# Stop recordings
|
||||
await audio_buffer.stop_recording()
|
||||
if audio_synchronizer:
|
||||
await audio_synchronizer.stop_recording()
|
||||
|
||||
# Only cancel the task if the call is not already disposed by the engine
|
||||
if not call_disposed:
|
||||
|
|
@ -84,12 +78,19 @@ def register_task_event_handler(
|
|||
engine: PipecatEngine,
|
||||
task: PipelineTask,
|
||||
transport,
|
||||
audio_buffer: AudioBuffer,
|
||||
audio_synchronizer: AudioSynchronizer,
|
||||
audio_buffer: AudioBufferProcessor,
|
||||
in_memory_audio_buffer: InMemoryAudioBuffer,
|
||||
in_memory_transcript_buffer: InMemoryTranscriptBuffer,
|
||||
pipeline_metrics_aggregator: PipelineMetricsAggregator,
|
||||
):
|
||||
@task.event_handler("on_pipeline_started")
|
||||
async def on_pipeline_started(task: PipelineTask, frame: Frame):
|
||||
logger.debug(
|
||||
"In on_pipeline_started callback handler - triggering initial LLM generation"
|
||||
)
|
||||
# Trigger initial LLM generation after pipeline has started
|
||||
await engine.llm.queue_frame(LLMContextFrame(engine.context))
|
||||
|
||||
@task.event_handler("on_pipeline_finished")
|
||||
async def on_pipeline_finished(
|
||||
task: PipelineTask,
|
||||
|
|
@ -101,8 +102,6 @@ def register_task_event_handler(
|
|||
|
||||
# Stop recordings
|
||||
await audio_buffer.stop_recording()
|
||||
if audio_synchronizer:
|
||||
await audio_synchronizer.stop_recording()
|
||||
|
||||
call_disposition = await engine.get_call_disposition()
|
||||
logger.debug(f"call disposition in on_pipeline_finished: {call_disposition}")
|
||||
|
|
@ -224,19 +223,21 @@ def register_task_event_handler(
|
|||
|
||||
|
||||
def register_audio_data_handler(
|
||||
audio_synchronizer, workflow_run_id, in_memory_buffer: InMemoryAudioBuffer
|
||||
audio_buffer: AudioBufferProcessor,
|
||||
workflow_run_id,
|
||||
in_memory_buffer: InMemoryAudioBuffer,
|
||||
):
|
||||
"""Register event handler for audio data"""
|
||||
logger.info(f"Registering audio data handler for workflow run {workflow_run_id}")
|
||||
|
||||
@audio_synchronizer.event_handler("on_merged_audio")
|
||||
async def on_merged_audio(_, pcm, sample_rate, num_channels):
|
||||
if not pcm:
|
||||
@audio_buffer.event_handler("on_audio_data")
|
||||
async def on_audio_data(buffer, audio, sample_rate, num_channels):
|
||||
if not audio:
|
||||
return
|
||||
|
||||
# Use in-memory buffer
|
||||
try:
|
||||
await in_memory_buffer.append(pcm)
|
||||
await in_memory_buffer.append(audio)
|
||||
except MemoryError as e:
|
||||
logger.error(f"Memory buffer full: {e}")
|
||||
# Could implement overflow to disk here if needed
|
||||
|
|
|
|||
|
|
@ -10,8 +10,7 @@ from api.services.pipecat.audio_config import AudioConfig
|
|||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.audio.audio_buffer_processor import AudioBuffer
|
||||
from pipecat.processors.audio.audio_synchronizer import AudioSynchronizer
|
||||
from pipecat.processors.audio.audio_buffer_processor import AudioBufferProcessor
|
||||
from pipecat.processors.transcript_processor import TranscriptProcessor
|
||||
from pipecat.utils.context import turn_var
|
||||
|
||||
|
|
@ -23,15 +22,8 @@ def create_pipeline_components(audio_config: AudioConfig, engine: "PipecatEngine
|
|||
"""Create and return the main pipeline components with proper audio configuration"""
|
||||
logger.info(f"Creating pipeline components with audio config: {audio_config}")
|
||||
|
||||
# Use new split audio buffer for better performance
|
||||
audio_buffer = AudioBuffer(
|
||||
sample_rate=audio_config.pipeline_sample_rate,
|
||||
buffer_size=audio_config.buffer_size_bytes,
|
||||
max_recording_bytes=audio_config.max_recording_bytes,
|
||||
)
|
||||
|
||||
# Create synchronizer for merged audio (outside pipeline)
|
||||
audio_synchronizer = AudioSynchronizer(
|
||||
# Use native AudioBufferProcessor for merged audio recording
|
||||
audio_buffer = AudioBufferProcessor(
|
||||
sample_rate=audio_config.pipeline_sample_rate,
|
||||
buffer_size=audio_config.buffer_size_bytes,
|
||||
)
|
||||
|
|
@ -42,7 +34,7 @@ def create_pipeline_components(audio_config: AudioConfig, engine: "PipecatEngine
|
|||
|
||||
context = LLMContext()
|
||||
|
||||
return audio_buffer, audio_synchronizer, transcript, context
|
||||
return audio_buffer, transcript, context
|
||||
|
||||
|
||||
def build_pipeline(
|
||||
|
|
@ -50,7 +42,6 @@ def build_pipeline(
|
|||
stt,
|
||||
transcript,
|
||||
audio_buffer,
|
||||
audio_synchronizer,
|
||||
llm,
|
||||
tts,
|
||||
user_context_aggregator,
|
||||
|
|
@ -59,30 +50,41 @@ def build_pipeline(
|
|||
stt_mute_filter,
|
||||
pipeline_metrics_aggregator,
|
||||
user_idle_disconnect,
|
||||
voicemail_detector=None,
|
||||
):
|
||||
"""Build the main pipeline with all components"""
|
||||
# Register processors with synchronizer for merged audio
|
||||
logger.info("Registering audio buffer processors with synchronizer")
|
||||
audio_synchronizer.register_processors(audio_buffer.input(), audio_buffer.output())
|
||||
"""Build the main pipeline with all components.
|
||||
|
||||
# Build processors list with optional context controller
|
||||
Args:
|
||||
audio_buffer: AudioBufferProcessor that handles both input and output audio recording.
|
||||
voicemail_detector: Optional native pipecat VoicemailDetector. When provided,
|
||||
inserts voicemail detection after STT. Note: We don't use the TTS gate
|
||||
to avoid blocking TTS frames during classification.
|
||||
"""
|
||||
# Build processors list with optional voicemail detection
|
||||
processors = [
|
||||
transport.input(), # Transport user input
|
||||
audio_buffer.input(), # Record input audio (only processes InputAudioRawFrame)
|
||||
stt, # STT can now have audio_passthrough=False
|
||||
stt_mute_filter, # STTMuteFilters don't let VAD related events pass through if muted
|
||||
user_idle_disconnect,
|
||||
transcript.user(),
|
||||
stt, # STT (audio_passthrough=True by default, passes InputAudioRawFrame)
|
||||
]
|
||||
|
||||
# Insert voicemail detector after STT if enabled
|
||||
# Note: We intentionally do NOT use voicemail_detector.gate() to allow TTS
|
||||
# frames to continue flowing during classification (non-blocking detection)
|
||||
if voicemail_detector:
|
||||
logger.info("Adding native voicemail detector to pipeline")
|
||||
processors.append(voicemail_detector.detector())
|
||||
|
||||
# Continue with the rest of the pipeline
|
||||
processors.extend(
|
||||
[
|
||||
stt_mute_filter, # STTMuteFilters don't let VAD related events pass through if muted
|
||||
user_idle_disconnect,
|
||||
transcript.user(),
|
||||
user_context_aggregator,
|
||||
llm, # LLM
|
||||
pipeline_engine_callback_processor,
|
||||
tts, # TTS
|
||||
transport.output(), # Transport bot output
|
||||
audio_buffer.output(), # Record output audio (only processes OutputAudioRawFrame)
|
||||
audio_buffer, # AudioBufferProcessor - records both input and output audio
|
||||
transcript.assistant(),
|
||||
assistant_context_aggregator, # Assistant spoken responses
|
||||
pipeline_metrics_aggregator,
|
||||
|
|
|
|||
|
|
@ -27,6 +27,7 @@ from api.services.pipecat.service_factory import (
|
|||
create_llm_service,
|
||||
create_stt_service,
|
||||
create_tts_service,
|
||||
create_voicemail_classification_llm,
|
||||
)
|
||||
from api.services.pipecat.tracing_config import setup_pipeline_tracing
|
||||
from api.services.pipecat.transport_setup import (
|
||||
|
|
@ -41,8 +42,12 @@ from api.services.telephony.stasis_rtp_connection import StasisRTPConnection
|
|||
from api.services.workflow.dto import ReactFlowDTO
|
||||
from api.services.workflow.pipecat_engine import PipecatEngine
|
||||
from api.services.workflow.workflow import WorkflowGraph
|
||||
from pipecat.extensions.voicemail.voicemail_detector import VoicemailDetector
|
||||
from pipecat.pipeline.base_task import PipelineTaskParams
|
||||
from pipecat.processors.aggregators.llm_response import LLMAssistantAggregatorParams
|
||||
from pipecat.processors.aggregators.llm_response import (
|
||||
LLMAssistantAggregatorParams,
|
||||
LLMUserAggregatorParams,
|
||||
)
|
||||
from pipecat.processors.aggregators.llm_response_universal import (
|
||||
LLMContextAggregatorPair,
|
||||
)
|
||||
|
|
@ -54,6 +59,7 @@ from pipecat.processors.filters.stt_mute_filter import (
|
|||
from pipecat.processors.user_idle_processor import UserIdleProcessor
|
||||
from pipecat.transports.smallwebrtc.connection import SmallWebRTCConnection
|
||||
from pipecat.utils.context import set_current_run_id
|
||||
from pipecat.utils.enums import EndTaskReason
|
||||
from pipecat.utils.tracing.context_registry import ContextProviderRegistry
|
||||
|
||||
# Setup tracing if enabled
|
||||
|
|
@ -468,9 +474,7 @@ async def _run_pipeline(
|
|||
)
|
||||
|
||||
# Create pipeline components with audio configuration and engine
|
||||
audio_buffer, audio_synchronizer, transcript, context = create_pipeline_components(
|
||||
audio_config, engine
|
||||
)
|
||||
audio_buffer, transcript, context = create_pipeline_components(audio_config, engine)
|
||||
|
||||
# Set the context and audio_buffer after creation
|
||||
engine.set_context(context)
|
||||
|
|
@ -484,8 +488,9 @@ async def _run_pipeline(
|
|||
expect_stripped_words=True,
|
||||
correct_aggregation_callback=engine.create_aggregation_correction_callback(),
|
||||
)
|
||||
user_params = LLMUserAggregatorParams(enable_emulated_vad_interruptions=True)
|
||||
context_aggregator = LLMContextAggregatorPair(
|
||||
context, assistant_params=assistant_params
|
||||
context, assistant_params=assistant_params, user_params=user_params
|
||||
)
|
||||
|
||||
# Create usage metrics aggregator with engine's callback
|
||||
|
|
@ -517,13 +522,35 @@ async def _run_pipeline(
|
|||
user_context_aggregator = context_aggregator.user()
|
||||
assistant_context_aggregator = context_aggregator.assistant()
|
||||
|
||||
# Create voicemail detector if enabled in the workflow's start node
|
||||
voicemail_detector = None
|
||||
start_node = workflow_graph.nodes.get(workflow_graph.start_node_id)
|
||||
if start_node and start_node.detect_voicemail:
|
||||
classification_llm = create_voicemail_classification_llm()
|
||||
if classification_llm:
|
||||
logger.info(
|
||||
f"Voicemail detection enabled for workflow run {workflow_run_id}"
|
||||
)
|
||||
voicemail_detector = VoicemailDetector(
|
||||
llm=classification_llm,
|
||||
voicemail_response_delay=2.0,
|
||||
)
|
||||
|
||||
# Register event handler to end task when voicemail is detected
|
||||
@voicemail_detector.event_handler("on_voicemail_detected")
|
||||
async def _on_voicemail_detected(_processor):
|
||||
logger.info(f"Voicemail detected for workflow run {workflow_run_id}")
|
||||
await engine.send_end_task_frame(
|
||||
reason=EndTaskReason.VOICEMAIL_DETECTED.value,
|
||||
abort_immediately=True,
|
||||
)
|
||||
|
||||
# Build the pipeline with the STT mute filter and context controller
|
||||
pipeline = build_pipeline(
|
||||
transport,
|
||||
stt,
|
||||
transcript,
|
||||
audio_buffer,
|
||||
audio_synchronizer,
|
||||
llm,
|
||||
tts,
|
||||
user_context_aggregator,
|
||||
|
|
@ -532,6 +559,7 @@ async def _run_pipeline(
|
|||
stt_mute_filter,
|
||||
pipeline_metrics_aggregator,
|
||||
user_idle_disconnect,
|
||||
voicemail_detector=voicemail_detector,
|
||||
)
|
||||
|
||||
# Create pipeline task with audio configuration
|
||||
|
|
@ -548,7 +576,6 @@ async def _run_pipeline(
|
|||
workflow_run_id,
|
||||
engine=engine,
|
||||
audio_buffer=audio_buffer,
|
||||
audio_synchronizer=audio_synchronizer,
|
||||
audio_config=audio_config,
|
||||
)
|
||||
)
|
||||
|
|
@ -559,15 +586,12 @@ async def _run_pipeline(
|
|||
task,
|
||||
transport,
|
||||
audio_buffer,
|
||||
audio_synchronizer,
|
||||
in_memory_audio_buffer,
|
||||
in_memory_transcript_buffer,
|
||||
pipeline_metrics_aggregator,
|
||||
)
|
||||
|
||||
register_audio_data_handler(
|
||||
audio_synchronizer, workflow_run_id, in_memory_audio_buffer
|
||||
)
|
||||
register_audio_data_handler(audio_buffer, workflow_run_id, in_memory_audio_buffer)
|
||||
register_transcript_handler(
|
||||
transcript, workflow_run_id, in_memory_transcript_buffer
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
import os
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
|
@ -20,6 +21,7 @@ from pipecat.services.openai.stt import OpenAISTTService
|
|||
from pipecat.services.openai.tts import OpenAITTSService
|
||||
from pipecat.services.sarvam.stt import SarvamSTTService
|
||||
from pipecat.services.sarvam.tts import SarvamTTSService
|
||||
from pipecat.services.speechmatics.stt import SpeechmaticsSTTService
|
||||
from pipecat.transcriptions.language import Language
|
||||
from pipecat.utils.text.xml_function_tag_filter import XMLFunctionTagFilter
|
||||
|
||||
|
|
@ -40,28 +42,20 @@ def create_stt_service(user_config):
|
|||
)
|
||||
logger.debug(f"Using DeepGram Model - {user_config.stt.model}")
|
||||
return DeepgramSTTService(
|
||||
live_options=live_options,
|
||||
api_key=user_config.stt.api_key,
|
||||
audio_passthrough=False, # Disable passthrough since audio is buffered separately
|
||||
live_options=live_options, api_key=user_config.stt.api_key
|
||||
)
|
||||
elif user_config.stt.provider == ServiceProviders.OPENAI.value:
|
||||
return OpenAISTTService(
|
||||
api_key=user_config.stt.api_key,
|
||||
model=user_config.stt.model,
|
||||
audio_passthrough=False, # Disable passthrough since audio is buffered separately
|
||||
api_key=user_config.stt.api_key, model=user_config.stt.model
|
||||
)
|
||||
elif user_config.stt.provider == ServiceProviders.CARTESIA.value:
|
||||
return CartesiaSTTService(
|
||||
api_key=user_config.stt.api_key,
|
||||
audio_passthrough=False, # Disable passthrough since audio is buffered separately
|
||||
)
|
||||
return CartesiaSTTService(api_key=user_config.stt.api_key)
|
||||
elif user_config.stt.provider == ServiceProviders.DOGRAH.value:
|
||||
base_url = MPS_API_URL.replace("http://", "ws://").replace("https://", "wss://")
|
||||
return DograhSTTService(
|
||||
base_url=base_url,
|
||||
api_key=user_config.stt.api_key,
|
||||
model=user_config.stt.model,
|
||||
audio_passthrough=False, # Disable passthrough since audio is buffered separately
|
||||
)
|
||||
elif user_config.stt.provider == ServiceProviders.SARVAM.value:
|
||||
# Map Sarvam language code to pipecat Language enum
|
||||
|
|
@ -85,7 +79,23 @@ def create_stt_service(user_config):
|
|||
api_key=user_config.stt.api_key,
|
||||
model=user_config.stt.model,
|
||||
params=SarvamSTTService.InputParams(language=pipecat_language),
|
||||
audio_passthrough=False,
|
||||
)
|
||||
elif user_config.stt.provider == ServiceProviders.SPEECHMATICS.value:
|
||||
from pipecat.services.speechmatics.stt import OperatingPoint
|
||||
|
||||
language = getattr(user_config.stt, "language", None) or "en"
|
||||
# Map model field to operating point (standard or enhanced)
|
||||
operating_point = (
|
||||
OperatingPoint.ENHANCED
|
||||
if user_config.stt.model == "enhanced"
|
||||
else OperatingPoint.STANDARD
|
||||
)
|
||||
return SpeechmaticsSTTService(
|
||||
api_key=user_config.stt.api_key,
|
||||
params=SpeechmaticsSTTService.InputParams(
|
||||
language=language,
|
||||
operating_point=operating_point,
|
||||
),
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
|
|
@ -138,6 +148,7 @@ def create_tts_service(user_config, audio_config: "AudioConfig"):
|
|||
api_key=user_config.tts.api_key,
|
||||
model=user_config.tts.model,
|
||||
voice=user_config.tts.voice,
|
||||
params=DograhTTSService.InputParams(speed=user_config.tts.speed),
|
||||
text_filters=[xml_function_tag_filter],
|
||||
)
|
||||
elif user_config.tts.provider == ServiceProviders.SARVAM.value:
|
||||
|
|
@ -222,3 +233,24 @@ def create_llm_service(user_config):
|
|||
)
|
||||
else:
|
||||
raise HTTPException(status_code=400, detail="Invalid LLM provider")
|
||||
|
||||
|
||||
def create_voicemail_classification_llm():
|
||||
"""Create a fast, lightweight LLM service for voicemail classification.
|
||||
|
||||
Uses gpt-4o-mini which is fast and cost-effective for simple classification tasks.
|
||||
The model only needs to output "CONVERSATION" or "VOICEMAIL" based on transcriptions.
|
||||
|
||||
Returns:
|
||||
OpenAILLMService instance, or None if OPENAI_API_KEY is not set.
|
||||
"""
|
||||
api_key = os.environ.get("OPENAI_API_KEY")
|
||||
if not api_key:
|
||||
logger.warning("OPENAI_API_KEY not set - voicemail detection will be disabled")
|
||||
return None
|
||||
|
||||
return OpenAILLMService(
|
||||
api_key=api_key,
|
||||
model="gpt-4o-mini",
|
||||
params=OpenAILLMService.InputParams(temperature=0.0),
|
||||
)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue