diff --git a/api/services/pipecat/recording_router_processor.py b/api/services/pipecat/recording_router_processor.py index 5db501a..23ceb67 100644 --- a/api/services/pipecat/recording_router_processor.py +++ b/api/services/pipecat/recording_router_processor.py @@ -17,6 +17,7 @@ from typing import Awaitable, Callable, Optional from loguru import logger +from api.services.pipecat.recording_audio_cache import RecordingAudio from api.services.workflow.pipecat_engine_context_composer import ( RECORDING_MARKER, TTS_MARKER, @@ -48,14 +49,14 @@ class RecordingRouterProcessor(FrameProcessor): Args: audio_sample_rate: Pipeline sample rate for OutputAudioRawFrame. fetch_recording_audio: Async callback that takes a recording_id and - returns raw 16-bit mono PCM bytes, or None on failure. + returns a RecordingAudio (audio + transcript), or None on failure. """ def __init__( self, *, audio_sample_rate: int, - fetch_recording_audio: Callable[[str], Awaitable[Optional[bytes]]], + fetch_recording_audio: Callable[..., Awaitable[Optional[RecordingAudio]]], **kwargs, ): super().__init__(**kwargs) diff --git a/api/tests/test_recording_router_processor.py b/api/tests/test_recording_router_processor.py index 24b76c2..5ef2057 100644 --- a/api/tests/test_recording_router_processor.py +++ b/api/tests/test_recording_router_processor.py @@ -13,6 +13,7 @@ from typing import Optional import pytest +from api.services.pipecat.recording_audio_cache import RecordingAudio from api.services.pipecat.recording_router_processor import ( RecordingRouterProcessor, ) @@ -37,9 +38,9 @@ from pipecat.tests import run_test FAKE_AUDIO = b"\x00\x01" * 8000 # 1 second of 16-bit mono @ 16 kHz -async def _fake_fetch(recording_id: str) -> Optional[bytes]: +async def _fake_fetch(recording_id: str) -> Optional[RecordingAudio]: """Stub that returns fake PCM audio for any recording_id.""" - return FAKE_AUDIO + return RecordingAudio(audio=FAKE_AUDIO) def _make_processor(**kwargs) -> RecordingRouterProcessor: @@ -189,7 +190,7 @@ class TestMixedMarkerSuppression: async def tracking_fetch(recording_id: str): fetched_ids.append(recording_id) - return FAKE_AUDIO + return RecordingAudio(audio=FAKE_AUDIO) processor = _make_processor(fetch=tracking_fetch) diff --git a/api/tests/test_text_and_audio_playback.py b/api/tests/test_text_and_audio_playback.py index ab2c95b..3a8b1a6 100644 --- a/api/tests/test_text_and_audio_playback.py +++ b/api/tests/test_text_and_audio_playback.py @@ -12,6 +12,7 @@ from unittest.mock import AsyncMock, Mock, patch import pytest +from api.services.pipecat.recording_audio_cache import RecordingAudio from api.services.workflow.dto import ( EdgeDataDTO, NodeDataDTO, @@ -51,7 +52,7 @@ END_PROMPT = "End Call System Prompt" TEXT_GREETING = "Hello, welcome to our service!" TEXT_TRANSITION = "Thank you for calling, goodbye!" AUDIO_GREETING_ID = "rec-greeting-001" -AUDIO_TRANSITION_ID = "rec-transition-001" +AUDIO_TRANSITION_ID = "101" FAKE_PCM_AUDIO = b"\x00\x01" * 1000 # Fake 16-bit mono PCM data @@ -204,16 +205,18 @@ async def run_pipeline_and_capture_frames( workflow_run_id=1, ) + transport_output = mock_transport.output() + if fetch_recording_audio: engine.set_fetch_recording_audio(fetch_recording_audio) + engine.set_transport_output(transport_output) - pipeline = Pipeline( - [llm, tts, mock_transport.output(), context_aggregator.assistant()] - ) + pipeline = Pipeline([llm, tts, transport_output, context_aggregator.assistant()]) task = PipelineTask(pipeline, params=PipelineParams(), enable_rtvi=False) engine.set_task(task) - # Spy on task.queue_frame to capture all frames queued by the engine + # Spy on task.queue_frame and transport_output.queue_frame to capture + # all frames queued by the engine (audio transitions go via transport output) queued_frames: list[Frame] = [] original_queue_frame = task.queue_frame @@ -223,6 +226,15 @@ async def run_pipeline_and_capture_frames( task.queue_frame = capturing_queue_frame + if fetch_recording_audio: + original_transport_queue = transport_output.queue_frame + + async def _spy_transport_queue(frame, *args, **kwargs): + queued_frames.append(frame) + await original_transport_queue(frame, *args, **kwargs) + + transport_output.queue_frame = _spy_transport_queue + with ( patch( "api.services.workflow.pipecat_engine.get_organization_id_from_workflow_run", @@ -424,7 +436,7 @@ class TestTransitionSpeech: }, ] - mock_fetch = AsyncMock(return_value=FAKE_PCM_AUDIO) + mock_fetch = AsyncMock(return_value=RecordingAudio(audio=FAKE_PCM_AUDIO)) llm, context, queued_frames = await run_pipeline_and_capture_frames( workflow=audio_workflow, @@ -437,7 +449,7 @@ class TestTransitionSpeech: assert llm.get_current_step() == 2 # Verify fetch was called with the correct recording ID - mock_fetch.assert_called_once_with(AUDIO_TRANSITION_ID) + mock_fetch.assert_called_once_with(recording_pk=int(AUDIO_TRANSITION_ID)) # Verify the three-frame audio sequence was queued started = [f for f in queued_frames if isinstance(f, TTSStartedFrame)] @@ -491,6 +503,10 @@ class TestPlayConfigMessage: engine._queued_frames.append(frame) engine.task.queue_frame = mock_queue_frame + + # Also capture frames queued via transport_output.queue_frame (audio playback) + engine._transport_output = Mock() + engine._transport_output.queue_frame = mock_queue_frame return engine @pytest.mark.asyncio @@ -510,16 +526,16 @@ class TestPlayConfigMessage: @pytest.mark.asyncio async def test_audio_queues_started_raw_stopped_frames(self, mock_engine): """messageType='audio' queues TTSStarted + TTSAudioRaw + TTSStopped.""" - mock_fetch = AsyncMock(return_value=FAKE_PCM_AUDIO) + mock_fetch = AsyncMock(return_value=RecordingAudio(audio=FAKE_PCM_AUDIO)) mock_engine._fetch_recording_audio = mock_fetch manager = CustomToolManager(mock_engine) - config = {"messageType": "audio", "audioRecordingId": "rec-end-001"} + config = {"messageType": "audio", "audioRecordingId": "201"} result = await manager._play_config_message(config) assert result is True - mock_fetch.assert_called_once_with("rec-end-001") + mock_fetch.assert_called_once_with(recording_pk=201) frames = mock_engine._queued_frames assert len(frames) == 3 @@ -553,7 +569,7 @@ class TestPlayConfigMessage: mock_engine._fetch_recording_audio = None manager = CustomToolManager(mock_engine) - config = {"messageType": "audio", "audioRecordingId": "rec-123"} + config = {"messageType": "audio", "audioRecordingId": "301"} result = await manager._play_config_message(config) @@ -567,12 +583,12 @@ class TestPlayConfigMessage: mock_engine._fetch_recording_audio = mock_fetch manager = CustomToolManager(mock_engine) - config = {"messageType": "audio", "audioRecordingId": "rec-123"} + config = {"messageType": "audio", "audioRecordingId": "301"} result = await manager._play_config_message(config) assert result is False - mock_fetch.assert_called_once_with("rec-123") + mock_fetch.assert_called_once_with(recording_pk=301) assert len(mock_engine._queued_frames) == 0 @pytest.mark.asyncio