feat: add recording audio option in tool and node transitions (#232)

* feat: allow uploading recording as part of node transition

* feat: allow recordings in tool transitions

* chore: fix tests
This commit is contained in:
Abhishek 2026-04-10 17:53:42 +05:30 committed by GitHub
parent 3f19a16e7f
commit 7c245051d2
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
54 changed files with 3575 additions and 640 deletions

View file

@ -0,0 +1,188 @@
"""Utilities for playing audio through the pipeline transport.
Provides one-shot and looping playback of raw PCM audio. All playback
should be routed through ``transport.output().queue_frame`` so the audio
reaches the caller without passing through STT (which would otherwise
generate phantom transcriptions).
"""
import asyncio
import uuid
from typing import Awaitable, Callable, Dict, Optional, Tuple
import numpy as np
from loguru import logger
from pipecat.frames.frames import (
Frame,
OutputAudioRawFrame,
TTSAudioRawFrame,
TTSStartedFrame,
TTSStoppedFrame,
TTSTextFrame,
)
try:
import soundfile as sf
except ModuleNotFoundError as e:
logger.error(f"Exception: {e}")
logger.error("In order to use audio playback, you need to `pip install soundfile`.")
raise Exception(f"Missing module: {e}")
# ---------------------------------------------------------------------------
# Audio file loading / caching
# ---------------------------------------------------------------------------
_audio_cache: Dict[Tuple[str, int], bytes] = {}
def load_audio_file(file_path: str, sample_rate: int) -> Optional[bytes]:
"""Load an audio file as PCM-16 bytes, caching the result.
Args:
file_path: Path to a WAV audio file.
sample_rate: Target sample rate (used as cache key; no resampling
is performed here).
Returns:
Raw PCM-16 bytes, or *None* on failure.
"""
cache_key = (file_path, sample_rate)
if cache_key in _audio_cache:
logger.debug(f"Using cached audio for {file_path} at {sample_rate}Hz")
return _audio_cache[cache_key]
try:
logger.info(f"Loading audio from {file_path} at {sample_rate}Hz")
sound, file_sample_rate = sf.read(file_path, dtype="int16")
logger.info(
f"Audio file loaded - file sample_rate: {file_sample_rate}, target: {sample_rate}"
)
# Ensure mono (take first channel if stereo)
if len(sound.shape) > 1:
sound = sound[:, 0]
if file_sample_rate != sample_rate:
logger.warning(
f"Audio file has sample rate {file_sample_rate}, expected {sample_rate}"
)
audio_bytes = sound.astype(np.int16).tobytes()
_audio_cache[cache_key] = audio_bytes
logger.info(f"Audio loaded: {len(sound)} samples at {sample_rate}Hz")
return audio_bytes
except Exception as e:
logger.error(f"Failed to load audio file {file_path}: {e}")
return None
def clear_audio_cache() -> None:
"""Clear the audio file cache to free memory."""
_audio_cache.clear()
logger.info("Audio cache cleared")
# ---------------------------------------------------------------------------
# Playback helpers
# ---------------------------------------------------------------------------
async def play_audio(
audio_data: bytes,
*,
sample_rate: int,
queue_frame: Callable[[Frame], Awaitable[None]],
transcript: Optional[str] = None,
append_to_context: bool = False,
) -> None:
"""Play raw PCM-16 audio once.
Pushes ``TTSStarted -> TTSAudioRaw -> TTSStopped`` so downstream
processors (audio buffer, context aggregators) handle the audio
correctly.
When *transcript* is provided a ``TTSTextFrame`` is also pushed so
that observers (e.g. ``RealtimeFeedbackObserver``) can relay the
spoken text to the UI.
Args:
audio_data: Raw 16-bit mono PCM bytes.
sample_rate: Pipeline sample rate (e.g. 16000).
queue_frame: Frame sink -- typically ``transport.output().queue_frame``.
transcript: Optional transcript of the recording.
append_to_context: Whether the transcript should be appended to
the LLM assistant context. Defaults to False.
"""
context_id = str(uuid.uuid4())
await queue_frame(TTSStartedFrame(context_id=context_id))
if transcript:
tts_text = TTSTextFrame(
text=transcript, aggregated_by="recording", context_id=context_id
)
tts_text.append_to_context = append_to_context
await queue_frame(tts_text)
await queue_frame(
TTSAudioRawFrame(
audio=audio_data,
sample_rate=sample_rate,
num_channels=1,
context_id=context_id,
)
)
await queue_frame(TTSStoppedFrame(context_id=context_id))
async def play_audio_loop(
*,
stop_event: asyncio.Event,
sample_rate: int,
queue_frame: Callable[[Frame], Awaitable[None]],
audio_file: Optional[str] = None,
) -> None:
"""Play audio in a loop until *stop_event* is set.
Used for hold music during call transfers and ringers during
pre-call data fetches.
Args:
stop_event: Set this event to terminate the loop.
sample_rate: Target sample rate for audio playback.
queue_frame: Frame sink -- typically ``transport.output().queue_frame``.
audio_file: Path to a WAV file. When *None* the default
``transfer_hold_ring_{sample_rate}.wav`` asset is used.
"""
if audio_file is None:
from api.constants import APP_ROOT_DIR
audio_file = str(
APP_ROOT_DIR / "assets" / f"transfer_hold_ring_{sample_rate}.wav"
)
audio_data = load_audio_file(audio_file, sample_rate)
if not audio_data:
logger.warning(f"Audio loop: failed to load {audio_file}, skipping")
return
num_samples = len(audio_data) // 2 # 16-bit PCM = 2 bytes per sample
duration = num_samples / sample_rate
logger.debug(f"Audio loop: playing at {sample_rate}Hz")
try:
while not stop_event.is_set():
frame = OutputAudioRawFrame(
audio=audio_data,
sample_rate=sample_rate,
num_channels=1,
)
await queue_frame(frame)
try:
await asyncio.wait_for(stop_event.wait(), timeout=duration + 1.5)
break
except asyncio.TimeoutError:
pass
except Exception as e:
logger.error(f"Audio loop error: {e}")
logger.debug("Audio loop: stopped")

View file

@ -6,6 +6,7 @@ from api.db import db_client
from api.enums import PostHogEvent, WorkflowRunState
from api.services.campaign.circuit_breaker import circuit_breaker
from api.services.pipecat.audio_config import AudioConfig
from api.services.pipecat.audio_playback import play_audio, play_audio_loop
from api.services.pipecat.in_memory_buffers import (
InMemoryAudioBuffer,
InMemoryLogsBuffer,
@ -16,8 +17,11 @@ from api.services.posthog_client import capture_event
from api.services.workflow.pipecat_engine import PipecatEngine
from api.tasks.arq import enqueue_job
from api.tasks.function_names import FunctionNames
from api.utils.hold_audio import play_hold_audio_loop
from pipecat.frames.frames import Frame, LLMContextFrame, TTSSpeakFrame
from pipecat.frames.frames import (
Frame,
LLMContextFrame,
TTSSpeakFrame,
)
from pipecat.pipeline.task import PipelineTask
from pipecat.processors.audio.audio_buffer_processor import AudioBufferProcessor
from pipecat.utils.enums import EndTaskReason
@ -64,6 +68,7 @@ def register_event_handlers(
pipeline_metrics_aggregator: PipelineMetricsAggregator,
audio_config=AudioConfig,
pre_call_fetch_task: asyncio.Task | None = None,
fetch_recording_audio=None,
user_provider_id: str | None = None,
):
"""Register all event handlers for transport and task events.
@ -123,7 +128,11 @@ def register_event_handlers(
stop_ringer = asyncio.Event()
sample_rate = audio_config.pipeline_sample_rate or 16000
ringer_task = asyncio.create_task(
play_hold_audio_loop(task, stop_ringer, sample_rate)
play_audio_loop(
stop_event=stop_ringer,
sample_rate=sample_rate,
queue_frame=transport.output().queue_frame,
)
)
try:
fetch_result = await pre_call_fetch_task
@ -151,12 +160,35 @@ def register_event_handlers(
# so that render_template() has the complete _call_context_vars.
await engine.set_node(engine.workflow.start_node_id)
greeting = engine.get_start_greeting()
if greeting:
logger.debug(
"Both pipeline_started and client_connected received - playing greeting via TTS"
)
await task.queue_frame(TTSSpeakFrame(greeting))
greeting_info = engine.get_start_greeting()
if greeting_info:
greeting_type, greeting_value = greeting_info
if (
greeting_type == "audio"
and greeting_value
and fetch_recording_audio
):
logger.debug(f"Playing audio greeting recording: {greeting_value}")
result = await fetch_recording_audio(
recording_pk=int(greeting_value)
)
if result:
await play_audio(
result.audio,
sample_rate=audio_config.pipeline_sample_rate or 16000,
queue_frame=transport.output().queue_frame,
transcript=result.transcript,
append_to_context=True,
)
else:
logger.warning(
f"Failed to fetch audio greeting {greeting_value}, "
"falling back to LLM generation"
)
await engine.llm.queue_frame(LLMContextFrame(engine.context))
else:
logger.debug("Playing text greeting via TTS")
await task.queue_frame(TTSSpeakFrame(greeting_value))
else:
logger.debug(
"Both pipeline_started and client_connected received - triggering initial LLM generation"

View file

@ -170,7 +170,10 @@ class RealtimeFeedbackObserver(BaseObserver):
frame_direction = data.direction
# Skip already processed frames (frames can be observed multiple times)
if frame.id in self._frames_seen:
if (
frame.id in self._frames_seen
or frame_direction != FrameDirection.DOWNSTREAM
):
return
self._frames_seen.add(frame.id)

View file

@ -7,7 +7,7 @@ subsequent plays (even from other workers) are instantaneous.
"""
import os
from typing import Awaitable, Callable, Optional
from typing import Awaitable, Callable, NamedTuple, Optional
import numpy as np
from loguru import logger
@ -22,14 +22,24 @@ from .audio_file_cache import (
write_cache_file,
)
class RecordingAudio(NamedTuple):
"""Audio bytes paired with the recording's transcript (when available)."""
audio: bytes
transcript: Optional[str] = None
# ---------------------------------------------------------------------------
# Cache path helper
# ---------------------------------------------------------------------------
def _cache_path(recording_id: str, sample_rate: int) -> str:
def _cache_path(organization_id: int, recording_id: str, sample_rate: int) -> str:
"""Return the on-disk path for a cached PCM file."""
return os.path.join(CACHE_DIR, f"{recording_id}_{sample_rate}.pcm")
return os.path.join(
CACHE_DIR, f"{organization_id}_{recording_id}_{sample_rate}.pcm"
)
# ---------------------------------------------------------------------------
@ -40,54 +50,95 @@ def _cache_path(recording_id: str, sample_rate: int) -> str:
def create_recording_audio_fetcher(
organization_id: int,
pipeline_sample_rate: int,
) -> Callable[[str], Awaitable[Optional[bytes]]]:
"""Create an async callback that returns raw PCM bytes for a recording_id.
) -> Callable[..., Awaitable[Optional[bytes]]]:
"""Create an async callback that returns raw PCM bytes for a recording.
The returned callable:
1. Checks the filesystem cache (keyed by ``recording_id`` + sample rate).
2. On miss, looks up the recording in the DB, downloads the audio file
from S3/MinIO, converts it to 16-bit mono PCM at *pipeline_sample_rate*,
trims leading/trailing silence, caches the result on disk, and returns it.
The returned callable accepts **one** of two keyword arguments:
- ``recording_pk`` the immutable integer primary key (used by
dropdown-based selections: greeting, edges, tool configs).
- ``recording_id`` the human-readable string ID (used by
prompt-based ``RECORDING_ID: xxx`` references).
Flow:
1. Checks the filesystem cache (keyed by org + pk + sample rate).
2. On miss, looks up the recording in the DB, downloads the audio
from S3/MinIO, converts to 16-bit mono PCM, trims silence, and
caches the result on disk.
Args:
organization_id: Organization owning the recordings.
pipeline_sample_rate: Target PCM sample rate for the pipeline.
Returns:
``async (recording_id: str) -> Optional[bytes]``
"""
from api.db import db_client
from api.services.storage import get_storage_for_backend
# Resolve storage instances once per backend at creation time, not per fetch.
_storage_cache: dict[str, object] = {}
_transcript_cache: dict[str, Optional[str]] = {}
def _get_storage(backend: str):
if backend not in _storage_cache:
_storage_cache[backend] = get_storage_for_backend(backend)
return _storage_cache[backend]
async def fetch(recording_id: str) -> Optional[bytes]:
cached = _cache_path(recording_id, pipeline_sample_rate)
async def _lookup_recording(
cache_key: str,
recording_pk: Optional[int],
recording_id: Optional[str],
):
"""DB lookup with transcript caching."""
if recording_pk is not None:
recording = await db_client.get_recording_by_id(
recording_pk, organization_id
)
else:
recording = await db_client.get_recording_by_recording_id(
recording_id, organization_id
)
if recording:
_transcript_cache[cache_key] = recording.transcript or None
return recording
async def fetch(
*,
recording_pk: Optional[int] = None,
recording_id: Optional[str] = None,
) -> Optional[RecordingAudio]:
if recording_pk is None and recording_id is None:
logger.warning("fetch called with neither recording_pk nor recording_id")
return None
# Use pk for cache key when available, otherwise recording_id
cache_key = str(recording_pk) if recording_pk is not None else recording_id
cached = _cache_path(organization_id, cache_key, pipeline_sample_rate)
# 1. Serve from filesystem cache
if os.path.exists(cached):
logger.debug(f"Recording {recording_id} served from disk cache")
return read_cached_file(cached)
logger.debug(f"Recording {cache_key} served from disk cache")
audio = read_cached_file(cached)
# Transcript may already be in memory from a prior fetch;
# if not, do a lightweight DB lookup.
if cache_key not in _transcript_cache:
await _lookup_recording(cache_key, recording_pk, recording_id)
return RecordingAudio(
audio=audio, transcript=_transcript_cache.get(cache_key)
)
# 2. DB lookup
recording = await db_client.get_recording_by_recording_id(
recording_id, organization_id
)
recording = await _lookup_recording(cache_key, recording_pk, recording_id)
if not recording:
logger.warning(f"Recording {recording_id} not found in database")
logger.warning(f"Recording {cache_key} not found in database")
return None
# 3. Download, convert, trim, and cache
pcm_data = await _download_and_convert(
recording, pipeline_sample_rate, _get_storage
)
return pcm_data
if pcm_data is None:
return None
return RecordingAudio(
audio=pcm_data, transcript=_transcript_cache.get(cache_key)
)
return fetch
@ -98,11 +149,10 @@ def create_recording_audio_fetcher(
async def warm_recording_cache(
workflow_id: int,
organization_id: int,
pipeline_sample_rate: int,
) -> None:
"""Pre-fetch all active recordings for a workflow into the disk cache.
"""Pre-fetch all active recordings for an organization into the disk cache.
Launched as a background ``asyncio.Task`` at pipeline startup so that
recordings are ready before the first playback request. Errors are logged
@ -112,9 +162,7 @@ async def warm_recording_cache(
from api.services.storage import get_storage_for_backend
try:
recordings = await db_client.get_recordings_for_workflow(
workflow_id, organization_id
)
recordings = await db_client.get_recordings(organization_id=organization_id)
if not recordings:
return
@ -122,15 +170,20 @@ async def warm_recording_cache(
uncached = [
r
for r in recordings
if not os.path.exists(_cache_path(r.recording_id, pipeline_sample_rate))
if not os.path.exists(
_cache_path(organization_id, str(r.id), pipeline_sample_rate)
)
and not os.path.exists(
_cache_path(organization_id, r.recording_id, pipeline_sample_rate)
)
]
if not uncached:
logger.debug(f"Recording cache already warm for workflow {workflow_id}")
logger.debug(f"Recording cache already warm for org {organization_id}")
return
logger.info(
f"Warming recording cache: {len(uncached)}/{len(recordings)} "
f"recording(s) for workflow {workflow_id}"
f"recording(s) for org {organization_id}"
)
# Resolve storage instances once per backend, not per recording
@ -156,7 +209,7 @@ async def warm_recording_cache(
f"Cache warm: error processing {recording.recording_id}"
)
logger.info(f"Recording cache warm complete for workflow {workflow_id}")
logger.info(f"Recording cache warm complete for org {organization_id}")
except Exception:
logger.exception("Recording cache warm failed")
@ -187,7 +240,11 @@ async def _download_and_convert(
pcm_data = _trim_silence(pcm_data, sample_rate)
# Write to disk cache
cached = _cache_path(recording.recording_id, sample_rate)
cached = _cache_path(
recording.organization_id,
recording.recording_id,
sample_rate,
)
write_cache_file(cached, pcm_data)
return pcm_data

View file

@ -17,6 +17,7 @@ from typing import Awaitable, Callable, Optional
from loguru import logger
from api.services.pipecat.recording_audio_cache import RecordingAudio
from api.services.workflow.pipecat_engine_context_composer import (
RECORDING_MARKER,
TTS_MARKER,
@ -48,14 +49,14 @@ class RecordingRouterProcessor(FrameProcessor):
Args:
audio_sample_rate: Pipeline sample rate for OutputAudioRawFrame.
fetch_recording_audio: Async callback that takes a recording_id and
returns raw 16-bit mono PCM bytes, or None on failure.
returns a RecordingAudio (audio + transcript), or None on failure.
"""
def __init__(
self,
*,
audio_sample_rate: int,
fetch_recording_audio: Callable[[str], Awaitable[Optional[bytes]]],
fetch_recording_audio: Callable[..., Awaitable[Optional[RecordingAudio]]],
**kwargs,
):
super().__init__(**kwargs)
@ -245,8 +246,8 @@ class RecordingRouterProcessor(FrameProcessor):
"""
logger.info(f"Playing pre-recorded audio: {recording_id}")
audio_data = await self._fetch_recording_audio(recording_id)
if not audio_data:
result = await self._fetch_recording_audio(recording_id=recording_id)
if not result:
logger.warning(
f"Failed to fetch recording {recording_id}, no audio will play"
)
@ -256,7 +257,7 @@ class RecordingRouterProcessor(FrameProcessor):
await self.push_frame(TTSStartedFrame(context_id=context_id))
await self.push_frame(
TTSAudioRawFrame(
audio=audio_data,
audio=result.audio,
sample_rate=self._audio_sample_rate,
num_channels=1,
context_id=context_id,
@ -264,10 +265,10 @@ class RecordingRouterProcessor(FrameProcessor):
)
await self.push_frame(TTSStoppedFrame(context_id=context_id))
duration_secs = len(audio_data) / (self._audio_sample_rate * 2)
duration_secs = len(result.audio) / (self._audio_sample_rate * 2)
logger.debug(
f"Finished pushing recording {recording_id} "
f"({len(audio_data)} bytes, {duration_secs:.1f}s)"
f"({len(result.audio)} bytes, {duration_secs:.1f}s)"
)
# ------------------------------------------------------------------

View file

@ -698,9 +698,7 @@ async def _run_pipeline(
# Check if the workflow has any active recordings so the engine can
# include recording response mode instructions in all node prompts.
has_recordings = await db_client.has_active_recordings(
workflow_id, workflow.organization_id
)
has_recordings = await db_client.has_active_recordings(workflow.organization_id)
context_compaction_enabled = (workflow.workflow_configurations or {}).get(
"context_compaction_enabled", False
@ -831,6 +829,14 @@ async def _run_pipeline(
voicemail_detector = None
recording_router = None
# Create recording audio fetcher (used by recording router, audio greetings,
# and audio transition speech)
fetch_audio = create_recording_audio_fetcher(
organization_id=workflow.organization_id,
pipeline_sample_rate=audio_config.pipeline_sample_rate,
)
engine.set_fetch_recording_audio(fetch_audio)
if not is_realtime:
# Create voicemail detector if enabled in workflow configurations
voicemail_config = (workflow.workflow_configurations or {}).get(
@ -871,10 +877,6 @@ async def _run_pipeline(
# Create recording router if workflow has active recordings
if has_recordings:
fetch_audio = create_recording_audio_fetcher(
organization_id=workflow.organization_id,
pipeline_sample_rate=audio_config.pipeline_sample_rate,
)
recording_router = RecordingRouterProcessor(
audio_sample_rate=audio_config.pipeline_sample_rate,
fetch_recording_audio=fetch_audio,
@ -883,7 +885,6 @@ async def _run_pipeline(
# before the first playback request.
asyncio.create_task(
warm_recording_cache(
workflow_id=workflow_id,
organization_id=workflow.organization_id,
pipeline_sample_rate=audio_config.pipeline_sample_rate,
)
@ -918,8 +919,9 @@ async def _run_pipeline(
# Create pipeline task with audio configuration
task = create_pipeline_task(pipeline, workflow_run_id, audio_config)
# Now set the task on the engine
# Now set the task and transport output on the engine
engine.set_task(task)
engine.set_transport_output(transport.output())
# Initialize the engine to set the initial context with
# System Prompt and Tools
@ -979,6 +981,7 @@ async def _run_pipeline(
pipeline_metrics_aggregator=pipeline_metrics_aggregator,
audio_config=audio_config,
pre_call_fetch_task=pre_call_fetch_task,
fetch_recording_audio=fetch_audio,
user_provider_id=user_provider_id,
)

View file

@ -230,7 +230,7 @@ def create_tts_service(user_config, audio_config: "AudioConfig"):
api_key=user_config.tts.api_key,
settings=DeepgramTTSSettings(voice=user_config.tts.voice),
text_filters=[xml_function_tag_filter],
skip_aggregator_types=["recording_router"],
skip_aggregator_types=["recording_router", "recording"],
silence_time_s=1.0,
)
elif user_config.tts.provider == ServiceProviders.OPENAI.value:
@ -238,7 +238,7 @@ def create_tts_service(user_config, audio_config: "AudioConfig"):
api_key=user_config.tts.api_key,
settings=OpenAITTSSettings(model=user_config.tts.model),
text_filters=[xml_function_tag_filter],
skip_aggregator_types=["recording_router"],
skip_aggregator_types=["recording_router", "recording"],
silence_time_s=1.0,
)
elif user_config.tts.provider == ServiceProviders.ELEVENLABS.value:
@ -258,7 +258,7 @@ def create_tts_service(user_config, audio_config: "AudioConfig"):
similarity_boost=0.75,
),
text_filters=[xml_function_tag_filter],
skip_aggregator_types=["recording_router"],
skip_aggregator_types=["recording_router", "recording"],
silence_time_s=1.0,
)
elif user_config.tts.provider == ServiceProviders.CARTESIA.value:
@ -284,7 +284,7 @@ def create_tts_service(user_config, audio_config: "AudioConfig"):
),
),
text_filters=[xml_function_tag_filter],
skip_aggregator_types=["recording_router"],
skip_aggregator_types=["recording_router", "recording"],
silence_time_s=1.0,
)
elif user_config.tts.provider == ServiceProviders.DOGRAH.value:
@ -299,7 +299,7 @@ def create_tts_service(user_config, audio_config: "AudioConfig"):
speed=user_config.tts.speed,
),
text_filters=[xml_function_tag_filter],
skip_aggregator_types=["recording_router"],
skip_aggregator_types=["recording_router", "recording"],
silence_time_s=1.0,
)
elif user_config.tts.provider == ServiceProviders.CAMB.value:
@ -312,7 +312,7 @@ def create_tts_service(user_config, audio_config: "AudioConfig"):
voice_id=voice_id,
model=user_config.tts.model,
text_filters=[xml_function_tag_filter],
skip_aggregator_types=["recording_router"],
skip_aggregator_types=["recording_router", "recording"],
)
# Set language directly as BCP-47 code (bypasses Language enum conversion)
tts._settings.language = language
@ -327,7 +327,7 @@ def create_tts_service(user_config, audio_config: "AudioConfig"):
speed=user_config.tts.speed,
),
text_filters=[xml_function_tag_filter],
skip_aggregator_types=["recording_router"],
skip_aggregator_types=["recording_router", "recording"],
silence_time_s=1.0,
)
elif user_config.tts.provider == ServiceProviders.RIME.value:
@ -352,7 +352,7 @@ def create_tts_service(user_config, audio_config: "AudioConfig"):
api_key=user_config.tts.api_key,
settings=RimeTTSSettings(**settings_kwargs),
text_filters=[xml_function_tag_filter],
skip_aggregator_types=["recording_router"],
skip_aggregator_types=["recording_router", "recording"],
silence_time_s=1.0,
)
elif user_config.tts.provider == ServiceProviders.SARVAM.value:
@ -382,7 +382,7 @@ def create_tts_service(user_config, audio_config: "AudioConfig"):
language=pipecat_language,
),
text_filters=[xml_function_tag_filter],
skip_aggregator_types=["recording_router"],
skip_aggregator_types=["recording_router", "recording"],
silence_time_s=1.0,
)
else: