mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-16 08:25:18 +02:00
feat: allow uploading recording as part of node transition
This commit is contained in:
parent
bb5f56bfb7
commit
65c76ca7ff
36 changed files with 2255 additions and 201 deletions
|
|
@ -11,12 +11,17 @@ from api.services.pipecat.in_memory_buffers import (
|
|||
InMemoryLogsBuffer,
|
||||
)
|
||||
from api.services.pipecat.pipeline_metrics_aggregator import PipelineMetricsAggregator
|
||||
from api.services.pipecat.recording_playback import queue_recording_audio
|
||||
from api.services.pipecat.tracing_config import get_trace_url
|
||||
from api.services.workflow.pipecat_engine import PipecatEngine
|
||||
from api.tasks.arq import enqueue_job
|
||||
from api.tasks.function_names import FunctionNames
|
||||
from api.utils.hold_audio import play_hold_audio_loop
|
||||
from pipecat.frames.frames import Frame, LLMContextFrame, TTSSpeakFrame
|
||||
from pipecat.frames.frames import (
|
||||
Frame,
|
||||
LLMContextFrame,
|
||||
TTSSpeakFrame,
|
||||
)
|
||||
from pipecat.pipeline.task import PipelineTask
|
||||
from pipecat.processors.audio.audio_buffer_processor import AudioBufferProcessor
|
||||
from pipecat.utils.enums import EndTaskReason
|
||||
|
|
@ -32,6 +37,7 @@ def register_event_handlers(
|
|||
pipeline_metrics_aggregator: PipelineMetricsAggregator,
|
||||
audio_config=AudioConfig,
|
||||
pre_call_fetch_task: asyncio.Task | None = None,
|
||||
fetch_recording_audio=None,
|
||||
):
|
||||
"""Register all event handlers for transport and task events.
|
||||
|
||||
|
|
@ -112,12 +118,31 @@ def register_event_handlers(
|
|||
# so that render_template() has the complete _call_context_vars.
|
||||
await engine.set_node(engine.workflow.start_node_id)
|
||||
|
||||
greeting = engine.get_start_greeting()
|
||||
if greeting:
|
||||
logger.debug(
|
||||
"Both pipeline_started and client_connected received - playing greeting via TTS"
|
||||
)
|
||||
await task.queue_frame(TTSSpeakFrame(greeting))
|
||||
greeting_info = engine.get_start_greeting()
|
||||
if greeting_info:
|
||||
greeting_type, greeting_value = greeting_info
|
||||
if (
|
||||
greeting_type == "audio"
|
||||
and greeting_value
|
||||
and fetch_recording_audio
|
||||
):
|
||||
logger.debug(f"Playing audio greeting recording: {greeting_value}")
|
||||
audio_data = await fetch_recording_audio(greeting_value)
|
||||
if audio_data:
|
||||
await queue_recording_audio(
|
||||
audio_data,
|
||||
sample_rate=audio_config.pipeline_sample_rate or 16000,
|
||||
queue_frame=task.queue_frame,
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"Failed to fetch audio greeting {greeting_value}, "
|
||||
"falling back to LLM generation"
|
||||
)
|
||||
await engine.llm.queue_frame(LLMContextFrame(engine.context))
|
||||
else:
|
||||
logger.debug("Playing text greeting via TTS")
|
||||
await task.queue_frame(TTSSpeakFrame(greeting_value))
|
||||
else:
|
||||
logger.debug(
|
||||
"Both pipeline_started and client_connected received - triggering initial LLM generation"
|
||||
|
|
|
|||
|
|
@ -27,9 +27,13 @@ from .audio_file_cache import (
|
|||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _cache_path(recording_id: str, sample_rate: int) -> str:
|
||||
def _cache_path(
|
||||
organization_id: int, workflow_id: int, recording_id: str, sample_rate: int
|
||||
) -> str:
|
||||
"""Return the on-disk path for a cached PCM file."""
|
||||
return os.path.join(CACHE_DIR, f"{recording_id}_{sample_rate}.pcm")
|
||||
return os.path.join(
|
||||
CACHE_DIR, f"{organization_id}_{workflow_id}_{recording_id}_{sample_rate}.pcm"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -39,18 +43,20 @@ def _cache_path(recording_id: str, sample_rate: int) -> str:
|
|||
|
||||
def create_recording_audio_fetcher(
|
||||
organization_id: int,
|
||||
workflow_id: int,
|
||||
pipeline_sample_rate: int,
|
||||
) -> Callable[[str], Awaitable[Optional[bytes]]]:
|
||||
"""Create an async callback that returns raw PCM bytes for a recording_id.
|
||||
|
||||
The returned callable:
|
||||
1. Checks the filesystem cache (keyed by ``recording_id`` + sample rate).
|
||||
1. Checks the filesystem cache (keyed by org/workflow/recording + sample rate).
|
||||
2. On miss, looks up the recording in the DB, downloads the audio file
|
||||
from S3/MinIO, converts it to 16-bit mono PCM at *pipeline_sample_rate*,
|
||||
trims leading/trailing silence, caches the result on disk, and returns it.
|
||||
|
||||
Args:
|
||||
organization_id: Organization owning the recordings.
|
||||
workflow_id: Workflow the recordings belong to.
|
||||
pipeline_sample_rate: Target PCM sample rate for the pipeline.
|
||||
|
||||
Returns:
|
||||
|
|
@ -68,7 +74,9 @@ def create_recording_audio_fetcher(
|
|||
return _storage_cache[backend]
|
||||
|
||||
async def fetch(recording_id: str) -> Optional[bytes]:
|
||||
cached = _cache_path(recording_id, pipeline_sample_rate)
|
||||
cached = _cache_path(
|
||||
organization_id, workflow_id, recording_id, pipeline_sample_rate
|
||||
)
|
||||
|
||||
# 1. Serve from filesystem cache
|
||||
if os.path.exists(cached):
|
||||
|
|
@ -77,7 +85,7 @@ def create_recording_audio_fetcher(
|
|||
|
||||
# 2. DB lookup
|
||||
recording = await db_client.get_recording_by_recording_id(
|
||||
recording_id, organization_id
|
||||
recording_id, organization_id, workflow_id
|
||||
)
|
||||
if not recording:
|
||||
logger.warning(f"Recording {recording_id} not found in database")
|
||||
|
|
@ -112,8 +120,8 @@ async def warm_recording_cache(
|
|||
from api.services.storage import get_storage_for_backend
|
||||
|
||||
try:
|
||||
recordings = await db_client.get_recordings_for_workflow(
|
||||
workflow_id, organization_id
|
||||
recordings = await db_client.get_recordings(
|
||||
organization_id=organization_id, workflow_id=workflow_id
|
||||
)
|
||||
if not recordings:
|
||||
return
|
||||
|
|
@ -122,7 +130,11 @@ async def warm_recording_cache(
|
|||
uncached = [
|
||||
r
|
||||
for r in recordings
|
||||
if not os.path.exists(_cache_path(r.recording_id, pipeline_sample_rate))
|
||||
if not os.path.exists(
|
||||
_cache_path(
|
||||
organization_id, workflow_id, r.recording_id, pipeline_sample_rate
|
||||
)
|
||||
)
|
||||
]
|
||||
if not uncached:
|
||||
logger.debug(f"Recording cache already warm for workflow {workflow_id}")
|
||||
|
|
@ -187,7 +199,12 @@ async def _download_and_convert(
|
|||
pcm_data = _trim_silence(pcm_data, sample_rate)
|
||||
|
||||
# Write to disk cache
|
||||
cached = _cache_path(recording.recording_id, sample_rate)
|
||||
cached = _cache_path(
|
||||
recording.organization_id,
|
||||
recording.workflow_id,
|
||||
recording.recording_id,
|
||||
sample_rate,
|
||||
)
|
||||
write_cache_file(cached, pcm_data)
|
||||
|
||||
return pcm_data
|
||||
|
|
|
|||
41
api/services/pipecat/recording_playback.py
Normal file
41
api/services/pipecat/recording_playback.py
Normal file
|
|
@ -0,0 +1,41 @@
|
|||
"""Shared helper for pushing pre-recorded audio frames into a pipeline."""
|
||||
|
||||
import uuid
|
||||
from typing import Awaitable, Callable
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
Frame,
|
||||
TTSAudioRawFrame,
|
||||
TTSStartedFrame,
|
||||
TTSStoppedFrame,
|
||||
)
|
||||
|
||||
|
||||
async def queue_recording_audio(
|
||||
audio_data: bytes,
|
||||
*,
|
||||
sample_rate: int,
|
||||
queue_frame: Callable[[Frame], Awaitable[None]],
|
||||
) -> None:
|
||||
"""Push TTSStarted → TTSAudioRaw → TTSStopped frames.
|
||||
|
||||
This is the canonical way to play pre-recorded PCM audio through the
|
||||
pipeline outside of the RecordingRouterProcessor (which uses its own
|
||||
``push_frame`` path).
|
||||
|
||||
Args:
|
||||
audio_data: Raw 16-bit mono PCM bytes.
|
||||
sample_rate: Pipeline sample rate (e.g. 16000).
|
||||
queue_frame: Typically ``task.queue_frame``.
|
||||
"""
|
||||
context_id = str(uuid.uuid4())
|
||||
await queue_frame(TTSStartedFrame(context_id=context_id))
|
||||
await queue_frame(
|
||||
TTSAudioRawFrame(
|
||||
audio=audio_data,
|
||||
sample_rate=sample_rate,
|
||||
num_channels=1,
|
||||
context_id=context_id,
|
||||
)
|
||||
)
|
||||
await queue_frame(TTSStoppedFrame(context_id=context_id))
|
||||
|
|
@ -828,6 +828,15 @@ async def _run_pipeline(
|
|||
voicemail_detector = None
|
||||
recording_router = None
|
||||
|
||||
# Create recording audio fetcher (used by recording router, audio greetings,
|
||||
# and audio transition speech)
|
||||
fetch_audio = create_recording_audio_fetcher(
|
||||
organization_id=workflow.organization_id,
|
||||
workflow_id=workflow_id,
|
||||
pipeline_sample_rate=audio_config.pipeline_sample_rate,
|
||||
)
|
||||
engine.set_fetch_recording_audio(fetch_audio)
|
||||
|
||||
if not is_realtime:
|
||||
# Create voicemail detector if enabled in workflow configurations
|
||||
voicemail_config = (workflow.workflow_configurations or {}).get(
|
||||
|
|
@ -868,10 +877,6 @@ async def _run_pipeline(
|
|||
|
||||
# Create recording router if workflow has active recordings
|
||||
if has_recordings:
|
||||
fetch_audio = create_recording_audio_fetcher(
|
||||
organization_id=workflow.organization_id,
|
||||
pipeline_sample_rate=audio_config.pipeline_sample_rate,
|
||||
)
|
||||
recording_router = RecordingRouterProcessor(
|
||||
audio_sample_rate=audio_config.pipeline_sample_rate,
|
||||
fetch_recording_audio=fetch_audio,
|
||||
|
|
@ -973,6 +978,7 @@ async def _run_pipeline(
|
|||
pipeline_metrics_aggregator=pipeline_metrics_aggregator,
|
||||
audio_config=audio_config,
|
||||
pre_call_fetch_task=pre_call_fetch_task,
|
||||
fetch_recording_audio=fetch_audio,
|
||||
)
|
||||
|
||||
register_audio_data_handler(audio_buffer, workflow_run_id, in_memory_audio_buffer)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue