mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-13 08:15:21 +02:00
chore: refactor and add tests (#130)
* chore: add tests for end call * Update pipecat module * fix: allow interruptions from deepgram flux * Add VadUserTurnStrategy * chore: add test for voicemail detection
This commit is contained in:
parent
2aedb839ff
commit
033fde8946
15 changed files with 2106 additions and 542 deletions
|
|
@ -6,27 +6,20 @@ import pytest
|
|||
|
||||
from api.services.workflow.dto import (
|
||||
EdgeDataDTO,
|
||||
ExtractionVariableDTO,
|
||||
NodeDataDTO,
|
||||
NodeType,
|
||||
Position,
|
||||
ReactFlowDTO,
|
||||
RFEdgeDTO,
|
||||
RFNodeDTO,
|
||||
VariableType,
|
||||
)
|
||||
from api.services.workflow.workflow import WorkflowGraph
|
||||
from pipecat.frames.frames import (
|
||||
BotSpeakingFrame,
|
||||
BotStartedSpeakingFrame,
|
||||
BotStoppedSpeakingFrame,
|
||||
Frame,
|
||||
TTSAudioRawFrame,
|
||||
TTSStartedFrame,
|
||||
TTSStoppedFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
|
||||
START_CALL_SYSTEM_PROMPT = "start_call_system_prompt"
|
||||
END_CALL_SYSTEM_PROMPT = "end_call_system_prompt"
|
||||
START_CALL_SYSTEM_PROMPT = "Start Call System Prompt"
|
||||
AGENT_SYSTEM_PROMPT = "Agent Node System Prompt"
|
||||
END_CALL_SYSTEM_PROMPT = "End Call System Prompt"
|
||||
|
||||
# Default workflow definition for mocking database WorkflowModel
|
||||
DEFAULT_WORKFLOW_DEFINITION = {
|
||||
|
|
@ -110,57 +103,6 @@ class MockUserConfig:
|
|||
embeddings: Optional[Any] = None
|
||||
|
||||
|
||||
class MockTransportProcessor(FrameProcessor):
|
||||
"""
|
||||
Mocks the transport behavior by emitting Bot speaking frames
|
||||
when it encounters TTS frames.
|
||||
|
||||
This simulates what a real transport would do when the bot is speaking:
|
||||
- TTSStartedFrame -> BotStartedSpeakingFrame
|
||||
- TTSAudioRawFrame -> BotSpeakingFrame
|
||||
- TTSStoppedFrame -> BotStoppedSpeakingFrame
|
||||
|
||||
Args:
|
||||
emit_bot_speaking: If True, also emits BotSpeakingFrame on TTSAudioRawFrame
|
||||
which is needed for user idle tracking to start conversation tracking. Default True.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
emit_bot_speaking: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self._emit_bot_speaking = emit_bot_speaking
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, TTSStartedFrame):
|
||||
# Emit BotStartedSpeakingFrame to indicate bot started speaking
|
||||
await self.push_frame(BotStartedSpeakingFrame())
|
||||
await self.push_frame(
|
||||
BotStartedSpeakingFrame(), direction=FrameDirection.UPSTREAM
|
||||
)
|
||||
elif isinstance(frame, TTSAudioRawFrame):
|
||||
# Emit BotSpeakingFrame - this is what triggers user idle tracking
|
||||
# to start conversation tracking
|
||||
if self._emit_bot_speaking:
|
||||
await self.push_frame(BotSpeakingFrame())
|
||||
await self.push_frame(
|
||||
BotSpeakingFrame(), direction=FrameDirection.UPSTREAM
|
||||
)
|
||||
elif isinstance(frame, TTSStoppedFrame):
|
||||
# Emit BotStoppedSpeakingFrame to indicate bot stopped speaking
|
||||
await self.push_frame(BotStoppedSpeakingFrame())
|
||||
await self.push_frame(
|
||||
BotStoppedSpeakingFrame(), direction=FrameDirection.UPSTREAM
|
||||
)
|
||||
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MockToolModel:
|
||||
"""Mock tool model for testing."""
|
||||
|
|
@ -299,14 +241,14 @@ def simple_workflow() -> WorkflowGraph:
|
|||
"""Create a simple two-node workflow for testing.
|
||||
|
||||
The workflow has:
|
||||
- Start node with a prompt
|
||||
- Start node with extraction enabled (extracts user_intent)
|
||||
- End node with a prompt
|
||||
- One edge connecting them with label "End Call"
|
||||
"""
|
||||
dto = ReactFlowDTO(
|
||||
nodes=[
|
||||
RFNodeDTO(
|
||||
id="1",
|
||||
id="start",
|
||||
type=NodeType.startNode,
|
||||
position=Position(x=0, y=0),
|
||||
data=NodeDataDTO(
|
||||
|
|
@ -315,10 +257,19 @@ def simple_workflow() -> WorkflowGraph:
|
|||
is_start=True,
|
||||
allow_interrupt=False,
|
||||
add_global_prompt=False,
|
||||
extraction_enabled=True,
|
||||
extraction_prompt="Extract user information from the conversation.",
|
||||
extraction_variables=[
|
||||
ExtractionVariableDTO(
|
||||
name="user_intent",
|
||||
type=VariableType.string,
|
||||
prompt="The user's intent or reason for calling",
|
||||
),
|
||||
],
|
||||
),
|
||||
),
|
||||
RFNodeDTO(
|
||||
id="2",
|
||||
id="end",
|
||||
type=NodeType.endNode,
|
||||
position=Position(x=0, y=200),
|
||||
data=NodeDataDTO(
|
||||
|
|
@ -327,14 +278,15 @@ def simple_workflow() -> WorkflowGraph:
|
|||
is_end=True,
|
||||
allow_interrupt=False,
|
||||
add_global_prompt=False,
|
||||
extraction_enabled=False,
|
||||
),
|
||||
),
|
||||
],
|
||||
edges=[
|
||||
RFEdgeDTO(
|
||||
id="1-2",
|
||||
source="1",
|
||||
target="2",
|
||||
id="start-end",
|
||||
source="start",
|
||||
target="end",
|
||||
data=EdgeDataDTO(
|
||||
label="End Call",
|
||||
condition="When the user says to end the call, end the call",
|
||||
|
|
@ -350,37 +302,59 @@ def three_node_workflow() -> WorkflowGraph:
|
|||
"""Create a three-node workflow for testing with an intermediate agent node.
|
||||
|
||||
The workflow has:
|
||||
- Start node
|
||||
- Agent node (for collecting information)
|
||||
- End node
|
||||
- Start node with extraction enabled (extracts greeting_type)
|
||||
- Agent node with extraction enabled (extracts user_name)
|
||||
- End node (no extraction)
|
||||
|
||||
Edges:
|
||||
- Start -> Agent (label: "Collect Info")
|
||||
- Agent -> End (label: "End Call")
|
||||
"""
|
||||
dto = ReactFlowDTO(
|
||||
nodes=[
|
||||
RFNodeDTO(
|
||||
id="1",
|
||||
id="start",
|
||||
type=NodeType.startNode,
|
||||
position=Position(x=0, y=0),
|
||||
data=NodeDataDTO(
|
||||
name="Start Call",
|
||||
prompt=START_CALL_SYSTEM_PROMPT,
|
||||
is_start=True,
|
||||
allow_interrupt=True,
|
||||
allow_interrupt=False,
|
||||
add_global_prompt=False,
|
||||
extraction_enabled=True,
|
||||
extraction_prompt="Extract greeting information from the conversation.",
|
||||
extraction_variables=[
|
||||
ExtractionVariableDTO(
|
||||
name="greeting_type",
|
||||
type=VariableType.string,
|
||||
prompt="The type of greeting used",
|
||||
),
|
||||
],
|
||||
),
|
||||
),
|
||||
RFNodeDTO(
|
||||
id="2",
|
||||
id="agent",
|
||||
type=NodeType.agentNode,
|
||||
position=Position(x=0, y=200),
|
||||
data=NodeDataDTO(
|
||||
name="Collect Info",
|
||||
prompt="Help the user with their request. Ask clarifying questions if needed.",
|
||||
allow_interrupt=True,
|
||||
prompt=AGENT_SYSTEM_PROMPT,
|
||||
allow_interrupt=False,
|
||||
add_global_prompt=False,
|
||||
extraction_enabled=True,
|
||||
extraction_prompt="Extract user details from the conversation.",
|
||||
extraction_variables=[
|
||||
ExtractionVariableDTO(
|
||||
name="user_name",
|
||||
type=VariableType.string,
|
||||
prompt="The user's name",
|
||||
),
|
||||
],
|
||||
),
|
||||
),
|
||||
RFNodeDTO(
|
||||
id="3",
|
||||
id="end",
|
||||
type=NodeType.endNode,
|
||||
position=Position(x=0, y=400),
|
||||
data=NodeDataDTO(
|
||||
|
|
@ -389,26 +363,187 @@ def three_node_workflow() -> WorkflowGraph:
|
|||
is_end=True,
|
||||
allow_interrupt=False,
|
||||
add_global_prompt=False,
|
||||
extraction_enabled=False,
|
||||
),
|
||||
),
|
||||
],
|
||||
edges=[
|
||||
RFEdgeDTO(
|
||||
id="1-2",
|
||||
source="1",
|
||||
target="2",
|
||||
id="start-agent",
|
||||
source="start",
|
||||
target="agent",
|
||||
data=EdgeDataDTO(
|
||||
label="Collect Info",
|
||||
condition="When the user wants help, collect their information",
|
||||
condition="When user has been greeted, proceed to collect information",
|
||||
),
|
||||
),
|
||||
RFEdgeDTO(
|
||||
id="2-3",
|
||||
source="2",
|
||||
target="3",
|
||||
id="agent-end",
|
||||
source="agent",
|
||||
target="end",
|
||||
data=EdgeDataDTO(
|
||||
label="End Call",
|
||||
condition="When the user is done or wants to end the call",
|
||||
condition="When information collection is complete, end the call",
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
return WorkflowGraph(dto)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def three_node_workflow_extraction_start_only() -> WorkflowGraph:
|
||||
"""Create a three-node workflow with extraction enabled ONLY on start node.
|
||||
|
||||
This fixture is specifically for testing that variable extraction is triggered
|
||||
for the correct node during transitions. The agent node has extraction disabled
|
||||
to verify extraction happens for the SOURCE node, not the TARGET node.
|
||||
|
||||
The workflow has:
|
||||
- Start node with extraction enabled (extracts user_name)
|
||||
- Agent node with extraction DISABLED
|
||||
- End node (no extraction)
|
||||
"""
|
||||
dto = ReactFlowDTO(
|
||||
nodes=[
|
||||
RFNodeDTO(
|
||||
id="start",
|
||||
type=NodeType.startNode,
|
||||
position=Position(x=0, y=0),
|
||||
data=NodeDataDTO(
|
||||
name="Start Call",
|
||||
prompt=START_CALL_SYSTEM_PROMPT,
|
||||
is_start=True,
|
||||
allow_interrupt=False,
|
||||
add_global_prompt=False,
|
||||
extraction_enabled=True,
|
||||
extraction_prompt="Extract the user's name from the conversation.",
|
||||
extraction_variables=[
|
||||
ExtractionVariableDTO(
|
||||
name="user_name",
|
||||
type=VariableType.string,
|
||||
prompt="The name the user provided",
|
||||
),
|
||||
],
|
||||
),
|
||||
),
|
||||
RFNodeDTO(
|
||||
id="agent",
|
||||
type=NodeType.agentNode,
|
||||
position=Position(x=0, y=200),
|
||||
data=NodeDataDTO(
|
||||
name="Collect Info",
|
||||
prompt=AGENT_SYSTEM_PROMPT,
|
||||
allow_interrupt=False,
|
||||
add_global_prompt=False,
|
||||
extraction_enabled=False, # Explicitly disabled for testing
|
||||
),
|
||||
),
|
||||
RFNodeDTO(
|
||||
id="end",
|
||||
type=NodeType.endNode,
|
||||
position=Position(x=0, y=400),
|
||||
data=NodeDataDTO(
|
||||
name="End Call",
|
||||
prompt=END_CALL_SYSTEM_PROMPT,
|
||||
is_end=True,
|
||||
allow_interrupt=False,
|
||||
add_global_prompt=False,
|
||||
extraction_enabled=False,
|
||||
),
|
||||
),
|
||||
],
|
||||
edges=[
|
||||
RFEdgeDTO(
|
||||
id="start-agent",
|
||||
source="start",
|
||||
target="agent",
|
||||
data=EdgeDataDTO(
|
||||
label="Collect Info",
|
||||
condition="When user has been greeted, proceed to collect information",
|
||||
),
|
||||
),
|
||||
RFEdgeDTO(
|
||||
id="agent-end",
|
||||
source="agent",
|
||||
target="end",
|
||||
data=EdgeDataDTO(
|
||||
label="End Call",
|
||||
condition="When information collection is complete, end the call",
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
return WorkflowGraph(dto)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def three_node_workflow_no_variable_extraction() -> WorkflowGraph:
|
||||
"""Create a three-node workflow without variable extraction
|
||||
|
||||
The workflow has:
|
||||
- Start node with extraction DISABLED
|
||||
- Agent node with extraction DISABLED
|
||||
- End node (no extraction)
|
||||
"""
|
||||
dto = ReactFlowDTO(
|
||||
nodes=[
|
||||
RFNodeDTO(
|
||||
id="start",
|
||||
type=NodeType.startNode,
|
||||
position=Position(x=0, y=0),
|
||||
data=NodeDataDTO(
|
||||
name="Start Call",
|
||||
prompt=START_CALL_SYSTEM_PROMPT,
|
||||
is_start=True,
|
||||
allow_interrupt=False,
|
||||
add_global_prompt=False,
|
||||
extraction_enabled=False,
|
||||
),
|
||||
),
|
||||
RFNodeDTO(
|
||||
id="agent",
|
||||
type=NodeType.agentNode,
|
||||
position=Position(x=0, y=200),
|
||||
data=NodeDataDTO(
|
||||
name="Collect Info",
|
||||
prompt=AGENT_SYSTEM_PROMPT,
|
||||
allow_interrupt=False,
|
||||
add_global_prompt=False,
|
||||
extraction_enabled=False, # Explicitly disabled for testing
|
||||
),
|
||||
),
|
||||
RFNodeDTO(
|
||||
id="end",
|
||||
type=NodeType.endNode,
|
||||
position=Position(x=0, y=400),
|
||||
data=NodeDataDTO(
|
||||
name="End Call",
|
||||
prompt=END_CALL_SYSTEM_PROMPT,
|
||||
is_end=True,
|
||||
allow_interrupt=False,
|
||||
add_global_prompt=False,
|
||||
extraction_enabled=False,
|
||||
),
|
||||
),
|
||||
],
|
||||
edges=[
|
||||
RFEdgeDTO(
|
||||
id="start-agent",
|
||||
source="start",
|
||||
target="agent",
|
||||
data=EdgeDataDTO(
|
||||
label="Collect Info",
|
||||
condition="When user has been greeted, proceed to collect information",
|
||||
),
|
||||
),
|
||||
RFEdgeDTO(
|
||||
id="agent-end",
|
||||
source="agent",
|
||||
target="end",
|
||||
data=EdgeDataDTO(
|
||||
label="End Call",
|
||||
condition="When information collection is complete, end the call",
|
||||
),
|
||||
),
|
||||
],
|
||||
|
|
|
|||
|
|
@ -19,18 +19,13 @@ from unittest.mock import AsyncMock, patch
|
|||
|
||||
import pytest
|
||||
|
||||
from api.services.workflow.dto import (
|
||||
EdgeDataDTO,
|
||||
NodeDataDTO,
|
||||
NodeType,
|
||||
Position,
|
||||
ReactFlowDTO,
|
||||
RFEdgeDTO,
|
||||
RFNodeDTO,
|
||||
)
|
||||
from api.services.workflow.pipecat_engine import PipecatEngine
|
||||
from api.services.workflow.workflow import WorkflowGraph
|
||||
from api.tests.conftest import MockTransportProcessor
|
||||
from api.tests.conftest import (
|
||||
AGENT_SYSTEM_PROMPT,
|
||||
END_CALL_SYSTEM_PROMPT,
|
||||
START_CALL_SYSTEM_PROMPT,
|
||||
)
|
||||
from pipecat.frames.frames import LLMContextFrame
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
|
|
@ -41,86 +36,7 @@ from pipecat.processors.aggregators.llm_response_universal import (
|
|||
LLMContextAggregatorPair,
|
||||
)
|
||||
from pipecat.tests import MockLLMService, MockTTSService
|
||||
|
||||
# Define prompts for test nodes
|
||||
START_NODE_PROMPT = "Start Node System Prompt"
|
||||
AGENT_NODE_PROMPT = "Agent Node System Prompt"
|
||||
END_NODE_PROMPT = "End Node System Prompt"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def three_node_workflow_for_context_test() -> WorkflowGraph:
|
||||
"""Create a three-node workflow for testing context updates during transitions.
|
||||
|
||||
The workflow has:
|
||||
- Start node with prompt to greet user
|
||||
- Agent node with prompt to collect information
|
||||
- End node with prompt to say goodbye
|
||||
|
||||
Edges:
|
||||
- Start -> Agent (label: "Collect Info")
|
||||
- Agent -> End (label: "End Call")
|
||||
"""
|
||||
dto = ReactFlowDTO(
|
||||
nodes=[
|
||||
RFNodeDTO(
|
||||
id="start",
|
||||
type=NodeType.startNode,
|
||||
position=Position(x=0, y=0),
|
||||
data=NodeDataDTO(
|
||||
name="Start Call",
|
||||
prompt=START_NODE_PROMPT,
|
||||
is_start=True,
|
||||
allow_interrupt=False,
|
||||
add_global_prompt=False,
|
||||
),
|
||||
),
|
||||
RFNodeDTO(
|
||||
id="agent",
|
||||
type=NodeType.agentNode,
|
||||
position=Position(x=0, y=200),
|
||||
data=NodeDataDTO(
|
||||
name="Collect Info",
|
||||
prompt=AGENT_NODE_PROMPT,
|
||||
allow_interrupt=False,
|
||||
add_global_prompt=False,
|
||||
),
|
||||
),
|
||||
RFNodeDTO(
|
||||
id="end",
|
||||
type=NodeType.endNode,
|
||||
position=Position(x=0, y=400),
|
||||
data=NodeDataDTO(
|
||||
name="End Call",
|
||||
prompt=END_NODE_PROMPT,
|
||||
is_end=True,
|
||||
allow_interrupt=False,
|
||||
add_global_prompt=False,
|
||||
),
|
||||
),
|
||||
],
|
||||
edges=[
|
||||
RFEdgeDTO(
|
||||
id="start-agent",
|
||||
source="start",
|
||||
target="agent",
|
||||
data=EdgeDataDTO(
|
||||
label="Collect Info",
|
||||
condition="When user has been greeted, proceed to collect information",
|
||||
),
|
||||
),
|
||||
RFEdgeDTO(
|
||||
id="agent-end",
|
||||
source="agent",
|
||||
target="end",
|
||||
data=EdgeDataDTO(
|
||||
label="End Call",
|
||||
condition="When information collection is complete, end the call",
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
return WorkflowGraph(dto)
|
||||
from pipecat.tests.mock_transport import MockTransport
|
||||
|
||||
|
||||
class ContextCapturingMockLLM(MockLLMService):
|
||||
|
|
@ -215,7 +131,8 @@ async def run_pipeline_and_capture_context(
|
|||
# Create MockTTSService
|
||||
tts = MockTTSService(mock_audio_duration_ms=10, frame_delay=0)
|
||||
|
||||
mock_transport_emulator = MockTransportProcessor(emit_bot_speaking=False)
|
||||
# Create MockTransport for simulating transport behavior
|
||||
mock_transport = MockTransport(emit_bot_speaking=False)
|
||||
|
||||
# Create LLM context
|
||||
context = LLMContext()
|
||||
|
|
@ -251,15 +168,13 @@ async def run_pipeline_and_capture_context(
|
|||
[
|
||||
llm,
|
||||
tts,
|
||||
mock_transport_emulator,
|
||||
mock_transport.output(),
|
||||
assistant_context_aggregator,
|
||||
]
|
||||
)
|
||||
|
||||
# Create pipeline task
|
||||
task = PipelineTask(
|
||||
pipeline, params=PipelineParams(allow_interruptions=False), enable_rtvi=False
|
||||
)
|
||||
task = PipelineTask(pipeline, params=PipelineParams(), enable_rtvi=False)
|
||||
|
||||
engine.set_task(task)
|
||||
|
||||
|
|
@ -294,7 +209,7 @@ class TestContextUpdateBeforeNextCompletion:
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_single_transition_updates_context_before_next_completion(
|
||||
self, three_node_workflow_for_context_test: WorkflowGraph
|
||||
self, three_node_workflow: WorkflowGraph
|
||||
):
|
||||
"""Test that a single transition function call updates context before next LLM generation.
|
||||
|
||||
|
|
@ -329,7 +244,7 @@ class TestContextUpdateBeforeNextCompletion:
|
|||
mock_steps = [step_0_chunks, step_1_chunks, step_2_chunks]
|
||||
|
||||
llm, _ = await run_pipeline_and_capture_context(
|
||||
workflow=three_node_workflow_for_context_test,
|
||||
workflow=three_node_workflow,
|
||||
mock_steps=mock_steps,
|
||||
set_node_delay=0.05, # Introduce 50ms delay in set_node
|
||||
)
|
||||
|
|
@ -341,17 +256,17 @@ class TestContextUpdateBeforeNextCompletion:
|
|||
|
||||
# Verify step 0 (start node) had start node's system prompt
|
||||
step_0_prompt = llm.get_system_prompt_at_step(0)
|
||||
assert START_NODE_PROMPT in step_0_prompt, (
|
||||
assert START_CALL_SYSTEM_PROMPT in step_0_prompt, (
|
||||
f"Step 0 should have start node prompt, got: {step_0_prompt[:100]}"
|
||||
)
|
||||
|
||||
# Verify step 1 (agent node) had:
|
||||
# 1. The agent node's system prompt (not start node's)
|
||||
step_1_prompt = llm.get_system_prompt_at_step(1)
|
||||
assert AGENT_NODE_PROMPT in step_1_prompt, (
|
||||
assert AGENT_SYSTEM_PROMPT in step_1_prompt, (
|
||||
f"Step 1 should have agent node prompt, got: {step_1_prompt[:100]}"
|
||||
)
|
||||
assert START_NODE_PROMPT not in step_1_prompt, (
|
||||
assert START_CALL_SYSTEM_PROMPT not in step_1_prompt, (
|
||||
"Step 1 should NOT have start node prompt anymore"
|
||||
)
|
||||
|
||||
|
|
@ -371,7 +286,7 @@ class TestContextUpdateBeforeNextCompletion:
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sequential_transitions_maintain_correct_context(
|
||||
self, three_node_workflow_for_context_test: WorkflowGraph
|
||||
self, three_node_workflow: WorkflowGraph
|
||||
):
|
||||
"""Test that sequential transitions maintain correct context at each step.
|
||||
|
||||
|
|
@ -403,7 +318,7 @@ class TestContextUpdateBeforeNextCompletion:
|
|||
mock_steps = [step_0_chunks, step_1_chunks, step_2_chunks]
|
||||
|
||||
llm, _ = await run_pipeline_and_capture_context(
|
||||
workflow=three_node_workflow_for_context_test,
|
||||
workflow=three_node_workflow,
|
||||
mock_steps=mock_steps,
|
||||
set_node_delay=0.05, # Introduce 50ms delay in set_node
|
||||
)
|
||||
|
|
@ -414,13 +329,13 @@ class TestContextUpdateBeforeNextCompletion:
|
|||
)
|
||||
|
||||
# Step 0: Start node - should have start prompt
|
||||
assert START_NODE_PROMPT in llm.get_system_prompt_at_step(0)
|
||||
assert START_CALL_SYSTEM_PROMPT in llm.get_system_prompt_at_step(0)
|
||||
|
||||
# Step 1: Agent node - should have agent prompt
|
||||
assert AGENT_NODE_PROMPT in llm.get_system_prompt_at_step(1)
|
||||
assert AGENT_SYSTEM_PROMPT in llm.get_system_prompt_at_step(1)
|
||||
|
||||
# Step 2: End node - should have end prompt
|
||||
assert END_NODE_PROMPT in llm.get_system_prompt_at_step(2)
|
||||
assert END_CALL_SYSTEM_PROMPT in llm.get_system_prompt_at_step(2)
|
||||
|
||||
# Verify each subsequent step has the previous tool results
|
||||
step_1_ctx = llm.get_context_at_step(1)
|
||||
|
|
@ -445,7 +360,7 @@ class TestContextUpdateBeforeNextCompletion:
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_context_messages_preserve_conversation_history(
|
||||
self, three_node_workflow_for_context_test: WorkflowGraph
|
||||
self, three_node_workflow: WorkflowGraph
|
||||
):
|
||||
"""Test that conversation history is preserved across node transitions.
|
||||
|
||||
|
|
@ -474,7 +389,7 @@ class TestContextUpdateBeforeNextCompletion:
|
|||
mock_steps = [step_0_chunks, step_1_chunks, step_2_chunks]
|
||||
|
||||
llm, _ = await run_pipeline_and_capture_context(
|
||||
workflow=three_node_workflow_for_context_test,
|
||||
workflow=three_node_workflow,
|
||||
mock_steps=mock_steps,
|
||||
)
|
||||
|
||||
|
|
|
|||
1097
api/tests/test_pipecat_engine_end_call.py
Normal file
1097
api/tests/test_pipecat_engine_end_call.py
Normal file
File diff suppressed because it is too large
Load diff
393
api/tests/test_pipecat_engine_node_switch_with_user_speech.py
Normal file
393
api/tests/test_pipecat_engine_node_switch_with_user_speech.py
Normal file
|
|
@ -0,0 +1,393 @@
|
|||
"""Tests for verifying behavior when node switch and user speech happen simultaneously.
|
||||
|
||||
This module tests the interaction between node transitions and user speaking events
|
||||
in the PipecatEngine. The key scenario being tested:
|
||||
|
||||
1. LLM calls a transition function to move from one node to another
|
||||
2. At the same time, user starts and stops speaking (triggered by FunctionCallResultFrame)
|
||||
3. The pipeline should handle both events correctly
|
||||
|
||||
The tests use a custom input transport that injects UserStartedSpeakingFrame and
|
||||
UserStoppedSpeakingFrame when triggered by a FunctionCallResultFrame observer.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import Optional
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from api.services.workflow.pipecat_engine import PipecatEngine
|
||||
from api.services.workflow.workflow import WorkflowGraph
|
||||
from pipecat.frames.frames import (
|
||||
CancelFrame,
|
||||
EndFrame,
|
||||
Frame,
|
||||
FunctionCallResultFrame,
|
||||
InputAudioRawFrame,
|
||||
LLMContextFrame,
|
||||
StartFrame,
|
||||
TranscriptionFrame,
|
||||
UserStartedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
)
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.llm_response import LLMAssistantAggregatorParams
|
||||
from pipecat.processors.aggregators.llm_response_universal import (
|
||||
LLMContextAggregatorPair,
|
||||
LLMUserAggregatorParams,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
from pipecat.tests import MockLLMService, MockTTSService
|
||||
from pipecat.tests.mock_transport import MockOutputTransport
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
from pipecat.turns.user_mute import (
|
||||
CallbackUserMuteStrategy,
|
||||
MuteUntilFirstBotCompleteUserMuteStrategy,
|
||||
)
|
||||
from pipecat.turns.user_start import (
|
||||
TranscriptionUserTurnStartStrategy,
|
||||
)
|
||||
from pipecat.turns.user_stop import (
|
||||
ExternalUserTurnStopStrategy,
|
||||
)
|
||||
from pipecat.turns.user_turn_strategies import UserTurnStrategies
|
||||
from pipecat.utils.time import time_now_iso8601
|
||||
|
||||
|
||||
class UserSpeechInjectingInputTransport(FrameProcessor):
|
||||
"""Mock input transport that injects user speaking frames on FunctionCallResultFrame.
|
||||
|
||||
This transport generates audio frames and automatically injects UserStartedSpeakingFrame
|
||||
and UserStoppedSpeakingFrame when it sees the first FunctionCallResultFrame flowing
|
||||
upstream through the pipeline.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
params: Optional[TransportParams] = None,
|
||||
*,
|
||||
generate_audio: bool = False,
|
||||
audio_interval_ms: int = 20,
|
||||
sample_rate: int = 16000,
|
||||
num_channels: int = 1,
|
||||
user_speech_initial_delay: float = 0.01,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self._params = params or TransportParams()
|
||||
self._generate_audio = generate_audio
|
||||
self._audio_interval_ms = audio_interval_ms
|
||||
self._sample_rate = sample_rate
|
||||
self._num_channels = num_channels
|
||||
self._user_speech_initial_delay = user_speech_initial_delay
|
||||
self._audio_task: Optional[asyncio.Task] = None
|
||||
self._running = False
|
||||
self._function_call_result_count = 0
|
||||
|
||||
async def _generate_audio_frames(self):
|
||||
"""Generate audio frames at regular intervals."""
|
||||
samples_per_frame = int(self._sample_rate * self._audio_interval_ms / 1000)
|
||||
bytes_per_frame = samples_per_frame * self._num_channels * 2
|
||||
silence_audio = bytes(bytes_per_frame)
|
||||
|
||||
while self._running:
|
||||
try:
|
||||
frame = InputAudioRawFrame(
|
||||
audio=silence_audio,
|
||||
sample_rate=self._sample_rate,
|
||||
num_channels=self._num_channels,
|
||||
)
|
||||
await self.push_frame(frame)
|
||||
await asyncio.sleep(self._audio_interval_ms / 1000)
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
|
||||
def _start_tasks(self):
|
||||
"""Start audio generation task."""
|
||||
if not self._running:
|
||||
self._running = True
|
||||
if self._generate_audio:
|
||||
self._audio_task = asyncio.create_task(self._generate_audio_frames())
|
||||
|
||||
def _stop_tasks(self):
|
||||
"""Stop all background tasks."""
|
||||
self._running = False
|
||||
if self._audio_task and not self._audio_task.done():
|
||||
self._audio_task.cancel()
|
||||
self._audio_task = None
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, StartFrame):
|
||||
self._start_tasks()
|
||||
elif isinstance(frame, (EndFrame, CancelFrame)):
|
||||
self._stop_tasks()
|
||||
elif isinstance(frame, FunctionCallResultFrame):
|
||||
# When we see FunctionCallResultFrame #1 flowing upstream,
|
||||
# inject user speaking frames downstream
|
||||
self._function_call_result_count += 1
|
||||
if self._function_call_result_count == 1:
|
||||
# Simulate first race condition to generate
|
||||
# LLM call close enough to the LLM call from
|
||||
# function call
|
||||
await asyncio.sleep(self._user_speech_initial_delay)
|
||||
await self.push_frame(UserStartedSpeakingFrame())
|
||||
|
||||
await asyncio.sleep(0)
|
||||
await self.push_frame(
|
||||
TranscriptionFrame("First User Speech", "abc", time_now_iso8601())
|
||||
)
|
||||
|
||||
await asyncio.sleep(0)
|
||||
await self.push_frame(UserStoppedSpeakingFrame())
|
||||
|
||||
# Generate second llm call
|
||||
await asyncio.sleep(0.1)
|
||||
await self.push_frame(UserStartedSpeakingFrame())
|
||||
|
||||
await asyncio.sleep(0)
|
||||
await self.push_frame(
|
||||
TranscriptionFrame("Second User Speech", "abc", time_now_iso8601())
|
||||
)
|
||||
|
||||
await asyncio.sleep(0)
|
||||
await self.push_frame(UserStoppedSpeakingFrame())
|
||||
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
async def cleanup(self):
|
||||
self._stop_tasks()
|
||||
await super().cleanup()
|
||||
|
||||
|
||||
class UserSpeechInjectingTransport(BaseTransport):
|
||||
"""Transport that injects user speaking frames on first FunctionCallResultFrame."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
params: Optional[TransportParams] = None,
|
||||
*,
|
||||
input_name: Optional[str] = None,
|
||||
output_name: Optional[str] = None,
|
||||
emit_bot_speaking: bool = True,
|
||||
generate_audio: bool = False,
|
||||
audio_interval_ms: int = 20,
|
||||
audio_sample_rate: int = 16000,
|
||||
audio_num_channels: int = 1,
|
||||
user_speech_initial_delay: float = 0.01,
|
||||
):
|
||||
super().__init__(input_name=input_name, output_name=output_name)
|
||||
self._params = params or TransportParams()
|
||||
self._input = UserSpeechInjectingInputTransport(
|
||||
self._params,
|
||||
name=self._input_name,
|
||||
generate_audio=generate_audio,
|
||||
audio_interval_ms=audio_interval_ms,
|
||||
sample_rate=audio_sample_rate,
|
||||
num_channels=audio_num_channels,
|
||||
user_speech_initial_delay=user_speech_initial_delay,
|
||||
)
|
||||
self._output = MockOutputTransport(
|
||||
self._params,
|
||||
emit_bot_speaking=emit_bot_speaking,
|
||||
name=self._output_name,
|
||||
)
|
||||
|
||||
def input(self) -> UserSpeechInjectingInputTransport:
|
||||
return self._input
|
||||
|
||||
def output(self) -> FrameProcessor:
|
||||
return self._output
|
||||
|
||||
|
||||
async def create_test_pipeline(
|
||||
workflow: WorkflowGraph,
|
||||
mock_llm: MockLLMService,
|
||||
generate_audio: bool = True,
|
||||
user_speech_initial_delay: float = 0.01,
|
||||
) -> tuple[PipecatEngine, UserSpeechInjectingTransport, PipelineTask]:
|
||||
"""Create a PipecatEngine with full pipeline for testing node switch scenarios.
|
||||
|
||||
The transport's input automatically injects UserStartedSpeakingFrame and
|
||||
UserStoppedSpeakingFrame when it sees the first FunctionCallResultFrame
|
||||
flowing upstream through the pipeline.
|
||||
|
||||
Args:
|
||||
workflow: The workflow graph to use.
|
||||
mock_llm: The mock LLM service.
|
||||
generate_audio: If True, the mock transport generates InputAudioRawFrame
|
||||
every 20ms to simulate real audio input.
|
||||
user_speech_initial_delay: Delay in seconds before injecting
|
||||
UserStartedSpeakingFrame after seeing FunctionCallResultFrame.
|
||||
|
||||
Returns:
|
||||
Tuple of (engine, transport, task)
|
||||
"""
|
||||
# Create MockTTSService
|
||||
tts = MockTTSService(mock_audio_duration_ms=10, frame_delay=0)
|
||||
|
||||
# Create custom transport that injects user speaking frames on FunctionCallResultFrame #1
|
||||
transport = UserSpeechInjectingTransport(
|
||||
generate_audio=generate_audio,
|
||||
audio_interval_ms=20,
|
||||
audio_sample_rate=16000,
|
||||
audio_num_channels=1,
|
||||
emit_bot_speaking=True,
|
||||
user_speech_initial_delay=user_speech_initial_delay,
|
||||
)
|
||||
|
||||
# Create LLM context
|
||||
context = LLMContext()
|
||||
|
||||
# Create PipecatEngine with the workflow
|
||||
engine = PipecatEngine(
|
||||
llm=mock_llm,
|
||||
context=context,
|
||||
workflow=workflow,
|
||||
call_context_vars={"customer_name": "Test User"},
|
||||
workflow_run_id=1,
|
||||
)
|
||||
|
||||
# Create user turn strategies matching run_pipeline.py
|
||||
user_turn_strategies = UserTurnStrategies(
|
||||
start=[TranscriptionUserTurnStartStrategy()],
|
||||
stop=[ExternalUserTurnStopStrategy()],
|
||||
)
|
||||
|
||||
# Create user mute strategies matching run_pipeline.py
|
||||
user_mute_strategies = [
|
||||
MuteUntilFirstBotCompleteUserMuteStrategy(),
|
||||
CallbackUserMuteStrategy(should_mute_callback=engine.should_mute_user),
|
||||
]
|
||||
|
||||
user_params = LLMUserAggregatorParams(
|
||||
user_turn_strategies=user_turn_strategies,
|
||||
user_mute_strategies=user_mute_strategies,
|
||||
)
|
||||
|
||||
# Create context aggregator with user and assistant params
|
||||
assistant_params = LLMAssistantAggregatorParams(expect_stripped_words=True)
|
||||
|
||||
context_aggregator = LLMContextAggregatorPair(
|
||||
context, assistant_params=assistant_params, user_params=user_params
|
||||
)
|
||||
user_context_aggregator = context_aggregator.user()
|
||||
assistant_context_aggregator = context_aggregator.assistant()
|
||||
|
||||
# Create the pipeline:
|
||||
# transport.input() -> user_aggregator -> LLM -> TTS -> transport.output() -> assistant_aggregator
|
||||
# The transport input watches for FunctionCallResultFrame flowing upstream
|
||||
# and injects user speaking frames when it sees the first one
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
transport.input(),
|
||||
user_context_aggregator,
|
||||
mock_llm,
|
||||
tts,
|
||||
transport.output(),
|
||||
assistant_context_aggregator,
|
||||
]
|
||||
)
|
||||
|
||||
# Create pipeline task
|
||||
task = PipelineTask(pipeline, params=PipelineParams(), enable_rtvi=False)
|
||||
|
||||
engine.set_task(task)
|
||||
|
||||
return engine, transport, task
|
||||
|
||||
|
||||
class TestNodeSwitchWithUserSpeech:
|
||||
"""Test scenarios where node switch and user speech happen simultaneously."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"user_speech_initial_delay,scenario_name",
|
||||
[
|
||||
(0.01, "delayed"),
|
||||
(0, "immediate"),
|
||||
],
|
||||
ids=["delayed_user_speech", "immediate_user_speech"],
|
||||
)
|
||||
async def test_node_switch_with_concurrent_user_speech(
|
||||
self,
|
||||
three_node_workflow_no_variable_extraction: WorkflowGraph,
|
||||
user_speech_initial_delay: float,
|
||||
scenario_name: str,
|
||||
):
|
||||
"""Test scenario: node transition happens while user is speaking.
|
||||
|
||||
This test creates the scenario where:
|
||||
1. LLM generates text and calls collect_info to transition from start to agent
|
||||
2. When FunctionCallResultFrame #1 is seen, UserStartedSpeakingFrame and
|
||||
UserStoppedSpeakingFrame are automatically injected from the pipeline source
|
||||
3. The pipeline processes both events concurrently
|
||||
|
||||
The FunctionCallResultObserver in the pipeline detects the first function call
|
||||
result and triggers the transport to inject user speaking frames.
|
||||
|
||||
This test is parameterized with two scenarios:
|
||||
- delayed_user_speech: 10ms delay before UserStartedSpeakingFrame (user_speech_initial_delay=0.01)
|
||||
- immediate_user_speech: No delay before UserStartedSpeakingFrame (user_speech_initial_delay=0)
|
||||
|
||||
This is a scenario creation test - no specific assertions yet.
|
||||
"""
|
||||
# Step 0 (Start node): greet user then call collect_info to transition to agent
|
||||
step_0_chunks = MockLLMService.create_mixed_chunks(
|
||||
text="Hello!",
|
||||
function_name="collect_info",
|
||||
arguments={},
|
||||
tool_call_id="call_transition_1",
|
||||
)
|
||||
|
||||
step_1_chunks = MockLLMService.create_text_chunks(
|
||||
text="Step 1 with some longer text that should cause multiple chunks to be created."
|
||||
)
|
||||
|
||||
step_2_chunks = MockLLMService.create_function_call_chunks(
|
||||
function_name="end_call",
|
||||
arguments={},
|
||||
tool_call_id="call_transition_2",
|
||||
)
|
||||
|
||||
mock_steps = [step_0_chunks, step_1_chunks, step_2_chunks]
|
||||
llm = MockLLMService(mock_steps=mock_steps, chunk_delay=0.001)
|
||||
|
||||
engine, _transport, task = await create_test_pipeline(
|
||||
three_node_workflow_no_variable_extraction,
|
||||
llm,
|
||||
user_speech_initial_delay=user_speech_initial_delay,
|
||||
)
|
||||
|
||||
# Patch DB calls
|
||||
with patch(
|
||||
"api.services.workflow.pipecat_engine.get_organization_id_from_workflow_run",
|
||||
new_callable=AsyncMock,
|
||||
return_value=1,
|
||||
):
|
||||
with patch(
|
||||
"api.services.workflow.pipecat_engine.apply_disposition_mapping",
|
||||
new_callable=AsyncMock,
|
||||
return_value="completed",
|
||||
):
|
||||
runner = PipelineRunner()
|
||||
|
||||
async def run_pipeline():
|
||||
await runner.run(task)
|
||||
|
||||
async def initialize_engine():
|
||||
await asyncio.sleep(0.01)
|
||||
await engine.initialize()
|
||||
# Start the LLM generation - user speech will be injected
|
||||
# automatically when FunctionCallResultFrame #1 is seen
|
||||
await engine.llm.queue_frame(LLMContextFrame(engine.context))
|
||||
|
||||
await asyncio.gather(run_pipeline(), initialize_engine())
|
||||
|
||||
# Total 4 generations out of which 1 was cancelled due to interruption
|
||||
assert llm.get_current_step() == 4
|
||||
|
|
@ -12,7 +12,7 @@ import pytest
|
|||
|
||||
from api.services.workflow.pipecat_engine import PipecatEngine
|
||||
from api.services.workflow.workflow import WorkflowGraph
|
||||
from api.tests.conftest import END_CALL_SYSTEM_PROMPT, MockTransportProcessor
|
||||
from api.tests.conftest import END_CALL_SYSTEM_PROMPT
|
||||
from pipecat.frames.frames import LLMContextFrame
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
|
|
@ -23,6 +23,7 @@ from pipecat.processors.aggregators.llm_response_universal import (
|
|||
LLMContextAggregatorPair,
|
||||
)
|
||||
from pipecat.tests import MockLLMService, MockTTSService
|
||||
from pipecat.tests.mock_transport import MockTransport
|
||||
|
||||
|
||||
async def run_pipeline_with_tool_calls(
|
||||
|
|
@ -65,7 +66,8 @@ async def run_pipeline_with_tool_calls(
|
|||
# Create MockTTSService to generate TTS frames
|
||||
tts = MockTTSService(mock_audio_duration_ms=10, frame_delay=0)
|
||||
|
||||
mock_transport_emulator = MockTransportProcessor(emit_bot_speaking=False)
|
||||
# Create MockTransport for simulating transport behavior
|
||||
mock_transport = MockTransport(emit_bot_speaking=False)
|
||||
|
||||
# Create LLM context
|
||||
context = LLMContext()
|
||||
|
|
@ -91,15 +93,13 @@ async def run_pipeline_with_tool_calls(
|
|||
[
|
||||
llm,
|
||||
tts,
|
||||
mock_transport_emulator,
|
||||
mock_transport.output(),
|
||||
assistant_context_aggregator,
|
||||
]
|
||||
)
|
||||
|
||||
# Create a real pipeline task
|
||||
task = PipelineTask(
|
||||
pipeline, params=PipelineParams(allow_interruptions=False), enable_rtvi=False
|
||||
)
|
||||
task = PipelineTask(pipeline, params=PipelineParams(), enable_rtvi=False)
|
||||
|
||||
engine.set_task(task)
|
||||
|
||||
|
|
|
|||
|
|
@ -17,23 +17,11 @@ from unittest.mock import AsyncMock, patch
|
|||
|
||||
import pytest
|
||||
|
||||
from api.services.workflow.dto import (
|
||||
EdgeDataDTO,
|
||||
ExtractionVariableDTO,
|
||||
NodeDataDTO,
|
||||
NodeType,
|
||||
Position,
|
||||
ReactFlowDTO,
|
||||
RFEdgeDTO,
|
||||
RFNodeDTO,
|
||||
VariableType,
|
||||
)
|
||||
from api.services.workflow.pipecat_engine import PipecatEngine
|
||||
from api.services.workflow.pipecat_engine_variable_extractor import (
|
||||
VariableExtractionManager,
|
||||
)
|
||||
from api.services.workflow.workflow import WorkflowGraph
|
||||
from api.tests.conftest import MockTransportProcessor
|
||||
from pipecat.frames.frames import LLMContextFrame
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
|
|
@ -44,96 +32,7 @@ from pipecat.processors.aggregators.llm_response_universal import (
|
|||
LLMContextAggregatorPair,
|
||||
)
|
||||
from pipecat.tests import MockLLMService, MockTTSService
|
||||
|
||||
# Define prompts for test nodes
|
||||
START_NODE_PROMPT = "Start Node System Prompt"
|
||||
AGENT_NODE_PROMPT = "Agent Node System Prompt"
|
||||
END_NODE_PROMPT = "End Node System Prompt"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def three_node_workflow_with_extraction_on_start() -> WorkflowGraph:
|
||||
"""Create a three-node workflow where only the start node has extraction enabled.
|
||||
|
||||
The workflow has:
|
||||
- Start node with extraction_enabled=True and extraction_variables set
|
||||
- Agent node with extraction_enabled=False (default)
|
||||
- End node with extraction_enabled=False (default)
|
||||
|
||||
This is used to test that variable extraction is triggered for the correct node
|
||||
during transitions.
|
||||
"""
|
||||
dto = ReactFlowDTO(
|
||||
nodes=[
|
||||
RFNodeDTO(
|
||||
id="start",
|
||||
type=NodeType.startNode,
|
||||
position=Position(x=0, y=0),
|
||||
data=NodeDataDTO(
|
||||
name="Start Call",
|
||||
prompt=START_NODE_PROMPT,
|
||||
is_start=True,
|
||||
allow_interrupt=False,
|
||||
add_global_prompt=False,
|
||||
extraction_enabled=True,
|
||||
extraction_prompt="Extract the user's name from the conversation.",
|
||||
extraction_variables=[
|
||||
ExtractionVariableDTO(
|
||||
name="user_name",
|
||||
type=VariableType.string,
|
||||
prompt="The name the user provided",
|
||||
),
|
||||
],
|
||||
),
|
||||
),
|
||||
RFNodeDTO(
|
||||
id="agent",
|
||||
type=NodeType.agentNode,
|
||||
position=Position(x=0, y=200),
|
||||
data=NodeDataDTO(
|
||||
name="Collect Info",
|
||||
prompt=AGENT_NODE_PROMPT,
|
||||
allow_interrupt=False,
|
||||
add_global_prompt=False,
|
||||
extraction_enabled=False, # Explicitly disabled
|
||||
),
|
||||
),
|
||||
RFNodeDTO(
|
||||
id="end",
|
||||
type=NodeType.endNode,
|
||||
position=Position(x=0, y=400),
|
||||
data=NodeDataDTO(
|
||||
name="End Call",
|
||||
prompt=END_NODE_PROMPT,
|
||||
is_end=True,
|
||||
allow_interrupt=False,
|
||||
add_global_prompt=False,
|
||||
extraction_enabled=False, # Explicitly disabled
|
||||
),
|
||||
),
|
||||
],
|
||||
edges=[
|
||||
RFEdgeDTO(
|
||||
id="start-agent",
|
||||
source="start",
|
||||
target="agent",
|
||||
data=EdgeDataDTO(
|
||||
label="Collect Info",
|
||||
condition="When user has been greeted, proceed to collect information",
|
||||
),
|
||||
),
|
||||
RFEdgeDTO(
|
||||
id="agent-end",
|
||||
source="agent",
|
||||
target="end",
|
||||
data=EdgeDataDTO(
|
||||
label="End Call",
|
||||
condition="When information collection is complete, end the call",
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
return WorkflowGraph(dto)
|
||||
from pipecat.tests.mock_transport import MockTransport
|
||||
|
||||
|
||||
class TestVariableExtractionDuringTransitions:
|
||||
|
|
@ -141,7 +40,7 @@ class TestVariableExtractionDuringTransitions:
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extraction_called_for_source_node_not_target_node(
|
||||
self, three_node_workflow_with_extraction_on_start: WorkflowGraph
|
||||
self, three_node_workflow_extraction_start_only: WorkflowGraph
|
||||
):
|
||||
"""Test that when transitioning from START to AGENT, extraction is called for START node.
|
||||
|
||||
|
|
@ -183,7 +82,8 @@ class TestVariableExtractionDuringTransitions:
|
|||
# Create MockTTSService
|
||||
tts = MockTTSService(mock_audio_duration_ms=10, frame_delay=0)
|
||||
|
||||
mock_transport_emulator = MockTransportProcessor(emit_bot_speaking=False)
|
||||
# Create MockTransport for simulating transport behavior
|
||||
mock_transport = MockTransport(emit_bot_speaking=False)
|
||||
|
||||
# Create LLM context
|
||||
context = LLMContext()
|
||||
|
|
@ -195,7 +95,7 @@ class TestVariableExtractionDuringTransitions:
|
|||
)
|
||||
assistant_context_aggregator = context_aggregator.assistant()
|
||||
|
||||
workflow = three_node_workflow_with_extraction_on_start
|
||||
workflow = three_node_workflow_extraction_start_only
|
||||
|
||||
# Create PipecatEngine with the workflow
|
||||
engine = PipecatEngine(
|
||||
|
|
@ -209,7 +109,7 @@ class TestVariableExtractionDuringTransitions:
|
|||
# Patch _perform_variable_extraction_if_needed to track calls
|
||||
original_perform_extraction = engine._perform_variable_extraction_if_needed
|
||||
|
||||
async def tracked_perform_extraction(node):
|
||||
async def tracked_perform_extraction(node, run_in_background=True):
|
||||
extraction_calls.append(
|
||||
{
|
||||
"node_id": node.id if node else None,
|
||||
|
|
@ -228,7 +128,7 @@ class TestVariableExtractionDuringTransitions:
|
|||
[
|
||||
llm,
|
||||
tts,
|
||||
mock_transport_emulator,
|
||||
mock_transport.output(),
|
||||
assistant_context_aggregator,
|
||||
]
|
||||
)
|
||||
|
|
@ -236,7 +136,7 @@ class TestVariableExtractionDuringTransitions:
|
|||
# Create pipeline task
|
||||
task = PipelineTask(
|
||||
pipeline,
|
||||
params=PipelineParams(allow_interruptions=False),
|
||||
params=PipelineParams(),
|
||||
enable_rtvi=False,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -14,7 +14,6 @@ import pytest
|
|||
|
||||
from api.services.workflow.pipecat_engine import PipecatEngine
|
||||
from api.services.workflow.workflow import WorkflowGraph
|
||||
from api.tests.conftest import MockTransportProcessor
|
||||
from pipecat.frames.frames import LLMContextFrame
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
|
|
@ -26,6 +25,7 @@ from pipecat.processors.aggregators.llm_response_universal import (
|
|||
LLMUserAggregatorParams,
|
||||
)
|
||||
from pipecat.tests import MockLLMService, MockTTSService
|
||||
from pipecat.tests.mock_transport import MockTransport
|
||||
|
||||
|
||||
async def run_pipeline_with_user_idle(
|
||||
|
|
@ -52,14 +52,15 @@ async def run_pipeline_with_user_idle(
|
|||
if mock_steps is None:
|
||||
mock_steps = MockLLMService.create_multi_step_responses(
|
||||
MockLLMService.create_text_chunks("Hello, how can I help you today?"),
|
||||
num_text_steps=3, # Initial + 2 idle responses
|
||||
num_text_steps=4, # Initial + 2 idle responses + 1 variable extraction
|
||||
step_prefix="Response",
|
||||
)
|
||||
|
||||
llm = MockLLMService(mock_steps=mock_steps, chunk_delay=0.001)
|
||||
tts = MockTTSService(mock_audio_duration_ms=10)
|
||||
|
||||
mock_transport = MockTransportProcessor()
|
||||
# Create MockTransport for simulating transport behavior
|
||||
mock_transport = MockTransport(emit_bot_speaking=True)
|
||||
|
||||
# Create LLM context
|
||||
context = LLMContext()
|
||||
|
|
@ -99,15 +100,13 @@ async def run_pipeline_with_user_idle(
|
|||
user_context_aggregator,
|
||||
llm,
|
||||
tts,
|
||||
mock_transport,
|
||||
mock_transport.output(),
|
||||
assistant_context_aggregator,
|
||||
]
|
||||
)
|
||||
|
||||
# Create pipeline task
|
||||
task = PipelineTask(
|
||||
pipeline, params=PipelineParams(allow_interruptions=False), enable_rtvi=False
|
||||
)
|
||||
task = PipelineTask(pipeline, params=PipelineParams(), enable_rtvi=False)
|
||||
|
||||
engine.set_task(task)
|
||||
|
||||
|
|
@ -145,8 +144,6 @@ async def run_pipeline_with_user_idle(
|
|||
async def wait_for_idle_to_trigger():
|
||||
# Wait long enough for idle timeouts to trigger
|
||||
await asyncio.sleep(total_wait_time)
|
||||
# Cancel the task if it's still running
|
||||
await task.cancel()
|
||||
|
||||
# Run all concurrently
|
||||
await asyncio.gather(
|
||||
|
|
|
|||
238
api/tests/test_voicemail_detector.py
Normal file
238
api/tests/test_voicemail_detector.py
Normal file
|
|
@ -0,0 +1,238 @@
|
|||
"""Tests for understanding voicemail detector behavior with user aggregator and LLM.
|
||||
|
||||
This module tests the interaction between the voicemail detector, user aggregator,
|
||||
and LLM in a pipeline. It demonstrates how the voicemail detector classifies
|
||||
incoming speech as CONVERSATION or VOICEMAIL and how the main LLM responds.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
|
||||
import pytest
|
||||
|
||||
from pipecat.extensions.voicemail.voicemail_detector import VoicemailDetector
|
||||
from pipecat.frames.frames import (
|
||||
EndTaskFrame,
|
||||
Frame,
|
||||
TranscriptionFrame,
|
||||
UserStartedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
)
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.llm_response import LLMAssistantAggregatorParams
|
||||
from pipecat.processors.aggregators.llm_response_universal import (
|
||||
LLMContextAggregatorPair,
|
||||
LLMUserAggregatorParams,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
from pipecat.tests import MockLLMService
|
||||
from pipecat.turns.user_start import (
|
||||
TranscriptionUserTurnStartStrategy,
|
||||
VADUserTurnStartStrategy,
|
||||
)
|
||||
from pipecat.turns.user_stop import (
|
||||
ExternalUserTurnStopStrategy,
|
||||
)
|
||||
from pipecat.turns.user_turn_strategies import UserTurnStrategies
|
||||
from pipecat.utils.time import time_now_iso8601
|
||||
|
||||
|
||||
class FrameInjector(FrameProcessor):
|
||||
"""Simple processor that can inject frames into the pipeline."""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._frames_to_inject: list[Frame] = []
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
async def inject_frame(
|
||||
self, frame: Frame, direction: FrameDirection = FrameDirection.DOWNSTREAM
|
||||
):
|
||||
"""Inject a frame into the pipeline."""
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
|
||||
class FrameCounter:
|
||||
"""Helper to count specific frame types seen by a processor."""
|
||||
|
||||
def __init__(self):
|
||||
self.user_stopped_speaking_count = 0
|
||||
self.user_started_speaking_count = 0
|
||||
|
||||
def wrap_process_frame(self, original_process_frame):
|
||||
"""Wrap a process_frame method to count UserStoppedSpeakingFrame."""
|
||||
|
||||
async def wrapped(frame: Frame, direction: FrameDirection):
|
||||
if isinstance(frame, UserStoppedSpeakingFrame):
|
||||
self.user_stopped_speaking_count += 1
|
||||
elif isinstance(frame, UserStartedSpeakingFrame):
|
||||
self.user_started_speaking_count += 1
|
||||
return await original_process_frame(frame, direction)
|
||||
|
||||
return wrapped
|
||||
|
||||
|
||||
class TestVoicemailDetectorWithUserAggregator:
|
||||
"""Test scenarios with voicemail detector and user aggregator."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_voicemail_detector_conversation_flow(self):
|
||||
"""Test: Voicemail detector classifies as CONVERSATION and main LLM responds.
|
||||
|
||||
This test bench shows the flow:
|
||||
1. User starts speaking, sends transcription, stops speaking
|
||||
2. Voicemail detector's internal LLM classifies as "CONVERSATION"
|
||||
3. Main LLM generates response text
|
||||
4. Second user turn with transcription
|
||||
5. Main LLM generates end_call function to end pipeline
|
||||
|
||||
Pipeline structure mirrors run_pipeline.py:
|
||||
injector -> voicemail_detector.detector() -> user_aggregator -> main_llm
|
||||
-> voicemail_detector.gate() -> assistant_aggregator
|
||||
"""
|
||||
context = LLMContext()
|
||||
|
||||
# Create user turn strategies
|
||||
user_turn_strategies = UserTurnStrategies(
|
||||
start=[
|
||||
VADUserTurnStartStrategy(),
|
||||
TranscriptionUserTurnStartStrategy(),
|
||||
],
|
||||
stop=[ExternalUserTurnStopStrategy()],
|
||||
)
|
||||
|
||||
user_params = LLMUserAggregatorParams(
|
||||
user_turn_strategies=user_turn_strategies,
|
||||
)
|
||||
|
||||
assistant_params = LLMAssistantAggregatorParams(expect_stripped_words=True)
|
||||
|
||||
context_aggregator = LLMContextAggregatorPair(
|
||||
context, assistant_params=assistant_params, user_params=user_params
|
||||
)
|
||||
user_context_aggregator = context_aggregator.user()
|
||||
assistant_context_aggregator = context_aggregator.assistant()
|
||||
|
||||
# Create mock LLM for main conversation
|
||||
# Step 0: First response after CONVERSATION classification
|
||||
# Step 1: Response to second user turn
|
||||
# Step 2: end_call function call to end pipeline
|
||||
main_llm_steps = [
|
||||
MockLLMService.create_text_chunks(text="Hello! I'm here to help you today.")
|
||||
]
|
||||
main_llm = MockLLMService(mock_steps=main_llm_steps, chunk_delay=0.001)
|
||||
|
||||
# Create mock LLM for voicemail classification
|
||||
# First response: "CONVERSATION" to close the voicemail gate
|
||||
voicemail_classification_steps = [
|
||||
MockLLMService.create_text_chunks(text="CONVERSATION"),
|
||||
]
|
||||
voicemail_llm = MockLLMService(
|
||||
mock_steps=voicemail_classification_steps, chunk_delay=0.001
|
||||
)
|
||||
|
||||
# Create voicemail detector with the classification LLM
|
||||
voicemail_detector = VoicemailDetector(
|
||||
llm=voicemail_llm,
|
||||
voicemail_response_delay=0,
|
||||
)
|
||||
|
||||
# Set up frame counter to track UserStoppedSpeakingFrame in voicemail detector's user aggregator
|
||||
voicemail_user_aggregator = voicemail_detector._context_aggregator.user()
|
||||
frame_counter = FrameCounter()
|
||||
original_process_frame = voicemail_user_aggregator.process_frame
|
||||
voicemail_user_aggregator.process_frame = frame_counter.wrap_process_frame(
|
||||
original_process_frame
|
||||
)
|
||||
|
||||
# Build pipeline similar to run_pipeline.py structure
|
||||
injector = FrameInjector()
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
injector,
|
||||
voicemail_detector.detector(), # Classification parallel pipeline
|
||||
user_context_aggregator,
|
||||
main_llm,
|
||||
assistant_context_aggregator,
|
||||
]
|
||||
)
|
||||
|
||||
task = PipelineTask(pipeline, params=PipelineParams(), enable_rtvi=False)
|
||||
runner = PipelineRunner()
|
||||
|
||||
async def run_pipeline():
|
||||
await runner.run(task)
|
||||
|
||||
async def inject_frames():
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
# === First user turn ===
|
||||
# This triggers voicemail classification AND main LLM response
|
||||
await injector.inject_frame(UserStartedSpeakingFrame())
|
||||
await asyncio.sleep(0)
|
||||
await injector.inject_frame(
|
||||
TranscriptionFrame("First User Speech", "user-123", time_now_iso8601())
|
||||
)
|
||||
await asyncio.sleep(0.05)
|
||||
await injector.inject_frame(UserStoppedSpeakingFrame())
|
||||
|
||||
# Wait for voicemail classification and main LLM response
|
||||
await asyncio.sleep(0.2)
|
||||
|
||||
# === Second user turn ===
|
||||
await injector.inject_frame(UserStartedSpeakingFrame())
|
||||
|
||||
await asyncio.sleep(0)
|
||||
await injector.inject_frame(
|
||||
TranscriptionFrame(
|
||||
"Second User Speech",
|
||||
"user-123",
|
||||
time_now_iso8601(),
|
||||
)
|
||||
)
|
||||
|
||||
await asyncio.sleep(0.05)
|
||||
await injector.inject_frame(UserStoppedSpeakingFrame())
|
||||
|
||||
await asyncio.sleep(0.05)
|
||||
await injector.inject_frame(
|
||||
EndTaskFrame(), direction=FrameDirection.UPSTREAM
|
||||
)
|
||||
|
||||
await asyncio.gather(run_pipeline(), inject_frames())
|
||||
|
||||
# Assert voicemail LLM was called once for classification
|
||||
assert voicemail_llm.get_current_step() == 1
|
||||
|
||||
# Assert main LLM was called twice (once per user turn)
|
||||
assert main_llm.get_current_step() == 2
|
||||
|
||||
# Assert voicemail detector's user aggregator saw UserStoppedSpeakingFrame only once
|
||||
# (because the classifier gate closes after CONVERSATION classification,
|
||||
# blocking subsequent frames from reaching the voicemail branch)
|
||||
assert frame_counter.user_stopped_speaking_count == 1, (
|
||||
f"Expected voicemail detector's user aggregator to see UserStoppedSpeakingFrame once, "
|
||||
f"but saw it {frame_counter.user_stopped_speaking_count} times"
|
||||
)
|
||||
|
||||
# We should see no more than 2 user started speaking frame. One from downstream FrameInjector
|
||||
# and one from upstream main pipeline's LLMUserAggregator
|
||||
assert frame_counter.user_started_speaking_count <= 2, (
|
||||
f"Expected voicemail detector's user aggregator to see UserStartedSpeakingFrame at most twice, "
|
||||
f"but saw it {frame_counter.user_started_speaking_count} times"
|
||||
)
|
||||
|
||||
# Assert the classifier gate is closed after classification
|
||||
assert voicemail_detector._classifier_gate._gate_opened is False, (
|
||||
"Expected classifier gate to be closed after CONVERSATION classification"
|
||||
)
|
||||
|
||||
# Assert the classifier gate is closed after classification
|
||||
assert voicemail_detector._classifier_upstream_gate._gate_open is False, (
|
||||
"Expected classifier upstream gate to be closed after CONVERSATION classification"
|
||||
)
|
||||
Loading…
Add table
Add a link
Reference in a new issue