mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-07 07:55:16 +02:00
chore: fix tests
This commit is contained in:
parent
ffe9a99401
commit
d34a23d36c
3 changed files with 36 additions and 18 deletions
|
|
@ -17,6 +17,7 @@ from typing import Awaitable, Callable, Optional
|
|||
|
||||
from loguru import logger
|
||||
|
||||
from api.services.pipecat.recording_audio_cache import RecordingAudio
|
||||
from api.services.workflow.pipecat_engine_context_composer import (
|
||||
RECORDING_MARKER,
|
||||
TTS_MARKER,
|
||||
|
|
@ -48,14 +49,14 @@ class RecordingRouterProcessor(FrameProcessor):
|
|||
Args:
|
||||
audio_sample_rate: Pipeline sample rate for OutputAudioRawFrame.
|
||||
fetch_recording_audio: Async callback that takes a recording_id and
|
||||
returns raw 16-bit mono PCM bytes, or None on failure.
|
||||
returns a RecordingAudio (audio + transcript), or None on failure.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
audio_sample_rate: int,
|
||||
fetch_recording_audio: Callable[[str], Awaitable[Optional[bytes]]],
|
||||
fetch_recording_audio: Callable[..., Awaitable[Optional[RecordingAudio]]],
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ from typing import Optional
|
|||
|
||||
import pytest
|
||||
|
||||
from api.services.pipecat.recording_audio_cache import RecordingAudio
|
||||
from api.services.pipecat.recording_router_processor import (
|
||||
RecordingRouterProcessor,
|
||||
)
|
||||
|
|
@ -37,9 +38,9 @@ from pipecat.tests import run_test
|
|||
FAKE_AUDIO = b"\x00\x01" * 8000 # 1 second of 16-bit mono @ 16 kHz
|
||||
|
||||
|
||||
async def _fake_fetch(recording_id: str) -> Optional[bytes]:
|
||||
async def _fake_fetch(recording_id: str) -> Optional[RecordingAudio]:
|
||||
"""Stub that returns fake PCM audio for any recording_id."""
|
||||
return FAKE_AUDIO
|
||||
return RecordingAudio(audio=FAKE_AUDIO)
|
||||
|
||||
|
||||
def _make_processor(**kwargs) -> RecordingRouterProcessor:
|
||||
|
|
@ -189,7 +190,7 @@ class TestMixedMarkerSuppression:
|
|||
|
||||
async def tracking_fetch(recording_id: str):
|
||||
fetched_ids.append(recording_id)
|
||||
return FAKE_AUDIO
|
||||
return RecordingAudio(audio=FAKE_AUDIO)
|
||||
|
||||
processor = _make_processor(fetch=tracking_fetch)
|
||||
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ from unittest.mock import AsyncMock, Mock, patch
|
|||
|
||||
import pytest
|
||||
|
||||
from api.services.pipecat.recording_audio_cache import RecordingAudio
|
||||
from api.services.workflow.dto import (
|
||||
EdgeDataDTO,
|
||||
NodeDataDTO,
|
||||
|
|
@ -51,7 +52,7 @@ END_PROMPT = "End Call System Prompt"
|
|||
TEXT_GREETING = "Hello, welcome to our service!"
|
||||
TEXT_TRANSITION = "Thank you for calling, goodbye!"
|
||||
AUDIO_GREETING_ID = "rec-greeting-001"
|
||||
AUDIO_TRANSITION_ID = "rec-transition-001"
|
||||
AUDIO_TRANSITION_ID = "101"
|
||||
FAKE_PCM_AUDIO = b"\x00\x01" * 1000 # Fake 16-bit mono PCM data
|
||||
|
||||
|
||||
|
|
@ -204,16 +205,18 @@ async def run_pipeline_and_capture_frames(
|
|||
workflow_run_id=1,
|
||||
)
|
||||
|
||||
transport_output = mock_transport.output()
|
||||
|
||||
if fetch_recording_audio:
|
||||
engine.set_fetch_recording_audio(fetch_recording_audio)
|
||||
engine.set_transport_output(transport_output)
|
||||
|
||||
pipeline = Pipeline(
|
||||
[llm, tts, mock_transport.output(), context_aggregator.assistant()]
|
||||
)
|
||||
pipeline = Pipeline([llm, tts, transport_output, context_aggregator.assistant()])
|
||||
task = PipelineTask(pipeline, params=PipelineParams(), enable_rtvi=False)
|
||||
engine.set_task(task)
|
||||
|
||||
# Spy on task.queue_frame to capture all frames queued by the engine
|
||||
# Spy on task.queue_frame and transport_output.queue_frame to capture
|
||||
# all frames queued by the engine (audio transitions go via transport output)
|
||||
queued_frames: list[Frame] = []
|
||||
original_queue_frame = task.queue_frame
|
||||
|
||||
|
|
@ -223,6 +226,15 @@ async def run_pipeline_and_capture_frames(
|
|||
|
||||
task.queue_frame = capturing_queue_frame
|
||||
|
||||
if fetch_recording_audio:
|
||||
original_transport_queue = transport_output.queue_frame
|
||||
|
||||
async def _spy_transport_queue(frame, *args, **kwargs):
|
||||
queued_frames.append(frame)
|
||||
await original_transport_queue(frame, *args, **kwargs)
|
||||
|
||||
transport_output.queue_frame = _spy_transport_queue
|
||||
|
||||
with (
|
||||
patch(
|
||||
"api.services.workflow.pipecat_engine.get_organization_id_from_workflow_run",
|
||||
|
|
@ -424,7 +436,7 @@ class TestTransitionSpeech:
|
|||
},
|
||||
]
|
||||
|
||||
mock_fetch = AsyncMock(return_value=FAKE_PCM_AUDIO)
|
||||
mock_fetch = AsyncMock(return_value=RecordingAudio(audio=FAKE_PCM_AUDIO))
|
||||
|
||||
llm, context, queued_frames = await run_pipeline_and_capture_frames(
|
||||
workflow=audio_workflow,
|
||||
|
|
@ -437,7 +449,7 @@ class TestTransitionSpeech:
|
|||
assert llm.get_current_step() == 2
|
||||
|
||||
# Verify fetch was called with the correct recording ID
|
||||
mock_fetch.assert_called_once_with(AUDIO_TRANSITION_ID)
|
||||
mock_fetch.assert_called_once_with(recording_pk=int(AUDIO_TRANSITION_ID))
|
||||
|
||||
# Verify the three-frame audio sequence was queued
|
||||
started = [f for f in queued_frames if isinstance(f, TTSStartedFrame)]
|
||||
|
|
@ -491,6 +503,10 @@ class TestPlayConfigMessage:
|
|||
engine._queued_frames.append(frame)
|
||||
|
||||
engine.task.queue_frame = mock_queue_frame
|
||||
|
||||
# Also capture frames queued via transport_output.queue_frame (audio playback)
|
||||
engine._transport_output = Mock()
|
||||
engine._transport_output.queue_frame = mock_queue_frame
|
||||
return engine
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -510,16 +526,16 @@ class TestPlayConfigMessage:
|
|||
@pytest.mark.asyncio
|
||||
async def test_audio_queues_started_raw_stopped_frames(self, mock_engine):
|
||||
"""messageType='audio' queues TTSStarted + TTSAudioRaw + TTSStopped."""
|
||||
mock_fetch = AsyncMock(return_value=FAKE_PCM_AUDIO)
|
||||
mock_fetch = AsyncMock(return_value=RecordingAudio(audio=FAKE_PCM_AUDIO))
|
||||
mock_engine._fetch_recording_audio = mock_fetch
|
||||
|
||||
manager = CustomToolManager(mock_engine)
|
||||
config = {"messageType": "audio", "audioRecordingId": "rec-end-001"}
|
||||
config = {"messageType": "audio", "audioRecordingId": "201"}
|
||||
|
||||
result = await manager._play_config_message(config)
|
||||
|
||||
assert result is True
|
||||
mock_fetch.assert_called_once_with("rec-end-001")
|
||||
mock_fetch.assert_called_once_with(recording_pk=201)
|
||||
|
||||
frames = mock_engine._queued_frames
|
||||
assert len(frames) == 3
|
||||
|
|
@ -553,7 +569,7 @@ class TestPlayConfigMessage:
|
|||
mock_engine._fetch_recording_audio = None
|
||||
|
||||
manager = CustomToolManager(mock_engine)
|
||||
config = {"messageType": "audio", "audioRecordingId": "rec-123"}
|
||||
config = {"messageType": "audio", "audioRecordingId": "301"}
|
||||
|
||||
result = await manager._play_config_message(config)
|
||||
|
||||
|
|
@ -567,12 +583,12 @@ class TestPlayConfigMessage:
|
|||
mock_engine._fetch_recording_audio = mock_fetch
|
||||
|
||||
manager = CustomToolManager(mock_engine)
|
||||
config = {"messageType": "audio", "audioRecordingId": "rec-123"}
|
||||
config = {"messageType": "audio", "audioRecordingId": "301"}
|
||||
|
||||
result = await manager._play_config_message(config)
|
||||
|
||||
assert result is False
|
||||
mock_fetch.assert_called_once_with("rec-123")
|
||||
mock_fetch.assert_called_once_with(recording_pk=301)
|
||||
assert len(mock_engine._queued_frames) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue