mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-16 08:25:18 +02:00
feat: simplify pipecat engine execution (#54)
This commit is contained in:
parent
99a768f291
commit
6ce25a589c
20 changed files with 52 additions and 1405 deletions
|
|
@ -1,179 +0,0 @@
|
|||
### - This test has some weird loop which keeps on increasing the context size
|
||||
|
||||
# import asyncio
|
||||
# import json
|
||||
# import unittest
|
||||
# from types import SimpleNamespace
|
||||
# from unittest import mock
|
||||
|
||||
# from loguru import logger
|
||||
|
||||
# from pipecat.frames.frames import (
|
||||
# FunctionCallInProgressFrame,
|
||||
# FunctionCallResultFrame,
|
||||
# FunctionCallsStartedFrame,
|
||||
# LLMFullResponseEndFrame,
|
||||
# LLMFullResponseStartFrame,
|
||||
# LLMGeneratedTextFrame,
|
||||
# LLMTextFrame,
|
||||
# )
|
||||
# from pipecat.pipeline.pipeline import Pipeline
|
||||
# from pipecat.processors.aggregators.openai_llm_context import (
|
||||
# OpenAILLMContext,
|
||||
# OpenAILLMContextFrame,
|
||||
# )
|
||||
# from pipecat.services.llm_service import (
|
||||
# FunctionCallParams,
|
||||
# FunctionCallResultProperties,
|
||||
# )
|
||||
# from pipecat.services.openai.llm import OpenAILLMService
|
||||
# from pipecat.tests.utils import run_test
|
||||
|
||||
|
||||
# class _MockAsyncStream:
|
||||
# """A minimal async-stream wrapper that mimics ``openai.AsyncStream``."""
|
||||
|
||||
# def __init__(self, chunks):
|
||||
# self._chunks = chunks
|
||||
|
||||
# def __aiter__(self):
|
||||
# self._idx = 0
|
||||
# return self
|
||||
|
||||
# async def __anext__(self):
|
||||
# if self._idx >= len(self._chunks):
|
||||
# raise StopAsyncIteration
|
||||
# item = self._chunks[self._idx]
|
||||
# self._idx += 1
|
||||
# await asyncio.sleep(0) # Yield control
|
||||
# return item
|
||||
|
||||
|
||||
# # ------------------------------------------------------------------
|
||||
# # Factories for mock chunks
|
||||
# # ------------------------------------------------------------------
|
||||
|
||||
|
||||
# def _make_tool_call(tool_name: str, args_json: str, *, idx: int = 0):
|
||||
# function = SimpleNamespace(name=tool_name, arguments=args_json)
|
||||
# return SimpleNamespace(index=idx, id=f"call-{idx}", function=function)
|
||||
|
||||
|
||||
# def _make_chunk(*, content: str | None = None, tool_calls=None, usage=None):
|
||||
# delta = SimpleNamespace()
|
||||
# # When we are asked to simulate multiple tool calls in parallel, OpenAI
|
||||
# # sends *separate* chunks for every tool-call index. To mimic that behaviour
|
||||
# # in tests we split a list of tool calls (>1) into individual chunks – one
|
||||
# # for each tool call – while keeping the original single-chunk behaviour
|
||||
# # when zero or one tool calls are supplied. This enables us to write
|
||||
# # concise tests such as ``_make_chunk(tool_calls=[call_1, call_2])`` that
|
||||
# # accurately reflect the streaming protocol.
|
||||
|
||||
# # No special handling needed if there is textual content or 0/1 tool calls.
|
||||
# if content is not None or tool_calls is None or len(tool_calls) <= 1:
|
||||
# if content is not None:
|
||||
# delta.content = content
|
||||
# # Always set tool_calls so downstream code can safely access it
|
||||
# delta.tool_calls = tool_calls if tool_calls is not None else None
|
||||
# return SimpleNamespace(choices=[SimpleNamespace(delta=delta)], usage=usage)
|
||||
|
||||
# # --- Multiple tool calls (len(tool_calls) > 1) ---
|
||||
# # Create a list of chunks, each containing a single tool call. This is the
|
||||
# # format produced by the OpenAI client when several tools are invoked in a
|
||||
# # single assistant response.
|
||||
# chunks = []
|
||||
# for tc in tool_calls:
|
||||
# delta_tc = SimpleNamespace(tool_calls=[tc])
|
||||
# chunks.append(SimpleNamespace(choices=[SimpleNamespace(delta=delta_tc)], usage=usage))
|
||||
|
||||
# return chunks
|
||||
|
||||
|
||||
# class TestBaseOpenAILLMService(unittest.IsolatedAsyncioTestCase):
|
||||
# async def test_process_context_with_patch(self):
|
||||
# streamed_text = "Hello from OpenAI!"
|
||||
# tool_name = "echo"
|
||||
# tool_name_2 = "echo_2"
|
||||
# tool_args = {"text": "hello"}
|
||||
# tool_args_2 = {"text": "hello_2"}
|
||||
|
||||
# # Build mocked stream (tool call first, then text)
|
||||
# chunks = [
|
||||
# _make_chunk(content=streamed_text),
|
||||
# _make_chunk(tool_calls=[_make_tool_call(tool_name, json.dumps(tool_args))]),
|
||||
# _make_chunk(tool_calls=[_make_tool_call(tool_name_2, json.dumps(tool_args_2), idx=1)]),
|
||||
# ]
|
||||
|
||||
# # Instantiate real OpenAILLMService (no need for actual API key)
|
||||
# llm = OpenAILLMService(model="gpt-4o-mini", api_key="test")
|
||||
|
||||
# # Patch get_chat_completions to return our mocked async stream
|
||||
# async def fake_get_chat_completions(self, context, messages): # noqa: D401
|
||||
# return _MockAsyncStream(chunks)
|
||||
|
||||
# with mock.patch.object(llm.__class__, "get_chat_completions", fake_get_chat_completions):
|
||||
# # Register echo tool
|
||||
# executed = False
|
||||
|
||||
# async def echo_handler(params: FunctionCallParams):
|
||||
# nonlocal executed
|
||||
# executed = True
|
||||
# # sleep for 1 second
|
||||
# logger.info("echo_handler: sleeping for 5 second")
|
||||
# await asyncio.sleep(5)
|
||||
# await params.result_callback(
|
||||
# {"ok": True},
|
||||
# properties=FunctionCallResultProperties(run_llm=True),
|
||||
# )
|
||||
|
||||
# async def echo_2_handler(params: FunctionCallParams):
|
||||
# nonlocal executed
|
||||
# executed = True
|
||||
# # sleep for 1 second
|
||||
# logger.info("echo_2_handler: sleeping for 5 second")
|
||||
# await asyncio.sleep(5)
|
||||
# await params.result_callback(
|
||||
# {"ok": True},
|
||||
# properties=FunctionCallResultProperties(run_llm=True),
|
||||
# )
|
||||
|
||||
# llm.register_function(tool_name, echo_handler)
|
||||
# llm.register_function(tool_name_2, echo_2_handler)
|
||||
|
||||
# # Prepare context and send
|
||||
# context = OpenAILLMContext()
|
||||
# context.add_message({"role": "user", "content": "Hi"})
|
||||
# frames_to_send = [OpenAILLMContextFrame(context)]
|
||||
|
||||
# expected_down_frames = [
|
||||
# LLMFullResponseStartFrame,
|
||||
# FunctionCallsStartedFrame,
|
||||
# FunctionCallInProgressFrame,
|
||||
# FunctionCallResultFrame,
|
||||
# LLMGeneratedTextFrame,
|
||||
# LLMTextFrame,
|
||||
# LLMFullResponseEndFrame,
|
||||
# ]
|
||||
|
||||
# context_aggregator = llm.create_context_aggregator(context)
|
||||
|
||||
# pipeline = Pipeline([llm, context_aggregator.assistant()])
|
||||
|
||||
# down_frames, _ = await run_test(
|
||||
# pipeline,
|
||||
# frames_to_send=frames_to_send,
|
||||
# expected_down_frames=expected_down_frames,
|
||||
# send_end_frame=False,
|
||||
# )
|
||||
|
||||
# # Assertions
|
||||
# self.assertTrue(executed)
|
||||
# for fr in down_frames:
|
||||
# if isinstance(fr, FunctionCallResultFrame):
|
||||
# self.assertTrue(fr.run_llm)
|
||||
# if isinstance(fr, LLMTextFrame):
|
||||
# self.assertEqual(fr.text, streamed_text)
|
||||
|
||||
|
||||
# if __name__ == "__main__":
|
||||
# unittest.main()
|
||||
|
|
@ -1,143 +0,0 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
"""
|
||||
Test script to verify that LLMGeneratedTextFrame signaling works correctly
|
||||
with the new local variable approach.
|
||||
"""
|
||||
|
||||
|
||||
def test_local_variable_logic():
|
||||
"""Test the core logic using the same pattern as the implementation"""
|
||||
|
||||
print("=== Testing Local Variable Logic ===")
|
||||
|
||||
# Simulate the logic from _process_context
|
||||
text_generation_signaled = False
|
||||
frames_sent = []
|
||||
|
||||
# Simulate chunks with text content
|
||||
chunks_with_content = ["Hello", " world", "!"]
|
||||
|
||||
for content in chunks_with_content:
|
||||
# This is the exact logic from our implementation
|
||||
if content: # equivalent to chunk.choices[0].delta.content
|
||||
if not text_generation_signaled:
|
||||
frames_sent.append("LLMGeneratedTextFrame")
|
||||
text_generation_signaled = True
|
||||
frames_sent.append(f"LLMTextFrame({content})")
|
||||
|
||||
print(f"Frames sent: {frames_sent}")
|
||||
|
||||
# Verify behavior
|
||||
generated_signals = [f for f in frames_sent if f == "LLMGeneratedTextFrame"]
|
||||
text_frames = [f for f in frames_sent if f.startswith("LLMTextFrame")]
|
||||
|
||||
assert len(generated_signals) == 1, (
|
||||
f"Expected 1 signal, got {len(generated_signals)}"
|
||||
)
|
||||
assert len(text_frames) == 3, f"Expected 3 text frames, got {len(text_frames)}"
|
||||
assert frames_sent[0] == "LLMGeneratedTextFrame", "Signal should be first"
|
||||
|
||||
print("✅ Local variable logic works correctly")
|
||||
return True
|
||||
|
||||
|
||||
def test_no_text_logic():
|
||||
"""Test that no signal is sent when there's no text"""
|
||||
|
||||
print("\n=== Testing No Text Logic ===")
|
||||
|
||||
text_generation_signaled = False
|
||||
frames_sent = []
|
||||
|
||||
# Simulate chunks with no text content (function calls only)
|
||||
chunks_with_content = [None, None, None] # No text content
|
||||
|
||||
for content in chunks_with_content:
|
||||
if content: # This will be False for all chunks
|
||||
if not text_generation_signaled:
|
||||
frames_sent.append("LLMGeneratedTextFrame")
|
||||
text_generation_signaled = True
|
||||
frames_sent.append(f"LLMTextFrame({content})")
|
||||
|
||||
print(f"Frames sent: {frames_sent}")
|
||||
|
||||
assert len(frames_sent) == 0, f"Expected no frames, got {frames_sent}"
|
||||
|
||||
print("✅ No signal sent when no text content")
|
||||
return True
|
||||
|
||||
|
||||
def test_mixed_content_logic():
|
||||
"""Test behavior with mixed function calls and text"""
|
||||
|
||||
print("\n=== Testing Mixed Content Logic ===")
|
||||
|
||||
text_generation_signaled = False
|
||||
frames_sent = []
|
||||
|
||||
# Simulate chunks: function call, text, function call, text
|
||||
chunks = [
|
||||
{"type": "function", "content": None},
|
||||
{"type": "text", "content": "Hello"},
|
||||
{"type": "function", "content": None},
|
||||
{"type": "text", "content": " world"},
|
||||
]
|
||||
|
||||
for chunk in chunks:
|
||||
if chunk["type"] == "function":
|
||||
frames_sent.append("FunctionCallFrame")
|
||||
elif chunk["content"]: # text content
|
||||
if not text_generation_signaled:
|
||||
frames_sent.append("LLMGeneratedTextFrame")
|
||||
text_generation_signaled = True
|
||||
frames_sent.append(f"LLMTextFrame({chunk['content']})")
|
||||
|
||||
print(f"Frames sent: {frames_sent}")
|
||||
|
||||
generated_signals = [f for f in frames_sent if f == "LLMGeneratedTextFrame"]
|
||||
|
||||
assert len(generated_signals) == 1, (
|
||||
f"Expected 1 signal, got {len(generated_signals)}"
|
||||
)
|
||||
# Signal should come before first text frame but after any function frames
|
||||
signal_index = frames_sent.index("LLMGeneratedTextFrame")
|
||||
first_text_index = next(
|
||||
i for i, f in enumerate(frames_sent) if f.startswith("LLMTextFrame")
|
||||
)
|
||||
assert signal_index == first_text_index - 1, (
|
||||
"Signal should come right before first text"
|
||||
)
|
||||
|
||||
print("✅ Mixed content logic works correctly")
|
||||
return True
|
||||
|
||||
|
||||
def main():
|
||||
try:
|
||||
test1_result = test_local_variable_logic()
|
||||
test2_result = test_no_text_logic()
|
||||
test3_result = test_mixed_content_logic()
|
||||
|
||||
print(f"\n=== Test Results ===")
|
||||
print(f"Local variable test: {'✅ PASS' if test1_result else '❌ FAIL'}")
|
||||
print(f"No text test: {'✅ PASS' if test2_result else '❌ FAIL'}")
|
||||
print(f"Mixed content test: {'✅ PASS' if test3_result else '❌ FAIL'}")
|
||||
|
||||
if test1_result and test2_result and test3_result:
|
||||
print("\n🎉 All LLMGeneratedTextFrame signaling logic tests passed!")
|
||||
print(
|
||||
"✅ Implementation correctly signals text generation once, as early as possible"
|
||||
)
|
||||
else:
|
||||
print("\n❌ Some tests failed.")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Test failed with error: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -1,536 +0,0 @@
|
|||
import asyncio
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
from pipecat.frames.frames import (
|
||||
EndFrame,
|
||||
LLMFullResponseEndFrame,
|
||||
LLMFullResponseStartFrame,
|
||||
TTSSpeakFrame,
|
||||
)
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContextFrame
|
||||
from pipecat.services.openai.llm import OpenAILLMContext
|
||||
|
||||
from api.services.workflow.dto import EdgeDataDTO, NodeDataDTO
|
||||
from api.services.workflow.pipecat_engine import PipecatEngine
|
||||
from api.services.workflow.workflow import Edge, Node, WorkflowGraph
|
||||
|
||||
|
||||
class TestPipecatEngineSetNode:
|
||||
"""Test cases for PipecatEngine.set_node method refactoring."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_workflow(self):
|
||||
"""Create a mock workflow with various node types."""
|
||||
workflow = Mock(spec=WorkflowGraph)
|
||||
workflow.nodes = {}
|
||||
workflow.start_node_id = "start_node"
|
||||
workflow.global_node_id = None
|
||||
return workflow
|
||||
|
||||
@pytest.fixture
|
||||
def mock_dependencies(self, mock_workflow):
|
||||
"""Create mock dependencies for PipecatEngine initialization."""
|
||||
task = AsyncMock()
|
||||
task.queue_frames = AsyncMock()
|
||||
task.queue_frame = AsyncMock()
|
||||
|
||||
llm = AsyncMock()
|
||||
llm.register_function = Mock()
|
||||
llm.push_frame = AsyncMock()
|
||||
|
||||
context = Mock(spec=OpenAILLMContext)
|
||||
context.set_node_name = Mock()
|
||||
|
||||
return {
|
||||
"task": task,
|
||||
"llm": llm,
|
||||
"context": context,
|
||||
"tts": Mock(),
|
||||
"transport": Mock(),
|
||||
"workflow": mock_workflow,
|
||||
"call_context_vars": {"test_var": "test_value"},
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def engine(self, mock_dependencies):
|
||||
"""Create a PipecatEngine instance."""
|
||||
# Add audio_buffer and workflow_run_id to dependencies
|
||||
mock_dependencies["audio_buffer"] = None
|
||||
mock_dependencies["workflow_run_id"] = 123
|
||||
engine = PipecatEngine(**mock_dependencies)
|
||||
# Mock the builtin function registration
|
||||
engine._register_builtin_functions = AsyncMock()
|
||||
return engine
|
||||
|
||||
def create_node(self, node_id, **kwargs):
|
||||
"""Helper to create a node with default values."""
|
||||
defaults = {
|
||||
"name": f"Node {node_id}",
|
||||
"prompt": f"Prompt for {node_id}",
|
||||
"is_static": False,
|
||||
"is_start": False,
|
||||
"is_end": False,
|
||||
"allow_interrupt": True,
|
||||
"extraction_enabled": False,
|
||||
"extraction_prompt": "",
|
||||
"extraction_variables": [],
|
||||
"add_global_prompt": True,
|
||||
"wait_for_user_response": False,
|
||||
"detect_voicemail": False,
|
||||
}
|
||||
defaults.update(kwargs)
|
||||
|
||||
data = Mock(spec=NodeDataDTO)
|
||||
for key, value in defaults.items():
|
||||
setattr(data, key, value)
|
||||
|
||||
node = Mock(spec=Node)
|
||||
node.id = node_id
|
||||
node.data = data
|
||||
node.out_edges = []
|
||||
|
||||
# Copy attributes from data to node
|
||||
for key, value in defaults.items():
|
||||
setattr(node, key, value)
|
||||
|
||||
return node
|
||||
|
||||
def create_edge(
|
||||
self, source, target, label="Continue", condition="Always continue"
|
||||
):
|
||||
"""Helper to create an edge."""
|
||||
data = Mock(spec=EdgeDataDTO)
|
||||
data.label = label
|
||||
data.condition = condition
|
||||
|
||||
edge = Mock(spec=Edge)
|
||||
edge.source = source
|
||||
edge.target = target
|
||||
edge.data = data
|
||||
edge.get_function_name = Mock(return_value=label.lower().replace(" ", "_"))
|
||||
|
||||
return edge
|
||||
|
||||
# ===== START NODE TESTS =====
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_node_static_immediate_execution(self, engine, mock_workflow):
|
||||
"""Test: Basic static start node executes immediately."""
|
||||
# Setup
|
||||
start_node = self.create_node(
|
||||
"start_node",
|
||||
is_start=True,
|
||||
is_static=True,
|
||||
prompt="Welcome to our service!",
|
||||
)
|
||||
next_node = self.create_node("next_node", is_static=False)
|
||||
|
||||
edge = self.create_edge("start_node", "next_node")
|
||||
start_node.out_edges = [edge]
|
||||
|
||||
mock_workflow.nodes = {"start_node": start_node, "next_node": next_node}
|
||||
|
||||
# Execute
|
||||
await engine.set_node("start_node")
|
||||
|
||||
# Verify
|
||||
# Should queue TTS immediately
|
||||
engine.task.queue_frames.assert_called_once()
|
||||
frames = engine.task.queue_frames.call_args[0][0]
|
||||
assert len(frames) == 3
|
||||
assert isinstance(frames[0], LLMFullResponseStartFrame)
|
||||
assert isinstance(frames[1], TTSSpeakFrame)
|
||||
assert frames[1].text == "Welcome to our service!"
|
||||
assert isinstance(frames[2], LLMFullResponseEndFrame)
|
||||
|
||||
# Static start nodes now set pending transition after context push
|
||||
assert engine._pending_control_transition_after_context_push is not None
|
||||
|
||||
# Should not have set detect_voicemail for static start without it
|
||||
assert not engine._detect_voicemail
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_node_with_detect_voicemail_no_audio_buffer(
|
||||
self, engine, mock_workflow
|
||||
):
|
||||
"""Test: Start node with voicemail detection but no audio buffer logs warning."""
|
||||
# Setup
|
||||
start_node = self.create_node(
|
||||
"start_node",
|
||||
is_start=True,
|
||||
is_static=True,
|
||||
detect_voicemail=True,
|
||||
prompt="Hello, this is a business call.",
|
||||
)
|
||||
|
||||
mock_workflow.nodes = {"start_node": start_node}
|
||||
|
||||
# Engine has no audio buffer (None)
|
||||
assert engine._audio_buffer is None
|
||||
|
||||
# Execute
|
||||
await engine.set_node("start_node")
|
||||
|
||||
# Verify
|
||||
# Should NOT set voicemail detection flag since no audio buffer
|
||||
assert engine._detect_voicemail is False
|
||||
assert engine._voicemail_detector is None
|
||||
|
||||
# Should queue TTS immediately
|
||||
engine.task.queue_frames.assert_called_once()
|
||||
frames = engine.task.queue_frames.call_args[0][0]
|
||||
assert isinstance(frames[1], TTSSpeakFrame)
|
||||
assert frames[1].text == "Hello, this is a business call."
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_node_non_static_with_detect_voicemail(
|
||||
self, engine, mock_workflow
|
||||
):
|
||||
"""Test: Non-static start node with voicemail detection without audio buffer."""
|
||||
# Setup
|
||||
start_node = self.create_node(
|
||||
"start_node",
|
||||
is_start=True,
|
||||
is_static=False, # Non-static
|
||||
detect_voicemail=True,
|
||||
prompt="You are an AI assistant. Start the conversation.",
|
||||
)
|
||||
|
||||
mock_workflow.nodes = {"start_node": start_node}
|
||||
|
||||
# Mock the context update method
|
||||
engine._update_llm_context = AsyncMock()
|
||||
engine._compose_system_message_functions_for_node = AsyncMock(
|
||||
return_value=({"role": "system", "content": "Test prompt"}, [])
|
||||
)
|
||||
|
||||
# Execute
|
||||
await engine.set_node("start_node")
|
||||
|
||||
# Verify
|
||||
# Should NOT set voicemail detection flags (no audio buffer)
|
||||
assert engine._detect_voicemail is False
|
||||
assert engine._voicemail_detector is None
|
||||
|
||||
# Should update LLM context for non-static node
|
||||
engine._update_llm_context.assert_called_once()
|
||||
|
||||
# Should queue context frame
|
||||
engine.task.queue_frame.assert_called_once()
|
||||
frame = engine.task.queue_frame.call_args[0][0]
|
||||
assert isinstance(frame, OpenAILLMContextFrame)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_node_static_with_wait_for_user_response(
|
||||
self, engine, mock_workflow
|
||||
):
|
||||
"""Test: Static start node with wait_for_user_response."""
|
||||
# Setup
|
||||
start_node = self.create_node(
|
||||
"start_node",
|
||||
is_start=True,
|
||||
is_static=True,
|
||||
wait_for_user_response=True,
|
||||
prompt="Please tell me your name.",
|
||||
)
|
||||
next_node = self.create_node("next_node")
|
||||
|
||||
edge = self.create_edge("start_node", "next_node")
|
||||
start_node.out_edges = [edge]
|
||||
|
||||
mock_workflow.nodes = {"start_node": start_node, "next_node": next_node}
|
||||
|
||||
# Execute
|
||||
await engine.set_node("start_node")
|
||||
|
||||
# Verify
|
||||
# Should queue TTS immediately
|
||||
engine.task.queue_frames.assert_called_once()
|
||||
|
||||
# Should have a pending control transition that will start the timer
|
||||
assert engine._pending_control_transition_after_context_push is not None
|
||||
|
||||
# Timer task should not exist yet
|
||||
assert (
|
||||
not hasattr(engine, "_user_response_timeout_task")
|
||||
or engine._user_response_timeout_task is None
|
||||
)
|
||||
|
||||
# Simulate context push to start the timer
|
||||
await engine.flush_pending_transitions(source="context_push")
|
||||
|
||||
# Now the timeout task should be created
|
||||
assert engine._user_response_timeout_task is not None
|
||||
assert not engine._user_response_timeout_task.done()
|
||||
|
||||
# Clean up the task
|
||||
engine._user_response_timeout_task.cancel()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_node_non_static(self, engine, mock_workflow):
|
||||
"""Test: Non-static start node sends context to LLM."""
|
||||
# Setup
|
||||
start_node = self.create_node(
|
||||
"start_node",
|
||||
is_start=True,
|
||||
is_static=False,
|
||||
prompt="You are a helpful assistant. Greet the user.",
|
||||
)
|
||||
|
||||
mock_workflow.nodes = {"start_node": start_node}
|
||||
|
||||
# Mock the context update method
|
||||
engine._update_llm_context = AsyncMock()
|
||||
engine._compose_system_message_functions_for_node = AsyncMock(
|
||||
return_value=({"role": "system", "content": "Test prompt"}, [])
|
||||
)
|
||||
|
||||
# Execute
|
||||
await engine.set_node("start_node")
|
||||
|
||||
# Verify
|
||||
# Should set context name
|
||||
engine.context.set_node_name.assert_called_once_with("Node start_node")
|
||||
|
||||
# Should update LLM context
|
||||
engine._update_llm_context.assert_called_once()
|
||||
|
||||
# Should queue context frame
|
||||
engine.task.queue_frame.assert_called_once()
|
||||
frame = engine.task.queue_frame.call_args[0][0]
|
||||
assert isinstance(frame, OpenAILLMContextFrame)
|
||||
|
||||
# ===== AGENT NODE TESTS =====
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_node_static(self, engine, mock_workflow):
|
||||
"""Test: Static agent node plays TTS and transitions."""
|
||||
# Setup
|
||||
agent_node = self.create_node(
|
||||
"agent_node", is_static=True, prompt="Processing your request..."
|
||||
)
|
||||
next_node = self.create_node("next_node")
|
||||
|
||||
edge = self.create_edge("agent_node", "next_node")
|
||||
agent_node.out_edges = [edge]
|
||||
|
||||
mock_workflow.nodes = {"agent_node": agent_node, "next_node": next_node}
|
||||
|
||||
# Execute
|
||||
await engine.set_node("agent_node")
|
||||
|
||||
# Verify
|
||||
# Should queue TTS
|
||||
engine.task.queue_frames.assert_called_once()
|
||||
frames = engine.task.queue_frames.call_args[0][0]
|
||||
assert isinstance(frames[1], TTSSpeakFrame)
|
||||
assert frames[1].text == "Processing your request..."
|
||||
|
||||
# Should have pending transition
|
||||
assert engine._pending_control_transition_after_context_push is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_node_non_static(self, engine, mock_workflow):
|
||||
"""Test: Non-static agent node sends context to LLM."""
|
||||
# Setup
|
||||
agent_node = self.create_node(
|
||||
"agent_node",
|
||||
is_static=False,
|
||||
prompt="Analyze the user's request and respond appropriately.",
|
||||
)
|
||||
decision_node = self.create_node("decision_node")
|
||||
|
||||
edge = self.create_edge("agent_node", "decision_node", "analyze_complete")
|
||||
agent_node.out_edges = [edge]
|
||||
|
||||
mock_workflow.nodes = {"agent_node": agent_node, "decision_node": decision_node}
|
||||
|
||||
# Mock methods
|
||||
engine._update_llm_context = AsyncMock()
|
||||
engine._compose_system_message_functions_for_node = AsyncMock(
|
||||
return_value=(
|
||||
{"role": "system", "content": "Test"},
|
||||
[{"name": "test_func"}],
|
||||
)
|
||||
)
|
||||
|
||||
# Execute
|
||||
await engine.set_node("agent_node")
|
||||
|
||||
# Verify
|
||||
# Should register transition function
|
||||
engine.llm.register_function.assert_called_once()
|
||||
call_args = engine.llm.register_function.call_args
|
||||
assert call_args[0][0] == "analyze_complete"
|
||||
assert callable(call_args[0][1]) # Check it's a function
|
||||
assert call_args[1]["cancel_on_interruption"] is True
|
||||
|
||||
# Should update context and send frame
|
||||
engine._update_llm_context.assert_called_once()
|
||||
engine.task.queue_frame.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_node_with_interruption_control(self, engine, mock_workflow):
|
||||
"""Test: Agent node respects allow_interrupt flag."""
|
||||
# Setup
|
||||
no_interrupt_node = self.create_node(
|
||||
"no_interrupt",
|
||||
is_static=True,
|
||||
allow_interrupt=False,
|
||||
prompt="Please wait while I process...",
|
||||
)
|
||||
|
||||
mock_workflow.nodes = {"no_interrupt": no_interrupt_node}
|
||||
|
||||
# Execute
|
||||
await engine.set_node("no_interrupt")
|
||||
|
||||
# Verify current node is set (for STT mute callback)
|
||||
assert engine._current_node == no_interrupt_node
|
||||
assert engine._current_node.allow_interrupt is False
|
||||
|
||||
# ===== END NODE TESTS =====
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_end_node_static(self, engine, mock_workflow):
|
||||
"""Test: Static end node plays final message and schedules end task."""
|
||||
# Setup
|
||||
end_node = self.create_node(
|
||||
"end_node",
|
||||
is_static=True,
|
||||
is_end=True,
|
||||
prompt="Thank you for calling. Goodbye!",
|
||||
)
|
||||
|
||||
mock_workflow.nodes = {"end_node": end_node}
|
||||
|
||||
# Execute
|
||||
await engine.set_node("end_node")
|
||||
|
||||
# Verify
|
||||
# Should queue TTS
|
||||
engine.task.queue_frames.assert_called_once()
|
||||
frames = engine.task.queue_frames.call_args[0][0]
|
||||
assert frames[1].text == "Thank you for calling. Goodbye!"
|
||||
|
||||
# Should have pending end task
|
||||
assert engine._pending_control_transition_after_context_push is not None
|
||||
|
||||
# Execute the pending transition
|
||||
await engine._pending_control_transition_after_context_push()
|
||||
|
||||
# Should have sent EndFrame via task.queue_frame
|
||||
# The second call should be the EndFrame (first was TTS frames)
|
||||
assert engine.task.queue_frame.call_count >= 1
|
||||
end_frame = engine.task.queue_frame.call_args[0][0]
|
||||
assert isinstance(end_frame, EndFrame)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_end_node_with_extraction(self, engine, mock_workflow):
|
||||
"""Test: End node with variable extraction."""
|
||||
# Setup
|
||||
end_node = self.create_node(
|
||||
"end_node",
|
||||
is_end=True,
|
||||
is_static=False,
|
||||
extraction_enabled=True,
|
||||
extraction_variables=["user_name", "satisfaction_level"],
|
||||
extraction_prompt="Extract user name and satisfaction",
|
||||
)
|
||||
|
||||
mock_workflow.nodes = {"end_node": end_node}
|
||||
|
||||
# Mock the extraction manager
|
||||
engine._variable_extraction_manager = Mock()
|
||||
engine._perform_variable_extraction_if_needed = AsyncMock()
|
||||
|
||||
# Mock context update and composition methods
|
||||
engine._update_llm_context = AsyncMock()
|
||||
engine._compose_system_message_functions_for_node = AsyncMock(
|
||||
return_value=({"role": "system", "content": "Test"}, [])
|
||||
)
|
||||
|
||||
# Execute
|
||||
await engine.set_node("end_node")
|
||||
|
||||
# Verify
|
||||
# Should trigger extraction
|
||||
engine._perform_variable_extraction_if_needed.assert_called_once_with(end_node)
|
||||
|
||||
# Should have pending end task
|
||||
assert engine._pending_control_transition_after_context_push is not None
|
||||
|
||||
# ===== CALLBACK INTEGRATION TESTS =====
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_stopped_speaking_during_response_wait(
|
||||
self, engine, mock_workflow
|
||||
):
|
||||
"""Test: User stops speaking triggers transition during wait_for_response."""
|
||||
# Setup
|
||||
start_node = self.create_node(
|
||||
"start_node", is_start=True, is_static=True, wait_for_user_response=True
|
||||
)
|
||||
next_node = self.create_node("next_node")
|
||||
|
||||
edge = self.create_edge("start_node", "next_node")
|
||||
start_node.out_edges = [edge]
|
||||
|
||||
mock_workflow.nodes = {"start_node": start_node, "next_node": next_node}
|
||||
|
||||
# Set current node to start node
|
||||
engine._current_node = start_node
|
||||
engine._user_response_timeout_task = asyncio.create_task(asyncio.sleep(3))
|
||||
|
||||
# Create callback and execute
|
||||
callback = engine.create_user_stopped_speaking_callback()
|
||||
|
||||
# Mock set_node to avoid recursion
|
||||
with patch.object(engine, "set_node", new=AsyncMock()) as mock_set_node:
|
||||
await callback()
|
||||
|
||||
# Verify
|
||||
mock_set_node.assert_called_once_with("next_node")
|
||||
assert engine._queue_context_frame is False # Should be set to False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_context_push_callback_executes_pending_transitions(self, engine):
|
||||
"""Test: flush_pending_transitions executes deferred transitions."""
|
||||
# Setup pending transitions
|
||||
mock_generated_transition = AsyncMock()
|
||||
mock_control_transition = AsyncMock()
|
||||
|
||||
engine._pending_generated_transition_after_context_push = (
|
||||
mock_generated_transition
|
||||
)
|
||||
engine._pending_control_transition_after_context_push = mock_control_transition
|
||||
|
||||
# Execute
|
||||
await engine.flush_pending_transitions(source="context_push")
|
||||
|
||||
# Verify both transitions were executed
|
||||
mock_generated_transition.assert_called_once()
|
||||
mock_control_transition.assert_called_once()
|
||||
|
||||
# Verify they were cleared
|
||||
assert engine._pending_generated_transition_after_context_push is None
|
||||
assert engine._pending_control_transition_after_context_push is None
|
||||
|
||||
# ===== COMPLEX SCENARIO TESTS =====
|
||||
|
||||
|
||||
# Add helper for testing with real async behavior
|
||||
def ANY(cls=None):
|
||||
"""Helper for matching any argument in mock calls."""
|
||||
|
||||
class AnyMatcher:
|
||||
def __init__(self, cls):
|
||||
self.cls = cls
|
||||
|
||||
def __eq__(self, other):
|
||||
if self.cls:
|
||||
return isinstance(other, self.cls)
|
||||
return True
|
||||
|
||||
return AnyMatcher(cls)
|
||||
Loading…
Add table
Add a link
Reference in a new issue