mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-25 08:48:13 +02:00
chore: fix tests
This commit is contained in:
parent
ffe9a99401
commit
d34a23d36c
3 changed files with 36 additions and 18 deletions
|
|
@ -17,6 +17,7 @@ from typing import Awaitable, Callable, Optional
|
||||||
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
|
from api.services.pipecat.recording_audio_cache import RecordingAudio
|
||||||
from api.services.workflow.pipecat_engine_context_composer import (
|
from api.services.workflow.pipecat_engine_context_composer import (
|
||||||
RECORDING_MARKER,
|
RECORDING_MARKER,
|
||||||
TTS_MARKER,
|
TTS_MARKER,
|
||||||
|
|
@ -48,14 +49,14 @@ class RecordingRouterProcessor(FrameProcessor):
|
||||||
Args:
|
Args:
|
||||||
audio_sample_rate: Pipeline sample rate for OutputAudioRawFrame.
|
audio_sample_rate: Pipeline sample rate for OutputAudioRawFrame.
|
||||||
fetch_recording_audio: Async callback that takes a recording_id and
|
fetch_recording_audio: Async callback that takes a recording_id and
|
||||||
returns raw 16-bit mono PCM bytes, or None on failure.
|
returns a RecordingAudio (audio + transcript), or None on failure.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
audio_sample_rate: int,
|
audio_sample_rate: int,
|
||||||
fetch_recording_audio: Callable[[str], Awaitable[Optional[bytes]]],
|
fetch_recording_audio: Callable[..., Awaitable[Optional[RecordingAudio]]],
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
|
|
||||||
|
|
@ -13,6 +13,7 @@ from typing import Optional
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from api.services.pipecat.recording_audio_cache import RecordingAudio
|
||||||
from api.services.pipecat.recording_router_processor import (
|
from api.services.pipecat.recording_router_processor import (
|
||||||
RecordingRouterProcessor,
|
RecordingRouterProcessor,
|
||||||
)
|
)
|
||||||
|
|
@ -37,9 +38,9 @@ from pipecat.tests import run_test
|
||||||
FAKE_AUDIO = b"\x00\x01" * 8000 # 1 second of 16-bit mono @ 16 kHz
|
FAKE_AUDIO = b"\x00\x01" * 8000 # 1 second of 16-bit mono @ 16 kHz
|
||||||
|
|
||||||
|
|
||||||
async def _fake_fetch(recording_id: str) -> Optional[bytes]:
|
async def _fake_fetch(recording_id: str) -> Optional[RecordingAudio]:
|
||||||
"""Stub that returns fake PCM audio for any recording_id."""
|
"""Stub that returns fake PCM audio for any recording_id."""
|
||||||
return FAKE_AUDIO
|
return RecordingAudio(audio=FAKE_AUDIO)
|
||||||
|
|
||||||
|
|
||||||
def _make_processor(**kwargs) -> RecordingRouterProcessor:
|
def _make_processor(**kwargs) -> RecordingRouterProcessor:
|
||||||
|
|
@ -189,7 +190,7 @@ class TestMixedMarkerSuppression:
|
||||||
|
|
||||||
async def tracking_fetch(recording_id: str):
|
async def tracking_fetch(recording_id: str):
|
||||||
fetched_ids.append(recording_id)
|
fetched_ids.append(recording_id)
|
||||||
return FAKE_AUDIO
|
return RecordingAudio(audio=FAKE_AUDIO)
|
||||||
|
|
||||||
processor = _make_processor(fetch=tracking_fetch)
|
processor = _make_processor(fetch=tracking_fetch)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -12,6 +12,7 @@ from unittest.mock import AsyncMock, Mock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from api.services.pipecat.recording_audio_cache import RecordingAudio
|
||||||
from api.services.workflow.dto import (
|
from api.services.workflow.dto import (
|
||||||
EdgeDataDTO,
|
EdgeDataDTO,
|
||||||
NodeDataDTO,
|
NodeDataDTO,
|
||||||
|
|
@ -51,7 +52,7 @@ END_PROMPT = "End Call System Prompt"
|
||||||
TEXT_GREETING = "Hello, welcome to our service!"
|
TEXT_GREETING = "Hello, welcome to our service!"
|
||||||
TEXT_TRANSITION = "Thank you for calling, goodbye!"
|
TEXT_TRANSITION = "Thank you for calling, goodbye!"
|
||||||
AUDIO_GREETING_ID = "rec-greeting-001"
|
AUDIO_GREETING_ID = "rec-greeting-001"
|
||||||
AUDIO_TRANSITION_ID = "rec-transition-001"
|
AUDIO_TRANSITION_ID = "101"
|
||||||
FAKE_PCM_AUDIO = b"\x00\x01" * 1000 # Fake 16-bit mono PCM data
|
FAKE_PCM_AUDIO = b"\x00\x01" * 1000 # Fake 16-bit mono PCM data
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -204,16 +205,18 @@ async def run_pipeline_and_capture_frames(
|
||||||
workflow_run_id=1,
|
workflow_run_id=1,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
transport_output = mock_transport.output()
|
||||||
|
|
||||||
if fetch_recording_audio:
|
if fetch_recording_audio:
|
||||||
engine.set_fetch_recording_audio(fetch_recording_audio)
|
engine.set_fetch_recording_audio(fetch_recording_audio)
|
||||||
|
engine.set_transport_output(transport_output)
|
||||||
|
|
||||||
pipeline = Pipeline(
|
pipeline = Pipeline([llm, tts, transport_output, context_aggregator.assistant()])
|
||||||
[llm, tts, mock_transport.output(), context_aggregator.assistant()]
|
|
||||||
)
|
|
||||||
task = PipelineTask(pipeline, params=PipelineParams(), enable_rtvi=False)
|
task = PipelineTask(pipeline, params=PipelineParams(), enable_rtvi=False)
|
||||||
engine.set_task(task)
|
engine.set_task(task)
|
||||||
|
|
||||||
# Spy on task.queue_frame to capture all frames queued by the engine
|
# Spy on task.queue_frame and transport_output.queue_frame to capture
|
||||||
|
# all frames queued by the engine (audio transitions go via transport output)
|
||||||
queued_frames: list[Frame] = []
|
queued_frames: list[Frame] = []
|
||||||
original_queue_frame = task.queue_frame
|
original_queue_frame = task.queue_frame
|
||||||
|
|
||||||
|
|
@ -223,6 +226,15 @@ async def run_pipeline_and_capture_frames(
|
||||||
|
|
||||||
task.queue_frame = capturing_queue_frame
|
task.queue_frame = capturing_queue_frame
|
||||||
|
|
||||||
|
if fetch_recording_audio:
|
||||||
|
original_transport_queue = transport_output.queue_frame
|
||||||
|
|
||||||
|
async def _spy_transport_queue(frame, *args, **kwargs):
|
||||||
|
queued_frames.append(frame)
|
||||||
|
await original_transport_queue(frame, *args, **kwargs)
|
||||||
|
|
||||||
|
transport_output.queue_frame = _spy_transport_queue
|
||||||
|
|
||||||
with (
|
with (
|
||||||
patch(
|
patch(
|
||||||
"api.services.workflow.pipecat_engine.get_organization_id_from_workflow_run",
|
"api.services.workflow.pipecat_engine.get_organization_id_from_workflow_run",
|
||||||
|
|
@ -424,7 +436,7 @@ class TestTransitionSpeech:
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
mock_fetch = AsyncMock(return_value=FAKE_PCM_AUDIO)
|
mock_fetch = AsyncMock(return_value=RecordingAudio(audio=FAKE_PCM_AUDIO))
|
||||||
|
|
||||||
llm, context, queued_frames = await run_pipeline_and_capture_frames(
|
llm, context, queued_frames = await run_pipeline_and_capture_frames(
|
||||||
workflow=audio_workflow,
|
workflow=audio_workflow,
|
||||||
|
|
@ -437,7 +449,7 @@ class TestTransitionSpeech:
|
||||||
assert llm.get_current_step() == 2
|
assert llm.get_current_step() == 2
|
||||||
|
|
||||||
# Verify fetch was called with the correct recording ID
|
# Verify fetch was called with the correct recording ID
|
||||||
mock_fetch.assert_called_once_with(AUDIO_TRANSITION_ID)
|
mock_fetch.assert_called_once_with(recording_pk=int(AUDIO_TRANSITION_ID))
|
||||||
|
|
||||||
# Verify the three-frame audio sequence was queued
|
# Verify the three-frame audio sequence was queued
|
||||||
started = [f for f in queued_frames if isinstance(f, TTSStartedFrame)]
|
started = [f for f in queued_frames if isinstance(f, TTSStartedFrame)]
|
||||||
|
|
@ -491,6 +503,10 @@ class TestPlayConfigMessage:
|
||||||
engine._queued_frames.append(frame)
|
engine._queued_frames.append(frame)
|
||||||
|
|
||||||
engine.task.queue_frame = mock_queue_frame
|
engine.task.queue_frame = mock_queue_frame
|
||||||
|
|
||||||
|
# Also capture frames queued via transport_output.queue_frame (audio playback)
|
||||||
|
engine._transport_output = Mock()
|
||||||
|
engine._transport_output.queue_frame = mock_queue_frame
|
||||||
return engine
|
return engine
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
|
@ -510,16 +526,16 @@ class TestPlayConfigMessage:
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_audio_queues_started_raw_stopped_frames(self, mock_engine):
|
async def test_audio_queues_started_raw_stopped_frames(self, mock_engine):
|
||||||
"""messageType='audio' queues TTSStarted + TTSAudioRaw + TTSStopped."""
|
"""messageType='audio' queues TTSStarted + TTSAudioRaw + TTSStopped."""
|
||||||
mock_fetch = AsyncMock(return_value=FAKE_PCM_AUDIO)
|
mock_fetch = AsyncMock(return_value=RecordingAudio(audio=FAKE_PCM_AUDIO))
|
||||||
mock_engine._fetch_recording_audio = mock_fetch
|
mock_engine._fetch_recording_audio = mock_fetch
|
||||||
|
|
||||||
manager = CustomToolManager(mock_engine)
|
manager = CustomToolManager(mock_engine)
|
||||||
config = {"messageType": "audio", "audioRecordingId": "rec-end-001"}
|
config = {"messageType": "audio", "audioRecordingId": "201"}
|
||||||
|
|
||||||
result = await manager._play_config_message(config)
|
result = await manager._play_config_message(config)
|
||||||
|
|
||||||
assert result is True
|
assert result is True
|
||||||
mock_fetch.assert_called_once_with("rec-end-001")
|
mock_fetch.assert_called_once_with(recording_pk=201)
|
||||||
|
|
||||||
frames = mock_engine._queued_frames
|
frames = mock_engine._queued_frames
|
||||||
assert len(frames) == 3
|
assert len(frames) == 3
|
||||||
|
|
@ -553,7 +569,7 @@ class TestPlayConfigMessage:
|
||||||
mock_engine._fetch_recording_audio = None
|
mock_engine._fetch_recording_audio = None
|
||||||
|
|
||||||
manager = CustomToolManager(mock_engine)
|
manager = CustomToolManager(mock_engine)
|
||||||
config = {"messageType": "audio", "audioRecordingId": "rec-123"}
|
config = {"messageType": "audio", "audioRecordingId": "301"}
|
||||||
|
|
||||||
result = await manager._play_config_message(config)
|
result = await manager._play_config_message(config)
|
||||||
|
|
||||||
|
|
@ -567,12 +583,12 @@ class TestPlayConfigMessage:
|
||||||
mock_engine._fetch_recording_audio = mock_fetch
|
mock_engine._fetch_recording_audio = mock_fetch
|
||||||
|
|
||||||
manager = CustomToolManager(mock_engine)
|
manager = CustomToolManager(mock_engine)
|
||||||
config = {"messageType": "audio", "audioRecordingId": "rec-123"}
|
config = {"messageType": "audio", "audioRecordingId": "301"}
|
||||||
|
|
||||||
result = await manager._play_config_message(config)
|
result = await manager._play_config_message(config)
|
||||||
|
|
||||||
assert result is False
|
assert result is False
|
||||||
mock_fetch.assert_called_once_with("rec-123")
|
mock_fetch.assert_called_once_with(recording_pk=301)
|
||||||
assert len(mock_engine._queued_frames) == 0
|
assert len(mock_engine._queued_frames) == 0
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue