dograh/api/tests/test_pipecat_engine_set_node.py
Abhishek Kumar 4f2a629340 Initial Commit 🚀 🚀
2025-09-09 14:37:32 +05:30

536 lines
18 KiB
Python

import asyncio
from unittest.mock import AsyncMock, Mock, patch
import pytest
from pipecat.frames.frames import (
EndFrame,
LLMFullResponseEndFrame,
LLMFullResponseStartFrame,
TTSSpeakFrame,
)
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContextFrame
from pipecat.services.openai.llm import OpenAILLMContext
from api.services.workflow.dto import EdgeDataDTO, NodeDataDTO
from api.services.workflow.pipecat_engine import PipecatEngine
from api.services.workflow.workflow import Edge, Node, WorkflowGraph
class TestPipecatEngineSetNode:
"""Test cases for PipecatEngine.set_node method refactoring."""
@pytest.fixture
def mock_workflow(self):
"""Create a mock workflow with various node types."""
workflow = Mock(spec=WorkflowGraph)
workflow.nodes = {}
workflow.start_node_id = "start_node"
workflow.global_node_id = None
return workflow
@pytest.fixture
def mock_dependencies(self, mock_workflow):
"""Create mock dependencies for PipecatEngine initialization."""
task = AsyncMock()
task.queue_frames = AsyncMock()
task.queue_frame = AsyncMock()
llm = AsyncMock()
llm.register_function = Mock()
llm.push_frame = AsyncMock()
context = Mock(spec=OpenAILLMContext)
context.set_node_name = Mock()
return {
"task": task,
"llm": llm,
"context": context,
"tts": Mock(),
"transport": Mock(),
"workflow": mock_workflow,
"call_context_vars": {"test_var": "test_value"},
}
@pytest.fixture
def engine(self, mock_dependencies):
"""Create a PipecatEngine instance."""
# Add audio_buffer and workflow_run_id to dependencies
mock_dependencies["audio_buffer"] = None
mock_dependencies["workflow_run_id"] = 123
engine = PipecatEngine(**mock_dependencies)
# Mock the builtin function registration
engine._register_builtin_functions = AsyncMock()
return engine
def create_node(self, node_id, **kwargs):
"""Helper to create a node with default values."""
defaults = {
"name": f"Node {node_id}",
"prompt": f"Prompt for {node_id}",
"is_static": False,
"is_start": False,
"is_end": False,
"allow_interrupt": True,
"extraction_enabled": False,
"extraction_prompt": "",
"extraction_variables": [],
"add_global_prompt": True,
"wait_for_user_response": False,
"detect_voicemail": False,
}
defaults.update(kwargs)
data = Mock(spec=NodeDataDTO)
for key, value in defaults.items():
setattr(data, key, value)
node = Mock(spec=Node)
node.id = node_id
node.data = data
node.out_edges = []
# Copy attributes from data to node
for key, value in defaults.items():
setattr(node, key, value)
return node
def create_edge(
self, source, target, label="Continue", condition="Always continue"
):
"""Helper to create an edge."""
data = Mock(spec=EdgeDataDTO)
data.label = label
data.condition = condition
edge = Mock(spec=Edge)
edge.source = source
edge.target = target
edge.data = data
edge.get_function_name = Mock(return_value=label.lower().replace(" ", "_"))
return edge
# ===== START NODE TESTS =====
@pytest.mark.asyncio
async def test_start_node_static_immediate_execution(self, engine, mock_workflow):
"""Test: Basic static start node executes immediately."""
# Setup
start_node = self.create_node(
"start_node",
is_start=True,
is_static=True,
prompt="Welcome to our service!",
)
next_node = self.create_node("next_node", is_static=False)
edge = self.create_edge("start_node", "next_node")
start_node.out_edges = [edge]
mock_workflow.nodes = {"start_node": start_node, "next_node": next_node}
# Execute
await engine.set_node("start_node")
# Verify
# Should queue TTS immediately
engine.task.queue_frames.assert_called_once()
frames = engine.task.queue_frames.call_args[0][0]
assert len(frames) == 3
assert isinstance(frames[0], LLMFullResponseStartFrame)
assert isinstance(frames[1], TTSSpeakFrame)
assert frames[1].text == "Welcome to our service!"
assert isinstance(frames[2], LLMFullResponseEndFrame)
# Static start nodes now set pending transition after context push
assert engine._pending_control_transition_after_context_push is not None
# Should not have set detect_voicemail for static start without it
assert not engine._detect_voicemail
@pytest.mark.asyncio
async def test_start_node_with_detect_voicemail_no_audio_buffer(
self, engine, mock_workflow
):
"""Test: Start node with voicemail detection but no audio buffer logs warning."""
# Setup
start_node = self.create_node(
"start_node",
is_start=True,
is_static=True,
detect_voicemail=True,
prompt="Hello, this is a business call.",
)
mock_workflow.nodes = {"start_node": start_node}
# Engine has no audio buffer (None)
assert engine._audio_buffer is None
# Execute
await engine.set_node("start_node")
# Verify
# Should NOT set voicemail detection flag since no audio buffer
assert engine._detect_voicemail is False
assert engine._voicemail_detector is None
# Should queue TTS immediately
engine.task.queue_frames.assert_called_once()
frames = engine.task.queue_frames.call_args[0][0]
assert isinstance(frames[1], TTSSpeakFrame)
assert frames[1].text == "Hello, this is a business call."
@pytest.mark.asyncio
async def test_start_node_non_static_with_detect_voicemail(
self, engine, mock_workflow
):
"""Test: Non-static start node with voicemail detection without audio buffer."""
# Setup
start_node = self.create_node(
"start_node",
is_start=True,
is_static=False, # Non-static
detect_voicemail=True,
prompt="You are an AI assistant. Start the conversation.",
)
mock_workflow.nodes = {"start_node": start_node}
# Mock the context update method
engine._update_llm_context = AsyncMock()
engine._compose_system_message_functions_for_node = AsyncMock(
return_value=({"role": "system", "content": "Test prompt"}, [])
)
# Execute
await engine.set_node("start_node")
# Verify
# Should NOT set voicemail detection flags (no audio buffer)
assert engine._detect_voicemail is False
assert engine._voicemail_detector is None
# Should update LLM context for non-static node
engine._update_llm_context.assert_called_once()
# Should queue context frame
engine.task.queue_frame.assert_called_once()
frame = engine.task.queue_frame.call_args[0][0]
assert isinstance(frame, OpenAILLMContextFrame)
@pytest.mark.asyncio
async def test_start_node_static_with_wait_for_user_response(
self, engine, mock_workflow
):
"""Test: Static start node with wait_for_user_response."""
# Setup
start_node = self.create_node(
"start_node",
is_start=True,
is_static=True,
wait_for_user_response=True,
prompt="Please tell me your name.",
)
next_node = self.create_node("next_node")
edge = self.create_edge("start_node", "next_node")
start_node.out_edges = [edge]
mock_workflow.nodes = {"start_node": start_node, "next_node": next_node}
# Execute
await engine.set_node("start_node")
# Verify
# Should queue TTS immediately
engine.task.queue_frames.assert_called_once()
# Should have a pending control transition that will start the timer
assert engine._pending_control_transition_after_context_push is not None
# Timer task should not exist yet
assert (
not hasattr(engine, "_user_response_timeout_task")
or engine._user_response_timeout_task is None
)
# Simulate context push to start the timer
await engine.flush_pending_transitions(source="context_push")
# Now the timeout task should be created
assert engine._user_response_timeout_task is not None
assert not engine._user_response_timeout_task.done()
# Clean up the task
engine._user_response_timeout_task.cancel()
@pytest.mark.asyncio
async def test_start_node_non_static(self, engine, mock_workflow):
"""Test: Non-static start node sends context to LLM."""
# Setup
start_node = self.create_node(
"start_node",
is_start=True,
is_static=False,
prompt="You are a helpful assistant. Greet the user.",
)
mock_workflow.nodes = {"start_node": start_node}
# Mock the context update method
engine._update_llm_context = AsyncMock()
engine._compose_system_message_functions_for_node = AsyncMock(
return_value=({"role": "system", "content": "Test prompt"}, [])
)
# Execute
await engine.set_node("start_node")
# Verify
# Should set context name
engine.context.set_node_name.assert_called_once_with("Node start_node")
# Should update LLM context
engine._update_llm_context.assert_called_once()
# Should queue context frame
engine.task.queue_frame.assert_called_once()
frame = engine.task.queue_frame.call_args[0][0]
assert isinstance(frame, OpenAILLMContextFrame)
# ===== AGENT NODE TESTS =====
@pytest.mark.asyncio
async def test_agent_node_static(self, engine, mock_workflow):
"""Test: Static agent node plays TTS and transitions."""
# Setup
agent_node = self.create_node(
"agent_node", is_static=True, prompt="Processing your request..."
)
next_node = self.create_node("next_node")
edge = self.create_edge("agent_node", "next_node")
agent_node.out_edges = [edge]
mock_workflow.nodes = {"agent_node": agent_node, "next_node": next_node}
# Execute
await engine.set_node("agent_node")
# Verify
# Should queue TTS
engine.task.queue_frames.assert_called_once()
frames = engine.task.queue_frames.call_args[0][0]
assert isinstance(frames[1], TTSSpeakFrame)
assert frames[1].text == "Processing your request..."
# Should have pending transition
assert engine._pending_control_transition_after_context_push is not None
@pytest.mark.asyncio
async def test_agent_node_non_static(self, engine, mock_workflow):
"""Test: Non-static agent node sends context to LLM."""
# Setup
agent_node = self.create_node(
"agent_node",
is_static=False,
prompt="Analyze the user's request and respond appropriately.",
)
decision_node = self.create_node("decision_node")
edge = self.create_edge("agent_node", "decision_node", "analyze_complete")
agent_node.out_edges = [edge]
mock_workflow.nodes = {"agent_node": agent_node, "decision_node": decision_node}
# Mock methods
engine._update_llm_context = AsyncMock()
engine._compose_system_message_functions_for_node = AsyncMock(
return_value=(
{"role": "system", "content": "Test"},
[{"name": "test_func"}],
)
)
# Execute
await engine.set_node("agent_node")
# Verify
# Should register transition function
engine.llm.register_function.assert_called_once()
call_args = engine.llm.register_function.call_args
assert call_args[0][0] == "analyze_complete"
assert callable(call_args[0][1]) # Check it's a function
assert call_args[1]["cancel_on_interruption"] is True
# Should update context and send frame
engine._update_llm_context.assert_called_once()
engine.task.queue_frame.assert_called_once()
@pytest.mark.asyncio
async def test_agent_node_with_interruption_control(self, engine, mock_workflow):
"""Test: Agent node respects allow_interrupt flag."""
# Setup
no_interrupt_node = self.create_node(
"no_interrupt",
is_static=True,
allow_interrupt=False,
prompt="Please wait while I process...",
)
mock_workflow.nodes = {"no_interrupt": no_interrupt_node}
# Execute
await engine.set_node("no_interrupt")
# Verify current node is set (for STT mute callback)
assert engine._current_node == no_interrupt_node
assert engine._current_node.allow_interrupt is False
# ===== END NODE TESTS =====
@pytest.mark.asyncio
async def test_end_node_static(self, engine, mock_workflow):
"""Test: Static end node plays final message and schedules end task."""
# Setup
end_node = self.create_node(
"end_node",
is_static=True,
is_end=True,
prompt="Thank you for calling. Goodbye!",
)
mock_workflow.nodes = {"end_node": end_node}
# Execute
await engine.set_node("end_node")
# Verify
# Should queue TTS
engine.task.queue_frames.assert_called_once()
frames = engine.task.queue_frames.call_args[0][0]
assert frames[1].text == "Thank you for calling. Goodbye!"
# Should have pending end task
assert engine._pending_control_transition_after_context_push is not None
# Execute the pending transition
await engine._pending_control_transition_after_context_push()
# Should have sent EndFrame via task.queue_frame
# The second call should be the EndFrame (first was TTS frames)
assert engine.task.queue_frame.call_count >= 1
end_frame = engine.task.queue_frame.call_args[0][0]
assert isinstance(end_frame, EndFrame)
@pytest.mark.asyncio
async def test_end_node_with_extraction(self, engine, mock_workflow):
"""Test: End node with variable extraction."""
# Setup
end_node = self.create_node(
"end_node",
is_end=True,
is_static=False,
extraction_enabled=True,
extraction_variables=["user_name", "satisfaction_level"],
extraction_prompt="Extract user name and satisfaction",
)
mock_workflow.nodes = {"end_node": end_node}
# Mock the extraction manager
engine._variable_extraction_manager = Mock()
engine._perform_variable_extraction_if_needed = AsyncMock()
# Mock context update and composition methods
engine._update_llm_context = AsyncMock()
engine._compose_system_message_functions_for_node = AsyncMock(
return_value=({"role": "system", "content": "Test"}, [])
)
# Execute
await engine.set_node("end_node")
# Verify
# Should trigger extraction
engine._perform_variable_extraction_if_needed.assert_called_once_with(end_node)
# Should have pending end task
assert engine._pending_control_transition_after_context_push is not None
# ===== CALLBACK INTEGRATION TESTS =====
@pytest.mark.asyncio
async def test_user_stopped_speaking_during_response_wait(
self, engine, mock_workflow
):
"""Test: User stops speaking triggers transition during wait_for_response."""
# Setup
start_node = self.create_node(
"start_node", is_start=True, is_static=True, wait_for_user_response=True
)
next_node = self.create_node("next_node")
edge = self.create_edge("start_node", "next_node")
start_node.out_edges = [edge]
mock_workflow.nodes = {"start_node": start_node, "next_node": next_node}
# Set current node to start node
engine._current_node = start_node
engine._user_response_timeout_task = asyncio.create_task(asyncio.sleep(3))
# Create callback and execute
callback = engine.create_user_stopped_speaking_callback()
# Mock set_node to avoid recursion
with patch.object(engine, "set_node", new=AsyncMock()) as mock_set_node:
await callback()
# Verify
mock_set_node.assert_called_once_with("next_node")
assert engine._queue_context_frame is False # Should be set to False
@pytest.mark.asyncio
async def test_context_push_callback_executes_pending_transitions(self, engine):
"""Test: flush_pending_transitions executes deferred transitions."""
# Setup pending transitions
mock_generated_transition = AsyncMock()
mock_control_transition = AsyncMock()
engine._pending_generated_transition_after_context_push = (
mock_generated_transition
)
engine._pending_control_transition_after_context_push = mock_control_transition
# Execute
await engine.flush_pending_transitions(source="context_push")
# Verify both transitions were executed
mock_generated_transition.assert_called_once()
mock_control_transition.assert_called_once()
# Verify they were cleared
assert engine._pending_generated_transition_after_context_push is None
assert engine._pending_control_transition_after_context_push is None
# ===== COMPLEX SCENARIO TESTS =====
# Add helper for testing with real async behavior
def ANY(cls=None):
"""Helper for matching any argument in mock calls."""
class AnyMatcher:
def __init__(self, cls):
self.cls = cls
def __eq__(self, other):
if self.cls:
return isinstance(other, self.cls)
return True
return AnyMatcher(cls)