mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-07 07:55:16 +02:00
* chore: bump pipecat version and fix tests * chore: add github workflow to run tests * fix: install reqirements.dev.txt in test script * fix: fix api-test action * feat: add integration test * test: add integration tests * test: add test for function call mute strategy
280 lines
11 KiB
Python
280 lines
11 KiB
Python
"""Tests verifying user is muted while a transition function is executing.
|
|
|
|
When the LLM calls a transition function (registered via
|
|
``_register_transition_function_with_llm``), pipecat broadcasts a
|
|
``FunctionCallsStartedFrame`` that ``FunctionCallUserMuteStrategy`` uses to
|
|
mute the user until a ``FunctionCallResultFrame`` arrives. These tests assert
|
|
that mute behavior holds end-to-end through the engine's transition flow,
|
|
so that user audio doesn't race the node switch / extraction / context update
|
|
that runs inside the transition function.
|
|
"""
|
|
|
|
import asyncio
|
|
from unittest.mock import AsyncMock, patch
|
|
|
|
import pytest
|
|
from pipecat.frames.frames import LLMContextFrame
|
|
from pipecat.pipeline.pipeline import Pipeline
|
|
from pipecat.pipeline.runner import PipelineRunner
|
|
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
|
from pipecat.processors.aggregators.llm_context import LLMContext
|
|
from pipecat.processors.aggregators.llm_response_universal import (
|
|
LLMAssistantAggregatorParams,
|
|
LLMContextAggregatorPair,
|
|
LLMUserAggregatorParams,
|
|
)
|
|
from pipecat.tests.mock_transport import MockTransport
|
|
from pipecat.transports.base_transport import TransportParams
|
|
from pipecat.turns.user_mute import (
|
|
CallbackUserMuteStrategy,
|
|
FunctionCallUserMuteStrategy,
|
|
MuteUntilFirstBotCompleteUserMuteStrategy,
|
|
)
|
|
|
|
from api.services.workflow.pipecat_engine import PipecatEngine
|
|
from api.services.workflow.pipecat_engine_variable_extractor import (
|
|
VariableExtractionManager,
|
|
)
|
|
from api.services.workflow.workflow import WorkflowGraph
|
|
from pipecat.tests import MockLLMService, MockTTSService
|
|
|
|
|
|
async def _build_engine_and_pipeline(
|
|
workflow: WorkflowGraph,
|
|
mock_llm: MockLLMService,
|
|
):
|
|
"""Set up engine + pipeline mirroring the non-realtime production wiring.
|
|
|
|
Returns (engine, task, function_call_mute_strategy, user_context_aggregator).
|
|
"""
|
|
tts = MockTTSService(mock_audio_duration_ms=40, frame_delay=0)
|
|
|
|
transport = MockTransport(
|
|
params=TransportParams(
|
|
audio_in_enabled=True,
|
|
audio_out_enabled=True,
|
|
audio_in_sample_rate=16000,
|
|
audio_out_sample_rate=16000,
|
|
),
|
|
)
|
|
|
|
context = LLMContext()
|
|
|
|
engine = PipecatEngine(
|
|
llm=mock_llm,
|
|
context=context,
|
|
workflow=workflow,
|
|
call_context_vars={"customer_name": "Test User"},
|
|
workflow_run_id=1,
|
|
)
|
|
|
|
# Hold a reference so the test can introspect the in-progress set.
|
|
function_call_mute_strategy = FunctionCallUserMuteStrategy()
|
|
|
|
# Match run_pipeline.py's non-realtime mute-strategy stack so the test
|
|
# exercises the same wiring that would be active in a real call.
|
|
user_mute_strategies = [
|
|
MuteUntilFirstBotCompleteUserMuteStrategy(),
|
|
function_call_mute_strategy,
|
|
CallbackUserMuteStrategy(should_mute_callback=engine.should_mute_user),
|
|
]
|
|
|
|
user_params = LLMUserAggregatorParams(user_mute_strategies=user_mute_strategies)
|
|
assistant_params = LLMAssistantAggregatorParams()
|
|
|
|
context_aggregator = LLMContextAggregatorPair(
|
|
context, assistant_params=assistant_params, user_params=user_params
|
|
)
|
|
user_context_aggregator = context_aggregator.user()
|
|
assistant_context_aggregator = context_aggregator.assistant()
|
|
|
|
pipeline = Pipeline(
|
|
[
|
|
transport.input(),
|
|
user_context_aggregator,
|
|
mock_llm,
|
|
tts,
|
|
transport.output(),
|
|
assistant_context_aggregator,
|
|
]
|
|
)
|
|
|
|
task = PipelineTask(pipeline, params=PipelineParams(), enable_rtvi=False)
|
|
engine.set_task(task)
|
|
|
|
return engine, task, function_call_mute_strategy, user_context_aggregator
|
|
|
|
|
|
class TestTransitionFunctionMutesUser:
|
|
"""Verify the user is muted while transition functions execute."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_user_is_muted_during_transition_function(
|
|
self, simple_workflow: WorkflowGraph
|
|
):
|
|
"""The user must be muted from the moment a transition function starts
|
|
until its result is delivered.
|
|
|
|
Scenario:
|
|
1. LLM calls the ``end_call`` transition function (start → end edge).
|
|
2. Wrap the registered handler so we can read mute state from inside it.
|
|
3. VERIFY: the function-call mute strategy has the call in flight.
|
|
4. VERIFY: the user aggregator's ``_user_is_muted`` flag is True.
|
|
"""
|
|
step_0_chunks = MockLLMService.create_function_call_chunks(
|
|
function_name="end_call",
|
|
arguments={},
|
|
tool_call_id="call_end_1",
|
|
)
|
|
llm = MockLLMService(mock_steps=[step_0_chunks], chunk_delay=0.001)
|
|
|
|
(
|
|
engine,
|
|
task,
|
|
function_call_mute_strategy,
|
|
user_context_aggregator,
|
|
) = await _build_engine_and_pipeline(simple_workflow, llm)
|
|
|
|
captured_states: list[dict] = []
|
|
|
|
# Wrap register_function so we can introspect mute state from inside
|
|
# the transition handler. We must wrap *after* the engine is created
|
|
# but *before* set_node registers the transition functions.
|
|
original_register_function = llm.register_function
|
|
|
|
def wrapping_register_function(name, func, *args, **kwargs):
|
|
async def wrapped(function_call_params):
|
|
# Yield once so the user aggregator has a chance to drain
|
|
# the broadcasted FunctionCallsStartedFrame and update its
|
|
# mute state before we sample it.
|
|
await asyncio.sleep(0.02)
|
|
captured_states.append(
|
|
{
|
|
"name": name,
|
|
"function_call_in_progress": bool(
|
|
function_call_mute_strategy._function_call_in_progress
|
|
),
|
|
"user_is_muted": user_context_aggregator._user_is_muted,
|
|
"tool_call_ids": set(
|
|
function_call_mute_strategy._function_call_in_progress
|
|
),
|
|
}
|
|
)
|
|
return await func(function_call_params)
|
|
|
|
return original_register_function(name, wrapped, *args, **kwargs)
|
|
|
|
llm.register_function = wrapping_register_function
|
|
|
|
with patch(
|
|
"api.db:db_client.get_organization_id_by_workflow_run_id",
|
|
new_callable=AsyncMock,
|
|
return_value=1,
|
|
):
|
|
with patch(
|
|
"api.services.workflow.pipecat_engine.apply_disposition_mapping",
|
|
new_callable=AsyncMock,
|
|
return_value="completed",
|
|
):
|
|
with patch.object(
|
|
VariableExtractionManager,
|
|
"_perform_extraction",
|
|
new_callable=AsyncMock,
|
|
return_value={"user_intent": "end call"},
|
|
):
|
|
runner = PipelineRunner()
|
|
|
|
async def run_pipeline():
|
|
await runner.run(task)
|
|
|
|
async def initialize_engine():
|
|
await asyncio.sleep(0.01)
|
|
await engine.initialize()
|
|
await engine.set_node(engine.workflow.start_node_id)
|
|
await engine.llm.queue_frame(LLMContextFrame(engine.context))
|
|
|
|
await asyncio.wait_for(
|
|
asyncio.gather(run_pipeline(), initialize_engine()),
|
|
timeout=10.0,
|
|
)
|
|
|
|
assert len(captured_states) == 1, (
|
|
f"Expected the transition function to be invoked exactly once, "
|
|
f"got {len(captured_states)}: {captured_states}"
|
|
)
|
|
state = captured_states[0]
|
|
assert state["name"] == "end_call"
|
|
assert state["function_call_in_progress"], (
|
|
"FunctionCallUserMuteStrategy should have the transition call in "
|
|
f"progress while the handler runs (state={state})"
|
|
)
|
|
assert "call_end_1" in state["tool_call_ids"], (
|
|
f"Expected tool_call_id 'call_end_1' to be tracked, got {state['tool_call_ids']}"
|
|
)
|
|
assert state["user_is_muted"], (
|
|
"User aggregator's _user_is_muted should be True during the "
|
|
f"transition function (state={state})"
|
|
)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_user_is_unmuted_after_transition_function_returns(
|
|
self, simple_workflow: WorkflowGraph
|
|
):
|
|
"""After the transition function's result is delivered, the function-call
|
|
mute strategy should clear its in-progress set. Other strategies in the
|
|
stack (CallbackUserMuteStrategy via engine.should_mute_user) may still
|
|
keep the pipeline muted because end_call_with_reason fires when the
|
|
engine reaches the End node, but the function-call strategy itself
|
|
must release its hold.
|
|
"""
|
|
step_0_chunks = MockLLMService.create_function_call_chunks(
|
|
function_name="end_call",
|
|
arguments={},
|
|
tool_call_id="call_end_1",
|
|
)
|
|
llm = MockLLMService(mock_steps=[step_0_chunks], chunk_delay=0.001)
|
|
|
|
(
|
|
engine,
|
|
task,
|
|
function_call_mute_strategy,
|
|
_user_context_aggregator,
|
|
) = await _build_engine_and_pipeline(simple_workflow, llm)
|
|
|
|
with patch(
|
|
"api.db:db_client.get_organization_id_by_workflow_run_id",
|
|
new_callable=AsyncMock,
|
|
return_value=1,
|
|
):
|
|
with patch(
|
|
"api.services.workflow.pipecat_engine.apply_disposition_mapping",
|
|
new_callable=AsyncMock,
|
|
return_value="completed",
|
|
):
|
|
with patch.object(
|
|
VariableExtractionManager,
|
|
"_perform_extraction",
|
|
new_callable=AsyncMock,
|
|
return_value={"user_intent": "end call"},
|
|
):
|
|
runner = PipelineRunner()
|
|
|
|
async def run_pipeline():
|
|
await runner.run(task)
|
|
|
|
async def initialize_engine():
|
|
await asyncio.sleep(0.01)
|
|
await engine.initialize()
|
|
await engine.set_node(engine.workflow.start_node_id)
|
|
await engine.llm.queue_frame(LLMContextFrame(engine.context))
|
|
|
|
await asyncio.wait_for(
|
|
asyncio.gather(run_pipeline(), initialize_engine()),
|
|
timeout=10.0,
|
|
)
|
|
|
|
assert function_call_mute_strategy._function_call_in_progress == set(), (
|
|
"FunctionCallUserMuteStrategy should have cleared its in-progress "
|
|
"set after the transition function's result was delivered, got "
|
|
f"{function_call_mute_strategy._function_call_in_progress}"
|
|
)
|