chore: fix tests

This commit is contained in:
Abhishek Kumar 2026-04-10 17:05:44 +05:30
parent 74dbafb055
commit 51fde746ba
3 changed files with 36 additions and 18 deletions

View file

@ -17,6 +17,7 @@ from typing import Awaitable, Callable, Optional
from loguru import logger
from api.services.pipecat.recording_audio_cache import RecordingAudio
from api.services.workflow.pipecat_engine_context_composer import (
RECORDING_MARKER,
TTS_MARKER,
@ -48,14 +49,14 @@ class RecordingRouterProcessor(FrameProcessor):
Args:
audio_sample_rate: Pipeline sample rate for OutputAudioRawFrame.
fetch_recording_audio: Async callback that takes a recording_id and
returns raw 16-bit mono PCM bytes, or None on failure.
returns a RecordingAudio (audio + transcript), or None on failure.
"""
def __init__(
self,
*,
audio_sample_rate: int,
fetch_recording_audio: Callable[[str], Awaitable[Optional[bytes]]],
fetch_recording_audio: Callable[..., Awaitable[Optional[RecordingAudio]]],
**kwargs,
):
super().__init__(**kwargs)

View file

@ -13,6 +13,7 @@ from typing import Optional
import pytest
from api.services.pipecat.recording_audio_cache import RecordingAudio
from api.services.pipecat.recording_router_processor import (
RecordingRouterProcessor,
)
@ -37,9 +38,9 @@ from pipecat.tests import run_test
FAKE_AUDIO = b"\x00\x01" * 8000 # 1 second of 16-bit mono @ 16 kHz
async def _fake_fetch(recording_id: str) -> Optional[bytes]:
async def _fake_fetch(recording_id: str) -> Optional[RecordingAudio]:
"""Stub that returns fake PCM audio for any recording_id."""
return FAKE_AUDIO
return RecordingAudio(audio=FAKE_AUDIO)
def _make_processor(**kwargs) -> RecordingRouterProcessor:
@ -189,7 +190,7 @@ class TestMixedMarkerSuppression:
async def tracking_fetch(recording_id: str):
fetched_ids.append(recording_id)
return FAKE_AUDIO
return RecordingAudio(audio=FAKE_AUDIO)
processor = _make_processor(fetch=tracking_fetch)

View file

@ -12,6 +12,7 @@ from unittest.mock import AsyncMock, Mock, patch
import pytest
from api.services.pipecat.recording_audio_cache import RecordingAudio
from api.services.workflow.dto import (
EdgeDataDTO,
NodeDataDTO,
@ -51,7 +52,7 @@ END_PROMPT = "End Call System Prompt"
TEXT_GREETING = "Hello, welcome to our service!"
TEXT_TRANSITION = "Thank you for calling, goodbye!"
AUDIO_GREETING_ID = "rec-greeting-001"
AUDIO_TRANSITION_ID = "rec-transition-001"
AUDIO_TRANSITION_ID = "101"
FAKE_PCM_AUDIO = b"\x00\x01" * 1000 # Fake 16-bit mono PCM data
@ -204,16 +205,18 @@ async def run_pipeline_and_capture_frames(
workflow_run_id=1,
)
transport_output = mock_transport.output()
if fetch_recording_audio:
engine.set_fetch_recording_audio(fetch_recording_audio)
engine.set_transport_output(transport_output)
pipeline = Pipeline(
[llm, tts, mock_transport.output(), context_aggregator.assistant()]
)
pipeline = Pipeline([llm, tts, transport_output, context_aggregator.assistant()])
task = PipelineTask(pipeline, params=PipelineParams(), enable_rtvi=False)
engine.set_task(task)
# Spy on task.queue_frame to capture all frames queued by the engine
# Spy on task.queue_frame and transport_output.queue_frame to capture
# all frames queued by the engine (audio transitions go via transport output)
queued_frames: list[Frame] = []
original_queue_frame = task.queue_frame
@ -223,6 +226,15 @@ async def run_pipeline_and_capture_frames(
task.queue_frame = capturing_queue_frame
if fetch_recording_audio:
original_transport_queue = transport_output.queue_frame
async def _spy_transport_queue(frame, *args, **kwargs):
queued_frames.append(frame)
await original_transport_queue(frame, *args, **kwargs)
transport_output.queue_frame = _spy_transport_queue
with (
patch(
"api.services.workflow.pipecat_engine.get_organization_id_from_workflow_run",
@ -424,7 +436,7 @@ class TestTransitionSpeech:
},
]
mock_fetch = AsyncMock(return_value=FAKE_PCM_AUDIO)
mock_fetch = AsyncMock(return_value=RecordingAudio(audio=FAKE_PCM_AUDIO))
llm, context, queued_frames = await run_pipeline_and_capture_frames(
workflow=audio_workflow,
@ -437,7 +449,7 @@ class TestTransitionSpeech:
assert llm.get_current_step() == 2
# Verify fetch was called with the correct recording ID
mock_fetch.assert_called_once_with(AUDIO_TRANSITION_ID)
mock_fetch.assert_called_once_with(recording_pk=int(AUDIO_TRANSITION_ID))
# Verify the three-frame audio sequence was queued
started = [f for f in queued_frames if isinstance(f, TTSStartedFrame)]
@ -491,6 +503,10 @@ class TestPlayConfigMessage:
engine._queued_frames.append(frame)
engine.task.queue_frame = mock_queue_frame
# Also capture frames queued via transport_output.queue_frame (audio playback)
engine._transport_output = Mock()
engine._transport_output.queue_frame = mock_queue_frame
return engine
@pytest.mark.asyncio
@ -510,16 +526,16 @@ class TestPlayConfigMessage:
@pytest.mark.asyncio
async def test_audio_queues_started_raw_stopped_frames(self, mock_engine):
"""messageType='audio' queues TTSStarted + TTSAudioRaw + TTSStopped."""
mock_fetch = AsyncMock(return_value=FAKE_PCM_AUDIO)
mock_fetch = AsyncMock(return_value=RecordingAudio(audio=FAKE_PCM_AUDIO))
mock_engine._fetch_recording_audio = mock_fetch
manager = CustomToolManager(mock_engine)
config = {"messageType": "audio", "audioRecordingId": "rec-end-001"}
config = {"messageType": "audio", "audioRecordingId": "201"}
result = await manager._play_config_message(config)
assert result is True
mock_fetch.assert_called_once_with("rec-end-001")
mock_fetch.assert_called_once_with(recording_pk=201)
frames = mock_engine._queued_frames
assert len(frames) == 3
@ -553,7 +569,7 @@ class TestPlayConfigMessage:
mock_engine._fetch_recording_audio = None
manager = CustomToolManager(mock_engine)
config = {"messageType": "audio", "audioRecordingId": "rec-123"}
config = {"messageType": "audio", "audioRecordingId": "301"}
result = await manager._play_config_message(config)
@ -567,12 +583,12 @@ class TestPlayConfigMessage:
mock_engine._fetch_recording_audio = mock_fetch
manager = CustomToolManager(mock_engine)
config = {"messageType": "audio", "audioRecordingId": "rec-123"}
config = {"messageType": "audio", "audioRecordingId": "301"}
result = await manager._play_config_message(config)
assert result is False
mock_fetch.assert_called_once_with("rec-123")
mock_fetch.assert_called_once_with(recording_pk=301)
assert len(mock_engine._queued_frames) == 0
@pytest.mark.asyncio