chore: refactor and add tests (#130)

* chore: add tests for end call

* Update pipecat module

* fix: allow interruptions from deepgram flux

* Add VadUserTurnStrategy

* chore: add test for voicemail detection
This commit is contained in:
Abhishek 2026-01-27 18:20:23 +05:30 committed by GitHub
parent 2aedb839ff
commit 033fde8946
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
15 changed files with 2106 additions and 542 deletions

View file

@ -17,23 +17,11 @@ from unittest.mock import AsyncMock, patch
import pytest
from api.services.workflow.dto import (
EdgeDataDTO,
ExtractionVariableDTO,
NodeDataDTO,
NodeType,
Position,
ReactFlowDTO,
RFEdgeDTO,
RFNodeDTO,
VariableType,
)
from api.services.workflow.pipecat_engine import PipecatEngine
from api.services.workflow.pipecat_engine_variable_extractor import (
VariableExtractionManager,
)
from api.services.workflow.workflow import WorkflowGraph
from api.tests.conftest import MockTransportProcessor
from pipecat.frames.frames import LLMContextFrame
from pipecat.pipeline.pipeline import Pipeline
from pipecat.pipeline.runner import PipelineRunner
@ -44,96 +32,7 @@ from pipecat.processors.aggregators.llm_response_universal import (
LLMContextAggregatorPair,
)
from pipecat.tests import MockLLMService, MockTTSService
# Define prompts for test nodes
START_NODE_PROMPT = "Start Node System Prompt"
AGENT_NODE_PROMPT = "Agent Node System Prompt"
END_NODE_PROMPT = "End Node System Prompt"
@pytest.fixture
def three_node_workflow_with_extraction_on_start() -> WorkflowGraph:
"""Create a three-node workflow where only the start node has extraction enabled.
The workflow has:
- Start node with extraction_enabled=True and extraction_variables set
- Agent node with extraction_enabled=False (default)
- End node with extraction_enabled=False (default)
This is used to test that variable extraction is triggered for the correct node
during transitions.
"""
dto = ReactFlowDTO(
nodes=[
RFNodeDTO(
id="start",
type=NodeType.startNode,
position=Position(x=0, y=0),
data=NodeDataDTO(
name="Start Call",
prompt=START_NODE_PROMPT,
is_start=True,
allow_interrupt=False,
add_global_prompt=False,
extraction_enabled=True,
extraction_prompt="Extract the user's name from the conversation.",
extraction_variables=[
ExtractionVariableDTO(
name="user_name",
type=VariableType.string,
prompt="The name the user provided",
),
],
),
),
RFNodeDTO(
id="agent",
type=NodeType.agentNode,
position=Position(x=0, y=200),
data=NodeDataDTO(
name="Collect Info",
prompt=AGENT_NODE_PROMPT,
allow_interrupt=False,
add_global_prompt=False,
extraction_enabled=False, # Explicitly disabled
),
),
RFNodeDTO(
id="end",
type=NodeType.endNode,
position=Position(x=0, y=400),
data=NodeDataDTO(
name="End Call",
prompt=END_NODE_PROMPT,
is_end=True,
allow_interrupt=False,
add_global_prompt=False,
extraction_enabled=False, # Explicitly disabled
),
),
],
edges=[
RFEdgeDTO(
id="start-agent",
source="start",
target="agent",
data=EdgeDataDTO(
label="Collect Info",
condition="When user has been greeted, proceed to collect information",
),
),
RFEdgeDTO(
id="agent-end",
source="agent",
target="end",
data=EdgeDataDTO(
label="End Call",
condition="When information collection is complete, end the call",
),
),
],
)
return WorkflowGraph(dto)
from pipecat.tests.mock_transport import MockTransport
class TestVariableExtractionDuringTransitions:
@ -141,7 +40,7 @@ class TestVariableExtractionDuringTransitions:
@pytest.mark.asyncio
async def test_extraction_called_for_source_node_not_target_node(
self, three_node_workflow_with_extraction_on_start: WorkflowGraph
self, three_node_workflow_extraction_start_only: WorkflowGraph
):
"""Test that when transitioning from START to AGENT, extraction is called for START node.
@ -183,7 +82,8 @@ class TestVariableExtractionDuringTransitions:
# Create MockTTSService
tts = MockTTSService(mock_audio_duration_ms=10, frame_delay=0)
mock_transport_emulator = MockTransportProcessor(emit_bot_speaking=False)
# Create MockTransport for simulating transport behavior
mock_transport = MockTransport(emit_bot_speaking=False)
# Create LLM context
context = LLMContext()
@ -195,7 +95,7 @@ class TestVariableExtractionDuringTransitions:
)
assistant_context_aggregator = context_aggregator.assistant()
workflow = three_node_workflow_with_extraction_on_start
workflow = three_node_workflow_extraction_start_only
# Create PipecatEngine with the workflow
engine = PipecatEngine(
@ -209,7 +109,7 @@ class TestVariableExtractionDuringTransitions:
# Patch _perform_variable_extraction_if_needed to track calls
original_perform_extraction = engine._perform_variable_extraction_if_needed
async def tracked_perform_extraction(node):
async def tracked_perform_extraction(node, run_in_background=True):
extraction_calls.append(
{
"node_id": node.id if node else None,
@ -228,7 +128,7 @@ class TestVariableExtractionDuringTransitions:
[
llm,
tts,
mock_transport_emulator,
mock_transport.output(),
assistant_context_aggregator,
]
)
@ -236,7 +136,7 @@ class TestVariableExtractionDuringTransitions:
# Create pipeline task
task = PipelineTask(
pipeline,
params=PipelineParams(allow_interruptions=False),
params=PipelineParams(),
enable_rtvi=False,
)