mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-28 08:49:42 +02:00
Initial Commit 🚀 🚀
This commit is contained in:
commit
4f2a629340
444 changed files with 76863 additions and 0 deletions
0
api/services/pipecat/__init__.py
Normal file
0
api/services/pipecat/__init__.py
Normal file
120
api/services/pipecat/audio_config.py
Normal file
120
api/services/pipecat/audio_config.py
Normal file
|
|
@ -0,0 +1,120 @@
|
|||
"""
|
||||
Audio configuration for pipeline components.
|
||||
|
||||
This module provides centralized audio configuration to ensure consistent
|
||||
sample rates across all pipeline components and proper coordination between
|
||||
transport serializers, VAD, and audio buffers.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from api.enums import WorkflowRunMode
|
||||
|
||||
|
||||
@dataclass
|
||||
class AudioConfig:
|
||||
"""Centralized audio configuration for the pipeline.
|
||||
|
||||
Note: Pipeline is limited to 16kHz maximum to support VAD.
|
||||
Transports handle resampling from/to higher rates (24kHz, 48kHz).
|
||||
|
||||
Attributes:
|
||||
transport_in_sample_rate: Sample rate of incoming audio from transport (after resampling)
|
||||
transport_out_sample_rate: Sample rate of outgoing audio to transport (before resampling)
|
||||
vad_sample_rate: Sample rate for VAD processing (8000 or 16000)
|
||||
pipeline_sample_rate: Internal pipeline processing sample rate (max 16000)
|
||||
buffer_size_seconds: Audio buffer size in seconds
|
||||
"""
|
||||
|
||||
transport_in_sample_rate: int
|
||||
transport_out_sample_rate: int
|
||||
vad_sample_rate: int = 16000 # VAD typically resamples internally
|
||||
pipeline_sample_rate: Optional[int] = None # If None, uses transport rates
|
||||
buffer_size_seconds: float = 1.0 # This is how frequenly we will call merge_auido
|
||||
|
||||
def __post_init__(self):
|
||||
# Validate VAD sample rate
|
||||
if self.vad_sample_rate not in [8000, 16000]:
|
||||
raise ValueError(
|
||||
f"VAD sample rate must be 8000 or 16000, got {self.vad_sample_rate}"
|
||||
)
|
||||
|
||||
# Set pipeline sample rate to transport out rate if not specified
|
||||
if self.pipeline_sample_rate is None:
|
||||
self.pipeline_sample_rate = min(self.transport_out_sample_rate, 16000)
|
||||
|
||||
# Ensure pipeline sample rate doesn't exceed 16kHz (VAD limitation)
|
||||
if self.pipeline_sample_rate > 16000:
|
||||
logger.warning(
|
||||
f"Pipeline sample rate {self.pipeline_sample_rate} exceeds 16kHz limit, "
|
||||
f"capping at 16kHz. Transport will handle resampling."
|
||||
)
|
||||
self.pipeline_sample_rate = 16000
|
||||
|
||||
# Log configuration for auditing
|
||||
logger.info(
|
||||
f"AudioConfig initialized: "
|
||||
f"transport_in={self.transport_in_sample_rate}Hz, "
|
||||
f"transport_out={self.transport_out_sample_rate}Hz, "
|
||||
f"vad={self.vad_sample_rate}Hz, "
|
||||
f"pipeline={self.pipeline_sample_rate}Hz, "
|
||||
f"buffer={self.buffer_size_seconds}s"
|
||||
)
|
||||
|
||||
@property
|
||||
def buffer_size_bytes(self) -> int:
|
||||
"""Calculate buffer size in bytes based on pipeline sample rate."""
|
||||
# 2 bytes per sample (16-bit PCM)
|
||||
return int(self.pipeline_sample_rate * 2 * self.buffer_size_seconds)
|
||||
|
||||
@property
|
||||
def buffer_size_samples(self) -> int:
|
||||
"""Calculate buffer size in samples based on pipeline sample rate."""
|
||||
return int(self.pipeline_sample_rate * self.buffer_size_seconds)
|
||||
|
||||
|
||||
def create_audio_config(transport_type: str) -> AudioConfig:
|
||||
"""Create audio configuration based on transport type.
|
||||
|
||||
Args:
|
||||
transport_type: Type of transport ("webrtc", "twilio", "stasis")
|
||||
|
||||
Returns:
|
||||
AudioConfig instance with appropriate settings
|
||||
"""
|
||||
if transport_type in (WorkflowRunMode.STASIS.value, WorkflowRunMode.TWILIO.value):
|
||||
return AudioConfig(
|
||||
transport_in_sample_rate=8000,
|
||||
transport_out_sample_rate=8000,
|
||||
vad_sample_rate=8000, # Use matching VAD rate
|
||||
pipeline_sample_rate=8000, # Keep at 8kHz to avoid resampling
|
||||
buffer_size_seconds=1.0,
|
||||
)
|
||||
elif transport_type in [
|
||||
WorkflowRunMode.WEBRTC.value,
|
||||
WorkflowRunMode.SMALLWEBRTC.value,
|
||||
]:
|
||||
# WebRTC typically uses 24kHz or 48kHz, but we limit pipeline to 16kHz
|
||||
# The transport will handle resampling between 24kHz and 16kHz
|
||||
return AudioConfig(
|
||||
transport_in_sample_rate=16000, # Transport will resample from 24kHz
|
||||
transport_out_sample_rate=16000, # Transport will resample to 24kHz
|
||||
vad_sample_rate=16000, # VAD native rate
|
||||
pipeline_sample_rate=16000, # Keep pipeline at 16kHz
|
||||
buffer_size_seconds=1.0,
|
||||
)
|
||||
else:
|
||||
# Default configuration
|
||||
logger.warning(
|
||||
f"Unknown transport type: {transport_type}, using default config"
|
||||
)
|
||||
return AudioConfig(
|
||||
transport_in_sample_rate=16000,
|
||||
transport_out_sample_rate=16000,
|
||||
vad_sample_rate=16000,
|
||||
pipeline_sample_rate=16000,
|
||||
buffer_size_seconds=1.0,
|
||||
)
|
||||
122
api/services/pipecat/audio_transcript_buffers.py
Normal file
122
api/services/pipecat/audio_transcript_buffers.py
Normal file
|
|
@ -0,0 +1,122 @@
|
|||
import asyncio
|
||||
import re
|
||||
import tempfile
|
||||
import wave
|
||||
from typing import List
|
||||
|
||||
from loguru import logger
|
||||
|
||||
|
||||
class InMemoryAudioBuffer:
|
||||
"""Buffer audio data in memory during a call, then write to temp file on disconnect."""
|
||||
|
||||
def __init__(self, workflow_run_id: int, sample_rate: int, num_channels: int = 1):
|
||||
self._workflow_run_id = workflow_run_id
|
||||
self._sample_rate = sample_rate
|
||||
self._num_channels = num_channels
|
||||
self._chunks: List[bytes] = []
|
||||
self._lock = asyncio.Lock()
|
||||
self._total_size = 0
|
||||
self._max_size = 100 * 1024 * 1024 # 100MB limit
|
||||
|
||||
async def append(self, pcm_data: bytes):
|
||||
"""Append PCM audio data to the buffer."""
|
||||
async with self._lock:
|
||||
if self._total_size + len(pcm_data) > self._max_size:
|
||||
logger.error(
|
||||
f"Audio buffer size limit exceeded for workflow {self._workflow_run_id}. "
|
||||
f"Current: {self._total_size}, Attempted to add: {len(pcm_data)}"
|
||||
)
|
||||
raise MemoryError("Audio buffer size limit exceeded")
|
||||
self._chunks.append(pcm_data)
|
||||
self._total_size += len(pcm_data)
|
||||
logger.trace(
|
||||
f"Appended {len(pcm_data)} bytes to audio buffer. Total size: {self._total_size}"
|
||||
)
|
||||
|
||||
async def write_to_temp_file(self) -> str:
|
||||
"""Write audio data to a temporary WAV file and return the path."""
|
||||
async with self._lock:
|
||||
temp_file = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
|
||||
logger.debug(
|
||||
f"Writing audio buffer to temp file {temp_file.name} for workflow {self._workflow_run_id}"
|
||||
)
|
||||
|
||||
# Write WAV header and PCM data
|
||||
with wave.open(temp_file.name, "wb") as wf:
|
||||
wf.setnchannels(self._num_channels)
|
||||
wf.setsampwidth(2) # 16-bit audio
|
||||
wf.setframerate(self._sample_rate)
|
||||
|
||||
# Concatenate all chunks
|
||||
for chunk in self._chunks:
|
||||
wf.writeframes(chunk)
|
||||
|
||||
logger.info(
|
||||
f"Successfully wrote {self._total_size} bytes of audio to {temp_file.name}"
|
||||
)
|
||||
return temp_file.name
|
||||
|
||||
@property
|
||||
def is_empty(self) -> bool:
|
||||
"""Check if the buffer is empty."""
|
||||
return len(self._chunks) == 0
|
||||
|
||||
@property
|
||||
def size(self) -> int:
|
||||
"""Get the total size of buffered data."""
|
||||
return self._total_size
|
||||
|
||||
|
||||
class InMemoryTranscriptBuffer:
|
||||
"""Buffer transcript data in memory during a call, then write to temp file on disconnect."""
|
||||
|
||||
# Compiled regex to identify user speech lines, e.g.
|
||||
# [2025-06-29T12:34:56.789+00:00] user: hello
|
||||
_USER_SPEECH_RE: re.Pattern[str] = re.compile(
|
||||
r"^\[\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}\.\d{3}\+\d{2}:\d{2}\] user: .+"
|
||||
)
|
||||
|
||||
def __init__(self, workflow_run_id: int):
|
||||
self._workflow_run_id = workflow_run_id
|
||||
self._lines: List[str] = []
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
async def append(self, transcript: str):
|
||||
"""Append transcript text to the buffer."""
|
||||
async with self._lock:
|
||||
self._lines.append(transcript)
|
||||
logger.trace(
|
||||
f"Appended transcript line to buffer for workflow {self._workflow_run_id}"
|
||||
)
|
||||
|
||||
async def write_to_temp_file(self) -> str:
|
||||
"""Write transcript to a temporary text file and return the path."""
|
||||
async with self._lock:
|
||||
temp_file = tempfile.NamedTemporaryFile(
|
||||
mode="w", suffix=".txt", delete=False
|
||||
)
|
||||
logger.debug(
|
||||
f"Writing transcript buffer to temp file {temp_file.name} for workflow {self._workflow_run_id}"
|
||||
)
|
||||
|
||||
content = "".join(self._lines)
|
||||
temp_file.write(content)
|
||||
temp_file.close()
|
||||
|
||||
logger.info(
|
||||
f"Successfully wrote {len(content)} chars of transcript to {temp_file.name}"
|
||||
)
|
||||
return temp_file.name
|
||||
|
||||
@property
|
||||
def is_empty(self) -> bool:
|
||||
"""Check if the buffer is empty."""
|
||||
return len(self._lines) == 0
|
||||
|
||||
def contains_user_speech(self) -> bool:
|
||||
"""Return True if any buffered transcript line matches the user speech pattern."""
|
||||
for line in self._lines:
|
||||
if self._USER_SPEECH_RE.match(line):
|
||||
return True
|
||||
return False
|
||||
69
api/services/pipecat/engine_pre_aggregator_processor.py
Normal file
69
api/services/pipecat/engine_pre_aggregator_processor.py
Normal file
|
|
@ -0,0 +1,69 @@
|
|||
"""Engine Pre-Aggregator Processor
|
||||
|
||||
This processor sits before the user context aggregator in the pipeline and handles
|
||||
engine-specific callbacks for frames that need to be processed before aggregation.
|
||||
This ensures the engine can update context before the aggregator generates LLM frames.
|
||||
"""
|
||||
|
||||
from typing import Awaitable, Callable, Optional
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from api.services.pipecat.exceptions import VoicemailDetectedException
|
||||
from pipecat.frames.frames import (
|
||||
Frame,
|
||||
UserStartedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
|
||||
|
||||
class EnginePreAggregatorProcessor(FrameProcessor):
|
||||
"""
|
||||
Processor that handles engine callbacks before user context aggregation.
|
||||
|
||||
This processor is positioned before the user context aggregator to ensure
|
||||
the engine can update LLM context before aggregation occurs.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
user_started_speaking_callback: Optional[Callable[[], Awaitable[None]]] = None,
|
||||
user_stopped_speaking_callback: Optional[Callable[[], Awaitable[None]]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self._user_started_speaking_callback = user_started_speaking_callback
|
||||
self._user_stopped_speaking_callback = user_stopped_speaking_callback
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
# Handle frames that need engine processing before aggregation
|
||||
if isinstance(frame, UserStartedSpeakingFrame):
|
||||
await self._handle_user_started_speaking()
|
||||
elif isinstance(frame, UserStoppedSpeakingFrame):
|
||||
try:
|
||||
await self._handle_user_stopped_speaking()
|
||||
except VoicemailDetectedException:
|
||||
# We have detected voicemail, lets not
|
||||
# forward the UserStoppedSpeakingFrame, so that
|
||||
# we don't issue an llm call from user context
|
||||
# aggregator
|
||||
logger.debug("Voicemail detected, not pushing UserStoppedSpeakingFrame")
|
||||
return
|
||||
|
||||
# Always push the frame downstream
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
async def _handle_user_started_speaking(self):
|
||||
"""Handle UserStartedSpeakingFrame before aggregation."""
|
||||
if self._user_started_speaking_callback:
|
||||
# logger.debug("Engine pre-aggregator: User started speaking")
|
||||
await self._user_started_speaking_callback()
|
||||
|
||||
async def _handle_user_stopped_speaking(self):
|
||||
"""Handle UserStoppedSpeakingFrame before aggregation."""
|
||||
if self._user_stopped_speaking_callback:
|
||||
# logger.debug("Engine pre-aggregator: User stopped speaking")
|
||||
await self._user_stopped_speaking_callback()
|
||||
249
api/services/pipecat/event_handlers.py
Normal file
249
api/services/pipecat/event_handlers.py
Normal file
|
|
@ -0,0 +1,249 @@
|
|||
from typing import Optional
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from api.db import db_client
|
||||
from api.services.campaign.call_dispatcher import campaign_call_dispatcher
|
||||
from api.services.pipecat.audio_transcript_buffers import (
|
||||
InMemoryAudioBuffer,
|
||||
InMemoryTranscriptBuffer,
|
||||
)
|
||||
from api.services.pipecat.pipeline_metrics_aggregator import PipelineMetricsAggregator
|
||||
from api.services.workflow.disposition_mapper import (
|
||||
apply_disposition_mapping,
|
||||
get_organization_id_from_workflow_run,
|
||||
)
|
||||
from api.services.workflow.pipecat_engine import PipecatEngine
|
||||
from api.tasks.arq import enqueue_job
|
||||
from api.tasks.function_names import FunctionNames
|
||||
from pipecat.pipeline.task import PipelineTask
|
||||
from pipecat.transports.base_transport import BaseTransport
|
||||
from pipecat.utils.enums import EndTaskReason
|
||||
|
||||
|
||||
def register_transport_event_handlers(
|
||||
transport,
|
||||
workflow_run_id,
|
||||
audio_buffer,
|
||||
task: PipelineTask,
|
||||
engine: PipecatEngine,
|
||||
usage_metrics_aggregator: PipelineMetricsAggregator,
|
||||
audio_synchronizer=None,
|
||||
audio_config=None,
|
||||
):
|
||||
"""Register event handlers for transport events"""
|
||||
|
||||
# Initialize in-memory buffers with proper audio configuration
|
||||
sample_rate = audio_config.pipeline_sample_rate if audio_config else 16000
|
||||
num_channels = 1 # Pipeline audio is always mono
|
||||
|
||||
logger.debug(
|
||||
f"Initializing audio buffer for workflow {workflow_run_id} "
|
||||
f"with sample_rate={sample_rate}Hz, channels={num_channels}"
|
||||
)
|
||||
|
||||
in_memory_audio_buffer = InMemoryAudioBuffer(
|
||||
workflow_run_id=workflow_run_id,
|
||||
sample_rate=sample_rate,
|
||||
num_channels=num_channels,
|
||||
)
|
||||
in_memory_transcript_buffer = InMemoryTranscriptBuffer(workflow_run_id)
|
||||
|
||||
@transport.event_handler("on_client_connected")
|
||||
async def on_client_connected(transport, participant):
|
||||
logger.debug("In on_client_connected callback handler - initializing workflow")
|
||||
await audio_buffer.start_recording()
|
||||
if audio_synchronizer:
|
||||
await audio_synchronizer.start_recording()
|
||||
await engine.initialize()
|
||||
|
||||
@transport.event_handler("on_client_disconnected")
|
||||
async def on_client_disconnected(
|
||||
transport: BaseTransport,
|
||||
participant,
|
||||
transport_disconnect_reason: Optional[str] = None,
|
||||
):
|
||||
logger.debug(
|
||||
f"In on_client_disconnected callback handler, disconnect_reason: {transport_disconnect_reason}"
|
||||
)
|
||||
|
||||
workflow_run = await db_client.get_workflow_run_by_id(workflow_run_id)
|
||||
|
||||
# First priority: Check if engine has a disconnect reason (local disconnect)
|
||||
engine_call_disposition = engine.get_call_disposition()
|
||||
gathered_context = engine.get_gathered_context()
|
||||
|
||||
# also consider existing gathered context in workflow_run
|
||||
gathered_context = {**gathered_context, **workflow_run.gathered_context}
|
||||
|
||||
if engine_call_disposition:
|
||||
# Engine has set a disconnect reason - this takes priority
|
||||
call_disposition = engine_call_disposition
|
||||
logger.debug(f"Engine disposition detected, code: {call_disposition}")
|
||||
elif transport_disconnect_reason:
|
||||
# TODO: Make this more generic using some DSL or equivalent. This is currently
|
||||
# configured to work for Kapil's bot
|
||||
call_duration = usage_metrics_aggregator.get_call_duration()
|
||||
if transport_disconnect_reason == EndTaskReason.USER_HANGUP.value:
|
||||
if call_duration < 10:
|
||||
call_disposition = "HU"
|
||||
else:
|
||||
call_disposition = "NIBP"
|
||||
else:
|
||||
# Transport provided a disconnect reason (remote hangup)
|
||||
call_disposition = transport_disconnect_reason
|
||||
logger.debug(
|
||||
f"Remote disconnect detected, reason: {call_disposition} duration: {call_duration}"
|
||||
)
|
||||
else:
|
||||
# No reason provided - assume user hangup
|
||||
call_disposition = EndTaskReason.UNKNOWN.value
|
||||
logger.debug("No disposition found from either engine or transport")
|
||||
|
||||
# Cancel task only when no engine disconnect reason (remote disconnect)
|
||||
if not engine_call_disposition:
|
||||
await task.cancel()
|
||||
|
||||
organization_id = await get_organization_id_from_workflow_run(workflow_run_id)
|
||||
mapped_call_disposition = await apply_disposition_mapping(
|
||||
call_disposition, organization_id
|
||||
)
|
||||
|
||||
gathered_context.update({"mapped_call_disposition": mapped_call_disposition})
|
||||
|
||||
if in_memory_transcript_buffer:
|
||||
call_tags = gathered_context.get("call_tags", [])
|
||||
|
||||
try:
|
||||
has_user_speech = in_memory_transcript_buffer.contains_user_speech()
|
||||
except Exception:
|
||||
has_user_speech = False
|
||||
|
||||
if has_user_speech and "user_speech" not in call_tags:
|
||||
call_tags.append("user_speech")
|
||||
|
||||
# Append any keys from gathered_context that start with 'tag_' to call_tags
|
||||
for key in gathered_context:
|
||||
if key.startswith("tag_") and key not in call_tags:
|
||||
call_tags.append(gathered_context[key])
|
||||
|
||||
gathered_context["call_tags"] = call_tags
|
||||
|
||||
# Clean up engine resources (including voicemail detector)
|
||||
await engine.cleanup()
|
||||
|
||||
await audio_buffer.stop_recording()
|
||||
if audio_synchronizer:
|
||||
await audio_synchronizer.stop_recording()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Close Smart-Turn WebSocket if the transport's analyzer supports it
|
||||
# ------------------------------------------------------------------
|
||||
try:
|
||||
turn_analyzer = None
|
||||
|
||||
# Most transports store their params (with turn_analyzer) directly.
|
||||
if hasattr(transport, "_params") and transport._params:
|
||||
turn_analyzer = getattr(transport._params, "turn_analyzer", None)
|
||||
|
||||
# Fallback: some transports expose params through input() instance.
|
||||
if turn_analyzer is None and hasattr(transport, "input"):
|
||||
try:
|
||||
input_transport = transport.input()
|
||||
if input_transport and hasattr(input_transport, "_params"):
|
||||
turn_analyzer = getattr(
|
||||
input_transport._params, "turn_analyzer", None
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if turn_analyzer and hasattr(turn_analyzer, "close"):
|
||||
await turn_analyzer.close()
|
||||
logger.debug("Closed turn analyzer websocket")
|
||||
except Exception as exc:
|
||||
logger.warning(f"Failed to close Smart-Turn analyzer gracefully: {exc}")
|
||||
|
||||
usage_info = usage_metrics_aggregator.get_all_usage_metrics_serialized()
|
||||
|
||||
logger.debug(f"Usage metrics: {usage_info}")
|
||||
|
||||
await db_client.update_workflow_run(
|
||||
run_id=workflow_run_id,
|
||||
usage_info=usage_info,
|
||||
gathered_context=gathered_context,
|
||||
is_completed=True,
|
||||
)
|
||||
|
||||
# Release concurrent slot for campaign calls
|
||||
if workflow_run and workflow_run.campaign_id:
|
||||
await campaign_call_dispatcher.release_call_slot(workflow_run_id)
|
||||
|
||||
# Write buffers to temp files and enqueue S3 upload
|
||||
try:
|
||||
# Only upload if buffers have content
|
||||
if not in_memory_audio_buffer.is_empty:
|
||||
audio_temp_path = await in_memory_audio_buffer.write_to_temp_file()
|
||||
await enqueue_job(
|
||||
FunctionNames.UPLOAD_AUDIO_TO_S3, workflow_run_id, audio_temp_path
|
||||
)
|
||||
else:
|
||||
logger.debug("Audio buffer is empty, skipping upload")
|
||||
|
||||
if not in_memory_transcript_buffer.is_empty:
|
||||
transcript_temp_path = (
|
||||
await in_memory_transcript_buffer.write_to_temp_file()
|
||||
)
|
||||
await enqueue_job(
|
||||
FunctionNames.UPLOAD_TRANSCRIPT_TO_S3,
|
||||
workflow_run_id,
|
||||
transcript_temp_path,
|
||||
)
|
||||
else:
|
||||
logger.debug("Transcript buffer is empty, skipping upload")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error preparing buffers for S3 upload: {e}", exc_info=True)
|
||||
|
||||
await enqueue_job(FunctionNames.CALCULATE_WORKFLOW_RUN_COST, workflow_run_id)
|
||||
await enqueue_job(
|
||||
FunctionNames.RUN_INTEGRATIONS_POST_WORKFLOW_RUN, workflow_run_id
|
||||
)
|
||||
|
||||
# Return the buffers so they can be passed to other handlers
|
||||
return in_memory_audio_buffer, in_memory_transcript_buffer
|
||||
|
||||
|
||||
def register_audio_data_handler(
|
||||
audio_synchronizer, workflow_run_id, in_memory_buffer: InMemoryAudioBuffer
|
||||
):
|
||||
"""Register event handler for audio data"""
|
||||
logger.info(f"Registering audio data handler for workflow run {workflow_run_id}")
|
||||
|
||||
@audio_synchronizer.event_handler("on_merged_audio")
|
||||
async def on_merged_audio(_, pcm, sample_rate, num_channels):
|
||||
if not pcm:
|
||||
return
|
||||
|
||||
# Use in-memory buffer
|
||||
try:
|
||||
await in_memory_buffer.append(pcm)
|
||||
except MemoryError as e:
|
||||
logger.error(f"Memory buffer full: {e}")
|
||||
# Could implement overflow to disk here if needed
|
||||
|
||||
|
||||
def register_transcript_handler(
|
||||
transcript, workflow_run_id, in_memory_buffer: InMemoryTranscriptBuffer
|
||||
):
|
||||
"""Register event handler for transcript updates"""
|
||||
|
||||
@transcript.event_handler("on_transcript_update")
|
||||
async def on_transcript_update(processor, frame):
|
||||
transcript_text = ""
|
||||
for msg in frame.messages:
|
||||
timestamp = f"[{msg.timestamp}] " if msg.timestamp else ""
|
||||
line = f"{timestamp}{msg.role}: {msg.content}\n"
|
||||
transcript_text += line
|
||||
|
||||
# Use in-memory buffer
|
||||
await in_memory_buffer.append(transcript_text)
|
||||
6
api/services/pipecat/exceptions.py
Normal file
6
api/services/pipecat/exceptions.py
Normal file
|
|
@ -0,0 +1,6 @@
|
|||
class VoicemailDetectedException(Exception):
|
||||
"""
|
||||
Exception raised when voicemail is detected.
|
||||
"""
|
||||
|
||||
pass
|
||||
147
api/services/pipecat/pipeline_builder.py
Normal file
147
api/services/pipecat/pipeline_builder.py
Normal file
|
|
@ -0,0 +1,147 @@
|
|||
import os
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from api.constants import (
|
||||
ENABLE_TRACING,
|
||||
)
|
||||
from api.services.pipecat.audio_config import AudioConfig
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||
from pipecat.processors.audio.audio_buffer_processor import AudioBuffer
|
||||
from pipecat.processors.audio.audio_synchronizer import AudioSynchronizer
|
||||
from pipecat.processors.transcript_processor import TranscriptProcessor
|
||||
from pipecat.utils.context import turn_var
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from api.services.workflow.pipecat_engine import PipecatEngine
|
||||
|
||||
|
||||
def create_pipeline_components(audio_config: AudioConfig, engine: "PipecatEngine"):
|
||||
"""Create and return the main pipeline components with proper audio configuration"""
|
||||
logger.info(f"Creating pipeline components with audio config: {audio_config}")
|
||||
|
||||
# Use new split audio buffer for better performance
|
||||
audio_buffer = AudioBuffer(
|
||||
sample_rate=audio_config.pipeline_sample_rate,
|
||||
buffer_size=audio_config.buffer_size_bytes,
|
||||
)
|
||||
|
||||
# Create synchronizer for merged audio (outside pipeline)
|
||||
audio_synchronizer = AudioSynchronizer(
|
||||
sample_rate=audio_config.pipeline_sample_rate,
|
||||
buffer_size=audio_config.buffer_size_bytes,
|
||||
)
|
||||
|
||||
transcript = TranscriptProcessor(
|
||||
assistant_correct_aggregation_callback=engine.create_aggregation_correction_callback()
|
||||
)
|
||||
|
||||
context = OpenAILLMContext()
|
||||
|
||||
return audio_buffer, audio_synchronizer, transcript, context
|
||||
|
||||
|
||||
def build_pipeline(
|
||||
transport,
|
||||
stt,
|
||||
transcript,
|
||||
audio_buffer,
|
||||
audio_synchronizer,
|
||||
llm,
|
||||
tts,
|
||||
user_context_aggregator,
|
||||
assistant_context_aggregator,
|
||||
pipeline_engine_callback_processor,
|
||||
stt_mute_filter,
|
||||
pipeline_metrics_aggregator,
|
||||
user_idle_disconnect,
|
||||
engine_pre_aggregator_processor=None,
|
||||
):
|
||||
"""Build the main pipeline with all components"""
|
||||
# Register processors with synchronizer for merged audio
|
||||
logger.info("Registering audio buffer processors with synchronizer")
|
||||
audio_synchronizer.register_processors(audio_buffer.input(), audio_buffer.output())
|
||||
|
||||
# Build processors list with optional context controller
|
||||
processors = [
|
||||
transport.input(), # Transport user input
|
||||
audio_buffer.input(), # Record input audio (only processes InputAudioRawFrame)
|
||||
stt_mute_filter,
|
||||
stt, # STT can now have audio_passthrough=False
|
||||
user_idle_disconnect,
|
||||
transcript.user(),
|
||||
]
|
||||
|
||||
# Insert engine pre-aggregator processor if provided (before user aggregator)
|
||||
if engine_pre_aggregator_processor:
|
||||
processors.append(engine_pre_aggregator_processor)
|
||||
|
||||
processors.extend(
|
||||
[
|
||||
user_context_aggregator,
|
||||
llm, # LLM
|
||||
pipeline_engine_callback_processor,
|
||||
tts, # TTS
|
||||
transport.output(), # Transport bot output
|
||||
audio_buffer.output(), # Record output audio (only processes OutputAudioRawFrame)
|
||||
transcript.assistant(),
|
||||
assistant_context_aggregator, # Assistant spoken responses
|
||||
pipeline_metrics_aggregator,
|
||||
]
|
||||
)
|
||||
|
||||
return Pipeline(processors)
|
||||
|
||||
|
||||
def create_pipeline_task(pipeline, workflow_run_id, audio_config: AudioConfig = None):
|
||||
"""Create a pipeline task with appropriate parameters"""
|
||||
# Set up pipeline params with audio configuration if provided
|
||||
pipeline_params = PipelineParams(
|
||||
allow_interruptions=True,
|
||||
enable_metrics=True,
|
||||
enable_usage_metrics=True,
|
||||
send_initial_empty_metrics=False,
|
||||
enable_heartbeats=True,
|
||||
)
|
||||
|
||||
# If audio_config is provided, set the audio sample rates
|
||||
if audio_config:
|
||||
pipeline_params.audio_in_sample_rate = audio_config.transport_in_sample_rate
|
||||
pipeline_params.audio_out_sample_rate = audio_config.transport_out_sample_rate
|
||||
logger.debug(
|
||||
f"Setting pipeline audio params - in: {audio_config.transport_in_sample_rate}Hz, "
|
||||
f"out: {audio_config.transport_out_sample_rate}Hz"
|
||||
)
|
||||
|
||||
task = PipelineTask(
|
||||
pipeline,
|
||||
params=pipeline_params,
|
||||
enable_tracing=ENABLE_TRACING,
|
||||
conversation_id=f"{workflow_run_id}",
|
||||
)
|
||||
|
||||
# Check if turn logging is enabled
|
||||
enable_turn_logging = os.getenv("ENABLE_TURN_LOGGING", "false").lower() == "true"
|
||||
|
||||
if enable_turn_logging:
|
||||
# Attach event handlers to propagate turn information into the logging context
|
||||
turn_observer = task.turn_tracking_observer
|
||||
|
||||
if turn_observer is not None:
|
||||
# Import turn context manager only if needed
|
||||
from api.services.pipecat.turn_context import get_turn_context_manager
|
||||
|
||||
async def _on_turn_started(observer, turn_number: int):
|
||||
"""Set the current turn number into the context variable."""
|
||||
# Set in both contextvar and turn context manager
|
||||
turn_var.set(turn_number)
|
||||
turn_manager = get_turn_context_manager()
|
||||
turn_manager.set_turn(turn_number)
|
||||
|
||||
# Register the handlers with the observer
|
||||
turn_observer.add_event_handler("on_turn_started", _on_turn_started)
|
||||
|
||||
return task
|
||||
84
api/services/pipecat/pipeline_engine_callbacks_processor.py
Normal file
84
api/services/pipecat/pipeline_engine_callbacks_processor.py
Normal file
|
|
@ -0,0 +1,84 @@
|
|||
import time
|
||||
from typing import Awaitable, Callable, Optional
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
Frame,
|
||||
HeartbeatFrame,
|
||||
LLMFullResponseStartFrame,
|
||||
LLMGeneratedTextFrame,
|
||||
LLMTextFrame,
|
||||
StartFrame,
|
||||
TTSSpeakFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
|
||||
|
||||
class PipelineEngineCallbacksProcessor(FrameProcessor):
|
||||
"""
|
||||
Custom PipelineEngineCallbacksProcessor that accepts callbacks for various
|
||||
use cases, like ending tasks when max call duration is exceeded, or informing
|
||||
the engine that the bot is done speaking.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_call_duration_seconds: int = 300,
|
||||
max_duration_end_task_callback: Optional[Callable[[], Awaitable[None]]] = None,
|
||||
llm_generated_text_callback: Optional[Callable[[], Awaitable[None]]] = None,
|
||||
generation_started_callback: Optional[Callable[[], Awaitable[None]]] = None,
|
||||
llm_text_frame_callback: Optional[Callable[[str], Awaitable[None]]] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self._start_time = None
|
||||
self._max_call_duration_seconds = max_call_duration_seconds
|
||||
self._max_duration_end_task_callback = max_duration_end_task_callback
|
||||
self._llm_generated_text_callback = llm_generated_text_callback
|
||||
self._generation_started_callback = generation_started_callback
|
||||
self._llm_text_frame_callback = llm_text_frame_callback
|
||||
self._end_task_frame_pushed = False
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, StartFrame):
|
||||
await self._start(frame)
|
||||
elif isinstance(frame, HeartbeatFrame):
|
||||
await self._check_call_duration()
|
||||
elif isinstance(frame, LLMGeneratedTextFrame):
|
||||
await self._generated_text_frame(frame)
|
||||
elif isinstance(frame, LLMFullResponseStartFrame):
|
||||
await self._generation_started()
|
||||
elif (
|
||||
isinstance(frame, (LLMTextFrame, TTSSpeakFrame))
|
||||
and self._llm_text_frame_callback
|
||||
):
|
||||
# Include TTSSpeakFrame here since for static nodes, we send TTSSpeakFrame
|
||||
# which can act as reference while fixing the aggregated trascript
|
||||
await self._llm_text_frame_callback(frame.text)
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
async def _start(self, _: StartFrame):
|
||||
self._start_time = time.time()
|
||||
|
||||
async def _check_call_duration(self):
|
||||
if self._start_time is not None:
|
||||
if time.time() - self._start_time > self._max_call_duration_seconds:
|
||||
if not self._end_task_frame_pushed:
|
||||
if self._max_duration_end_task_callback:
|
||||
await self._max_duration_end_task_callback()
|
||||
self._end_task_frame_pushed = True
|
||||
else:
|
||||
logger.debug(
|
||||
"Max call duration exceeded. Skipping EndTaskFrame since already sent"
|
||||
)
|
||||
|
||||
async def _generated_text_frame(self, _: LLMGeneratedTextFrame):
|
||||
"""Handle LLMGeneratedTextFrame."""
|
||||
if self._llm_generated_text_callback is not None:
|
||||
await self._llm_generated_text_callback()
|
||||
|
||||
async def _generation_started(self):
|
||||
if self._generation_started_callback:
|
||||
await self._generation_started_callback()
|
||||
162
api/services/pipecat/pipeline_metrics_aggregator.py
Normal file
162
api/services/pipecat/pipeline_metrics_aggregator.py
Normal file
|
|
@ -0,0 +1,162 @@
|
|||
import time
|
||||
from collections import defaultdict
|
||||
from typing import Dict, Optional
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
CancelFrame,
|
||||
EndFrame,
|
||||
Frame,
|
||||
MetricsFrame,
|
||||
StartFrame,
|
||||
)
|
||||
from pipecat.metrics.metrics import (
|
||||
LLMTokenUsage,
|
||||
LLMUsageMetricsData,
|
||||
STTUsageMetricsData,
|
||||
TTSUsageMetricsData,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
|
||||
|
||||
class PipelineMetricsAggregator(FrameProcessor):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
# Structure: {f"{processor}|||{model}": aggregated_metrics}
|
||||
# For LLM: aggregated_metrics is LLMTokenUsage
|
||||
# For TTS: aggregated_metrics is int (total characters)
|
||||
# For STT: aggregated_metrics is float (total seconds)
|
||||
|
||||
self._start_time: Optional[float] = None
|
||||
self._stop_time: Optional[float] = None
|
||||
self._llm_usage_metrics: Dict[str, LLMTokenUsage] = {}
|
||||
self._tts_usage_metrics: Dict[str, int] = defaultdict(int)
|
||||
self._stt_usage_metrics: Dict[str, float] = defaultdict(float)
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, StartFrame):
|
||||
await self._start(frame)
|
||||
elif isinstance(frame, EndFrame):
|
||||
await self._stop(frame)
|
||||
elif isinstance(frame, CancelFrame):
|
||||
await self._cancel(frame)
|
||||
elif isinstance(frame, MetricsFrame):
|
||||
for data in frame.data:
|
||||
if isinstance(data, LLMUsageMetricsData):
|
||||
await self._handle_llm_usage_metrics(data)
|
||||
elif isinstance(data, TTSUsageMetricsData):
|
||||
await self._handle_tts_usage_metrics(data)
|
||||
elif isinstance(data, STTUsageMetricsData):
|
||||
await self._handle_stt_usage_metrics(data)
|
||||
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
async def _start(self, _: StartFrame):
|
||||
"""Start tracking call duration."""
|
||||
self._start_time = time.time()
|
||||
self._stop_time = None
|
||||
|
||||
async def _stop(self, _: EndFrame):
|
||||
"""Stop tracking call duration."""
|
||||
if self._start_time is not None and self._stop_time is None:
|
||||
self._stop_time = time.time()
|
||||
|
||||
async def _cancel(self, _: CancelFrame):
|
||||
"""Handle call cancellation - also stop tracking duration."""
|
||||
if self._start_time is not None and self._stop_time is None:
|
||||
self._stop_time = time.time()
|
||||
|
||||
async def _handle_llm_usage_metrics(self, data: LLMUsageMetricsData):
|
||||
key = f"{data.processor}|||{data.model}"
|
||||
new_usage = data.value
|
||||
|
||||
if key in self._llm_usage_metrics:
|
||||
# Aggregate with existing metrics
|
||||
existing = self._llm_usage_metrics[key]
|
||||
aggregated = LLMTokenUsage(
|
||||
prompt_tokens=existing.prompt_tokens + new_usage.prompt_tokens,
|
||||
completion_tokens=existing.completion_tokens
|
||||
+ new_usage.completion_tokens,
|
||||
total_tokens=existing.total_tokens + new_usage.total_tokens,
|
||||
cache_read_input_tokens=(existing.cache_read_input_tokens or 0)
|
||||
+ (new_usage.cache_read_input_tokens or 0),
|
||||
cache_creation_input_tokens=(existing.cache_creation_input_tokens or 0)
|
||||
+ (new_usage.cache_creation_input_tokens or 0),
|
||||
)
|
||||
self._llm_usage_metrics[key] = aggregated
|
||||
else:
|
||||
# First occurrence for this processor+model combination
|
||||
self._llm_usage_metrics[key] = LLMTokenUsage(
|
||||
prompt_tokens=new_usage.prompt_tokens,
|
||||
completion_tokens=new_usage.completion_tokens,
|
||||
total_tokens=new_usage.total_tokens,
|
||||
cache_read_input_tokens=new_usage.cache_read_input_tokens,
|
||||
cache_creation_input_tokens=new_usage.cache_creation_input_tokens,
|
||||
)
|
||||
|
||||
logger.debug(f"LLM usage metrics: {self._llm_usage_metrics}")
|
||||
|
||||
async def _handle_tts_usage_metrics(self, data: TTSUsageMetricsData):
|
||||
key = f"{data.processor}|||{data.model}"
|
||||
self._tts_usage_metrics[key] += data.value
|
||||
# logger.debug(f"TTS usage metrics: {self._tts_usage_metrics}")
|
||||
|
||||
async def _handle_stt_usage_metrics(self, data: STTUsageMetricsData):
|
||||
key = f"{data.processor}|||{data.model}"
|
||||
self._stt_usage_metrics[key] += data.value
|
||||
logger.debug(f"STT usage metrics: {self._stt_usage_metrics}")
|
||||
|
||||
def get_llm_usage_metrics(self) -> Dict[str, LLMTokenUsage]:
|
||||
"""Get the aggregated LLM usage metrics grouped by processor|||model."""
|
||||
return self._llm_usage_metrics
|
||||
|
||||
def get_tts_usage_metrics(self) -> Dict[str, int]:
|
||||
"""Get the aggregated TTS usage metrics grouped by processor|||model."""
|
||||
return self._tts_usage_metrics
|
||||
|
||||
def get_stt_usage_metrics(self) -> Dict[str, float]:
|
||||
"""Get the aggregated STT usage metrics grouped by processor|||model."""
|
||||
return self._stt_usage_metrics
|
||||
|
||||
def get_call_duration(self) -> float:
|
||||
"""Get call duration"""
|
||||
if self._start_time is None:
|
||||
return 0.0
|
||||
|
||||
if self._stop_time is None:
|
||||
call_duration = time.time() - self._start_time
|
||||
else:
|
||||
call_duration = self._stop_time - self._start_time
|
||||
|
||||
# Lets return a rounded integer
|
||||
return int(round(call_duration))
|
||||
|
||||
def get_all_usage_metrics_serialized(self) -> Dict[str, Dict[str, any]]:
|
||||
"""Get all aggregated usage metrics in JSON-serializable format."""
|
||||
serialized_llm = {}
|
||||
for key, usage in self._llm_usage_metrics.items():
|
||||
serialized_llm[key] = {
|
||||
"prompt_tokens": usage.prompt_tokens,
|
||||
"completion_tokens": usage.completion_tokens,
|
||||
"total_tokens": usage.total_tokens,
|
||||
"cache_read_input_tokens": usage.cache_read_input_tokens,
|
||||
"cache_creation_input_tokens": usage.cache_creation_input_tokens,
|
||||
}
|
||||
|
||||
return {
|
||||
"llm": serialized_llm,
|
||||
"tts": dict(self._tts_usage_metrics),
|
||||
"stt": dict(self._stt_usage_metrics),
|
||||
"call_duration_seconds": self.get_call_duration(),
|
||||
}
|
||||
|
||||
def reset_metrics(self):
|
||||
"""Reset all aggregated metrics."""
|
||||
self._llm_usage_metrics.clear()
|
||||
self._tts_usage_metrics.clear()
|
||||
self._stt_usage_metrics.clear()
|
||||
self._start_time = None
|
||||
self._stop_time = None
|
||||
388
api/services/pipecat/run_pipeline.py
Normal file
388
api/services/pipecat/run_pipeline.py
Normal file
|
|
@ -0,0 +1,388 @@
|
|||
from typing import Optional
|
||||
|
||||
from fastapi import HTTPException, WebSocket
|
||||
from loguru import logger
|
||||
|
||||
from api.db import db_client
|
||||
from api.enums import WorkflowRunMode
|
||||
from api.services.pipecat.audio_config import AudioConfig, create_audio_config
|
||||
from api.services.pipecat.engine_pre_aggregator_processor import (
|
||||
EnginePreAggregatorProcessor,
|
||||
)
|
||||
from api.services.pipecat.event_handlers import (
|
||||
register_audio_data_handler,
|
||||
register_transcript_handler,
|
||||
register_transport_event_handlers,
|
||||
)
|
||||
from api.services.pipecat.pipeline_builder import (
|
||||
build_pipeline,
|
||||
create_pipeline_components,
|
||||
create_pipeline_task,
|
||||
)
|
||||
from api.services.pipecat.pipeline_engine_callbacks_processor import (
|
||||
PipelineEngineCallbacksProcessor,
|
||||
)
|
||||
from api.services.pipecat.pipeline_metrics_aggregator import PipelineMetricsAggregator
|
||||
from api.services.pipecat.service_factory import (
|
||||
create_llm_service,
|
||||
create_stt_service,
|
||||
create_tts_service,
|
||||
)
|
||||
from api.services.pipecat.tracing_config import setup_pipeline_tracing
|
||||
from api.services.pipecat.transport_setup import (
|
||||
create_stasis_transport,
|
||||
create_twilio_transport,
|
||||
create_webrtc_transport,
|
||||
)
|
||||
from api.services.telephony.stasis_rtp_connection import StasisRTPConnection
|
||||
from api.services.workflow.dto import ReactFlowDTO
|
||||
from api.services.workflow.pipecat_engine import PipecatEngine
|
||||
from api.services.workflow.workflow import WorkflowGraph
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.processors.aggregators.llm_response import LLMAssistantAggregatorParams
|
||||
from pipecat.processors.filters.stt_mute_filter import (
|
||||
STTMuteConfig,
|
||||
STTMuteFilter,
|
||||
STTMuteStrategy,
|
||||
)
|
||||
from pipecat.processors.user_idle_processor import UserIdleProcessor
|
||||
from pipecat.transports.network.webrtc_connection import SmallWebRTCConnection
|
||||
from pipecat.utils.context import set_current_run_id
|
||||
from pipecat.utils.tracing.context_registry import ContextProviderRegistry
|
||||
|
||||
# Setup tracing if enabled
|
||||
setup_pipeline_tracing()
|
||||
|
||||
|
||||
async def run_pipeline_twilio(
|
||||
websocket_client: WebSocket,
|
||||
stream_sid: str,
|
||||
call_sid: str,
|
||||
workflow_id: int,
|
||||
workflow_run_id: int,
|
||||
user_id: int,
|
||||
) -> None:
|
||||
"""Run pipeline for Twilio connections"""
|
||||
logger.debug(
|
||||
f"Running pipeline for Twilio connection with workflow_id: {workflow_id} and workflow_run_id: {workflow_run_id}"
|
||||
)
|
||||
set_current_run_id(workflow_run_id)
|
||||
|
||||
# Store Twilio call SID in cost_info for later cost calculation
|
||||
cost_info = {"twilio_call_sid": call_sid}
|
||||
await db_client.update_workflow_run(workflow_run_id, cost_info=cost_info)
|
||||
|
||||
# Get workflow to extract all pipeline configurations
|
||||
workflow = await db_client.get_workflow(workflow_id, user_id)
|
||||
vad_config = None
|
||||
ambient_noise_config = None
|
||||
if workflow and workflow.workflow_configurations:
|
||||
if "vad_configuration" in workflow.workflow_configurations:
|
||||
vad_config = workflow.workflow_configurations["vad_configuration"]
|
||||
if "ambient_noise_configuration" in workflow.workflow_configurations:
|
||||
ambient_noise_config = workflow.workflow_configurations[
|
||||
"ambient_noise_configuration"
|
||||
]
|
||||
|
||||
# Create audio configuration for Twilio
|
||||
audio_config = create_audio_config(WorkflowRunMode.TWILIO.value)
|
||||
|
||||
transport = create_twilio_transport(
|
||||
websocket_client,
|
||||
stream_sid,
|
||||
call_sid,
|
||||
workflow_run_id,
|
||||
audio_config,
|
||||
vad_config,
|
||||
ambient_noise_config,
|
||||
)
|
||||
await _run_pipeline(
|
||||
transport,
|
||||
workflow_id,
|
||||
workflow_run_id,
|
||||
user_id,
|
||||
audio_config=audio_config,
|
||||
)
|
||||
|
||||
|
||||
async def run_pipeline_smallwebrtc(
|
||||
webrtc_connection: SmallWebRTCConnection,
|
||||
workflow_id: int,
|
||||
workflow_run_id: int,
|
||||
user_id: int,
|
||||
call_context_vars: dict = {},
|
||||
) -> None:
|
||||
"""Run pipeline for WebRTC connections"""
|
||||
logger.debug(
|
||||
f"Running pipeline for WebRTC connection with workflow_id: {workflow_id} and workflow_run_id: {workflow_run_id}"
|
||||
)
|
||||
set_current_run_id(workflow_run_id)
|
||||
|
||||
# Get workflow to extract all pipeline configurations
|
||||
workflow = await db_client.get_workflow(workflow_id, user_id)
|
||||
vad_config = None
|
||||
ambient_noise_config = None
|
||||
if workflow and workflow.workflow_configurations:
|
||||
if "vad_configuration" in workflow.workflow_configurations:
|
||||
vad_config = workflow.workflow_configurations["vad_configuration"]
|
||||
if "ambient_noise_configuration" in workflow.workflow_configurations:
|
||||
ambient_noise_config = workflow.workflow_configurations[
|
||||
"ambient_noise_configuration"
|
||||
]
|
||||
|
||||
# Create audio configuration for WebRTC
|
||||
audio_config = create_audio_config(WorkflowRunMode.SMALLWEBRTC.value)
|
||||
|
||||
transport = create_webrtc_transport(
|
||||
webrtc_connection,
|
||||
workflow_run_id,
|
||||
audio_config,
|
||||
vad_config,
|
||||
ambient_noise_config,
|
||||
)
|
||||
await _run_pipeline(
|
||||
transport,
|
||||
workflow_id,
|
||||
workflow_run_id,
|
||||
user_id,
|
||||
call_context_vars=call_context_vars,
|
||||
audio_config=audio_config,
|
||||
)
|
||||
|
||||
|
||||
async def run_pipeline_ari_stasis(
|
||||
stasis_connection: StasisRTPConnection,
|
||||
workflow_id: int,
|
||||
workflow_run_id: int,
|
||||
user_id: int,
|
||||
call_context_vars: dict,
|
||||
) -> None:
|
||||
"""Run pipeline for ARI connections"""
|
||||
logger.debug(
|
||||
f"Running pipeline for ARI connection with workflow_id: {workflow_id} and workflow_run_id: {workflow_run_id}"
|
||||
)
|
||||
set_current_run_id(workflow_run_id)
|
||||
|
||||
# Get workflow to extract all pipeline configurations
|
||||
workflow = await db_client.get_workflow(workflow_id, user_id)
|
||||
vad_config = None
|
||||
ambient_noise_config = None
|
||||
if workflow and workflow.workflow_configurations:
|
||||
if "vad_configuration" in workflow.workflow_configurations:
|
||||
vad_config = workflow.workflow_configurations["vad_configuration"]
|
||||
if "ambient_noise_configuration" in workflow.workflow_configurations:
|
||||
ambient_noise_config = workflow.workflow_configurations[
|
||||
"ambient_noise_configuration"
|
||||
]
|
||||
|
||||
# Create audio configuration for Stasis
|
||||
audio_config = create_audio_config(WorkflowRunMode.STASIS.value)
|
||||
|
||||
transport = create_stasis_transport(
|
||||
stasis_connection,
|
||||
workflow_run_id,
|
||||
audio_config,
|
||||
vad_config,
|
||||
ambient_noise_config,
|
||||
)
|
||||
await _run_pipeline(
|
||||
transport,
|
||||
workflow_id,
|
||||
workflow_run_id,
|
||||
user_id,
|
||||
call_context_vars=call_context_vars,
|
||||
audio_config=audio_config,
|
||||
stasis_connection=stasis_connection, # Pass connection for immediate transfers
|
||||
)
|
||||
|
||||
|
||||
async def _run_pipeline(
|
||||
transport,
|
||||
workflow_id: int,
|
||||
workflow_run_id: int,
|
||||
user_id: int,
|
||||
call_context_vars: dict = {},
|
||||
audio_config: AudioConfig = None,
|
||||
stasis_connection: Optional[StasisRTPConnection] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Run the pipeline with the given transport and configuration
|
||||
|
||||
Args:
|
||||
transport: The transport to use for the pipeline
|
||||
workflow_id: The ID of the workflow
|
||||
workflow_run_id: The ID of the workflow run
|
||||
user_id: The ID of the user
|
||||
mode: The mode of the pipeline (twilio or smallwebrtc)
|
||||
"""
|
||||
workflow_run = await db_client.get_workflow_run(workflow_run_id, user_id)
|
||||
|
||||
# If the workflow run is already completed, we don't need to run it again
|
||||
if workflow_run.is_completed:
|
||||
raise HTTPException(status_code=400, detail="Workflow run already completed")
|
||||
|
||||
merged_call_context_vars = workflow_run.initial_context
|
||||
# If there is some extra call_context_vars, update them
|
||||
if call_context_vars:
|
||||
merged_call_context_vars = {**merged_call_context_vars, **call_context_vars}
|
||||
await db_client.update_workflow_run(
|
||||
workflow_run_id, initial_context=merged_call_context_vars
|
||||
)
|
||||
|
||||
# Get user configuration
|
||||
user_config = await db_client.get_user_configurations(user_id)
|
||||
|
||||
# Create services based on user configuration
|
||||
stt = create_stt_service(user_config)
|
||||
tts = create_tts_service(user_config, audio_config)
|
||||
llm = create_llm_service(user_config)
|
||||
|
||||
# Get workflow first so we can create engine before pipeline components
|
||||
workflow = await db_client.get_workflow(workflow_id, user_id)
|
||||
if not workflow:
|
||||
raise HTTPException(status_code=404, detail="Workflow not found")
|
||||
|
||||
# Extract configurations from workflow configurations
|
||||
max_call_duration_seconds = 300 # Default 5 minutes
|
||||
max_user_idle_timeout = 10.0 # Default 10 seconds
|
||||
|
||||
if workflow.workflow_configurations:
|
||||
# Use workflow-specific max call duration if provided
|
||||
if "max_call_duration" in workflow.workflow_configurations:
|
||||
max_call_duration_seconds = workflow.workflow_configurations[
|
||||
"max_call_duration"
|
||||
]
|
||||
|
||||
# Use workflow-specific max user idle timeout if provided
|
||||
if "max_user_idle_timeout" in workflow.workflow_configurations:
|
||||
max_user_idle_timeout = workflow.workflow_configurations[
|
||||
"max_user_idle_timeout"
|
||||
]
|
||||
|
||||
workflow_graph = WorkflowGraph(
|
||||
ReactFlowDTO.model_validate(workflow.workflow_definition_with_fallback)
|
||||
)
|
||||
|
||||
engine = PipecatEngine(
|
||||
llm=llm,
|
||||
tts=tts,
|
||||
workflow=workflow_graph,
|
||||
call_context_vars=merged_call_context_vars,
|
||||
workflow_run_id=workflow_run_id,
|
||||
)
|
||||
|
||||
# Create pipeline components with audio configuration and engine
|
||||
audio_buffer, audio_synchronizer, transcript, context = create_pipeline_components(
|
||||
audio_config, engine
|
||||
)
|
||||
|
||||
# Set the context and audio_buffer after creation
|
||||
engine.set_context(context)
|
||||
engine.set_audio_buffer(audio_buffer)
|
||||
|
||||
# Set Stasis connection for immediate transfers (if available)
|
||||
if stasis_connection:
|
||||
engine.set_stasis_connection(stasis_connection)
|
||||
|
||||
assistant_params = LLMAssistantAggregatorParams(
|
||||
expect_stripped_words=True,
|
||||
correct_aggregation_callback=engine.create_aggregation_correction_callback(),
|
||||
)
|
||||
context_aggregator = llm.create_context_aggregator(
|
||||
context, assistant_params=assistant_params
|
||||
)
|
||||
|
||||
# Create engine pre-aggregator processor for speaking events
|
||||
engine_pre_aggregator_processor = EnginePreAggregatorProcessor(
|
||||
user_started_speaking_callback=engine.create_user_started_speaking_callback(),
|
||||
user_stopped_speaking_callback=engine.create_user_stopped_speaking_callback(),
|
||||
)
|
||||
|
||||
# Create usage metrics aggregator with engine's callback
|
||||
pipeline_engine_callback_processor = PipelineEngineCallbacksProcessor(
|
||||
max_call_duration_seconds=max_call_duration_seconds,
|
||||
max_duration_end_task_callback=engine.create_max_duration_callback(),
|
||||
llm_generated_text_callback=engine.create_llm_generated_text_callback(),
|
||||
generation_started_callback=engine.create_generation_started_callback(),
|
||||
llm_text_frame_callback=engine.handle_llm_text_frame,
|
||||
# Note: speaking event callbacks are now handled by pre-aggregator processor
|
||||
)
|
||||
|
||||
pipeline_metrics_aggregator = PipelineMetricsAggregator()
|
||||
|
||||
# Create STT mute filter using the selected strategies and the engine's callback
|
||||
stt_mute_filter = STTMuteFilter(
|
||||
config=STTMuteConfig(
|
||||
strategies={
|
||||
STTMuteStrategy.MUTE_UNTIL_FIRST_BOT_COMPLETE,
|
||||
STTMuteStrategy.CUSTOM,
|
||||
},
|
||||
should_mute_callback=engine.create_should_mute_callback(),
|
||||
)
|
||||
)
|
||||
|
||||
# Use engine's user idle callback with configured timeout
|
||||
user_idle_disconnect = UserIdleProcessor(
|
||||
callback=engine.create_user_idle_callback(), timeout=max_user_idle_timeout
|
||||
)
|
||||
|
||||
user_context_aggregator = context_aggregator.user()
|
||||
assistant_context_aggregator = context_aggregator.assistant()
|
||||
|
||||
@assistant_context_aggregator.event_handler("on_push_aggregation")
|
||||
async def on_assistant_aggregator_push_context(_aggregator):
|
||||
logger.debug("Assistant aggregator push context – flushing pending transitions")
|
||||
await engine.flush_pending_transitions(source="context_push")
|
||||
|
||||
# Build the pipeline with the STT mute filter and context controller
|
||||
pipeline = build_pipeline(
|
||||
transport,
|
||||
stt,
|
||||
transcript,
|
||||
audio_buffer,
|
||||
audio_synchronizer,
|
||||
llm,
|
||||
tts,
|
||||
user_context_aggregator,
|
||||
assistant_context_aggregator,
|
||||
pipeline_engine_callback_processor,
|
||||
stt_mute_filter,
|
||||
pipeline_metrics_aggregator,
|
||||
user_idle_disconnect,
|
||||
engine_pre_aggregator_processor=engine_pre_aggregator_processor,
|
||||
)
|
||||
|
||||
# Create pipeline task with audio configuration
|
||||
task = create_pipeline_task(pipeline, workflow_run_id, audio_config)
|
||||
|
||||
# Now set the task on the engine
|
||||
engine.set_task(task)
|
||||
|
||||
# Register event handlers
|
||||
in_memory_audio_buffer, in_memory_transcript_buffer = (
|
||||
register_transport_event_handlers(
|
||||
transport,
|
||||
workflow_run_id,
|
||||
audio_buffer,
|
||||
task,
|
||||
engine=engine,
|
||||
usage_metrics_aggregator=pipeline_metrics_aggregator,
|
||||
audio_synchronizer=audio_synchronizer,
|
||||
audio_config=audio_config,
|
||||
)
|
||||
)
|
||||
register_audio_data_handler(
|
||||
audio_synchronizer, workflow_run_id, in_memory_audio_buffer
|
||||
)
|
||||
register_transcript_handler(
|
||||
transcript, workflow_run_id, in_memory_transcript_buffer
|
||||
)
|
||||
|
||||
try:
|
||||
# Run the pipeline
|
||||
runner = PipelineRunner()
|
||||
await runner.run(task)
|
||||
logger.info(f"Pipeline runner completed for run {workflow_run_id}")
|
||||
finally:
|
||||
ContextProviderRegistry.remove_providers(str(workflow_run_id))
|
||||
logger.debug(f"Cleaned up context providers for workflow run {workflow_run_id}")
|
||||
150
api/services/pipecat/service_factory.py
Normal file
150
api/services/pipecat/service_factory.py
Normal file
|
|
@ -0,0 +1,150 @@
|
|||
from typing import TYPE_CHECKING
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
from api.constants import MPS_API_URL
|
||||
from api.services.configuration.registry import ServiceProviders
|
||||
from pipecat.services.azure.llm import AzureLLMService
|
||||
from pipecat.services.cartesia.stt import CartesiaSTTService
|
||||
from pipecat.services.deepgram.stt import DeepgramSTTService
|
||||
from pipecat.services.deepgram.tts import DeepgramTTSService
|
||||
from pipecat.services.dograh.llm import DograhLLMService
|
||||
from pipecat.services.dograh.stt import DograhSTTService
|
||||
from pipecat.services.dograh.tts import DograhTTSService
|
||||
from pipecat.services.elevenlabs.tts import ElevenLabsTTSService
|
||||
from pipecat.services.google.llm import GoogleLLMService
|
||||
from pipecat.services.groq.llm import GroqLLMService
|
||||
from pipecat.services.openai.llm import OpenAILLMService
|
||||
from pipecat.services.openai.stt import OpenAISTTService
|
||||
from pipecat.services.openai.tts import OpenAITTSService
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from api.services.pipecat.audio_config import AudioConfig
|
||||
|
||||
|
||||
def create_stt_service(user_config):
|
||||
"""Create and return appropriate STT service based on user configuration"""
|
||||
if user_config.stt.provider == ServiceProviders.DEEPGRAM.value:
|
||||
return DeepgramSTTService(
|
||||
api_key=user_config.stt.api_key,
|
||||
audio_passthrough=False, # Disable passthrough since audio is buffered separately
|
||||
)
|
||||
elif user_config.stt.provider == ServiceProviders.OPENAI.value:
|
||||
return OpenAISTTService(
|
||||
api_key=user_config.stt.api_key,
|
||||
model=user_config.stt.model.value,
|
||||
audio_passthrough=False, # Disable passthrough since audio is buffered separately
|
||||
)
|
||||
elif user_config.stt.provider == ServiceProviders.CARTESIA.value:
|
||||
return CartesiaSTTService(
|
||||
api_key=user_config.stt.api_key,
|
||||
audio_passthrough=False, # Disable passthrough since audio is buffered separately
|
||||
)
|
||||
elif user_config.stt.provider == ServiceProviders.DOGRAH.value:
|
||||
base_url = MPS_API_URL.replace("http://", "ws://").replace("https://", "wss://")
|
||||
return DograhSTTService(
|
||||
base_url=base_url,
|
||||
api_key=user_config.stt.api_key,
|
||||
model=user_config.stt.model.value,
|
||||
audio_passthrough=False, # Disable passthrough since audio is buffered separately
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=400, detail=f"Invalid STT provider {user_config.stt.provider}"
|
||||
)
|
||||
|
||||
|
||||
def create_tts_service(user_config, audio_config: "AudioConfig"):
|
||||
"""Create and return appropriate TTS service based on user configuration
|
||||
|
||||
Args:
|
||||
user_config: User configuration containing TTS settings
|
||||
transport_type: Type of transport (e.g., 'stasis', 'twilio', 'webrtc')
|
||||
"""
|
||||
if user_config.tts.provider == ServiceProviders.DEEPGRAM.value:
|
||||
return DeepgramTTSService(
|
||||
api_key=user_config.tts.api_key,
|
||||
voice=user_config.tts.voice.value,
|
||||
sample_rate=24000,
|
||||
)
|
||||
elif user_config.tts.provider == ServiceProviders.OPENAI.value:
|
||||
return OpenAITTSService(
|
||||
api_key=user_config.tts.api_key, model=user_config.tts.model.value
|
||||
)
|
||||
elif user_config.tts.provider == ServiceProviders.ELEVENLABS.value:
|
||||
voice_id = user_config.tts.voice.split(" - ")[1]
|
||||
return ElevenLabsTTSService(
|
||||
reconnect_on_error=False,
|
||||
api_key=user_config.tts.api_key,
|
||||
voice_id=voice_id,
|
||||
model=user_config.tts.model.value,
|
||||
params=ElevenLabsTTSService.InputParams(
|
||||
stability=0.8, speed=user_config.tts.speed, similarity_boost=0.75
|
||||
),
|
||||
)
|
||||
elif user_config.tts.provider == ServiceProviders.DOGRAH.value:
|
||||
# Convert HTTP URL to WebSocket URL for TTS
|
||||
base_url = MPS_API_URL.replace("http://", "ws://").replace("https://", "wss://")
|
||||
# Handle both enum and string values for model and voice
|
||||
return DograhTTSService(
|
||||
base_url=base_url,
|
||||
api_key=user_config.tts.api_key,
|
||||
model=user_config.tts.model.value,
|
||||
voice=user_config.tts.voice.value,
|
||||
sample_rate=24000,
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=400, detail=f"Invalid TTS provider {user_config.tts.provider}"
|
||||
)
|
||||
|
||||
|
||||
def create_llm_service(user_config):
|
||||
"""Create and return appropriate LLM service based on user configuration"""
|
||||
if user_config.llm.provider == ServiceProviders.OPENAI.value:
|
||||
if "gpt-5" in user_config.llm.model.value:
|
||||
return OpenAILLMService(
|
||||
api_key=user_config.llm.api_key,
|
||||
model=user_config.llm.model.value,
|
||||
params=OpenAILLMService.InputParams(
|
||||
reasoning_effort="minimal", verbosity="low"
|
||||
),
|
||||
)
|
||||
else:
|
||||
return OpenAILLMService(
|
||||
api_key=user_config.llm.api_key,
|
||||
model=user_config.llm.model.value,
|
||||
params=OpenAILLMService.InputParams(temperature=0.1),
|
||||
)
|
||||
elif user_config.llm.provider == ServiceProviders.GROQ.value:
|
||||
print(
|
||||
f"Creating Groq LLM service with API key: {user_config.llm.api_key} and model: {user_config.llm.model.value}"
|
||||
)
|
||||
return GroqLLMService(
|
||||
api_key=user_config.llm.api_key,
|
||||
model=user_config.llm.model.value,
|
||||
params=OpenAILLMService.InputParams(temperature=0.1),
|
||||
)
|
||||
elif user_config.llm.provider == ServiceProviders.GOOGLE.value:
|
||||
# Use the correct InputParams class for Google to avoid propagating OpenAI-specific
|
||||
# NOT_GIVEN sentinels that break Pydantic validation in GoogleLLMService.
|
||||
return GoogleLLMService(
|
||||
api_key=user_config.llm.api_key,
|
||||
model=user_config.llm.model.value,
|
||||
params=GoogleLLMService.InputParams(temperature=0.1),
|
||||
)
|
||||
elif user_config.llm.provider == ServiceProviders.AZURE.value:
|
||||
return AzureLLMService(
|
||||
api_key=user_config.llm.api_key,
|
||||
endpoint=user_config.llm.endpoint,
|
||||
model=user_config.llm.model.value, # Azure uses deployment name as model
|
||||
params=AzureLLMService.InputParams(temperature=0.1),
|
||||
)
|
||||
elif user_config.llm.provider == ServiceProviders.DOGRAH.value:
|
||||
return DograhLLMService(
|
||||
base_url=f"{MPS_API_URL}/api/v1/llm",
|
||||
api_key=user_config.llm.api_key,
|
||||
model=user_config.llm.model.value,
|
||||
)
|
||||
else:
|
||||
raise HTTPException(status_code=400, detail="Invalid LLM provider")
|
||||
44
api/services/pipecat/tracing_config.py
Normal file
44
api/services/pipecat/tracing_config.py
Normal file
|
|
@ -0,0 +1,44 @@
|
|||
import base64
|
||||
import os
|
||||
|
||||
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
|
||||
|
||||
from api.constants import ENABLE_TRACING
|
||||
from pipecat.utils.tracing.setup import setup_tracing
|
||||
|
||||
|
||||
def is_tracing_enabled():
|
||||
"""Check if tracing should be enabled based on ENABLE_TRACING flag."""
|
||||
# Tracing is only enabled when ENABLE_TRACING is explicitly set to true
|
||||
# This makes the system OSS-friendly by default (no external dependencies required)
|
||||
return ENABLE_TRACING
|
||||
|
||||
|
||||
def setup_pipeline_tracing():
|
||||
"""Setup tracing for the pipeline if enabled"""
|
||||
if is_tracing_enabled():
|
||||
# Only set up Langfuse if credentials are provided
|
||||
langfuse_host = os.environ.get("LANGFUSE_HOST")
|
||||
langfuse_public_key = os.environ.get("LANGFUSE_PUBLIC_KEY")
|
||||
langfuse_secret_key = os.environ.get("LANGFUSE_SECRET_KEY")
|
||||
|
||||
if not all([langfuse_host, langfuse_public_key, langfuse_secret_key]):
|
||||
print(
|
||||
"Warning: ENABLE_TRACING is true but Langfuse credentials are not configured. Tracing disabled."
|
||||
)
|
||||
return
|
||||
|
||||
LANGFUSE_AUTH = base64.b64encode(
|
||||
f"{langfuse_public_key}:{langfuse_secret_key}".encode()
|
||||
).decode()
|
||||
|
||||
os.environ["OTEL_EXPORTER_OTLP_ENDPOINT"] = f"{langfuse_host}/api/public/otel"
|
||||
os.environ["OTEL_EXPORTER_OTLP_HEADERS"] = (
|
||||
f"Authorization=Basic {LANGFUSE_AUTH}"
|
||||
)
|
||||
|
||||
otlp_exporter = OTLPSpanExporter()
|
||||
setup_tracing(service_name="dograh-pipeline", exporter=otlp_exporter)
|
||||
print("Langfuse tracing enabled")
|
||||
else:
|
||||
print("Tracing disabled (ENABLE_TRACING=false)")
|
||||
299
api/services/pipecat/transport_setup.py
Normal file
299
api/services/pipecat/transport_setup.py
Normal file
|
|
@ -0,0 +1,299 @@
|
|||
import os
|
||||
|
||||
from fastapi import WebSocket
|
||||
|
||||
from api.constants import APP_ROOT_DIR, ENABLE_RNNOISE, ENABLE_SMART_TURN
|
||||
from api.services.pipecat.audio_config import AudioConfig
|
||||
from api.services.smart_turn.websocket_smart_turn import (
|
||||
WebSocketSmartTurnAnalyzer,
|
||||
)
|
||||
from api.services.telephony.stasis_rtp_connection import StasisRTPConnection
|
||||
from api.services.telephony.stasis_rtp_serializer import StasisRTPFrameSerializer
|
||||
from api.services.telephony.stasis_rtp_transport import (
|
||||
StasisRTPTransport,
|
||||
StasisRTPTransportParams,
|
||||
)
|
||||
from pipecat.audio.filters.rnnoise_filter import RNNoiseFilter
|
||||
from pipecat.audio.mixers.silence_audio_mixer import SilenceAudioMixer
|
||||
from pipecat.audio.mixers.soundfile_mixer import SoundfileMixer
|
||||
from pipecat.audio.turn.smart_turn.base_smart_turn import SmartTurnParams
|
||||
from pipecat.audio.vad.silero import SileroVADAnalyzer, VADParams
|
||||
from pipecat.serializers.twilio import TwilioFrameSerializer
|
||||
from pipecat.transports import InternalTransport
|
||||
from pipecat.transports.base_transport import TransportParams
|
||||
from pipecat.transports.network.fastapi_websocket import (
|
||||
FastAPIWebsocketParams,
|
||||
FastAPIWebsocketTransport,
|
||||
)
|
||||
from pipecat.transports.network.small_webrtc import SmallWebRTCTransport
|
||||
from pipecat.transports.network.webrtc_connection import SmallWebRTCConnection
|
||||
|
||||
librnnoise_path = os.path.normpath(
|
||||
str(APP_ROOT_DIR / "native" / "rnnoise" / "librnnoise.so")
|
||||
)
|
||||
|
||||
|
||||
def create_turn_analyzer(workflow_run_id: int, audio_config: AudioConfig):
|
||||
"""Create a turn analyzer backed by the local Smart Turn HTTP service.
|
||||
|
||||
Args:
|
||||
workflow_run_id: ID of the workflow run for turn analyzer context
|
||||
audio_config: Audio configuration containing pipeline sample rate
|
||||
"""
|
||||
if ENABLE_SMART_TURN:
|
||||
service_url = os.getenv(
|
||||
"SMART_TURN_WS_SERVICE_ENDPOINT", "ws://localhost:8010/ws"
|
||||
)
|
||||
|
||||
# Prepare optional authentication headers for Smart Turn service
|
||||
secret_key = os.getenv("SMART_TURN_HTTP_SERVICE_KEY")
|
||||
headers = {"X-API-Key": secret_key} if secret_key else None
|
||||
|
||||
return WebSocketSmartTurnAnalyzer(
|
||||
url=service_url,
|
||||
headers=headers,
|
||||
sample_rate=audio_config.pipeline_sample_rate,
|
||||
params=SmartTurnParams(
|
||||
stop_secs=1.5, # send turn complete if silent for stop_secs seconds
|
||||
pre_speech_ms=0, # send speech segments before speech was detected by VAD
|
||||
max_duration_secs=5, # max duration of speech to be sent to the end of turn analyzer
|
||||
# we don't want to _clear except when we have end of turn prediction as 1 from last run
|
||||
# else if we have speaking -> queit -> trigger end of turn -> clear() and then
|
||||
# we have speak -> queit, we may end up sending a very small segment of speech
|
||||
# to end of turn model, which is not good
|
||||
use_only_last_vad_segment=False,
|
||||
),
|
||||
service_context=workflow_run_id,
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def create_twilio_transport(
|
||||
websocket_client: WebSocket,
|
||||
stream_sid: str,
|
||||
call_sid: str,
|
||||
workflow_run_id: int,
|
||||
audio_config: AudioConfig,
|
||||
vad_config: dict | None = None,
|
||||
ambient_noise_config: dict | None = None,
|
||||
):
|
||||
"""Create a transport for Twilio connections"""
|
||||
turn_analyzer = create_turn_analyzer(workflow_run_id, audio_config)
|
||||
|
||||
serializer = TwilioFrameSerializer(
|
||||
stream_sid=stream_sid,
|
||||
call_sid=call_sid,
|
||||
account_sid=os.environ["TWILIO_ACCOUNT_SID"],
|
||||
auth_token=os.environ["TWILIO_AUTH_TOKEN"],
|
||||
)
|
||||
|
||||
return FastAPIWebsocketTransport(
|
||||
websocket=websocket_client,
|
||||
params=FastAPIWebsocketParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
audio_in_sample_rate=audio_config.transport_in_sample_rate,
|
||||
audio_out_sample_rate=audio_config.transport_out_sample_rate,
|
||||
vad_analyzer=(
|
||||
SileroVADAnalyzer(
|
||||
params=VADParams(
|
||||
confidence=vad_config.get("confidence", 0.7),
|
||||
start_secs=vad_config.get("start_seconds", 0.4),
|
||||
stop_secs=vad_config.get("stop_seconds", 0.8),
|
||||
min_volume=vad_config.get("minimum_volume", 0.6),
|
||||
)
|
||||
)
|
||||
if vad_config
|
||||
else SileroVADAnalyzer()
|
||||
), # Sample rate will be set by transport
|
||||
audio_out_mixer=(
|
||||
SoundfileMixer(
|
||||
sound_files={
|
||||
"office": APP_ROOT_DIR
|
||||
/ "assets"
|
||||
/ f"office-ambience-{audio_config.transport_out_sample_rate}-mono.wav"
|
||||
},
|
||||
default_sound="office",
|
||||
volume=ambient_noise_config.get("volume", 0.3),
|
||||
)
|
||||
if ambient_noise_config and ambient_noise_config.get("enabled", False)
|
||||
else SilenceAudioMixer()
|
||||
),
|
||||
turn_analyzer=turn_analyzer,
|
||||
serializer=serializer,
|
||||
audio_in_filter=RNNoiseFilter(library_path=librnnoise_path)
|
||||
if ENABLE_RNNOISE
|
||||
else None,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def create_webrtc_transport(
|
||||
webrtc_connection: SmallWebRTCConnection,
|
||||
workflow_run_id: int,
|
||||
audio_config: AudioConfig,
|
||||
vad_config: dict | None = None,
|
||||
ambient_noise_config: dict | None = None,
|
||||
):
|
||||
"""Create a transport for WebRTC connections"""
|
||||
turn_analyzer = create_turn_analyzer(workflow_run_id, audio_config)
|
||||
|
||||
return SmallWebRTCTransport(
|
||||
webrtc_connection=webrtc_connection,
|
||||
params=TransportParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
audio_in_sample_rate=audio_config.transport_in_sample_rate,
|
||||
audio_out_sample_rate=audio_config.transport_out_sample_rate,
|
||||
vad_analyzer=(
|
||||
SileroVADAnalyzer(
|
||||
params=VADParams(
|
||||
confidence=vad_config.get("confidence", 0.7),
|
||||
start_secs=vad_config.get("start_seconds", 0.4),
|
||||
stop_secs=vad_config.get("stop_seconds", 0.8),
|
||||
min_volume=vad_config.get("minimum_volume", 0.6),
|
||||
)
|
||||
)
|
||||
if vad_config
|
||||
else SileroVADAnalyzer()
|
||||
), # Sample rate will be set by transport
|
||||
audio_out_mixer=(
|
||||
SoundfileMixer(
|
||||
sound_files={
|
||||
"office": APP_ROOT_DIR
|
||||
/ "assets"
|
||||
/ f"office-ambience-{audio_config.transport_out_sample_rate}-mono.wav"
|
||||
},
|
||||
default_sound="office",
|
||||
volume=ambient_noise_config.get("volume", 0.3),
|
||||
)
|
||||
if ambient_noise_config and ambient_noise_config.get("enabled", False)
|
||||
else SilenceAudioMixer()
|
||||
),
|
||||
turn_analyzer=turn_analyzer,
|
||||
audio_in_filter=RNNoiseFilter(library_path=librnnoise_path)
|
||||
if ENABLE_RNNOISE
|
||||
else None,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def create_stasis_transport(
|
||||
stasis_connection: StasisRTPConnection,
|
||||
workflow_run_id: int,
|
||||
audio_config: AudioConfig,
|
||||
vad_config: dict | None = None,
|
||||
ambient_noise_config: dict | None = None,
|
||||
):
|
||||
"""Create a transport for ARI connections"""
|
||||
turn_analyzer = create_turn_analyzer(workflow_run_id, audio_config)
|
||||
|
||||
serializer = StasisRTPFrameSerializer(
|
||||
StasisRTPFrameSerializer.InputParams(
|
||||
sample_rate=audio_config.transport_in_sample_rate
|
||||
)
|
||||
)
|
||||
|
||||
return StasisRTPTransport(
|
||||
stasis_connection,
|
||||
params=StasisRTPTransportParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
audio_out_sample_rate=audio_config.transport_out_sample_rate,
|
||||
audio_in_sample_rate=audio_config.transport_in_sample_rate,
|
||||
audio_out_10ms_chunks=2, # Send 20ms packets
|
||||
vad_analyzer=(
|
||||
SileroVADAnalyzer(
|
||||
params=VADParams(
|
||||
confidence=vad_config.get("confidence", 0.7),
|
||||
start_secs=vad_config.get("start_seconds", 0.4),
|
||||
stop_secs=vad_config.get("stop_seconds", 0.8),
|
||||
min_volume=vad_config.get("minimum_volume", 0.6),
|
||||
)
|
||||
)
|
||||
if vad_config
|
||||
else SileroVADAnalyzer()
|
||||
), # Sample rate will be set by transport
|
||||
audio_out_mixer=(
|
||||
SoundfileMixer(
|
||||
sound_files={
|
||||
"office": APP_ROOT_DIR
|
||||
/ "assets"
|
||||
/ f"office-ambience-{audio_config.transport_out_sample_rate}-mono.wav"
|
||||
},
|
||||
default_sound="office",
|
||||
volume=ambient_noise_config.get("volume", 0.3),
|
||||
)
|
||||
if ambient_noise_config and ambient_noise_config.get("enabled", False)
|
||||
else SilenceAudioMixer()
|
||||
),
|
||||
turn_analyzer=turn_analyzer,
|
||||
serializer=serializer,
|
||||
audio_in_filter=RNNoiseFilter(library_path=librnnoise_path)
|
||||
if ENABLE_RNNOISE
|
||||
else None,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def create_internal_transport(
|
||||
workflow_run_id: int,
|
||||
audio_config: AudioConfig,
|
||||
latency_seconds: float = 0.0,
|
||||
vad_config: dict | None = None,
|
||||
ambient_noise_config: dict | None = None,
|
||||
):
|
||||
"""Create an internal transport for agent-to-agent connections (LoopTalk).
|
||||
|
||||
Args:
|
||||
workflow_run_id: ID of the workflow run for turn analyzer context
|
||||
audio_config: Audio configuration for the transport
|
||||
latency_seconds: Network latency to simulate
|
||||
|
||||
Returns:
|
||||
InternalTransport instance configured with turn analyzer
|
||||
"""
|
||||
turn_analyzer = create_turn_analyzer(workflow_run_id, audio_config)
|
||||
|
||||
# Create and return the internal transport with latency
|
||||
return InternalTransport(
|
||||
params=TransportParams(
|
||||
audio_out_enabled=True,
|
||||
audio_out_sample_rate=audio_config.transport_out_sample_rate,
|
||||
audio_out_channels=1,
|
||||
audio_in_enabled=True,
|
||||
audio_in_sample_rate=audio_config.transport_in_sample_rate,
|
||||
audio_in_channels=1,
|
||||
vad_analyzer=(
|
||||
SileroVADAnalyzer(
|
||||
params=VADParams(
|
||||
confidence=vad_config.get("confidence", 0.7),
|
||||
start_secs=vad_config.get("start_seconds", 0.4),
|
||||
stop_secs=vad_config.get("stop_seconds", 0.8),
|
||||
min_volume=vad_config.get("minimum_volume", 0.6),
|
||||
)
|
||||
)
|
||||
if vad_config
|
||||
else SileroVADAnalyzer()
|
||||
),
|
||||
audio_out_mixer=(
|
||||
SoundfileMixer(
|
||||
sound_files={
|
||||
"office": APP_ROOT_DIR
|
||||
/ "assets"
|
||||
/ f"office-ambience-{audio_config.transport_out_sample_rate}-mono.wav"
|
||||
},
|
||||
default_sound="office",
|
||||
volume=ambient_noise_config.get("volume", 0.3),
|
||||
)
|
||||
if ambient_noise_config and ambient_noise_config.get("enabled", False)
|
||||
else SilenceAudioMixer()
|
||||
),
|
||||
turn_analyzer=turn_analyzer,
|
||||
audio_in_filter=RNNoiseFilter(library_path=librnnoise_path)
|
||||
if ENABLE_RNNOISE
|
||||
else None,
|
||||
),
|
||||
latency_seconds=latency_seconds,
|
||||
)
|
||||
76
api/services/pipecat/turn_context.py
Normal file
76
api/services/pipecat/turn_context.py
Normal file
|
|
@ -0,0 +1,76 @@
|
|||
"""Turn context management for logging across async boundaries.
|
||||
|
||||
This module provides a mechanism to track turn numbers across different
|
||||
async contexts, working around the limitation that contextvars don't
|
||||
propagate through asyncio.create_task() calls.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import Dict, Optional
|
||||
|
||||
from pipecat.utils.context import turn_var
|
||||
|
||||
|
||||
class TurnContextManager:
|
||||
"""Manages turn context across async task boundaries.
|
||||
|
||||
This class provides a workaround for the issue where contextvars
|
||||
don't propagate through asyncio.create_task() calls in the pipecat
|
||||
library's event system.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
# Map from task to turn number
|
||||
self._task_turns: Dict[asyncio.Task, int] = {}
|
||||
# Store the pipeline task reference
|
||||
self._pipeline_task: Optional[asyncio.Task] = None
|
||||
self._current_turn: int = 0
|
||||
|
||||
def set_pipeline_task(self, task: asyncio.Task):
|
||||
"""Set the main pipeline task reference."""
|
||||
self._pipeline_task = task
|
||||
|
||||
def set_turn(self, turn_number: int):
|
||||
"""Set the turn number for the current context."""
|
||||
self._current_turn = turn_number
|
||||
# Set in contextvar for direct access
|
||||
turn_var.set(turn_number)
|
||||
|
||||
# Also store for the current task
|
||||
try:
|
||||
current_task = asyncio.current_task()
|
||||
if current_task:
|
||||
self._task_turns[current_task] = turn_number
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
||||
def get_turn(self) -> int:
|
||||
"""Get the turn number, trying multiple sources."""
|
||||
# First try contextvar
|
||||
turn = turn_var.get()
|
||||
if turn > 0:
|
||||
return turn
|
||||
|
||||
# Try current task mapping
|
||||
try:
|
||||
current_task = asyncio.current_task()
|
||||
if current_task and current_task in self._task_turns:
|
||||
return self._task_turns[current_task]
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
||||
# Fall back to stored current turn
|
||||
return self._current_turn
|
||||
|
||||
def cleanup_task(self, task: asyncio.Task):
|
||||
"""Clean up turn mapping for completed tasks."""
|
||||
self._task_turns.pop(task, None)
|
||||
|
||||
|
||||
# Global instance
|
||||
_turn_context_manager = TurnContextManager()
|
||||
|
||||
|
||||
def get_turn_context_manager() -> TurnContextManager:
|
||||
"""Get the global turn context manager instance."""
|
||||
return _turn_context_manager
|
||||
Loading…
Add table
Add a link
Reference in a new issue