mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-07 07:55:16 +02:00
fix: allow interruption on start_node
This commit is contained in:
parent
5fe1c8ce2f
commit
7e438ad049
10 changed files with 1013 additions and 144 deletions
|
|
@ -6,6 +6,8 @@ from api.services.workflow.disposition_mapper import (
|
|||
)
|
||||
from api.services.workflow.workflow import Node, WorkflowGraph
|
||||
from pipecat.frames.frames import (
|
||||
BotStartedSpeakingFrame,
|
||||
BotStoppedSpeakingFrame,
|
||||
CancelFrame,
|
||||
EndFrame,
|
||||
FunctionCallResultProperties,
|
||||
|
|
@ -99,6 +101,9 @@ class PipecatEngine:
|
|||
# Controls whether user input should be muted
|
||||
self._mute_pipeline: bool = False
|
||||
|
||||
# Tracks whether the bot is currently speaking (for allow_interrupt logic)
|
||||
self._bot_is_speaking: bool = False
|
||||
|
||||
# Custom tool manager (initialized in initialize())
|
||||
self._custom_tool_manager: Optional[CustomToolManager] = None
|
||||
|
||||
|
|
@ -614,10 +619,30 @@ class PipecatEngine:
|
|||
"""
|
||||
Callback for CallbackUserMuteStrategy to determine if the user should be muted.
|
||||
|
||||
This method tracks bot speaking state from frames and mutes the user when:
|
||||
- The pipeline is being shut down (_mute_pipeline is True), OR
|
||||
- The bot is speaking AND the current node has allow_interrupt=False
|
||||
|
||||
Returns:
|
||||
True if the user should be muted, False otherwise.
|
||||
"""
|
||||
return self._mute_pipeline
|
||||
# Track bot speaking state from frames
|
||||
if isinstance(frame, BotStartedSpeakingFrame):
|
||||
self._bot_is_speaking = True
|
||||
elif isinstance(frame, BotStoppedSpeakingFrame):
|
||||
self._bot_is_speaking = False
|
||||
|
||||
# Always mute if pipeline is shutting down
|
||||
if self._mute_pipeline:
|
||||
return True
|
||||
|
||||
# Mute if bot is speaking and current node doesn't allow interruption
|
||||
if self._bot_is_speaking and self._current_node:
|
||||
# If we should not allow interruption, mute the pipeline
|
||||
if not self._current_node.allow_interrupt:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def create_user_idle_handler(self):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -37,6 +37,7 @@ from pipecat.processors.aggregators.llm_response_universal import (
|
|||
)
|
||||
from pipecat.tests import MockLLMService, MockTTSService
|
||||
from pipecat.tests.mock_transport import MockTransport
|
||||
from pipecat.transports.base_transport import TransportParams
|
||||
|
||||
|
||||
class ContextCapturingMockLLM(MockLLMService):
|
||||
|
|
@ -129,10 +130,17 @@ async def run_pipeline_and_capture_context(
|
|||
llm = ContextCapturingMockLLM(mock_steps=mock_steps, chunk_delay=0.001)
|
||||
|
||||
# Create MockTTSService
|
||||
tts = MockTTSService(mock_audio_duration_ms=10, frame_delay=0)
|
||||
tts = MockTTSService(mock_audio_duration_ms=40, frame_delay=0)
|
||||
|
||||
# Create MockTransport for simulating transport behavior
|
||||
mock_transport = MockTransport(emit_bot_speaking=False)
|
||||
mock_transport = MockTransport(
|
||||
params=TransportParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
audio_in_sample_rate=16000,
|
||||
audio_out_sample_rate=16000,
|
||||
),
|
||||
)
|
||||
|
||||
# Create LLM context
|
||||
context = LLMContext()
|
||||
|
|
|
|||
|
|
@ -53,6 +53,7 @@ from pipecat.processors.aggregators.llm_response_universal import (
|
|||
)
|
||||
from pipecat.tests import MockLLMService, MockTTSService
|
||||
from pipecat.tests.mock_transport import MockTransport
|
||||
from pipecat.transports.base_transport import TransportParams
|
||||
from pipecat.turns.user_mute import (
|
||||
CallbackUserMuteStrategy,
|
||||
MuteUntilFirstBotCompleteUserMuteStrategy,
|
||||
|
|
@ -125,15 +126,17 @@ async def create_engine_with_tracking(
|
|||
Tuple of (engine, tts, transport, task)
|
||||
"""
|
||||
# Create MockTTSService
|
||||
tts = MockTTSService(mock_audio_duration_ms=10, frame_delay=0)
|
||||
tts = MockTTSService(mock_audio_duration_ms=40, frame_delay=0)
|
||||
|
||||
# Create MockTransport with audio generation to simulate real pipeline
|
||||
mock_transport = MockTransport(
|
||||
generate_audio=generate_audio,
|
||||
audio_interval_ms=20,
|
||||
audio_sample_rate=16000,
|
||||
audio_num_channels=1,
|
||||
emit_bot_speaking=True,
|
||||
params=TransportParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
audio_in_sample_rate=16000,
|
||||
audio_out_sample_rate=16000,
|
||||
),
|
||||
)
|
||||
|
||||
# Create LLM context
|
||||
|
|
|
|||
|
|
@ -7,12 +7,11 @@ in the PipecatEngine. The key scenario being tested:
|
|||
2. At the same time, user starts and stops speaking (triggered by FunctionCallResultFrame)
|
||||
3. The pipeline should handle both events correctly
|
||||
|
||||
The tests use a custom input transport that injects UserStartedSpeakingFrame and
|
||||
UserStoppedSpeakingFrame when triggered by a FunctionCallResultFrame observer.
|
||||
The tests use a UserSpeechInjector processor that injects UserStartedSpeakingFrame and
|
||||
UserStoppedSpeakingFrame when triggered by a FunctionCallResultFrame.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import Optional
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
|
@ -20,13 +19,9 @@ import pytest
|
|||
from api.services.workflow.pipecat_engine import PipecatEngine
|
||||
from api.services.workflow.workflow import WorkflowGraph
|
||||
from pipecat.frames.frames import (
|
||||
CancelFrame,
|
||||
EndFrame,
|
||||
Frame,
|
||||
FunctionCallResultFrame,
|
||||
InputAudioRawFrame,
|
||||
LLMContextFrame,
|
||||
StartFrame,
|
||||
TranscriptionFrame,
|
||||
UserStartedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
|
|
@ -42,8 +37,8 @@ from pipecat.processors.aggregators.llm_response_universal import (
|
|||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
from pipecat.tests import MockLLMService, MockTTSService
|
||||
from pipecat.tests.mock_transport import MockOutputTransport
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
from pipecat.tests.mock_transport import MockTransport
|
||||
from pipecat.transports.base_transport import TransportParams
|
||||
from pipecat.turns.user_mute import (
|
||||
CallbackUserMuteStrategy,
|
||||
MuteUntilFirstBotCompleteUserMuteStrategy,
|
||||
|
|
@ -58,76 +53,35 @@ from pipecat.turns.user_turn_strategies import UserTurnStrategies
|
|||
from pipecat.utils.time import time_now_iso8601
|
||||
|
||||
|
||||
class UserSpeechInjectingInputTransport(FrameProcessor):
|
||||
"""Mock input transport that injects user speaking frames on FunctionCallResultFrame.
|
||||
class UserSpeechInjector(FrameProcessor):
|
||||
"""Processor that injects user speaking frames on FunctionCallResultFrame.
|
||||
|
||||
This transport generates audio frames and automatically injects UserStartedSpeakingFrame
|
||||
and UserStoppedSpeakingFrame when it sees the first FunctionCallResultFrame flowing
|
||||
upstream through the pipeline.
|
||||
When this processor sees the first FunctionCallResultFrame flowing upstream,
|
||||
it injects UserStartedSpeakingFrame, TranscriptionFrame, and UserStoppedSpeakingFrame
|
||||
downstream to simulate user speech during a function call.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
params: Optional[TransportParams] = None,
|
||||
*,
|
||||
generate_audio: bool = False,
|
||||
audio_interval_ms: int = 20,
|
||||
sample_rate: int = 16000,
|
||||
num_channels: int = 1,
|
||||
user_speech_initial_delay: float = 0.01,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the user speech injector.
|
||||
|
||||
Args:
|
||||
user_speech_initial_delay: Delay in seconds before injecting
|
||||
UserStartedSpeakingFrame after seeing FunctionCallResultFrame.
|
||||
**kwargs: Additional arguments passed to parent class.
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
self._params = params or TransportParams()
|
||||
self._generate_audio = generate_audio
|
||||
self._audio_interval_ms = audio_interval_ms
|
||||
self._sample_rate = sample_rate
|
||||
self._num_channels = num_channels
|
||||
self._user_speech_initial_delay = user_speech_initial_delay
|
||||
self._audio_task: Optional[asyncio.Task] = None
|
||||
self._running = False
|
||||
self._function_call_result_count = 0
|
||||
|
||||
async def _generate_audio_frames(self):
|
||||
"""Generate audio frames at regular intervals."""
|
||||
samples_per_frame = int(self._sample_rate * self._audio_interval_ms / 1000)
|
||||
bytes_per_frame = samples_per_frame * self._num_channels * 2
|
||||
silence_audio = bytes(bytes_per_frame)
|
||||
|
||||
while self._running:
|
||||
try:
|
||||
frame = InputAudioRawFrame(
|
||||
audio=silence_audio,
|
||||
sample_rate=self._sample_rate,
|
||||
num_channels=self._num_channels,
|
||||
)
|
||||
await self.push_frame(frame)
|
||||
await asyncio.sleep(self._audio_interval_ms / 1000)
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
|
||||
def _start_tasks(self):
|
||||
"""Start audio generation task."""
|
||||
if not self._running:
|
||||
self._running = True
|
||||
if self._generate_audio:
|
||||
self._audio_task = asyncio.create_task(self._generate_audio_frames())
|
||||
|
||||
def _stop_tasks(self):
|
||||
"""Stop all background tasks."""
|
||||
self._running = False
|
||||
if self._audio_task and not self._audio_task.done():
|
||||
self._audio_task.cancel()
|
||||
self._audio_task = None
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, StartFrame):
|
||||
self._start_tasks()
|
||||
elif isinstance(frame, (EndFrame, CancelFrame)):
|
||||
self._stop_tasks()
|
||||
elif isinstance(frame, FunctionCallResultFrame):
|
||||
if isinstance(frame, FunctionCallResultFrame):
|
||||
# When we see FunctionCallResultFrame #1 flowing upstream,
|
||||
# inject user speaking frames downstream
|
||||
self._function_call_result_count += 1
|
||||
|
|
@ -160,68 +114,21 @@ class UserSpeechInjectingInputTransport(FrameProcessor):
|
|||
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
async def cleanup(self):
|
||||
self._stop_tasks()
|
||||
await super().cleanup()
|
||||
|
||||
|
||||
class UserSpeechInjectingTransport(BaseTransport):
|
||||
"""Transport that injects user speaking frames on first FunctionCallResultFrame."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
params: Optional[TransportParams] = None,
|
||||
*,
|
||||
input_name: Optional[str] = None,
|
||||
output_name: Optional[str] = None,
|
||||
emit_bot_speaking: bool = True,
|
||||
generate_audio: bool = False,
|
||||
audio_interval_ms: int = 20,
|
||||
audio_sample_rate: int = 16000,
|
||||
audio_num_channels: int = 1,
|
||||
user_speech_initial_delay: float = 0.01,
|
||||
):
|
||||
super().__init__(input_name=input_name, output_name=output_name)
|
||||
self._params = params or TransportParams()
|
||||
self._input = UserSpeechInjectingInputTransport(
|
||||
self._params,
|
||||
name=self._input_name,
|
||||
generate_audio=generate_audio,
|
||||
audio_interval_ms=audio_interval_ms,
|
||||
sample_rate=audio_sample_rate,
|
||||
num_channels=audio_num_channels,
|
||||
user_speech_initial_delay=user_speech_initial_delay,
|
||||
)
|
||||
self._output = MockOutputTransport(
|
||||
self._params,
|
||||
emit_bot_speaking=emit_bot_speaking,
|
||||
name=self._output_name,
|
||||
)
|
||||
|
||||
def input(self) -> UserSpeechInjectingInputTransport:
|
||||
return self._input
|
||||
|
||||
def output(self) -> FrameProcessor:
|
||||
return self._output
|
||||
|
||||
|
||||
async def create_test_pipeline(
|
||||
workflow: WorkflowGraph,
|
||||
mock_llm: MockLLMService,
|
||||
generate_audio: bool = True,
|
||||
user_speech_initial_delay: float = 0.01,
|
||||
) -> tuple[PipecatEngine, UserSpeechInjectingTransport, PipelineTask]:
|
||||
) -> tuple[PipecatEngine, MockTransport, PipelineTask]:
|
||||
"""Create a PipecatEngine with full pipeline for testing node switch scenarios.
|
||||
|
||||
The transport's input automatically injects UserStartedSpeakingFrame and
|
||||
UserStoppedSpeakingFrame when it sees the first FunctionCallResultFrame
|
||||
flowing upstream through the pipeline.
|
||||
The pipeline includes a UserSpeechInjector processor that injects
|
||||
UserStartedSpeakingFrame and UserStoppedSpeakingFrame when it sees
|
||||
the first FunctionCallResultFrame flowing upstream.
|
||||
|
||||
Args:
|
||||
workflow: The workflow graph to use.
|
||||
mock_llm: The mock LLM service.
|
||||
generate_audio: If True, the mock transport generates InputAudioRawFrame
|
||||
every 20ms to simulate real audio input.
|
||||
user_speech_initial_delay: Delay in seconds before injecting
|
||||
UserStartedSpeakingFrame after seeing FunctionCallResultFrame.
|
||||
|
||||
|
|
@ -229,15 +136,20 @@ async def create_test_pipeline(
|
|||
Tuple of (engine, transport, task)
|
||||
"""
|
||||
# Create MockTTSService
|
||||
tts = MockTTSService(mock_audio_duration_ms=10, frame_delay=0)
|
||||
tts = MockTTSService(mock_audio_duration_ms=40, frame_delay=0)
|
||||
|
||||
# Create custom transport that injects user speaking frames on FunctionCallResultFrame #1
|
||||
transport = UserSpeechInjectingTransport(
|
||||
generate_audio=generate_audio,
|
||||
audio_interval_ms=20,
|
||||
audio_sample_rate=16000,
|
||||
audio_num_channels=1,
|
||||
emit_bot_speaking=True,
|
||||
# Create MockTransport
|
||||
transport = MockTransport(
|
||||
params=TransportParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
audio_in_sample_rate=16000,
|
||||
audio_out_sample_rate=16000,
|
||||
),
|
||||
)
|
||||
|
||||
# Create user speech injector processor
|
||||
user_speech_injector = UserSpeechInjector(
|
||||
user_speech_initial_delay=user_speech_initial_delay,
|
||||
)
|
||||
|
||||
|
|
@ -280,12 +192,13 @@ async def create_test_pipeline(
|
|||
assistant_context_aggregator = context_aggregator.assistant()
|
||||
|
||||
# Create the pipeline:
|
||||
# transport.input() -> user_aggregator -> LLM -> TTS -> transport.output() -> assistant_aggregator
|
||||
# The transport input watches for FunctionCallResultFrame flowing upstream
|
||||
# transport.input() -> user_speech_injector -> user_aggregator -> LLM -> TTS -> transport.output() -> assistant_aggregator
|
||||
# The user_speech_injector watches for FunctionCallResultFrame flowing upstream
|
||||
# and injects user speaking frames when it sees the first one
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
transport.input(),
|
||||
user_speech_injector,
|
||||
user_context_aggregator,
|
||||
mock_llm,
|
||||
tts,
|
||||
|
|
@ -325,17 +238,15 @@ class TestNodeSwitchWithUserSpeech:
|
|||
This test creates the scenario where:
|
||||
1. LLM generates text and calls collect_info to transition from start to agent
|
||||
2. When FunctionCallResultFrame #1 is seen, UserStartedSpeakingFrame and
|
||||
UserStoppedSpeakingFrame are automatically injected from the pipeline source
|
||||
UserStoppedSpeakingFrame are automatically injected by UserSpeechInjector
|
||||
3. The pipeline processes both events concurrently
|
||||
|
||||
The FunctionCallResultObserver in the pipeline detects the first function call
|
||||
result and triggers the transport to inject user speaking frames.
|
||||
The UserSpeechInjector processor in the pipeline detects the first function call
|
||||
result and injects user speaking frames.
|
||||
|
||||
This test is parameterized with two scenarios:
|
||||
- delayed_user_speech: 10ms delay before UserStartedSpeakingFrame (user_speech_initial_delay=0.01)
|
||||
- immediate_user_speech: No delay before UserStartedSpeakingFrame (user_speech_initial_delay=0)
|
||||
|
||||
This is a scenario creation test - no specific assertions yet.
|
||||
"""
|
||||
# Step 0 (Start node): greet user then call collect_info to transition to agent
|
||||
step_0_chunks = MockLLMService.create_mixed_chunks(
|
||||
|
|
|
|||
|
|
@ -24,6 +24,7 @@ from pipecat.processors.aggregators.llm_response_universal import (
|
|||
)
|
||||
from pipecat.tests import MockLLMService, MockTTSService
|
||||
from pipecat.tests.mock_transport import MockTransport
|
||||
from pipecat.transports.base_transport import TransportParams
|
||||
|
||||
|
||||
async def run_pipeline_with_tool_calls(
|
||||
|
|
@ -64,10 +65,17 @@ async def run_pipeline_with_tool_calls(
|
|||
llm = MockLLMService(mock_steps=mock_steps, chunk_delay=0.001)
|
||||
|
||||
# Create MockTTSService to generate TTS frames
|
||||
tts = MockTTSService(mock_audio_duration_ms=10, frame_delay=0)
|
||||
tts = MockTTSService(mock_audio_duration_ms=40, frame_delay=0)
|
||||
|
||||
# Create MockTransport for simulating transport behavior
|
||||
mock_transport = MockTransport(emit_bot_speaking=False)
|
||||
mock_transport = MockTransport(
|
||||
params=TransportParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
audio_in_sample_rate=16000,
|
||||
audio_out_sample_rate=16000,
|
||||
),
|
||||
)
|
||||
|
||||
# Create LLM context
|
||||
context = LLMContext()
|
||||
|
|
|
|||
|
|
@ -33,6 +33,7 @@ from pipecat.processors.aggregators.llm_response_universal import (
|
|||
)
|
||||
from pipecat.tests import MockLLMService, MockTTSService
|
||||
from pipecat.tests.mock_transport import MockTransport
|
||||
from pipecat.transports.base_transport import TransportParams
|
||||
|
||||
|
||||
class TestVariableExtractionDuringTransitions:
|
||||
|
|
@ -80,10 +81,17 @@ class TestVariableExtractionDuringTransitions:
|
|||
llm = MockLLMService(mock_steps=mock_steps, chunk_delay=0.001)
|
||||
|
||||
# Create MockTTSService
|
||||
tts = MockTTSService(mock_audio_duration_ms=10, frame_delay=0)
|
||||
tts = MockTTSService(mock_audio_duration_ms=40, frame_delay=0)
|
||||
|
||||
# Create MockTransport for simulating transport behavior
|
||||
mock_transport = MockTransport(emit_bot_speaking=False)
|
||||
mock_transport = MockTransport(
|
||||
params=TransportParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
audio_in_sample_rate=16000,
|
||||
audio_out_sample_rate=16000,
|
||||
),
|
||||
)
|
||||
|
||||
# Create LLM context
|
||||
context = LLMContext()
|
||||
|
|
|
|||
399
api/tests/test_tts_endframe_with_audio_write_failure.py
Normal file
399
api/tests/test_tts_endframe_with_audio_write_failure.py
Normal file
|
|
@ -0,0 +1,399 @@
|
|||
"""Tests for TTS pause_frame_processing with audio write failure scenarios.
|
||||
|
||||
This module tests a scenario where:
|
||||
1. TTS service has pause_frame_processing=True
|
||||
2. Output transport's write_audio_frame returns False (simulating failure)
|
||||
3. TTS pauses frame processing while generating audio
|
||||
4. Audio write failures occur but BotStoppedSpeakingFrame is never sent
|
||||
5. TTS remains paused indefinitely
|
||||
6. end_call_with_reason is called and hangs because EndFrame can't be processed
|
||||
|
||||
The root cause is that when write_audio_frame fails consecutively in _audio_task_handler,
|
||||
it breaks out of the loop without calling _bot_stopped_speaking(), leaving the TTS
|
||||
in a paused state that blocks all subsequent frame processing including EndFrame.
|
||||
|
||||
Two test scenarios are covered:
|
||||
1. Bot started speaking, then audio write fails (fail_after_n_frames > 0)
|
||||
- BotStartedSpeakingFrame is emitted
|
||||
- Some audio is written successfully
|
||||
- Write starts failing, _audio_task_handler breaks out
|
||||
- _bot_stopped_speaking() is NOT called (the bug)
|
||||
- TTS remains paused
|
||||
|
||||
2. Bot never started speaking because write failed immediately (fail_after_n_frames = 0)
|
||||
- Audio write fails from the first frame
|
||||
- _bot_currently_speaking() is called but write fails
|
||||
- _audio_task_handler breaks out after consecutive failures
|
||||
- _bot_stopped_speaking() is NOT called (the bug)
|
||||
- TTS remains paused
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from api.services.workflow.pipecat_engine import PipecatEngine
|
||||
from api.services.workflow.pipecat_engine_variable_extractor import (
|
||||
VariableExtractionManager,
|
||||
)
|
||||
from api.services.workflow.workflow import WorkflowGraph
|
||||
from pipecat.frames.frames import LLMContextFrame
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.llm_response import LLMAssistantAggregatorParams
|
||||
from pipecat.processors.aggregators.llm_response_universal import (
|
||||
LLMContextAggregatorPair,
|
||||
LLMUserAggregatorParams,
|
||||
)
|
||||
from pipecat.tests import MockLLMService, MockTTSService
|
||||
from pipecat.tests.mock_transport import MockTransport
|
||||
from pipecat.transports.base_transport import TransportParams
|
||||
from pipecat.turns.user_mute import (
|
||||
CallbackUserMuteStrategy,
|
||||
MuteUntilFirstBotCompleteUserMuteStrategy,
|
||||
)
|
||||
from pipecat.utils.enums import EndTaskReason
|
||||
|
||||
|
||||
async def create_test_pipeline_with_failing_transport(
|
||||
workflow: WorkflowGraph,
|
||||
mock_llm: MockLLMService,
|
||||
fail_after_n_frames: int = 0,
|
||||
) -> tuple[PipecatEngine, MockTTSService, MockTransport, PipelineTask]:
|
||||
"""Create a PipecatEngine with failing output transport for testing.
|
||||
|
||||
Uses the real MockTransport which now extends BaseOutputTransport and uses
|
||||
the real MediaSender machinery. This properly simulates:
|
||||
- Bot speaking events through _handle_bot_speech and _bot_currently_speaking
|
||||
- Audio write failure handling in _audio_task_handler
|
||||
- The bug where _bot_stopped_speaking() is not called after consecutive failures
|
||||
|
||||
Args:
|
||||
workflow: The workflow graph to use.
|
||||
mock_llm: The mock LLM service.
|
||||
fail_after_n_frames: Number of audio frames that will succeed before
|
||||
write starts failing. Set to 0 to fail immediately.
|
||||
|
||||
Returns:
|
||||
Tuple of (engine, tts, transport, task)
|
||||
"""
|
||||
# Create TTS with pause_frame_processing=True
|
||||
# This causes TTS to pause processing frames while generating audio,
|
||||
# waiting for BotStoppedSpeakingFrame to resume
|
||||
tts = MockTTSService(
|
||||
mock_audio_duration_ms=200, # Shorter for faster test
|
||||
frame_delay=0.001, # Minimal delay
|
||||
pause_frame_processing=True, # Key setting for this test
|
||||
)
|
||||
|
||||
# Create transport that fails audio writes
|
||||
# Uses the real MediaSender._audio_task_handler which:
|
||||
# 1. Calls write_audio_frame
|
||||
# 2. Handles bot speaking events through _handle_bot_speech
|
||||
# 3. Breaks out after consecutive failures (the bug - doesn't call _bot_stopped_speaking)
|
||||
transport = MockTransport(
|
||||
generate_audio=False, # No input audio for this test
|
||||
audio_write_succeeds=False, # Enable write failure mode
|
||||
fail_after_n_frames=fail_after_n_frames,
|
||||
params=TransportParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
audio_in_sample_rate=16000,
|
||||
audio_out_sample_rate=16000,
|
||||
# Use faster failure detection for tests
|
||||
audio_out_max_consecutive_failures=2,
|
||||
audio_out_sleep_between_failures=0.25,
|
||||
),
|
||||
)
|
||||
|
||||
# Create LLM context
|
||||
context = LLMContext()
|
||||
|
||||
# Create PipecatEngine
|
||||
engine = PipecatEngine(
|
||||
llm=mock_llm,
|
||||
context=context,
|
||||
workflow=workflow,
|
||||
call_context_vars={"customer_name": "Test User"},
|
||||
workflow_run_id=1,
|
||||
)
|
||||
|
||||
# Create user mute strategies
|
||||
user_mute_strategies = [
|
||||
MuteUntilFirstBotCompleteUserMuteStrategy(),
|
||||
CallbackUserMuteStrategy(should_mute_callback=engine.should_mute_user),
|
||||
]
|
||||
|
||||
user_params = LLMUserAggregatorParams(
|
||||
user_mute_strategies=user_mute_strategies,
|
||||
)
|
||||
|
||||
assistant_params = LLMAssistantAggregatorParams(expect_stripped_words=True)
|
||||
|
||||
context_aggregator = LLMContextAggregatorPair(
|
||||
context, assistant_params=assistant_params, user_params=user_params
|
||||
)
|
||||
user_context_aggregator = context_aggregator.user()
|
||||
assistant_context_aggregator = context_aggregator.assistant()
|
||||
|
||||
# Create the pipeline
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
transport.input(),
|
||||
user_context_aggregator,
|
||||
mock_llm,
|
||||
tts,
|
||||
transport.output(),
|
||||
assistant_context_aggregator,
|
||||
]
|
||||
)
|
||||
|
||||
# Create pipeline task
|
||||
task = PipelineTask(pipeline, params=PipelineParams(), enable_rtvi=False)
|
||||
|
||||
engine.set_task(task)
|
||||
|
||||
return engine, tts, transport, task
|
||||
|
||||
|
||||
class TestTTSPauseWithAudioWriteFailure:
|
||||
"""Test scenarios where TTS pause_frame_processing interacts with audio write failures."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bot_never_started_speaking_write_fails_immediately(
|
||||
self, simple_workflow: WorkflowGraph
|
||||
):
|
||||
"""Test scenario where bot never starts speaking because write fails immediately.
|
||||
|
||||
Scenario:
|
||||
1. LLM generates text response
|
||||
2. TTS starts generating audio with pause_frame_processing=True
|
||||
3. TTS pauses frame processing (waits for BotStoppedSpeakingFrame)
|
||||
4. MediaSender tries to write audio, calls _bot_currently_speaking
|
||||
5. write_audio_frame returns False immediately
|
||||
6. After consecutive failures, _audio_task_handler breaks out
|
||||
7. BUG: _bot_stopped_speaking() is NOT called
|
||||
8. TTS remains paused, blocking EndFrame
|
||||
9. Pipeline hangs on end_call_with_reason
|
||||
|
||||
This test verifies the hang behavior by using a timeout.
|
||||
Note: Uses audio_out_max_consecutive_failures=2 for faster test execution.
|
||||
"""
|
||||
# Create LLM response that will trigger TTS
|
||||
step_0_chunks = MockLLMService.create_text_chunks(
|
||||
"Hello! This is a test message that should cause TTS to pause."
|
||||
)
|
||||
|
||||
test_timed_out = False
|
||||
mock_steps = [step_0_chunks]
|
||||
llm = MockLLMService(mock_steps=mock_steps, chunk_delay=0.001)
|
||||
|
||||
(
|
||||
engine,
|
||||
tts,
|
||||
transport,
|
||||
task,
|
||||
) = await create_test_pipeline_with_failing_transport(
|
||||
simple_workflow,
|
||||
llm,
|
||||
fail_after_n_frames=0, # Fail immediately - bot never starts speaking
|
||||
)
|
||||
|
||||
# Patch DB calls
|
||||
with patch(
|
||||
"api.services.workflow.pipecat_engine.get_organization_id_from_workflow_run",
|
||||
new_callable=AsyncMock,
|
||||
return_value=1,
|
||||
):
|
||||
with patch(
|
||||
"api.services.workflow.pipecat_engine.apply_disposition_mapping",
|
||||
new_callable=AsyncMock,
|
||||
return_value="completed",
|
||||
):
|
||||
with patch.object(
|
||||
VariableExtractionManager,
|
||||
"_perform_extraction",
|
||||
new_callable=AsyncMock,
|
||||
return_value={},
|
||||
):
|
||||
runner = PipelineRunner()
|
||||
|
||||
async def run_pipeline():
|
||||
await runner.run(task)
|
||||
|
||||
async def initialize_and_end_call():
|
||||
await asyncio.sleep(0.01)
|
||||
await engine.initialize()
|
||||
|
||||
# Start LLM generation - this will trigger TTS
|
||||
await engine.llm.queue_frame(LLMContextFrame(engine.context))
|
||||
|
||||
# Sleep so that processing is paused in TTS Service
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
await engine.end_call_with_reason(
|
||||
EndTaskReason.USER_HANGUP.value,
|
||||
abort_immediately=False,
|
||||
)
|
||||
|
||||
# Create tasks explicitly for better control
|
||||
pipeline_task = asyncio.create_task(run_pipeline())
|
||||
end_call_task = asyncio.create_task(initialize_and_end_call())
|
||||
|
||||
# Wait with timeout
|
||||
done, pending = await asyncio.wait(
|
||||
[pipeline_task, end_call_task],
|
||||
timeout=3.0,
|
||||
return_when=asyncio.ALL_COMPLETED,
|
||||
)
|
||||
|
||||
# If there are pending tasks, we timed out
|
||||
if pending:
|
||||
test_timed_out = True
|
||||
# Cancel all pending tasks
|
||||
for t in pending:
|
||||
t.cancel()
|
||||
|
||||
# Give limited time for cleanup
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
asyncio.gather(*pending, return_exceptions=True),
|
||||
timeout=1.0,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
pass # Cleanup took too long, continue anyway
|
||||
|
||||
# Verify audio write was attempted but failed
|
||||
output_transport = transport._output
|
||||
assert output_transport._write_attempts > 0, (
|
||||
"Audio write should have been attempted"
|
||||
)
|
||||
assert output_transport._frames_written == 0, (
|
||||
"No frames should have been written successfully"
|
||||
)
|
||||
|
||||
assert test_timed_out is False, (
|
||||
"Test timed out - pipeline hung due to TTS being paused. "
|
||||
"BotStoppedSpeakingFrame was not sent before CancelTaskFrame."
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bot_started_speaking_then_write_fails(
|
||||
self, simple_workflow: WorkflowGraph
|
||||
):
|
||||
"""Test scenario where bot starts speaking, then audio write fails mid-stream.
|
||||
|
||||
This tests a more realistic scenario where the transport starts working
|
||||
but then encounters issues (e.g., client disconnect mid-stream).
|
||||
|
||||
Scenario:
|
||||
1. LLM generates text response
|
||||
2. TTS starts generating audio with pause_frame_processing=True
|
||||
3. First N audio frames are written successfully
|
||||
4. BotStartedSpeakingFrame is emitted
|
||||
5. Subsequent writes start failing
|
||||
6. After consecutive failures, _audio_task_handler breaks out
|
||||
7. BUG: _bot_stopped_speaking() is NOT called
|
||||
8. TTS remains paused, blocking EndFrame
|
||||
|
||||
Note: Uses audio_out_max_consecutive_failures=2 for faster test execution.
|
||||
"""
|
||||
step_0_chunks = MockLLMService.create_text_chunks(
|
||||
"This is a longer message to ensure multiple audio frames are generated."
|
||||
)
|
||||
|
||||
test_timed_out = False
|
||||
mock_steps = [step_0_chunks]
|
||||
llm = MockLLMService(mock_steps=mock_steps, chunk_delay=0.001)
|
||||
|
||||
# Allow first 3 frames to succeed, then fail
|
||||
# This simulates bot starting to speak, then transport disconnecting
|
||||
(
|
||||
engine,
|
||||
tts,
|
||||
transport,
|
||||
task,
|
||||
) = await create_test_pipeline_with_failing_transport(
|
||||
simple_workflow,
|
||||
llm,
|
||||
fail_after_n_frames=3, # Bot starts speaking, then fails
|
||||
)
|
||||
|
||||
with patch(
|
||||
"api.services.workflow.pipecat_engine.get_organization_id_from_workflow_run",
|
||||
new_callable=AsyncMock,
|
||||
return_value=1,
|
||||
):
|
||||
with patch(
|
||||
"api.services.workflow.pipecat_engine.apply_disposition_mapping",
|
||||
new_callable=AsyncMock,
|
||||
return_value="completed",
|
||||
):
|
||||
with patch.object(
|
||||
VariableExtractionManager,
|
||||
"_perform_extraction",
|
||||
new_callable=AsyncMock,
|
||||
return_value={},
|
||||
):
|
||||
runner = PipelineRunner()
|
||||
|
||||
async def run_pipeline():
|
||||
await runner.run(task)
|
||||
|
||||
async def initialize_and_observe():
|
||||
await asyncio.sleep(0.01)
|
||||
await engine.initialize()
|
||||
|
||||
await engine.llm.queue_frame(LLMContextFrame(engine.context))
|
||||
|
||||
# Sleep so that processing is paused in TTS Service
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
await engine.end_call_with_reason(
|
||||
EndTaskReason.USER_HANGUP.value,
|
||||
abort_immediately=False,
|
||||
)
|
||||
|
||||
# Create tasks explicitly for better control
|
||||
pipeline_task = asyncio.create_task(run_pipeline())
|
||||
end_call_task = asyncio.create_task(initialize_and_observe())
|
||||
|
||||
# Wait with timeout
|
||||
done, pending = await asyncio.wait(
|
||||
[pipeline_task, end_call_task],
|
||||
timeout=3.0,
|
||||
return_when=asyncio.ALL_COMPLETED,
|
||||
)
|
||||
|
||||
# If there are pending tasks, we timed out
|
||||
if pending:
|
||||
test_timed_out = True
|
||||
# Cancel all pending tasks
|
||||
for t in pending:
|
||||
t.cancel()
|
||||
|
||||
# Give limited time for cleanup
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
asyncio.gather(*pending, return_exceptions=True),
|
||||
timeout=1.0,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
pass # Cleanup took too long, continue anyway
|
||||
|
||||
# Verify some frames were written successfully before failure
|
||||
output_transport = transport._output
|
||||
assert output_transport._frames_written == 3, (
|
||||
f"Expected 3 successful writes, got {output_transport._frames_written}"
|
||||
)
|
||||
assert output_transport._write_attempts > 3, (
|
||||
"Should have attempted more writes after initial successes"
|
||||
)
|
||||
|
||||
assert test_timed_out is False, (
|
||||
"Test timed out - pipeline hung due to TTS being paused. "
|
||||
"BotStoppedSpeakingFrame was not sent before CancelTaskFrame."
|
||||
)
|
||||
|
|
@ -26,6 +26,7 @@ from pipecat.processors.aggregators.llm_response_universal import (
|
|||
)
|
||||
from pipecat.tests import MockLLMService, MockTTSService
|
||||
from pipecat.tests.mock_transport import MockTransport
|
||||
from pipecat.transports.base_transport import TransportParams
|
||||
|
||||
|
||||
async def run_pipeline_with_user_idle(
|
||||
|
|
@ -57,10 +58,17 @@ async def run_pipeline_with_user_idle(
|
|||
)
|
||||
|
||||
llm = MockLLMService(mock_steps=mock_steps, chunk_delay=0.001)
|
||||
tts = MockTTSService(mock_audio_duration_ms=10)
|
||||
tts = MockTTSService(mock_audio_duration_ms=40)
|
||||
|
||||
# Create MockTransport for simulating transport behavior
|
||||
mock_transport = MockTransport(emit_bot_speaking=True)
|
||||
mock_transport = MockTransport(
|
||||
params=TransportParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
audio_in_sample_rate=16000,
|
||||
audio_out_sample_rate=16000,
|
||||
),
|
||||
)
|
||||
|
||||
# Create LLM context
|
||||
context = LLMContext()
|
||||
|
|
|
|||
499
api/tests/test_user_muting_during_bot_speech.py
Normal file
499
api/tests/test_user_muting_during_bot_speech.py
Normal file
|
|
@ -0,0 +1,499 @@
|
|||
"""Tests for verifying user muting behavior based on bot speaking state.
|
||||
|
||||
This module tests the user muting behavior with different allow_interrupt settings:
|
||||
|
||||
1. Pipeline is always muted until first BotStoppedSpeaking
|
||||
2. When allow_interrupt=True, pipeline is NOT muted after second BotStartedSpeaking
|
||||
3. When allow_interrupt=False, pipeline IS muted during second bot speech
|
||||
|
||||
The observer is placed BEFORE user_aggregator to check mute status when
|
||||
bot speaking events flow upstream.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import List
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from api.services.workflow.pipecat_engine import PipecatEngine
|
||||
from api.services.workflow.pipecat_engine_variable_extractor import (
|
||||
VariableExtractionManager,
|
||||
)
|
||||
from api.services.workflow.workflow import WorkflowGraph
|
||||
from pipecat.frames.frames import (
|
||||
BotStartedSpeakingFrame,
|
||||
BotStoppedSpeakingFrame,
|
||||
Frame,
|
||||
LLMContextFrame,
|
||||
TranscriptionFrame,
|
||||
UserStartedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
)
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.llm_response import LLMAssistantAggregatorParams
|
||||
from pipecat.processors.aggregators.llm_response_universal import (
|
||||
LLMContextAggregatorPair,
|
||||
LLMUserAggregator,
|
||||
LLMUserAggregatorParams,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
from pipecat.tests import MockLLMService, MockTTSService
|
||||
from pipecat.tests.mock_transport import MockTransport
|
||||
from pipecat.transports.base_transport import TransportParams
|
||||
from pipecat.turns.user_mute import (
|
||||
CallbackUserMuteStrategy,
|
||||
MuteUntilFirstBotCompleteUserMuteStrategy,
|
||||
)
|
||||
from pipecat.turns.user_turn_strategies import ExternalUserTurnStrategies
|
||||
from pipecat.utils.time import time_now_iso8601
|
||||
|
||||
|
||||
class BotSpeakingObserverProcessor(FrameProcessor):
|
||||
"""Observer that records mute status when bot speaking events flow upstream.
|
||||
|
||||
Placed BEFORE user_aggregator in the pipeline. When bot speaking frames
|
||||
flow upstream (from output transport), they pass through user_aggregator
|
||||
first (updating its state), then reach this observer.
|
||||
|
||||
Pipeline structure:
|
||||
transport.input() -> observer -> user_aggregator -> llm -> tts -> transport.output()
|
||||
|
||||
UPSTREAM flow: transport.output() -> tts -> llm -> user_aggregator -> observer -> transport.input()
|
||||
"""
|
||||
|
||||
def __init__(self, user_aggregator: LLMUserAggregator, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.user_aggregator = user_aggregator
|
||||
self.bot_started_count = 0
|
||||
self.bot_stopped_count = 0
|
||||
self.mute_status_on_bot_started: List[bool] = []
|
||||
self.mute_status_on_bot_stopped: List[bool] = []
|
||||
|
||||
# Events for synchronization
|
||||
self.first_bot_started = asyncio.Event()
|
||||
self.first_bot_stopped = asyncio.Event()
|
||||
self.second_bot_started = asyncio.Event()
|
||||
self.second_bot_stopped = asyncio.Event()
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if direction == FrameDirection.UPSTREAM:
|
||||
if isinstance(frame, BotStartedSpeakingFrame):
|
||||
self.bot_started_count += 1
|
||||
# Check the current mute status from user_aggregator
|
||||
muted = self.user_aggregator._user_is_muted
|
||||
self.mute_status_on_bot_started.append(muted)
|
||||
|
||||
if self.bot_started_count == 1:
|
||||
self.first_bot_started.set()
|
||||
elif self.bot_started_count == 2:
|
||||
self.second_bot_started.set()
|
||||
|
||||
elif isinstance(frame, BotStoppedSpeakingFrame):
|
||||
self.bot_stopped_count += 1
|
||||
# Check the current mute status from user_aggregator
|
||||
muted = self.user_aggregator._user_is_muted
|
||||
self.mute_status_on_bot_stopped.append(muted)
|
||||
|
||||
if self.bot_stopped_count == 1:
|
||||
self.first_bot_stopped.set()
|
||||
elif self.bot_stopped_count == 2:
|
||||
self.second_bot_stopped.set()
|
||||
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
|
||||
def set_workflow_allow_interrupt_in_start_node(
|
||||
workflow: WorkflowGraph, allow_interrupt: bool
|
||||
):
|
||||
"""Set allow_interrupt on all nodes in the workflow."""
|
||||
for node in workflow.nodes.values():
|
||||
if node.is_start:
|
||||
node.allow_interrupt = allow_interrupt
|
||||
|
||||
|
||||
async def create_engine_for_mute_test(
|
||||
workflow: WorkflowGraph,
|
||||
mock_llm: MockLLMService,
|
||||
tts_duration_ms: int = 100,
|
||||
) -> tuple[
|
||||
PipecatEngine,
|
||||
MockTTSService,
|
||||
MockTransport,
|
||||
PipelineTask,
|
||||
LLMUserAggregator,
|
||||
BotSpeakingObserverProcessor,
|
||||
]:
|
||||
"""Create a PipecatEngine with observer BEFORE user_aggregator for testing.
|
||||
|
||||
Pipeline structure:
|
||||
transport.input() -> observer -> user_aggregator -> mock_llm -> tts -> transport.output() -> assistant_aggregator
|
||||
|
||||
Returns:
|
||||
Tuple of (engine, tts, transport, task, user_aggregator, observer)
|
||||
"""
|
||||
tts = MockTTSService(mock_audio_duration_ms=tts_duration_ms, frame_delay=0.01)
|
||||
|
||||
mock_transport = MockTransport(
|
||||
generate_audio=False,
|
||||
params=TransportParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
audio_in_sample_rate=16000,
|
||||
audio_out_sample_rate=16000,
|
||||
),
|
||||
)
|
||||
|
||||
context = LLMContext()
|
||||
|
||||
engine = PipecatEngine(
|
||||
llm=mock_llm,
|
||||
context=context,
|
||||
workflow=workflow,
|
||||
call_context_vars={"customer_name": "Test User"},
|
||||
workflow_run_id=1,
|
||||
)
|
||||
|
||||
# Create context aggregator with user mute strategies
|
||||
assistant_params = LLMAssistantAggregatorParams(expect_stripped_words=True)
|
||||
|
||||
user_mute_strategies = [
|
||||
MuteUntilFirstBotCompleteUserMuteStrategy(),
|
||||
CallbackUserMuteStrategy(should_mute_callback=engine.should_mute_user),
|
||||
]
|
||||
|
||||
user_params = LLMUserAggregatorParams(
|
||||
user_turn_strategies=ExternalUserTurnStrategies(),
|
||||
user_mute_strategies=user_mute_strategies,
|
||||
)
|
||||
|
||||
context_aggregator = LLMContextAggregatorPair(
|
||||
context, assistant_params=assistant_params, user_params=user_params
|
||||
)
|
||||
user_context_aggregator = context_aggregator.user()
|
||||
assistant_context_aggregator = context_aggregator.assistant()
|
||||
|
||||
# Create observer with reference to user_aggregator
|
||||
observer = BotSpeakingObserverProcessor(user_context_aggregator)
|
||||
|
||||
# Pipeline: observer is BEFORE user_aggregator
|
||||
# This means upstream frames (bot speaking) pass through user_aggregator first,
|
||||
# then reach the observer
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
mock_transport.input(),
|
||||
observer,
|
||||
user_context_aggregator,
|
||||
mock_llm,
|
||||
tts,
|
||||
mock_transport.output(),
|
||||
assistant_context_aggregator,
|
||||
]
|
||||
)
|
||||
|
||||
task = PipelineTask(pipeline, params=PipelineParams(), enable_rtvi=False)
|
||||
engine.set_task(task)
|
||||
|
||||
return engine, tts, mock_transport, task, user_context_aggregator, observer
|
||||
|
||||
|
||||
async def queue_user_speaking_and_transcript_frames(task):
|
||||
await task.queue_frame(UserStartedSpeakingFrame())
|
||||
await asyncio.sleep(0)
|
||||
await task.queue_frame(
|
||||
TranscriptionFrame("User Speech", "user_id", time_now_iso8601())
|
||||
)
|
||||
await asyncio.sleep(0)
|
||||
await task.queue_frame(UserStoppedSpeakingFrame())
|
||||
|
||||
|
||||
class TestUserMutingDuringBotSpeech:
|
||||
"""Test user muting behavior based on bot speaking state."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_muted_until_first_bot_stopped_speaking(
|
||||
self, simple_workflow: WorkflowGraph
|
||||
):
|
||||
"""Test that pipeline is always muted until first BotStoppedSpeaking.
|
||||
|
||||
Both allow_interrupt=True and allow_interrupt=False should be muted
|
||||
during the first bot response due to MuteUntilFirstBotCompleteUserMuteStrategy.
|
||||
"""
|
||||
set_workflow_allow_interrupt_in_start_node(
|
||||
simple_workflow, allow_interrupt=False
|
||||
)
|
||||
|
||||
step_0_chunks = MockLLMService.create_text_chunks("Hello!")
|
||||
step_1_chunks = MockLLMService.create_text_chunks("How can I help?")
|
||||
mock_steps = [step_0_chunks, step_1_chunks]
|
||||
llm = MockLLMService(mock_steps=mock_steps, chunk_delay=0.001)
|
||||
|
||||
(
|
||||
engine,
|
||||
_tts,
|
||||
_transport,
|
||||
task,
|
||||
_user_aggregator,
|
||||
observer,
|
||||
) = await create_engine_for_mute_test(simple_workflow, llm, tts_duration_ms=50)
|
||||
|
||||
with patch(
|
||||
"api.services.workflow.pipecat_engine.get_organization_id_from_workflow_run",
|
||||
new_callable=AsyncMock,
|
||||
return_value=1,
|
||||
):
|
||||
with patch(
|
||||
"api.services.workflow.pipecat_engine.apply_disposition_mapping",
|
||||
new_callable=AsyncMock,
|
||||
return_value="completed",
|
||||
):
|
||||
with patch.object(
|
||||
VariableExtractionManager,
|
||||
"_perform_extraction",
|
||||
new_callable=AsyncMock,
|
||||
return_value={},
|
||||
):
|
||||
runner = PipelineRunner()
|
||||
|
||||
async def run_pipeline():
|
||||
await runner.run(task)
|
||||
|
||||
async def run_test():
|
||||
await asyncio.sleep(0.01)
|
||||
await engine.initialize()
|
||||
|
||||
# Trigger first LLM completion
|
||||
await engine.llm.queue_frame(LLMContextFrame(engine.context))
|
||||
|
||||
# Wait for first bot started
|
||||
await asyncio.wait_for(
|
||||
observer.first_bot_started.wait(), timeout=5.0
|
||||
)
|
||||
|
||||
# Queue user speaking frames so that second generation starts
|
||||
await queue_user_speaking_and_transcript_frames(task)
|
||||
|
||||
# Wait for first bot stopped
|
||||
await asyncio.wait_for(
|
||||
observer.first_bot_stopped.wait(), timeout=5.0
|
||||
)
|
||||
|
||||
await task.cancel()
|
||||
|
||||
await asyncio.gather(
|
||||
run_pipeline(),
|
||||
run_test(),
|
||||
return_exceptions=True,
|
||||
)
|
||||
|
||||
# VERIFY: Muted at first BotStartedSpeaking
|
||||
assert len(observer.mute_status_on_bot_started) >= 1
|
||||
assert observer.mute_status_on_bot_started[0] is True, (
|
||||
"Pipeline should be muted at first BotStartedSpeaking"
|
||||
)
|
||||
|
||||
# VERIFY: Unmuted at first BotStoppedSpeaking
|
||||
assert len(observer.mute_status_on_bot_stopped) >= 1
|
||||
assert observer.mute_status_on_bot_stopped[0] is False, (
|
||||
"Pipeline should be unmuted at first BotStoppedSpeaking"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_allow_interrupt_true_not_muted_after_second_bot_started(
|
||||
self, simple_workflow: WorkflowGraph
|
||||
):
|
||||
"""Test that when allow_interrupt=True, pipeline is NOT muted after second BotStartedSpeaking.
|
||||
|
||||
After first bot response completes:
|
||||
- User speaks and triggers second LLM response
|
||||
- When second BotStartedSpeaking arrives, user should NOT be muted
|
||||
because allow_interrupt=True allows interruption
|
||||
"""
|
||||
set_workflow_allow_interrupt_in_start_node(
|
||||
simple_workflow, allow_interrupt=True
|
||||
)
|
||||
|
||||
step_0_chunks = MockLLMService.create_text_chunks("Hello!")
|
||||
step_1_chunks = MockLLMService.create_text_chunks("I can help with that.")
|
||||
mock_steps = [step_0_chunks, step_1_chunks]
|
||||
llm = MockLLMService(mock_steps=mock_steps, chunk_delay=0.001)
|
||||
|
||||
(
|
||||
engine,
|
||||
_tts,
|
||||
_transport,
|
||||
task,
|
||||
_user_aggregator,
|
||||
observer,
|
||||
) = await create_engine_for_mute_test(simple_workflow, llm, tts_duration_ms=50)
|
||||
|
||||
with patch(
|
||||
"api.services.workflow.pipecat_engine.get_organization_id_from_workflow_run",
|
||||
new_callable=AsyncMock,
|
||||
return_value=1,
|
||||
):
|
||||
with patch(
|
||||
"api.services.workflow.pipecat_engine.apply_disposition_mapping",
|
||||
new_callable=AsyncMock,
|
||||
return_value="completed",
|
||||
):
|
||||
with patch.object(
|
||||
VariableExtractionManager,
|
||||
"_perform_extraction",
|
||||
new_callable=AsyncMock,
|
||||
return_value={},
|
||||
):
|
||||
runner = PipelineRunner()
|
||||
|
||||
async def run_pipeline():
|
||||
await runner.run(task)
|
||||
|
||||
async def run_test():
|
||||
await asyncio.sleep(0.01)
|
||||
await engine.initialize()
|
||||
|
||||
# Trigger first LLM completion
|
||||
await engine.llm.queue_frame(LLMContextFrame(engine.context))
|
||||
|
||||
# Wait for first bot stopped (first response complete)
|
||||
await asyncio.wait_for(
|
||||
observer.first_bot_stopped.wait(), timeout=5.0
|
||||
)
|
||||
|
||||
# Queue user speaking frames for second generation
|
||||
await queue_user_speaking_and_transcript_frames(task)
|
||||
|
||||
# Wait for second bot started
|
||||
await asyncio.wait_for(
|
||||
observer.second_bot_started.wait(), timeout=5.0
|
||||
)
|
||||
|
||||
# Wait for second bot stopped
|
||||
await asyncio.wait_for(
|
||||
observer.second_bot_stopped.wait(), timeout=5.0
|
||||
)
|
||||
|
||||
await task.cancel()
|
||||
|
||||
await asyncio.gather(
|
||||
run_pipeline(),
|
||||
run_test(),
|
||||
return_exceptions=True,
|
||||
)
|
||||
|
||||
# VERIFY: First bot started - should be muted (MuteUntilFirstBotComplete)
|
||||
assert len(observer.mute_status_on_bot_started) >= 2
|
||||
assert observer.mute_status_on_bot_started[0] is True, (
|
||||
"Pipeline should be muted at first BotStartedSpeaking"
|
||||
)
|
||||
|
||||
# VERIFY: Second bot started - should NOT be muted (allow_interrupt=True)
|
||||
assert observer.mute_status_on_bot_started[1] is False, (
|
||||
"Pipeline should NOT be muted at second BotStartedSpeaking when allow_interrupt=True"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_allow_interrupt_false_muted_during_second_bot_speech(
|
||||
self, simple_workflow: WorkflowGraph
|
||||
):
|
||||
"""Test that when allow_interrupt=False, pipeline IS muted during second bot speech.
|
||||
|
||||
After first bot response completes:
|
||||
- User speaks and triggers second LLM response
|
||||
- When second BotStartedSpeaking arrives, user SHOULD be muted
|
||||
because allow_interrupt=False prevents interruption
|
||||
- When second BotStoppedSpeaking arrives, user should be unmuted
|
||||
"""
|
||||
set_workflow_allow_interrupt_in_start_node(
|
||||
simple_workflow, allow_interrupt=False
|
||||
)
|
||||
|
||||
step_0_chunks = MockLLMService.create_text_chunks("Hello!")
|
||||
step_1_chunks = MockLLMService.create_text_chunks("I can help with that.")
|
||||
mock_steps = [step_0_chunks, step_1_chunks]
|
||||
llm = MockLLMService(mock_steps=mock_steps, chunk_delay=0.001)
|
||||
|
||||
(
|
||||
engine,
|
||||
_tts,
|
||||
_transport,
|
||||
task,
|
||||
_user_aggregator,
|
||||
observer,
|
||||
) = await create_engine_for_mute_test(simple_workflow, llm, tts_duration_ms=50)
|
||||
|
||||
with patch(
|
||||
"api.services.workflow.pipecat_engine.get_organization_id_from_workflow_run",
|
||||
new_callable=AsyncMock,
|
||||
return_value=1,
|
||||
):
|
||||
with patch(
|
||||
"api.services.workflow.pipecat_engine.apply_disposition_mapping",
|
||||
new_callable=AsyncMock,
|
||||
return_value="completed",
|
||||
):
|
||||
with patch.object(
|
||||
VariableExtractionManager,
|
||||
"_perform_extraction",
|
||||
new_callable=AsyncMock,
|
||||
return_value={},
|
||||
):
|
||||
runner = PipelineRunner()
|
||||
|
||||
async def run_pipeline():
|
||||
await runner.run(task)
|
||||
|
||||
async def run_test():
|
||||
await asyncio.sleep(0.01)
|
||||
await engine.initialize()
|
||||
|
||||
# Trigger first LLM completion
|
||||
await engine.llm.queue_frame(LLMContextFrame(engine.context))
|
||||
|
||||
# Wait for first bot stopped (first response complete)
|
||||
await asyncio.wait_for(
|
||||
observer.first_bot_stopped.wait(), timeout=5.0
|
||||
)
|
||||
|
||||
# Queue user speaking frames for second llm generation
|
||||
await queue_user_speaking_and_transcript_frames(task)
|
||||
|
||||
# Wait for second bot started
|
||||
await asyncio.wait_for(
|
||||
observer.second_bot_started.wait(), timeout=5.0
|
||||
)
|
||||
|
||||
# Wait for second bot stopped
|
||||
await asyncio.wait_for(
|
||||
observer.second_bot_stopped.wait(), timeout=5.0
|
||||
)
|
||||
|
||||
await task.cancel()
|
||||
|
||||
await asyncio.gather(
|
||||
run_pipeline(),
|
||||
run_test(),
|
||||
return_exceptions=True,
|
||||
)
|
||||
|
||||
# VERIFY: First bot started - should be muted (MuteUntilFirstBotComplete)
|
||||
assert len(observer.mute_status_on_bot_started) >= 2
|
||||
assert observer.mute_status_on_bot_started[0] is True, (
|
||||
"Pipeline should be muted at first BotStartedSpeaking"
|
||||
)
|
||||
|
||||
# VERIFY: Second bot started - SHOULD be muted (allow_interrupt=False)
|
||||
assert observer.mute_status_on_bot_started[1] is True, (
|
||||
"Pipeline should be muted at second BotStartedSpeaking when allow_interrupt=False"
|
||||
)
|
||||
|
||||
# VERIFY: Second bot stopped - should be unmuted
|
||||
assert len(observer.mute_status_on_bot_stopped) >= 2
|
||||
assert observer.mute_status_on_bot_stopped[1] is False, (
|
||||
"Pipeline should be unmuted at second BotStoppedSpeaking"
|
||||
)
|
||||
Loading…
Add table
Add a link
Reference in a new issue