mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-10 08:05:22 +02:00
chore: refactor and add tests (#130)
* chore: add tests for end call * Update pipecat module * fix: allow interruptions from deepgram flux * Add VadUserTurnStrategy * chore: add test for voicemail detection
This commit is contained in:
parent
2aedb839ff
commit
033fde8946
15 changed files with 2106 additions and 542 deletions
|
|
@ -17,23 +17,11 @@ from unittest.mock import AsyncMock, patch
|
|||
|
||||
import pytest
|
||||
|
||||
from api.services.workflow.dto import (
|
||||
EdgeDataDTO,
|
||||
ExtractionVariableDTO,
|
||||
NodeDataDTO,
|
||||
NodeType,
|
||||
Position,
|
||||
ReactFlowDTO,
|
||||
RFEdgeDTO,
|
||||
RFNodeDTO,
|
||||
VariableType,
|
||||
)
|
||||
from api.services.workflow.pipecat_engine import PipecatEngine
|
||||
from api.services.workflow.pipecat_engine_variable_extractor import (
|
||||
VariableExtractionManager,
|
||||
)
|
||||
from api.services.workflow.workflow import WorkflowGraph
|
||||
from api.tests.conftest import MockTransportProcessor
|
||||
from pipecat.frames.frames import LLMContextFrame
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
|
|
@ -44,96 +32,7 @@ from pipecat.processors.aggregators.llm_response_universal import (
|
|||
LLMContextAggregatorPair,
|
||||
)
|
||||
from pipecat.tests import MockLLMService, MockTTSService
|
||||
|
||||
# Define prompts for test nodes
|
||||
START_NODE_PROMPT = "Start Node System Prompt"
|
||||
AGENT_NODE_PROMPT = "Agent Node System Prompt"
|
||||
END_NODE_PROMPT = "End Node System Prompt"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def three_node_workflow_with_extraction_on_start() -> WorkflowGraph:
|
||||
"""Create a three-node workflow where only the start node has extraction enabled.
|
||||
|
||||
The workflow has:
|
||||
- Start node with extraction_enabled=True and extraction_variables set
|
||||
- Agent node with extraction_enabled=False (default)
|
||||
- End node with extraction_enabled=False (default)
|
||||
|
||||
This is used to test that variable extraction is triggered for the correct node
|
||||
during transitions.
|
||||
"""
|
||||
dto = ReactFlowDTO(
|
||||
nodes=[
|
||||
RFNodeDTO(
|
||||
id="start",
|
||||
type=NodeType.startNode,
|
||||
position=Position(x=0, y=0),
|
||||
data=NodeDataDTO(
|
||||
name="Start Call",
|
||||
prompt=START_NODE_PROMPT,
|
||||
is_start=True,
|
||||
allow_interrupt=False,
|
||||
add_global_prompt=False,
|
||||
extraction_enabled=True,
|
||||
extraction_prompt="Extract the user's name from the conversation.",
|
||||
extraction_variables=[
|
||||
ExtractionVariableDTO(
|
||||
name="user_name",
|
||||
type=VariableType.string,
|
||||
prompt="The name the user provided",
|
||||
),
|
||||
],
|
||||
),
|
||||
),
|
||||
RFNodeDTO(
|
||||
id="agent",
|
||||
type=NodeType.agentNode,
|
||||
position=Position(x=0, y=200),
|
||||
data=NodeDataDTO(
|
||||
name="Collect Info",
|
||||
prompt=AGENT_NODE_PROMPT,
|
||||
allow_interrupt=False,
|
||||
add_global_prompt=False,
|
||||
extraction_enabled=False, # Explicitly disabled
|
||||
),
|
||||
),
|
||||
RFNodeDTO(
|
||||
id="end",
|
||||
type=NodeType.endNode,
|
||||
position=Position(x=0, y=400),
|
||||
data=NodeDataDTO(
|
||||
name="End Call",
|
||||
prompt=END_NODE_PROMPT,
|
||||
is_end=True,
|
||||
allow_interrupt=False,
|
||||
add_global_prompt=False,
|
||||
extraction_enabled=False, # Explicitly disabled
|
||||
),
|
||||
),
|
||||
],
|
||||
edges=[
|
||||
RFEdgeDTO(
|
||||
id="start-agent",
|
||||
source="start",
|
||||
target="agent",
|
||||
data=EdgeDataDTO(
|
||||
label="Collect Info",
|
||||
condition="When user has been greeted, proceed to collect information",
|
||||
),
|
||||
),
|
||||
RFEdgeDTO(
|
||||
id="agent-end",
|
||||
source="agent",
|
||||
target="end",
|
||||
data=EdgeDataDTO(
|
||||
label="End Call",
|
||||
condition="When information collection is complete, end the call",
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
return WorkflowGraph(dto)
|
||||
from pipecat.tests.mock_transport import MockTransport
|
||||
|
||||
|
||||
class TestVariableExtractionDuringTransitions:
|
||||
|
|
@ -141,7 +40,7 @@ class TestVariableExtractionDuringTransitions:
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extraction_called_for_source_node_not_target_node(
|
||||
self, three_node_workflow_with_extraction_on_start: WorkflowGraph
|
||||
self, three_node_workflow_extraction_start_only: WorkflowGraph
|
||||
):
|
||||
"""Test that when transitioning from START to AGENT, extraction is called for START node.
|
||||
|
||||
|
|
@ -183,7 +82,8 @@ class TestVariableExtractionDuringTransitions:
|
|||
# Create MockTTSService
|
||||
tts = MockTTSService(mock_audio_duration_ms=10, frame_delay=0)
|
||||
|
||||
mock_transport_emulator = MockTransportProcessor(emit_bot_speaking=False)
|
||||
# Create MockTransport for simulating transport behavior
|
||||
mock_transport = MockTransport(emit_bot_speaking=False)
|
||||
|
||||
# Create LLM context
|
||||
context = LLMContext()
|
||||
|
|
@ -195,7 +95,7 @@ class TestVariableExtractionDuringTransitions:
|
|||
)
|
||||
assistant_context_aggregator = context_aggregator.assistant()
|
||||
|
||||
workflow = three_node_workflow_with_extraction_on_start
|
||||
workflow = three_node_workflow_extraction_start_only
|
||||
|
||||
# Create PipecatEngine with the workflow
|
||||
engine = PipecatEngine(
|
||||
|
|
@ -209,7 +109,7 @@ class TestVariableExtractionDuringTransitions:
|
|||
# Patch _perform_variable_extraction_if_needed to track calls
|
||||
original_perform_extraction = engine._perform_variable_extraction_if_needed
|
||||
|
||||
async def tracked_perform_extraction(node):
|
||||
async def tracked_perform_extraction(node, run_in_background=True):
|
||||
extraction_calls.append(
|
||||
{
|
||||
"node_id": node.id if node else None,
|
||||
|
|
@ -228,7 +128,7 @@ class TestVariableExtractionDuringTransitions:
|
|||
[
|
||||
llm,
|
||||
tts,
|
||||
mock_transport_emulator,
|
||||
mock_transport.output(),
|
||||
assistant_context_aggregator,
|
||||
]
|
||||
)
|
||||
|
|
@ -236,7 +136,7 @@ class TestVariableExtractionDuringTransitions:
|
|||
# Create pipeline task
|
||||
task = PipelineTask(
|
||||
pipeline,
|
||||
params=PipelineParams(allow_interruptions=False),
|
||||
params=PipelineParams(),
|
||||
enable_rtvi=False,
|
||||
)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue