dograh/api/tests/test_pipecat_engine_transition_mute.py
Abhishek d4b6afb020
feat: add logs in campaigns for failure or pausing (#265)
* feat: add logs in campaigns on failure

* chore: bump pipecat

* chore: update format.sh

* chore: fix github workflow

* fix: fix formatting errors
2026-05-05 19:23:50 +05:30

280 lines
11 KiB
Python

"""Tests verifying user is muted while a transition function is executing.
When the LLM calls a transition function (registered via
``_register_transition_function_with_llm``), pipecat broadcasts a
``FunctionCallsStartedFrame`` that ``FunctionCallUserMuteStrategy`` uses to
mute the user until a ``FunctionCallResultFrame`` arrives. These tests assert
that mute behavior holds end-to-end through the engine's transition flow,
so that user audio doesn't race the node switch / extraction / context update
that runs inside the transition function.
"""
import asyncio
from unittest.mock import AsyncMock, patch
import pytest
from api.services.workflow.pipecat_engine import PipecatEngine
from api.services.workflow.pipecat_engine_variable_extractor import (
VariableExtractionManager,
)
from api.services.workflow.workflow import WorkflowGraph
from pipecat.frames.frames import LLMContextFrame
from pipecat.pipeline.pipeline import Pipeline
from pipecat.pipeline.runner import PipelineRunner
from pipecat.pipeline.task import PipelineParams, PipelineTask
from pipecat.processors.aggregators.llm_context import LLMContext
from pipecat.processors.aggregators.llm_response_universal import (
LLMAssistantAggregatorParams,
LLMContextAggregatorPair,
LLMUserAggregatorParams,
)
from pipecat.tests import MockLLMService, MockTTSService
from pipecat.tests.mock_transport import MockTransport
from pipecat.transports.base_transport import TransportParams
from pipecat.turns.user_mute import (
CallbackUserMuteStrategy,
FunctionCallUserMuteStrategy,
MuteUntilFirstBotCompleteUserMuteStrategy,
)
async def _build_engine_and_pipeline(
workflow: WorkflowGraph,
mock_llm: MockLLMService,
):
"""Set up engine + pipeline mirroring the non-realtime production wiring.
Returns (engine, task, function_call_mute_strategy, user_context_aggregator).
"""
tts = MockTTSService(mock_audio_duration_ms=40, frame_delay=0)
transport = MockTransport(
params=TransportParams(
audio_in_enabled=True,
audio_out_enabled=True,
audio_in_sample_rate=16000,
audio_out_sample_rate=16000,
),
)
context = LLMContext()
engine = PipecatEngine(
llm=mock_llm,
context=context,
workflow=workflow,
call_context_vars={"customer_name": "Test User"},
workflow_run_id=1,
)
# Hold a reference so the test can introspect the in-progress set.
function_call_mute_strategy = FunctionCallUserMuteStrategy()
# Match run_pipeline.py's non-realtime mute-strategy stack so the test
# exercises the same wiring that would be active in a real call.
user_mute_strategies = [
MuteUntilFirstBotCompleteUserMuteStrategy(),
function_call_mute_strategy,
CallbackUserMuteStrategy(should_mute_callback=engine.should_mute_user),
]
user_params = LLMUserAggregatorParams(user_mute_strategies=user_mute_strategies)
assistant_params = LLMAssistantAggregatorParams()
context_aggregator = LLMContextAggregatorPair(
context, assistant_params=assistant_params, user_params=user_params
)
user_context_aggregator = context_aggregator.user()
assistant_context_aggregator = context_aggregator.assistant()
pipeline = Pipeline(
[
transport.input(),
user_context_aggregator,
mock_llm,
tts,
transport.output(),
assistant_context_aggregator,
]
)
task = PipelineTask(pipeline, params=PipelineParams(), enable_rtvi=False)
engine.set_task(task)
return engine, task, function_call_mute_strategy, user_context_aggregator
class TestTransitionFunctionMutesUser:
"""Verify the user is muted while transition functions execute."""
@pytest.mark.asyncio
async def test_user_is_muted_during_transition_function(
self, simple_workflow: WorkflowGraph
):
"""The user must be muted from the moment a transition function starts
until its result is delivered.
Scenario:
1. LLM calls the ``end_call`` transition function (start → end edge).
2. Wrap the registered handler so we can read mute state from inside it.
3. VERIFY: the function-call mute strategy has the call in flight.
4. VERIFY: the user aggregator's ``_user_is_muted`` flag is True.
"""
step_0_chunks = MockLLMService.create_function_call_chunks(
function_name="end_call",
arguments={},
tool_call_id="call_end_1",
)
llm = MockLLMService(mock_steps=[step_0_chunks], chunk_delay=0.001)
(
engine,
task,
function_call_mute_strategy,
user_context_aggregator,
) = await _build_engine_and_pipeline(simple_workflow, llm)
captured_states: list[dict] = []
# Wrap register_function so we can introspect mute state from inside
# the transition handler. We must wrap *after* the engine is created
# but *before* set_node registers the transition functions.
original_register_function = llm.register_function
def wrapping_register_function(name, func, *args, **kwargs):
async def wrapped(function_call_params):
# Yield once so the user aggregator has a chance to drain
# the broadcasted FunctionCallsStartedFrame and update its
# mute state before we sample it.
await asyncio.sleep(0.02)
captured_states.append(
{
"name": name,
"function_call_in_progress": bool(
function_call_mute_strategy._function_call_in_progress
),
"user_is_muted": user_context_aggregator._user_is_muted,
"tool_call_ids": set(
function_call_mute_strategy._function_call_in_progress
),
}
)
return await func(function_call_params)
return original_register_function(name, wrapped, *args, **kwargs)
llm.register_function = wrapping_register_function
with patch(
"api.db:db_client.get_organization_id_by_workflow_run_id",
new_callable=AsyncMock,
return_value=1,
):
with patch(
"api.services.workflow.pipecat_engine.apply_disposition_mapping",
new_callable=AsyncMock,
return_value="completed",
):
with patch.object(
VariableExtractionManager,
"_perform_extraction",
new_callable=AsyncMock,
return_value={"user_intent": "end call"},
):
runner = PipelineRunner()
async def run_pipeline():
await runner.run(task)
async def initialize_engine():
await asyncio.sleep(0.01)
await engine.initialize()
await engine.set_node(engine.workflow.start_node_id)
await engine.llm.queue_frame(LLMContextFrame(engine.context))
await asyncio.wait_for(
asyncio.gather(run_pipeline(), initialize_engine()),
timeout=10.0,
)
assert len(captured_states) == 1, (
f"Expected the transition function to be invoked exactly once, "
f"got {len(captured_states)}: {captured_states}"
)
state = captured_states[0]
assert state["name"] == "end_call"
assert state["function_call_in_progress"], (
"FunctionCallUserMuteStrategy should have the transition call in "
f"progress while the handler runs (state={state})"
)
assert "call_end_1" in state["tool_call_ids"], (
f"Expected tool_call_id 'call_end_1' to be tracked, got {state['tool_call_ids']}"
)
assert state["user_is_muted"], (
"User aggregator's _user_is_muted should be True during the "
f"transition function (state={state})"
)
@pytest.mark.asyncio
async def test_user_is_unmuted_after_transition_function_returns(
self, simple_workflow: WorkflowGraph
):
"""After the transition function's result is delivered, the function-call
mute strategy should clear its in-progress set. Other strategies in the
stack (CallbackUserMuteStrategy via engine.should_mute_user) may still
keep the pipeline muted because end_call_with_reason fires when the
engine reaches the End node, but the function-call strategy itself
must release its hold.
"""
step_0_chunks = MockLLMService.create_function_call_chunks(
function_name="end_call",
arguments={},
tool_call_id="call_end_1",
)
llm = MockLLMService(mock_steps=[step_0_chunks], chunk_delay=0.001)
(
engine,
task,
function_call_mute_strategy,
_user_context_aggregator,
) = await _build_engine_and_pipeline(simple_workflow, llm)
with patch(
"api.db:db_client.get_organization_id_by_workflow_run_id",
new_callable=AsyncMock,
return_value=1,
):
with patch(
"api.services.workflow.pipecat_engine.apply_disposition_mapping",
new_callable=AsyncMock,
return_value="completed",
):
with patch.object(
VariableExtractionManager,
"_perform_extraction",
new_callable=AsyncMock,
return_value={"user_intent": "end call"},
):
runner = PipelineRunner()
async def run_pipeline():
await runner.run(task)
async def initialize_engine():
await asyncio.sleep(0.01)
await engine.initialize()
await engine.set_node(engine.workflow.start_node_id)
await engine.llm.queue_frame(LLMContextFrame(engine.context))
await asyncio.wait_for(
asyncio.gather(run_pipeline(), initialize_engine()),
timeout=10.0,
)
assert function_call_mute_strategy._function_call_in_progress == set(), (
"FunctionCallUserMuteStrategy should have cleared its in-progress "
"set after the transition function's result was delivered, got "
f"{function_call_mute_strategy._function_call_in_progress}"
)