mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-16 08:25:18 +02:00
chore: refactor and add tests (#130)
* chore: add tests for end call * Update pipecat module * fix: allow interruptions from deepgram flux * Add VadUserTurnStrategy * chore: add test for voicemail detection
This commit is contained in:
parent
2aedb839ff
commit
033fde8946
15 changed files with 2106 additions and 542 deletions
|
|
@ -19,18 +19,13 @@ from unittest.mock import AsyncMock, patch
|
|||
|
||||
import pytest
|
||||
|
||||
from api.services.workflow.dto import (
|
||||
EdgeDataDTO,
|
||||
NodeDataDTO,
|
||||
NodeType,
|
||||
Position,
|
||||
ReactFlowDTO,
|
||||
RFEdgeDTO,
|
||||
RFNodeDTO,
|
||||
)
|
||||
from api.services.workflow.pipecat_engine import PipecatEngine
|
||||
from api.services.workflow.workflow import WorkflowGraph
|
||||
from api.tests.conftest import MockTransportProcessor
|
||||
from api.tests.conftest import (
|
||||
AGENT_SYSTEM_PROMPT,
|
||||
END_CALL_SYSTEM_PROMPT,
|
||||
START_CALL_SYSTEM_PROMPT,
|
||||
)
|
||||
from pipecat.frames.frames import LLMContextFrame
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
|
|
@ -41,86 +36,7 @@ from pipecat.processors.aggregators.llm_response_universal import (
|
|||
LLMContextAggregatorPair,
|
||||
)
|
||||
from pipecat.tests import MockLLMService, MockTTSService
|
||||
|
||||
# Define prompts for test nodes
|
||||
START_NODE_PROMPT = "Start Node System Prompt"
|
||||
AGENT_NODE_PROMPT = "Agent Node System Prompt"
|
||||
END_NODE_PROMPT = "End Node System Prompt"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def three_node_workflow_for_context_test() -> WorkflowGraph:
|
||||
"""Create a three-node workflow for testing context updates during transitions.
|
||||
|
||||
The workflow has:
|
||||
- Start node with prompt to greet user
|
||||
- Agent node with prompt to collect information
|
||||
- End node with prompt to say goodbye
|
||||
|
||||
Edges:
|
||||
- Start -> Agent (label: "Collect Info")
|
||||
- Agent -> End (label: "End Call")
|
||||
"""
|
||||
dto = ReactFlowDTO(
|
||||
nodes=[
|
||||
RFNodeDTO(
|
||||
id="start",
|
||||
type=NodeType.startNode,
|
||||
position=Position(x=0, y=0),
|
||||
data=NodeDataDTO(
|
||||
name="Start Call",
|
||||
prompt=START_NODE_PROMPT,
|
||||
is_start=True,
|
||||
allow_interrupt=False,
|
||||
add_global_prompt=False,
|
||||
),
|
||||
),
|
||||
RFNodeDTO(
|
||||
id="agent",
|
||||
type=NodeType.agentNode,
|
||||
position=Position(x=0, y=200),
|
||||
data=NodeDataDTO(
|
||||
name="Collect Info",
|
||||
prompt=AGENT_NODE_PROMPT,
|
||||
allow_interrupt=False,
|
||||
add_global_prompt=False,
|
||||
),
|
||||
),
|
||||
RFNodeDTO(
|
||||
id="end",
|
||||
type=NodeType.endNode,
|
||||
position=Position(x=0, y=400),
|
||||
data=NodeDataDTO(
|
||||
name="End Call",
|
||||
prompt=END_NODE_PROMPT,
|
||||
is_end=True,
|
||||
allow_interrupt=False,
|
||||
add_global_prompt=False,
|
||||
),
|
||||
),
|
||||
],
|
||||
edges=[
|
||||
RFEdgeDTO(
|
||||
id="start-agent",
|
||||
source="start",
|
||||
target="agent",
|
||||
data=EdgeDataDTO(
|
||||
label="Collect Info",
|
||||
condition="When user has been greeted, proceed to collect information",
|
||||
),
|
||||
),
|
||||
RFEdgeDTO(
|
||||
id="agent-end",
|
||||
source="agent",
|
||||
target="end",
|
||||
data=EdgeDataDTO(
|
||||
label="End Call",
|
||||
condition="When information collection is complete, end the call",
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
return WorkflowGraph(dto)
|
||||
from pipecat.tests.mock_transport import MockTransport
|
||||
|
||||
|
||||
class ContextCapturingMockLLM(MockLLMService):
|
||||
|
|
@ -215,7 +131,8 @@ async def run_pipeline_and_capture_context(
|
|||
# Create MockTTSService
|
||||
tts = MockTTSService(mock_audio_duration_ms=10, frame_delay=0)
|
||||
|
||||
mock_transport_emulator = MockTransportProcessor(emit_bot_speaking=False)
|
||||
# Create MockTransport for simulating transport behavior
|
||||
mock_transport = MockTransport(emit_bot_speaking=False)
|
||||
|
||||
# Create LLM context
|
||||
context = LLMContext()
|
||||
|
|
@ -251,15 +168,13 @@ async def run_pipeline_and_capture_context(
|
|||
[
|
||||
llm,
|
||||
tts,
|
||||
mock_transport_emulator,
|
||||
mock_transport.output(),
|
||||
assistant_context_aggregator,
|
||||
]
|
||||
)
|
||||
|
||||
# Create pipeline task
|
||||
task = PipelineTask(
|
||||
pipeline, params=PipelineParams(allow_interruptions=False), enable_rtvi=False
|
||||
)
|
||||
task = PipelineTask(pipeline, params=PipelineParams(), enable_rtvi=False)
|
||||
|
||||
engine.set_task(task)
|
||||
|
||||
|
|
@ -294,7 +209,7 @@ class TestContextUpdateBeforeNextCompletion:
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_single_transition_updates_context_before_next_completion(
|
||||
self, three_node_workflow_for_context_test: WorkflowGraph
|
||||
self, three_node_workflow: WorkflowGraph
|
||||
):
|
||||
"""Test that a single transition function call updates context before next LLM generation.
|
||||
|
||||
|
|
@ -329,7 +244,7 @@ class TestContextUpdateBeforeNextCompletion:
|
|||
mock_steps = [step_0_chunks, step_1_chunks, step_2_chunks]
|
||||
|
||||
llm, _ = await run_pipeline_and_capture_context(
|
||||
workflow=three_node_workflow_for_context_test,
|
||||
workflow=three_node_workflow,
|
||||
mock_steps=mock_steps,
|
||||
set_node_delay=0.05, # Introduce 50ms delay in set_node
|
||||
)
|
||||
|
|
@ -341,17 +256,17 @@ class TestContextUpdateBeforeNextCompletion:
|
|||
|
||||
# Verify step 0 (start node) had start node's system prompt
|
||||
step_0_prompt = llm.get_system_prompt_at_step(0)
|
||||
assert START_NODE_PROMPT in step_0_prompt, (
|
||||
assert START_CALL_SYSTEM_PROMPT in step_0_prompt, (
|
||||
f"Step 0 should have start node prompt, got: {step_0_prompt[:100]}"
|
||||
)
|
||||
|
||||
# Verify step 1 (agent node) had:
|
||||
# 1. The agent node's system prompt (not start node's)
|
||||
step_1_prompt = llm.get_system_prompt_at_step(1)
|
||||
assert AGENT_NODE_PROMPT in step_1_prompt, (
|
||||
assert AGENT_SYSTEM_PROMPT in step_1_prompt, (
|
||||
f"Step 1 should have agent node prompt, got: {step_1_prompt[:100]}"
|
||||
)
|
||||
assert START_NODE_PROMPT not in step_1_prompt, (
|
||||
assert START_CALL_SYSTEM_PROMPT not in step_1_prompt, (
|
||||
"Step 1 should NOT have start node prompt anymore"
|
||||
)
|
||||
|
||||
|
|
@ -371,7 +286,7 @@ class TestContextUpdateBeforeNextCompletion:
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sequential_transitions_maintain_correct_context(
|
||||
self, three_node_workflow_for_context_test: WorkflowGraph
|
||||
self, three_node_workflow: WorkflowGraph
|
||||
):
|
||||
"""Test that sequential transitions maintain correct context at each step.
|
||||
|
||||
|
|
@ -403,7 +318,7 @@ class TestContextUpdateBeforeNextCompletion:
|
|||
mock_steps = [step_0_chunks, step_1_chunks, step_2_chunks]
|
||||
|
||||
llm, _ = await run_pipeline_and_capture_context(
|
||||
workflow=three_node_workflow_for_context_test,
|
||||
workflow=three_node_workflow,
|
||||
mock_steps=mock_steps,
|
||||
set_node_delay=0.05, # Introduce 50ms delay in set_node
|
||||
)
|
||||
|
|
@ -414,13 +329,13 @@ class TestContextUpdateBeforeNextCompletion:
|
|||
)
|
||||
|
||||
# Step 0: Start node - should have start prompt
|
||||
assert START_NODE_PROMPT in llm.get_system_prompt_at_step(0)
|
||||
assert START_CALL_SYSTEM_PROMPT in llm.get_system_prompt_at_step(0)
|
||||
|
||||
# Step 1: Agent node - should have agent prompt
|
||||
assert AGENT_NODE_PROMPT in llm.get_system_prompt_at_step(1)
|
||||
assert AGENT_SYSTEM_PROMPT in llm.get_system_prompt_at_step(1)
|
||||
|
||||
# Step 2: End node - should have end prompt
|
||||
assert END_NODE_PROMPT in llm.get_system_prompt_at_step(2)
|
||||
assert END_CALL_SYSTEM_PROMPT in llm.get_system_prompt_at_step(2)
|
||||
|
||||
# Verify each subsequent step has the previous tool results
|
||||
step_1_ctx = llm.get_context_at_step(1)
|
||||
|
|
@ -445,7 +360,7 @@ class TestContextUpdateBeforeNextCompletion:
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_context_messages_preserve_conversation_history(
|
||||
self, three_node_workflow_for_context_test: WorkflowGraph
|
||||
self, three_node_workflow: WorkflowGraph
|
||||
):
|
||||
"""Test that conversation history is preserved across node transitions.
|
||||
|
||||
|
|
@ -474,7 +389,7 @@ class TestContextUpdateBeforeNextCompletion:
|
|||
mock_steps = [step_0_chunks, step_1_chunks, step_2_chunks]
|
||||
|
||||
llm, _ = await run_pipeline_and_capture_context(
|
||||
workflow=three_node_workflow_for_context_test,
|
||||
workflow=three_node_workflow,
|
||||
mock_steps=mock_steps,
|
||||
)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue