mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-07 07:55:16 +02:00
chore: refactor and add tests (#130)
* chore: add tests for end call * Update pipecat module * fix: allow interruptions from deepgram flux * Add VadUserTurnStrategy * chore: add test for voicemail detection
This commit is contained in:
parent
2aedb839ff
commit
033fde8946
15 changed files with 2106 additions and 542 deletions
|
|
@ -10,16 +10,13 @@ from api.services.pipecat.in_memory_buffers import (
|
|||
InMemoryTranscriptBuffer,
|
||||
)
|
||||
from api.services.pipecat.pipeline_metrics_aggregator import PipelineMetricsAggregator
|
||||
from api.services.workflow.disposition_mapper import (
|
||||
apply_disposition_mapping,
|
||||
get_organization_id_from_workflow_run,
|
||||
)
|
||||
from api.services.workflow.pipecat_engine import PipecatEngine
|
||||
from api.tasks.arq import enqueue_job
|
||||
from api.tasks.function_names import FunctionNames
|
||||
from pipecat.frames.frames import Frame, LLMContextFrame
|
||||
from pipecat.pipeline.task import PipelineTask
|
||||
from pipecat.processors.audio.audio_buffer_processor import AudioBufferProcessor
|
||||
from pipecat.utils.enums import EndTaskReason
|
||||
|
||||
|
||||
def register_event_handlers(
|
||||
|
|
@ -83,18 +80,17 @@ def register_event_handlers(
|
|||
@transport.event_handler("on_client_disconnected")
|
||||
async def on_client_disconnected(_transport, _participant):
|
||||
call_disposed = engine.is_call_disposed()
|
||||
|
||||
logger.debug(
|
||||
f"In on_client_disconnected callback handler. Call disposed: {call_disposed}"
|
||||
)
|
||||
engine.handle_client_disconnected()
|
||||
|
||||
# Stop recordings
|
||||
await audio_buffer.stop_recording()
|
||||
|
||||
# Only cancel the task if the call is not already disposed by the engine
|
||||
if not call_disposed:
|
||||
await task.cancel()
|
||||
# End the call
|
||||
await engine.end_call_with_reason(
|
||||
EndTaskReason.USER_HANGUP.value, abort_immediately=True
|
||||
)
|
||||
|
||||
@task.event_handler("on_pipeline_started")
|
||||
async def on_pipeline_started(_task: PipelineTask, _frame: Frame):
|
||||
|
|
@ -114,9 +110,6 @@ def register_event_handlers(
|
|||
# Stop recordings
|
||||
await audio_buffer.stop_recording()
|
||||
|
||||
call_disposition = await engine.get_call_disposition()
|
||||
logger.debug(f"call disposition in on_pipeline_finished: {call_disposition}")
|
||||
|
||||
gathered_context = await engine.get_gathered_context()
|
||||
|
||||
# Add trace URL if available (must be done before conversation tracing ends)
|
||||
|
|
@ -129,13 +122,6 @@ def register_event_handlers(
|
|||
# also consider existing gathered context in workflow_run
|
||||
gathered_context = {**gathered_context, **workflow_run.gathered_context}
|
||||
|
||||
organization_id = await get_organization_id_from_workflow_run(workflow_run_id)
|
||||
mapped_call_disposition = await apply_disposition_mapping(
|
||||
call_disposition, organization_id
|
||||
)
|
||||
|
||||
gathered_context.update({"mapped_call_disposition": mapped_call_disposition})
|
||||
|
||||
# Set user_speech call tag
|
||||
if in_memory_transcript_buffer:
|
||||
call_tags = gathered_context.get("call_tags", [])
|
||||
|
|
|
|||
|
|
@ -57,6 +57,10 @@ def build_pipeline(
|
|||
# Insert voicemail detector after STT if enabled
|
||||
# Note: We intentionally do NOT use voicemail_detector.gate() to allow TTS
|
||||
# frames to continue flowing during classification (non-blocking detection)
|
||||
|
||||
# Note: We must keep user_context_aggregator after voicemail_detector
|
||||
# or else, LLMContextFrames generated from user_context_aggregator will
|
||||
# start generating LLM Completion from Voicemail Classifier
|
||||
if voicemail_detector:
|
||||
logger.info("Adding native voicemail detector to pipeline")
|
||||
processors.append(voicemail_detector.detector())
|
||||
|
|
|
|||
|
|
@ -52,9 +52,11 @@ from pipecat.processors.aggregators.llm_response_universal import (
|
|||
LLMUserAggregatorParams,
|
||||
)
|
||||
from pipecat.transports.smallwebrtc.connection import SmallWebRTCConnection
|
||||
from pipecat.turns.user_mute import MuteUntilFirstBotCompleteUserMuteStrategy
|
||||
from pipecat.turns.user_mute import (
|
||||
CallbackUserMuteStrategy,
|
||||
MuteUntilFirstBotCompleteUserMuteStrategy,
|
||||
)
|
||||
from pipecat.turns.user_start import (
|
||||
ExternalUserTurnStartStrategy,
|
||||
TranscriptionUserTurnStartStrategy,
|
||||
)
|
||||
from pipecat.turns.user_start.vad_user_turn_start_strategy import (
|
||||
|
|
@ -547,7 +549,7 @@ async def _run_pipeline(
|
|||
|
||||
if is_deepgram_flux:
|
||||
user_turn_strategies = UserTurnStrategies(
|
||||
start=[VADUserTurnStartStrategy(), ExternalUserTurnStartStrategy()],
|
||||
start=[VADUserTurnStartStrategy(), TranscriptionUserTurnStartStrategy()],
|
||||
stop=[ExternalUserTurnStopStrategy()],
|
||||
)
|
||||
else:
|
||||
|
|
@ -556,9 +558,16 @@ async def _run_pipeline(
|
|||
stop=[TranscriptionUserTurnStopStrategy()],
|
||||
)
|
||||
|
||||
# Create user mute strategies
|
||||
# - CallbackUserMuteStrategy: mutes based on engine's _mute_pipeline state
|
||||
user_mute_strategies = [
|
||||
MuteUntilFirstBotCompleteUserMuteStrategy(),
|
||||
CallbackUserMuteStrategy(should_mute_callback=engine.should_mute_user),
|
||||
]
|
||||
|
||||
user_params = LLMUserAggregatorParams(
|
||||
user_turn_strategies=user_turn_strategies,
|
||||
user_mute_strategies=[MuteUntilFirstBotCompleteUserMuteStrategy()],
|
||||
user_mute_strategies=user_mute_strategies,
|
||||
user_idle_timeout=max_user_idle_timeout,
|
||||
)
|
||||
context_aggregator = LLMContextAggregatorPair(
|
||||
|
|
@ -606,7 +615,7 @@ async def _run_pipeline(
|
|||
@voicemail_detector.event_handler("on_voicemail_detected")
|
||||
async def _on_voicemail_detected(_processor):
|
||||
logger.info(f"Voicemail detected for workflow run {workflow_run_id}")
|
||||
await engine.send_end_task_frame(
|
||||
await engine.end_call_with_reason(
|
||||
reason=EndTaskReason.VOICEMAIL_DETECTED.value,
|
||||
abort_immediately=True,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -9,7 +9,6 @@ from pipecat.frames.frames import (
|
|||
CancelFrame,
|
||||
EndFrame,
|
||||
FunctionCallResultProperties,
|
||||
TTSSpeakFrame,
|
||||
)
|
||||
from pipecat.pipeline.task import PipelineTask
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
|
|
@ -19,6 +18,7 @@ from pipecat.utils.enums import EndTaskReason
|
|||
|
||||
if TYPE_CHECKING:
|
||||
from api.services.telephony.stasis_rtp_connection import StasisRTPConnection
|
||||
from pipecat.frames.frames import Frame
|
||||
from pipecat.services.anthropic.llm import AnthropicLLMService
|
||||
from pipecat.services.google.llm import GoogleLLMService
|
||||
from pipecat.services.openai.llm import OpenAILLMService
|
||||
|
|
@ -49,7 +49,6 @@ from api.services.workflow.tools.timezone import (
|
|||
get_current_time,
|
||||
get_time_tools,
|
||||
)
|
||||
from pipecat.processors.filters.stt_mute_filter import STTMuteFilter
|
||||
from pipecat.utils.tracing.context_registry import get_current_turn_context
|
||||
|
||||
|
||||
|
|
@ -79,12 +78,10 @@ class PipecatEngine:
|
|||
self._workflow_run_id = workflow_run_id
|
||||
self._node_transition_callback = node_transition_callback
|
||||
self._initialized = False
|
||||
self._client_disconnected = False
|
||||
self._call_disposed = False
|
||||
self._current_node: Optional[Node] = None
|
||||
self._gathered_context: dict = {}
|
||||
self._user_response_timeout_task: Optional[asyncio.Task] = None
|
||||
self._call_disposition: Optional[str] = None
|
||||
|
||||
# Stasis connection for immediate transfers
|
||||
self._stasis_connection: Optional["StasisRTPConnection"] = None
|
||||
|
|
@ -99,6 +96,9 @@ class PipecatEngine:
|
|||
# Track current LLM reference text for TTS aggregation correction
|
||||
self._current_llm_generation_reference_text: str = ""
|
||||
|
||||
# Controls whether user input should be muted
|
||||
self._mute_pipeline: bool = False
|
||||
|
||||
# Custom tool manager (initialized in initialize())
|
||||
self._custom_tool_manager: Optional[CustomToolManager] = None
|
||||
|
||||
|
|
@ -215,9 +215,14 @@ class PipecatEngine:
|
|||
This way, when we do set_node from within this function, and go for LLM completion with updated
|
||||
system prompts, the context is updated with function call result.
|
||||
"""
|
||||
# FIXME: There is a potential race condition, when we generate LLM Completion from UserContextAggregator
|
||||
# with FunctionCallResultFrame and we call end_call_with_reason where we queue EndFrame or CancelFrame.
|
||||
# If EndFrame reaches the LLM Processor before the ContextFrame, we might never run generation which
|
||||
# might be intended
|
||||
|
||||
# Queue EndFrame if we just transitioned to EndNode
|
||||
if self._current_node.is_end:
|
||||
await self.send_end_task_frame(
|
||||
await self.end_call_with_reason(
|
||||
EndTaskReason.USER_QUALIFIED.value
|
||||
)
|
||||
|
||||
|
|
@ -356,44 +361,52 @@ class PipecatEngine:
|
|||
self.llm.register_function("retrieve_from_knowledge_base", retrieve_kb_func)
|
||||
|
||||
async def _perform_variable_extraction_if_needed(
|
||||
self, previous_node: Optional[Node]
|
||||
self, node: Optional[Node], run_in_background: bool = True
|
||||
) -> None:
|
||||
"""Perform variable extraction if the previous node had extraction enabled."""
|
||||
if (
|
||||
previous_node
|
||||
and previous_node.extraction_enabled
|
||||
and previous_node.extraction_variables
|
||||
):
|
||||
"""Perform variable extraction if the node has extraction enabled.
|
||||
|
||||
Args:
|
||||
node: The node to extract variables from.
|
||||
run_in_background: If True, runs extraction as a fire-and-forget task.
|
||||
If False, awaits the extraction synchronously.
|
||||
"""
|
||||
if not (node and node.extraction_enabled and node.extraction_variables):
|
||||
return
|
||||
|
||||
# Capture the current turn context for otel tracing
|
||||
# before creating the background task
|
||||
parent_context = get_current_turn_context()
|
||||
|
||||
extraction_prompt = self._format_prompt(node.extraction_prompt)
|
||||
extraction_variables = node.extraction_variables
|
||||
|
||||
async def _do_extraction():
|
||||
try:
|
||||
extracted_data = (
|
||||
await self._variable_extraction_manager._perform_extraction(
|
||||
extraction_variables, parent_context, extraction_prompt
|
||||
)
|
||||
)
|
||||
self._gathered_context.update(extracted_data)
|
||||
logger.debug(
|
||||
f"Variable extraction completed. Extracted: {extracted_data}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error during variable extraction: {str(e)}")
|
||||
|
||||
if run_in_background:
|
||||
logger.debug(
|
||||
f"Scheduling background variable extraction for node: {previous_node.name}"
|
||||
f"Scheduling background variable extraction for node: {node.name}"
|
||||
)
|
||||
asyncio.create_task(_do_extraction())
|
||||
else:
|
||||
logger.debug(
|
||||
f"Performing synchronous variable extraction for node: {node.name}"
|
||||
)
|
||||
await _do_extraction()
|
||||
|
||||
# Capture the current turn context before creating the background task
|
||||
parent_context = get_current_turn_context()
|
||||
extraction_prompt = self._format_prompt(previous_node.extraction_prompt)
|
||||
extraction_variables = previous_node.extraction_variables
|
||||
|
||||
async def _background_extraction():
|
||||
try:
|
||||
extracted_data = (
|
||||
await self._variable_extraction_manager._perform_extraction(
|
||||
extraction_variables, parent_context, extraction_prompt
|
||||
)
|
||||
)
|
||||
self._gathered_context.update(extracted_data)
|
||||
logger.debug(
|
||||
f"Background variable extraction completed. Extracted: {extracted_data}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error during background variable extraction: {str(e)}"
|
||||
)
|
||||
|
||||
# Fire and forget - extraction happens in background without blocking
|
||||
asyncio.create_task(_background_extraction())
|
||||
|
||||
async def _setup_llm_context_and_start_generation(self, node: Node) -> None:
|
||||
"""Common method to set up LLM context and queue context frame for non-static nodes."""
|
||||
async def _setup_llm_context(self, node: Node) -> None:
|
||||
"""Common method to set up LLM context"""
|
||||
# Set node name for tracing
|
||||
try:
|
||||
self.context.set_node_name(node.name)
|
||||
|
|
@ -470,61 +483,54 @@ class PipecatEngine:
|
|||
if node.is_static:
|
||||
raise ValueError("Static nodes are not supported!")
|
||||
else:
|
||||
# Start generation for non-static start node
|
||||
await self._setup_llm_context_and_start_generation(node)
|
||||
# Setup LLM Context with Prompts and Functions
|
||||
await self._setup_llm_context(node)
|
||||
|
||||
async def _handle_end_node(self, node: Node) -> None:
|
||||
"""Handle end node execution."""
|
||||
if node.is_static:
|
||||
raise ValueError("Static nodes are not supported!")
|
||||
else:
|
||||
await self._setup_llm_context_and_start_generation(node)
|
||||
|
||||
# If this end node has extraction enabled, perform extraction immediately
|
||||
if node.extraction_enabled and node.extraction_variables:
|
||||
await self._perform_variable_extraction_if_needed(node)
|
||||
# Setup LLM Context with Prompts and Functions
|
||||
await self._setup_llm_context(node)
|
||||
|
||||
async def _handle_agent_node(self, node: Node) -> None:
|
||||
"""Handle agent node execution."""
|
||||
if node.is_static:
|
||||
raise ValueError("Static nodes are not supported!")
|
||||
else:
|
||||
# Set context and functions for non-static agent node
|
||||
await self._setup_llm_context_and_start_generation(node)
|
||||
# Setup LLM Context with Prompts and Functions
|
||||
await self._setup_llm_context(node)
|
||||
|
||||
async def send_end_task_frame(
|
||||
async def end_call_with_reason(
|
||||
self,
|
||||
reason: str,
|
||||
abort_immediately: bool = False,
|
||||
):
|
||||
"""
|
||||
Centralized method to send EndTaskFrame with metadata including
|
||||
call_transfer_context and call_context_vars
|
||||
Centralized method to end the call with disposition mapping
|
||||
"""
|
||||
if self._call_disposed or self._client_disconnected:
|
||||
# Call is already disposed and client disconnected
|
||||
logger.debug(
|
||||
f"Not sending EndFrame since call is already disposed: Call Disposed: {self._call_disposed} Client Disconnected: {self._client_disconnected}"
|
||||
)
|
||||
if self._call_disposed:
|
||||
logger.debug(f"Call already Disposed: {self._call_disposed}")
|
||||
return
|
||||
|
||||
self._call_disposed = True
|
||||
|
||||
frame_to_push = CancelFrame() if abort_immediately else EndFrame()
|
||||
# Mute the pipeline
|
||||
self._mute_pipeline = True
|
||||
|
||||
# Customer disposition code using their mapping
|
||||
mapped_disposition = ""
|
||||
# Perform final variable extraction synchronously before ending
|
||||
await self._perform_variable_extraction_if_needed(
|
||||
self._current_node, run_in_background=False
|
||||
)
|
||||
|
||||
frame_to_push = CancelFrame() if abort_immediately else EndFrame()
|
||||
|
||||
# Apply disposition mapping - first try call_disposition if it is,
|
||||
# extracted from the call conversation then fall back to reason
|
||||
call_disposition = self._gathered_context.get("call_disposition", "")
|
||||
organization_id = await self._get_organization_id()
|
||||
|
||||
# If client is disconnected before we get a chance to disconnect from
|
||||
# the bot, lets consider that as final disposition
|
||||
if self._client_disconnected:
|
||||
call_disposition = EndTaskReason.USER_HANGUP.value
|
||||
|
||||
if call_disposition:
|
||||
# If call_disposition exists, map it
|
||||
mapped_disposition = await apply_disposition_mapping(
|
||||
|
|
@ -532,90 +538,16 @@ class PipecatEngine:
|
|||
)
|
||||
# Store the original and mapped values
|
||||
self._gathered_context["extracted_call_disposition"] = call_disposition
|
||||
self._gathered_context["call_disposition"] = mapped_disposition
|
||||
self._gathered_context["call_disposition"] = call_disposition
|
||||
self._gathered_context["mapped_call_disposition"] = mapped_disposition
|
||||
else:
|
||||
# Otherwise, map the disconnect reason
|
||||
mapped_disposition = await apply_disposition_mapping(
|
||||
reason, organization_id
|
||||
)
|
||||
# Store the mapped disconnect reason
|
||||
self._gathered_context["call_disposition"] = mapped_disposition
|
||||
|
||||
# TODO: Generalise this
|
||||
self._gathered_context["address"] = ", ".join(
|
||||
[
|
||||
self._call_context_vars.get("address1", ""),
|
||||
self._call_context_vars.get("address2", ""),
|
||||
self._call_context_vars.get("address3", ""),
|
||||
self._call_context_vars.get("city", ""),
|
||||
self._call_context_vars.get("state", ""),
|
||||
self._call_context_vars.get("province", ""),
|
||||
self._call_context_vars.get("postal_code", ""),
|
||||
]
|
||||
)
|
||||
self._gathered_context["full_name"] = " ".join(
|
||||
[
|
||||
self._call_context_vars.get("first_name", ""),
|
||||
self._call_context_vars.get("middle_initial", ""),
|
||||
self._call_context_vars.get("last_name", ""),
|
||||
]
|
||||
)
|
||||
self._gathered_context["agent_name"] = "Alex"
|
||||
self._gathered_context["customer_phone_number"] = self._call_context_vars.get(
|
||||
"phone", ""
|
||||
)
|
||||
self._gathered_context["timezone"] = self._call_context_vars.get("province", "")
|
||||
self._gathered_context["vendor_id"] = self._call_context_vars.get(
|
||||
"vendor_lead_code", ""
|
||||
)
|
||||
|
||||
decision_maker = self._gathered_context.get("primary_cardholder", False)
|
||||
employment_status = self._gathered_context.get("employment_status", "N/A")
|
||||
call_transfer_context = {
|
||||
"first_name": self._call_context_vars.get("first_name", ""),
|
||||
"full_name": self._gathered_context.get("full_name", ""),
|
||||
"phone": self._call_context_vars.get("phone", ""),
|
||||
"lead_id": self._call_context_vars.get("lead_id"),
|
||||
"disposition": mapped_disposition,
|
||||
"agent_name": self._gathered_context.get("agent_name", "Alex"),
|
||||
"decision_maker": str(decision_maker),
|
||||
"employment": employment_status.title() if employment_status else "N/A",
|
||||
"debts": self._gathered_context.get("total_debt", "N/A"),
|
||||
"number_of_credit_cards": self._gathered_context.get(
|
||||
"number_of_credit_cards", "N/A"
|
||||
),
|
||||
"time": self._gathered_context.get("time"),
|
||||
}
|
||||
|
||||
logger.debug(
|
||||
f"gathered_context: {self._gathered_context} call_transfer_context: {call_transfer_context}"
|
||||
)
|
||||
|
||||
# Initiate immediate transfer for Stasis connections when user is qualified
|
||||
if (
|
||||
reason == EndTaskReason.USER_QUALIFIED.value
|
||||
and self._stasis_connection is not None
|
||||
and not abort_immediately
|
||||
):
|
||||
try:
|
||||
logger.info(
|
||||
f"Initiating immediate Stasis transfer for channel {self._stasis_connection.channel_id}"
|
||||
)
|
||||
await self._stasis_connection.transfer(call_transfer_context)
|
||||
logger.info("Immediate transfer initiated successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initiate immediate transfer: {e}")
|
||||
# Continue with normal flow even if immediate transfer fails
|
||||
|
||||
if reason == EndTaskReason.CALL_DURATION_EXCEEDED.value:
|
||||
await self.task.queue_frame(
|
||||
TTSSpeakFrame(
|
||||
"Sorry! It seems like our time has exceeded. Someone from our team will reach out to you soon. Thank you!"
|
||||
)
|
||||
)
|
||||
|
||||
# Store the original reason for later retrieval in event handler
|
||||
self._call_disposition = mapped_disposition
|
||||
self._gathered_context["call_disposition"] = reason
|
||||
self._gathered_context["mapped_call_disposition"] = mapped_disposition
|
||||
|
||||
logger.debug(
|
||||
f"Finishing run with reason: {reason}, disposition: {mapped_disposition} queueing frame {frame_to_push}"
|
||||
|
|
@ -678,11 +610,14 @@ class PipecatEngine:
|
|||
|
||||
return system_message, functions
|
||||
|
||||
def create_should_mute_callback(self) -> Callable[[STTMuteFilter], Awaitable[bool]]:
|
||||
async def should_mute_user(self, frame: "Frame") -> bool:
|
||||
"""
|
||||
This callback is called by STTMuteFilter to determine if the STT should be muted.
|
||||
Callback for CallbackUserMuteStrategy to determine if the user should be muted.
|
||||
|
||||
Returns:
|
||||
True if the user should be muted, False otherwise.
|
||||
"""
|
||||
return engine_callbacks.create_should_mute_callback(self)
|
||||
return self._mute_pipeline
|
||||
|
||||
def create_user_idle_handler(self):
|
||||
"""
|
||||
|
|
@ -746,26 +681,10 @@ class PipecatEngine:
|
|||
"""Accumulate LLM text frames to build reference text."""
|
||||
self._current_llm_generation_reference_text += text
|
||||
|
||||
def handle_client_disconnected(self):
|
||||
"""Handle client disconnected event."""
|
||||
self._client_disconnected = True
|
||||
|
||||
def is_call_disposed(self):
|
||||
"""Check whether a call has been disposed by the engine"""
|
||||
return self._call_disposed
|
||||
|
||||
async def get_call_disposition(self) -> Optional[str]:
|
||||
"""Get the disconnect reason set by the engine."""
|
||||
if self._call_disposition:
|
||||
# We would have a _call_disposition variable set if we have initiated
|
||||
# a disconnect from the bot, i.e we have called send_end_task_frame.
|
||||
return self._call_disposition
|
||||
|
||||
if self._client_disconnected:
|
||||
return EndTaskReason.USER_HANGUP.value
|
||||
else:
|
||||
return EndTaskReason.UNKNOWN.value
|
||||
|
||||
async def get_gathered_context(self) -> dict:
|
||||
"""Get the gathered context including extracted variables."""
|
||||
return self._gathered_context.copy()
|
||||
|
|
|
|||
|
|
@ -11,46 +11,19 @@ unit-testing.
|
|||
"""
|
||||
|
||||
import re
|
||||
from typing import TYPE_CHECKING, Awaitable, Callable
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
LLMMessagesAppendFrame,
|
||||
)
|
||||
from pipecat.processors.filters.stt_mute_filter import STTMuteFilter
|
||||
from pipecat.utils.enums import EndTaskReason
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from api.services.workflow.pipecat_engine import PipecatEngine
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# STT mute handling
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def create_should_mute_callback(
|
||||
engine: "PipecatEngine",
|
||||
) -> Callable[[STTMuteFilter], Awaitable[bool]]:
|
||||
"""Return a callback indicating whether STT should be muted.
|
||||
|
||||
STT is muted when *interruptions are **not*** allowed on the current node.
|
||||
"""
|
||||
|
||||
async def callback(_: STTMuteFilter) -> bool: # noqa: D401
|
||||
if engine._current_node is None:
|
||||
# Default to not muting if we have no active node yet.
|
||||
return False
|
||||
|
||||
logger.debug(
|
||||
f"STT mute callback: allow_interrupt={engine._current_node.allow_interrupt}"
|
||||
)
|
||||
return not engine._current_node.allow_interrupt
|
||||
|
||||
return callback
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# User-idle handling
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -85,7 +58,7 @@ class UserIdleHandler:
|
|||
"content": "The user has been quiet. We will be disconnecting the call now. Wish them a good day in the language that the user has been speaking so far.",
|
||||
}
|
||||
await aggregator.push_frame(LLMMessagesAppendFrame([message], run_llm=True))
|
||||
await self._engine.send_end_task_frame(
|
||||
await self._engine.end_call_with_reason(
|
||||
EndTaskReason.USER_IDLE_MAX_DURATION_EXCEEDED.value
|
||||
)
|
||||
|
||||
|
|
@ -105,7 +78,7 @@ def create_max_duration_callback(engine: "PipecatEngine"):
|
|||
|
||||
async def handle_max_duration():
|
||||
logger.debug("Max call duration exceeded. Terminating call")
|
||||
await engine.send_end_task_frame(EndTaskReason.CALL_DURATION_EXCEEDED.value)
|
||||
await engine.end_call_with_reason(EndTaskReason.CALL_DURATION_EXCEEDED.value)
|
||||
|
||||
return handle_max_duration
|
||||
|
||||
|
|
|
|||
|
|
@ -23,15 +23,12 @@ from api.services.workflow.tools.custom_tool import (
|
|||
from pipecat.adapters.schemas.function_schema import FunctionSchema
|
||||
from pipecat.frames.frames import FunctionCallResultProperties, TTSSpeakFrame
|
||||
from pipecat.services.llm_service import FunctionCallParams
|
||||
from pipecat.utils.enums import EndTaskReason
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from api.services.workflow.pipecat_engine import PipecatEngine
|
||||
|
||||
|
||||
# End task reason for end call tool
|
||||
END_CALL_TOOL_REASON = "end_call_tool"
|
||||
|
||||
|
||||
class CustomToolManager:
|
||||
"""Manager for custom tool registration and execution.
|
||||
|
||||
|
|
@ -214,21 +211,22 @@ class CustomToolManager:
|
|||
logger.info(f"Playing custom goodbye message: {custom_message}")
|
||||
await self._engine.task.queue_frame(TTSSpeakFrame(custom_message))
|
||||
# End the call after the message (not immediately)
|
||||
await self._engine.send_end_task_frame(
|
||||
END_CALL_TOOL_REASON, abort_immediately=False
|
||||
await self._engine.end_call_with_reason(
|
||||
EndTaskReason.END_CALL_TOOL_REASON.value,
|
||||
abort_immediately=False,
|
||||
)
|
||||
else:
|
||||
# No message - end call immediately
|
||||
logger.info("Ending call immediately (no goodbye message)")
|
||||
await self._engine.send_end_task_frame(
|
||||
END_CALL_TOOL_REASON, abort_immediately=True
|
||||
await self._engine.end_call_with_reason(
|
||||
EndTaskReason.END_CALL_TOOL_REASON.value, abort_immediately=True
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"End call tool '{function_name}' execution failed: {e}")
|
||||
# Still try to end the call even if there's an error
|
||||
await self._engine.send_end_task_frame(
|
||||
END_CALL_TOOL_REASON, abort_immediately=True
|
||||
await self._engine.end_call_with_reason(
|
||||
EndTaskReason.UNEXPECTED_ERROR.value, abort_immediately=True
|
||||
)
|
||||
|
||||
return end_call_handler
|
||||
|
|
|
|||
|
|
@ -6,27 +6,20 @@ import pytest
|
|||
|
||||
from api.services.workflow.dto import (
|
||||
EdgeDataDTO,
|
||||
ExtractionVariableDTO,
|
||||
NodeDataDTO,
|
||||
NodeType,
|
||||
Position,
|
||||
ReactFlowDTO,
|
||||
RFEdgeDTO,
|
||||
RFNodeDTO,
|
||||
VariableType,
|
||||
)
|
||||
from api.services.workflow.workflow import WorkflowGraph
|
||||
from pipecat.frames.frames import (
|
||||
BotSpeakingFrame,
|
||||
BotStartedSpeakingFrame,
|
||||
BotStoppedSpeakingFrame,
|
||||
Frame,
|
||||
TTSAudioRawFrame,
|
||||
TTSStartedFrame,
|
||||
TTSStoppedFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
|
||||
START_CALL_SYSTEM_PROMPT = "start_call_system_prompt"
|
||||
END_CALL_SYSTEM_PROMPT = "end_call_system_prompt"
|
||||
START_CALL_SYSTEM_PROMPT = "Start Call System Prompt"
|
||||
AGENT_SYSTEM_PROMPT = "Agent Node System Prompt"
|
||||
END_CALL_SYSTEM_PROMPT = "End Call System Prompt"
|
||||
|
||||
# Default workflow definition for mocking database WorkflowModel
|
||||
DEFAULT_WORKFLOW_DEFINITION = {
|
||||
|
|
@ -110,57 +103,6 @@ class MockUserConfig:
|
|||
embeddings: Optional[Any] = None
|
||||
|
||||
|
||||
class MockTransportProcessor(FrameProcessor):
|
||||
"""
|
||||
Mocks the transport behavior by emitting Bot speaking frames
|
||||
when it encounters TTS frames.
|
||||
|
||||
This simulates what a real transport would do when the bot is speaking:
|
||||
- TTSStartedFrame -> BotStartedSpeakingFrame
|
||||
- TTSAudioRawFrame -> BotSpeakingFrame
|
||||
- TTSStoppedFrame -> BotStoppedSpeakingFrame
|
||||
|
||||
Args:
|
||||
emit_bot_speaking: If True, also emits BotSpeakingFrame on TTSAudioRawFrame
|
||||
which is needed for user idle tracking to start conversation tracking. Default True.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
emit_bot_speaking: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self._emit_bot_speaking = emit_bot_speaking
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, TTSStartedFrame):
|
||||
# Emit BotStartedSpeakingFrame to indicate bot started speaking
|
||||
await self.push_frame(BotStartedSpeakingFrame())
|
||||
await self.push_frame(
|
||||
BotStartedSpeakingFrame(), direction=FrameDirection.UPSTREAM
|
||||
)
|
||||
elif isinstance(frame, TTSAudioRawFrame):
|
||||
# Emit BotSpeakingFrame - this is what triggers user idle tracking
|
||||
# to start conversation tracking
|
||||
if self._emit_bot_speaking:
|
||||
await self.push_frame(BotSpeakingFrame())
|
||||
await self.push_frame(
|
||||
BotSpeakingFrame(), direction=FrameDirection.UPSTREAM
|
||||
)
|
||||
elif isinstance(frame, TTSStoppedFrame):
|
||||
# Emit BotStoppedSpeakingFrame to indicate bot stopped speaking
|
||||
await self.push_frame(BotStoppedSpeakingFrame())
|
||||
await self.push_frame(
|
||||
BotStoppedSpeakingFrame(), direction=FrameDirection.UPSTREAM
|
||||
)
|
||||
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MockToolModel:
|
||||
"""Mock tool model for testing."""
|
||||
|
|
@ -299,14 +241,14 @@ def simple_workflow() -> WorkflowGraph:
|
|||
"""Create a simple two-node workflow for testing.
|
||||
|
||||
The workflow has:
|
||||
- Start node with a prompt
|
||||
- Start node with extraction enabled (extracts user_intent)
|
||||
- End node with a prompt
|
||||
- One edge connecting them with label "End Call"
|
||||
"""
|
||||
dto = ReactFlowDTO(
|
||||
nodes=[
|
||||
RFNodeDTO(
|
||||
id="1",
|
||||
id="start",
|
||||
type=NodeType.startNode,
|
||||
position=Position(x=0, y=0),
|
||||
data=NodeDataDTO(
|
||||
|
|
@ -315,10 +257,19 @@ def simple_workflow() -> WorkflowGraph:
|
|||
is_start=True,
|
||||
allow_interrupt=False,
|
||||
add_global_prompt=False,
|
||||
extraction_enabled=True,
|
||||
extraction_prompt="Extract user information from the conversation.",
|
||||
extraction_variables=[
|
||||
ExtractionVariableDTO(
|
||||
name="user_intent",
|
||||
type=VariableType.string,
|
||||
prompt="The user's intent or reason for calling",
|
||||
),
|
||||
],
|
||||
),
|
||||
),
|
||||
RFNodeDTO(
|
||||
id="2",
|
||||
id="end",
|
||||
type=NodeType.endNode,
|
||||
position=Position(x=0, y=200),
|
||||
data=NodeDataDTO(
|
||||
|
|
@ -327,14 +278,15 @@ def simple_workflow() -> WorkflowGraph:
|
|||
is_end=True,
|
||||
allow_interrupt=False,
|
||||
add_global_prompt=False,
|
||||
extraction_enabled=False,
|
||||
),
|
||||
),
|
||||
],
|
||||
edges=[
|
||||
RFEdgeDTO(
|
||||
id="1-2",
|
||||
source="1",
|
||||
target="2",
|
||||
id="start-end",
|
||||
source="start",
|
||||
target="end",
|
||||
data=EdgeDataDTO(
|
||||
label="End Call",
|
||||
condition="When the user says to end the call, end the call",
|
||||
|
|
@ -350,37 +302,59 @@ def three_node_workflow() -> WorkflowGraph:
|
|||
"""Create a three-node workflow for testing with an intermediate agent node.
|
||||
|
||||
The workflow has:
|
||||
- Start node
|
||||
- Agent node (for collecting information)
|
||||
- End node
|
||||
- Start node with extraction enabled (extracts greeting_type)
|
||||
- Agent node with extraction enabled (extracts user_name)
|
||||
- End node (no extraction)
|
||||
|
||||
Edges:
|
||||
- Start -> Agent (label: "Collect Info")
|
||||
- Agent -> End (label: "End Call")
|
||||
"""
|
||||
dto = ReactFlowDTO(
|
||||
nodes=[
|
||||
RFNodeDTO(
|
||||
id="1",
|
||||
id="start",
|
||||
type=NodeType.startNode,
|
||||
position=Position(x=0, y=0),
|
||||
data=NodeDataDTO(
|
||||
name="Start Call",
|
||||
prompt=START_CALL_SYSTEM_PROMPT,
|
||||
is_start=True,
|
||||
allow_interrupt=True,
|
||||
allow_interrupt=False,
|
||||
add_global_prompt=False,
|
||||
extraction_enabled=True,
|
||||
extraction_prompt="Extract greeting information from the conversation.",
|
||||
extraction_variables=[
|
||||
ExtractionVariableDTO(
|
||||
name="greeting_type",
|
||||
type=VariableType.string,
|
||||
prompt="The type of greeting used",
|
||||
),
|
||||
],
|
||||
),
|
||||
),
|
||||
RFNodeDTO(
|
||||
id="2",
|
||||
id="agent",
|
||||
type=NodeType.agentNode,
|
||||
position=Position(x=0, y=200),
|
||||
data=NodeDataDTO(
|
||||
name="Collect Info",
|
||||
prompt="Help the user with their request. Ask clarifying questions if needed.",
|
||||
allow_interrupt=True,
|
||||
prompt=AGENT_SYSTEM_PROMPT,
|
||||
allow_interrupt=False,
|
||||
add_global_prompt=False,
|
||||
extraction_enabled=True,
|
||||
extraction_prompt="Extract user details from the conversation.",
|
||||
extraction_variables=[
|
||||
ExtractionVariableDTO(
|
||||
name="user_name",
|
||||
type=VariableType.string,
|
||||
prompt="The user's name",
|
||||
),
|
||||
],
|
||||
),
|
||||
),
|
||||
RFNodeDTO(
|
||||
id="3",
|
||||
id="end",
|
||||
type=NodeType.endNode,
|
||||
position=Position(x=0, y=400),
|
||||
data=NodeDataDTO(
|
||||
|
|
@ -389,26 +363,187 @@ def three_node_workflow() -> WorkflowGraph:
|
|||
is_end=True,
|
||||
allow_interrupt=False,
|
||||
add_global_prompt=False,
|
||||
extraction_enabled=False,
|
||||
),
|
||||
),
|
||||
],
|
||||
edges=[
|
||||
RFEdgeDTO(
|
||||
id="1-2",
|
||||
source="1",
|
||||
target="2",
|
||||
id="start-agent",
|
||||
source="start",
|
||||
target="agent",
|
||||
data=EdgeDataDTO(
|
||||
label="Collect Info",
|
||||
condition="When the user wants help, collect their information",
|
||||
condition="When user has been greeted, proceed to collect information",
|
||||
),
|
||||
),
|
||||
RFEdgeDTO(
|
||||
id="2-3",
|
||||
source="2",
|
||||
target="3",
|
||||
id="agent-end",
|
||||
source="agent",
|
||||
target="end",
|
||||
data=EdgeDataDTO(
|
||||
label="End Call",
|
||||
condition="When the user is done or wants to end the call",
|
||||
condition="When information collection is complete, end the call",
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
return WorkflowGraph(dto)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def three_node_workflow_extraction_start_only() -> WorkflowGraph:
|
||||
"""Create a three-node workflow with extraction enabled ONLY on start node.
|
||||
|
||||
This fixture is specifically for testing that variable extraction is triggered
|
||||
for the correct node during transitions. The agent node has extraction disabled
|
||||
to verify extraction happens for the SOURCE node, not the TARGET node.
|
||||
|
||||
The workflow has:
|
||||
- Start node with extraction enabled (extracts user_name)
|
||||
- Agent node with extraction DISABLED
|
||||
- End node (no extraction)
|
||||
"""
|
||||
dto = ReactFlowDTO(
|
||||
nodes=[
|
||||
RFNodeDTO(
|
||||
id="start",
|
||||
type=NodeType.startNode,
|
||||
position=Position(x=0, y=0),
|
||||
data=NodeDataDTO(
|
||||
name="Start Call",
|
||||
prompt=START_CALL_SYSTEM_PROMPT,
|
||||
is_start=True,
|
||||
allow_interrupt=False,
|
||||
add_global_prompt=False,
|
||||
extraction_enabled=True,
|
||||
extraction_prompt="Extract the user's name from the conversation.",
|
||||
extraction_variables=[
|
||||
ExtractionVariableDTO(
|
||||
name="user_name",
|
||||
type=VariableType.string,
|
||||
prompt="The name the user provided",
|
||||
),
|
||||
],
|
||||
),
|
||||
),
|
||||
RFNodeDTO(
|
||||
id="agent",
|
||||
type=NodeType.agentNode,
|
||||
position=Position(x=0, y=200),
|
||||
data=NodeDataDTO(
|
||||
name="Collect Info",
|
||||
prompt=AGENT_SYSTEM_PROMPT,
|
||||
allow_interrupt=False,
|
||||
add_global_prompt=False,
|
||||
extraction_enabled=False, # Explicitly disabled for testing
|
||||
),
|
||||
),
|
||||
RFNodeDTO(
|
||||
id="end",
|
||||
type=NodeType.endNode,
|
||||
position=Position(x=0, y=400),
|
||||
data=NodeDataDTO(
|
||||
name="End Call",
|
||||
prompt=END_CALL_SYSTEM_PROMPT,
|
||||
is_end=True,
|
||||
allow_interrupt=False,
|
||||
add_global_prompt=False,
|
||||
extraction_enabled=False,
|
||||
),
|
||||
),
|
||||
],
|
||||
edges=[
|
||||
RFEdgeDTO(
|
||||
id="start-agent",
|
||||
source="start",
|
||||
target="agent",
|
||||
data=EdgeDataDTO(
|
||||
label="Collect Info",
|
||||
condition="When user has been greeted, proceed to collect information",
|
||||
),
|
||||
),
|
||||
RFEdgeDTO(
|
||||
id="agent-end",
|
||||
source="agent",
|
||||
target="end",
|
||||
data=EdgeDataDTO(
|
||||
label="End Call",
|
||||
condition="When information collection is complete, end the call",
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
return WorkflowGraph(dto)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def three_node_workflow_no_variable_extraction() -> WorkflowGraph:
|
||||
"""Create a three-node workflow without variable extraction
|
||||
|
||||
The workflow has:
|
||||
- Start node with extraction DISABLED
|
||||
- Agent node with extraction DISABLED
|
||||
- End node (no extraction)
|
||||
"""
|
||||
dto = ReactFlowDTO(
|
||||
nodes=[
|
||||
RFNodeDTO(
|
||||
id="start",
|
||||
type=NodeType.startNode,
|
||||
position=Position(x=0, y=0),
|
||||
data=NodeDataDTO(
|
||||
name="Start Call",
|
||||
prompt=START_CALL_SYSTEM_PROMPT,
|
||||
is_start=True,
|
||||
allow_interrupt=False,
|
||||
add_global_prompt=False,
|
||||
extraction_enabled=False,
|
||||
),
|
||||
),
|
||||
RFNodeDTO(
|
||||
id="agent",
|
||||
type=NodeType.agentNode,
|
||||
position=Position(x=0, y=200),
|
||||
data=NodeDataDTO(
|
||||
name="Collect Info",
|
||||
prompt=AGENT_SYSTEM_PROMPT,
|
||||
allow_interrupt=False,
|
||||
add_global_prompt=False,
|
||||
extraction_enabled=False, # Explicitly disabled for testing
|
||||
),
|
||||
),
|
||||
RFNodeDTO(
|
||||
id="end",
|
||||
type=NodeType.endNode,
|
||||
position=Position(x=0, y=400),
|
||||
data=NodeDataDTO(
|
||||
name="End Call",
|
||||
prompt=END_CALL_SYSTEM_PROMPT,
|
||||
is_end=True,
|
||||
allow_interrupt=False,
|
||||
add_global_prompt=False,
|
||||
extraction_enabled=False,
|
||||
),
|
||||
),
|
||||
],
|
||||
edges=[
|
||||
RFEdgeDTO(
|
||||
id="start-agent",
|
||||
source="start",
|
||||
target="agent",
|
||||
data=EdgeDataDTO(
|
||||
label="Collect Info",
|
||||
condition="When user has been greeted, proceed to collect information",
|
||||
),
|
||||
),
|
||||
RFEdgeDTO(
|
||||
id="agent-end",
|
||||
source="agent",
|
||||
target="end",
|
||||
data=EdgeDataDTO(
|
||||
label="End Call",
|
||||
condition="When information collection is complete, end the call",
|
||||
),
|
||||
),
|
||||
],
|
||||
|
|
|
|||
|
|
@ -19,18 +19,13 @@ from unittest.mock import AsyncMock, patch
|
|||
|
||||
import pytest
|
||||
|
||||
from api.services.workflow.dto import (
|
||||
EdgeDataDTO,
|
||||
NodeDataDTO,
|
||||
NodeType,
|
||||
Position,
|
||||
ReactFlowDTO,
|
||||
RFEdgeDTO,
|
||||
RFNodeDTO,
|
||||
)
|
||||
from api.services.workflow.pipecat_engine import PipecatEngine
|
||||
from api.services.workflow.workflow import WorkflowGraph
|
||||
from api.tests.conftest import MockTransportProcessor
|
||||
from api.tests.conftest import (
|
||||
AGENT_SYSTEM_PROMPT,
|
||||
END_CALL_SYSTEM_PROMPT,
|
||||
START_CALL_SYSTEM_PROMPT,
|
||||
)
|
||||
from pipecat.frames.frames import LLMContextFrame
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
|
|
@ -41,86 +36,7 @@ from pipecat.processors.aggregators.llm_response_universal import (
|
|||
LLMContextAggregatorPair,
|
||||
)
|
||||
from pipecat.tests import MockLLMService, MockTTSService
|
||||
|
||||
# Define prompts for test nodes
|
||||
START_NODE_PROMPT = "Start Node System Prompt"
|
||||
AGENT_NODE_PROMPT = "Agent Node System Prompt"
|
||||
END_NODE_PROMPT = "End Node System Prompt"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def three_node_workflow_for_context_test() -> WorkflowGraph:
|
||||
"""Create a three-node workflow for testing context updates during transitions.
|
||||
|
||||
The workflow has:
|
||||
- Start node with prompt to greet user
|
||||
- Agent node with prompt to collect information
|
||||
- End node with prompt to say goodbye
|
||||
|
||||
Edges:
|
||||
- Start -> Agent (label: "Collect Info")
|
||||
- Agent -> End (label: "End Call")
|
||||
"""
|
||||
dto = ReactFlowDTO(
|
||||
nodes=[
|
||||
RFNodeDTO(
|
||||
id="start",
|
||||
type=NodeType.startNode,
|
||||
position=Position(x=0, y=0),
|
||||
data=NodeDataDTO(
|
||||
name="Start Call",
|
||||
prompt=START_NODE_PROMPT,
|
||||
is_start=True,
|
||||
allow_interrupt=False,
|
||||
add_global_prompt=False,
|
||||
),
|
||||
),
|
||||
RFNodeDTO(
|
||||
id="agent",
|
||||
type=NodeType.agentNode,
|
||||
position=Position(x=0, y=200),
|
||||
data=NodeDataDTO(
|
||||
name="Collect Info",
|
||||
prompt=AGENT_NODE_PROMPT,
|
||||
allow_interrupt=False,
|
||||
add_global_prompt=False,
|
||||
),
|
||||
),
|
||||
RFNodeDTO(
|
||||
id="end",
|
||||
type=NodeType.endNode,
|
||||
position=Position(x=0, y=400),
|
||||
data=NodeDataDTO(
|
||||
name="End Call",
|
||||
prompt=END_NODE_PROMPT,
|
||||
is_end=True,
|
||||
allow_interrupt=False,
|
||||
add_global_prompt=False,
|
||||
),
|
||||
),
|
||||
],
|
||||
edges=[
|
||||
RFEdgeDTO(
|
||||
id="start-agent",
|
||||
source="start",
|
||||
target="agent",
|
||||
data=EdgeDataDTO(
|
||||
label="Collect Info",
|
||||
condition="When user has been greeted, proceed to collect information",
|
||||
),
|
||||
),
|
||||
RFEdgeDTO(
|
||||
id="agent-end",
|
||||
source="agent",
|
||||
target="end",
|
||||
data=EdgeDataDTO(
|
||||
label="End Call",
|
||||
condition="When information collection is complete, end the call",
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
return WorkflowGraph(dto)
|
||||
from pipecat.tests.mock_transport import MockTransport
|
||||
|
||||
|
||||
class ContextCapturingMockLLM(MockLLMService):
|
||||
|
|
@ -215,7 +131,8 @@ async def run_pipeline_and_capture_context(
|
|||
# Create MockTTSService
|
||||
tts = MockTTSService(mock_audio_duration_ms=10, frame_delay=0)
|
||||
|
||||
mock_transport_emulator = MockTransportProcessor(emit_bot_speaking=False)
|
||||
# Create MockTransport for simulating transport behavior
|
||||
mock_transport = MockTransport(emit_bot_speaking=False)
|
||||
|
||||
# Create LLM context
|
||||
context = LLMContext()
|
||||
|
|
@ -251,15 +168,13 @@ async def run_pipeline_and_capture_context(
|
|||
[
|
||||
llm,
|
||||
tts,
|
||||
mock_transport_emulator,
|
||||
mock_transport.output(),
|
||||
assistant_context_aggregator,
|
||||
]
|
||||
)
|
||||
|
||||
# Create pipeline task
|
||||
task = PipelineTask(
|
||||
pipeline, params=PipelineParams(allow_interruptions=False), enable_rtvi=False
|
||||
)
|
||||
task = PipelineTask(pipeline, params=PipelineParams(), enable_rtvi=False)
|
||||
|
||||
engine.set_task(task)
|
||||
|
||||
|
|
@ -294,7 +209,7 @@ class TestContextUpdateBeforeNextCompletion:
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_single_transition_updates_context_before_next_completion(
|
||||
self, three_node_workflow_for_context_test: WorkflowGraph
|
||||
self, three_node_workflow: WorkflowGraph
|
||||
):
|
||||
"""Test that a single transition function call updates context before next LLM generation.
|
||||
|
||||
|
|
@ -329,7 +244,7 @@ class TestContextUpdateBeforeNextCompletion:
|
|||
mock_steps = [step_0_chunks, step_1_chunks, step_2_chunks]
|
||||
|
||||
llm, _ = await run_pipeline_and_capture_context(
|
||||
workflow=three_node_workflow_for_context_test,
|
||||
workflow=three_node_workflow,
|
||||
mock_steps=mock_steps,
|
||||
set_node_delay=0.05, # Introduce 50ms delay in set_node
|
||||
)
|
||||
|
|
@ -341,17 +256,17 @@ class TestContextUpdateBeforeNextCompletion:
|
|||
|
||||
# Verify step 0 (start node) had start node's system prompt
|
||||
step_0_prompt = llm.get_system_prompt_at_step(0)
|
||||
assert START_NODE_PROMPT in step_0_prompt, (
|
||||
assert START_CALL_SYSTEM_PROMPT in step_0_prompt, (
|
||||
f"Step 0 should have start node prompt, got: {step_0_prompt[:100]}"
|
||||
)
|
||||
|
||||
# Verify step 1 (agent node) had:
|
||||
# 1. The agent node's system prompt (not start node's)
|
||||
step_1_prompt = llm.get_system_prompt_at_step(1)
|
||||
assert AGENT_NODE_PROMPT in step_1_prompt, (
|
||||
assert AGENT_SYSTEM_PROMPT in step_1_prompt, (
|
||||
f"Step 1 should have agent node prompt, got: {step_1_prompt[:100]}"
|
||||
)
|
||||
assert START_NODE_PROMPT not in step_1_prompt, (
|
||||
assert START_CALL_SYSTEM_PROMPT not in step_1_prompt, (
|
||||
"Step 1 should NOT have start node prompt anymore"
|
||||
)
|
||||
|
||||
|
|
@ -371,7 +286,7 @@ class TestContextUpdateBeforeNextCompletion:
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sequential_transitions_maintain_correct_context(
|
||||
self, three_node_workflow_for_context_test: WorkflowGraph
|
||||
self, three_node_workflow: WorkflowGraph
|
||||
):
|
||||
"""Test that sequential transitions maintain correct context at each step.
|
||||
|
||||
|
|
@ -403,7 +318,7 @@ class TestContextUpdateBeforeNextCompletion:
|
|||
mock_steps = [step_0_chunks, step_1_chunks, step_2_chunks]
|
||||
|
||||
llm, _ = await run_pipeline_and_capture_context(
|
||||
workflow=three_node_workflow_for_context_test,
|
||||
workflow=three_node_workflow,
|
||||
mock_steps=mock_steps,
|
||||
set_node_delay=0.05, # Introduce 50ms delay in set_node
|
||||
)
|
||||
|
|
@ -414,13 +329,13 @@ class TestContextUpdateBeforeNextCompletion:
|
|||
)
|
||||
|
||||
# Step 0: Start node - should have start prompt
|
||||
assert START_NODE_PROMPT in llm.get_system_prompt_at_step(0)
|
||||
assert START_CALL_SYSTEM_PROMPT in llm.get_system_prompt_at_step(0)
|
||||
|
||||
# Step 1: Agent node - should have agent prompt
|
||||
assert AGENT_NODE_PROMPT in llm.get_system_prompt_at_step(1)
|
||||
assert AGENT_SYSTEM_PROMPT in llm.get_system_prompt_at_step(1)
|
||||
|
||||
# Step 2: End node - should have end prompt
|
||||
assert END_NODE_PROMPT in llm.get_system_prompt_at_step(2)
|
||||
assert END_CALL_SYSTEM_PROMPT in llm.get_system_prompt_at_step(2)
|
||||
|
||||
# Verify each subsequent step has the previous tool results
|
||||
step_1_ctx = llm.get_context_at_step(1)
|
||||
|
|
@ -445,7 +360,7 @@ class TestContextUpdateBeforeNextCompletion:
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_context_messages_preserve_conversation_history(
|
||||
self, three_node_workflow_for_context_test: WorkflowGraph
|
||||
self, three_node_workflow: WorkflowGraph
|
||||
):
|
||||
"""Test that conversation history is preserved across node transitions.
|
||||
|
||||
|
|
@ -474,7 +389,7 @@ class TestContextUpdateBeforeNextCompletion:
|
|||
mock_steps = [step_0_chunks, step_1_chunks, step_2_chunks]
|
||||
|
||||
llm, _ = await run_pipeline_and_capture_context(
|
||||
workflow=three_node_workflow_for_context_test,
|
||||
workflow=three_node_workflow,
|
||||
mock_steps=mock_steps,
|
||||
)
|
||||
|
||||
|
|
|
|||
1097
api/tests/test_pipecat_engine_end_call.py
Normal file
1097
api/tests/test_pipecat_engine_end_call.py
Normal file
File diff suppressed because it is too large
Load diff
393
api/tests/test_pipecat_engine_node_switch_with_user_speech.py
Normal file
393
api/tests/test_pipecat_engine_node_switch_with_user_speech.py
Normal file
|
|
@ -0,0 +1,393 @@
|
|||
"""Tests for verifying behavior when node switch and user speech happen simultaneously.
|
||||
|
||||
This module tests the interaction between node transitions and user speaking events
|
||||
in the PipecatEngine. The key scenario being tested:
|
||||
|
||||
1. LLM calls a transition function to move from one node to another
|
||||
2. At the same time, user starts and stops speaking (triggered by FunctionCallResultFrame)
|
||||
3. The pipeline should handle both events correctly
|
||||
|
||||
The tests use a custom input transport that injects UserStartedSpeakingFrame and
|
||||
UserStoppedSpeakingFrame when triggered by a FunctionCallResultFrame observer.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import Optional
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from api.services.workflow.pipecat_engine import PipecatEngine
|
||||
from api.services.workflow.workflow import WorkflowGraph
|
||||
from pipecat.frames.frames import (
|
||||
CancelFrame,
|
||||
EndFrame,
|
||||
Frame,
|
||||
FunctionCallResultFrame,
|
||||
InputAudioRawFrame,
|
||||
LLMContextFrame,
|
||||
StartFrame,
|
||||
TranscriptionFrame,
|
||||
UserStartedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
)
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.llm_response import LLMAssistantAggregatorParams
|
||||
from pipecat.processors.aggregators.llm_response_universal import (
|
||||
LLMContextAggregatorPair,
|
||||
LLMUserAggregatorParams,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
from pipecat.tests import MockLLMService, MockTTSService
|
||||
from pipecat.tests.mock_transport import MockOutputTransport
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
from pipecat.turns.user_mute import (
|
||||
CallbackUserMuteStrategy,
|
||||
MuteUntilFirstBotCompleteUserMuteStrategy,
|
||||
)
|
||||
from pipecat.turns.user_start import (
|
||||
TranscriptionUserTurnStartStrategy,
|
||||
)
|
||||
from pipecat.turns.user_stop import (
|
||||
ExternalUserTurnStopStrategy,
|
||||
)
|
||||
from pipecat.turns.user_turn_strategies import UserTurnStrategies
|
||||
from pipecat.utils.time import time_now_iso8601
|
||||
|
||||
|
||||
class UserSpeechInjectingInputTransport(FrameProcessor):
|
||||
"""Mock input transport that injects user speaking frames on FunctionCallResultFrame.
|
||||
|
||||
This transport generates audio frames and automatically injects UserStartedSpeakingFrame
|
||||
and UserStoppedSpeakingFrame when it sees the first FunctionCallResultFrame flowing
|
||||
upstream through the pipeline.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
params: Optional[TransportParams] = None,
|
||||
*,
|
||||
generate_audio: bool = False,
|
||||
audio_interval_ms: int = 20,
|
||||
sample_rate: int = 16000,
|
||||
num_channels: int = 1,
|
||||
user_speech_initial_delay: float = 0.01,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self._params = params or TransportParams()
|
||||
self._generate_audio = generate_audio
|
||||
self._audio_interval_ms = audio_interval_ms
|
||||
self._sample_rate = sample_rate
|
||||
self._num_channels = num_channels
|
||||
self._user_speech_initial_delay = user_speech_initial_delay
|
||||
self._audio_task: Optional[asyncio.Task] = None
|
||||
self._running = False
|
||||
self._function_call_result_count = 0
|
||||
|
||||
async def _generate_audio_frames(self):
|
||||
"""Generate audio frames at regular intervals."""
|
||||
samples_per_frame = int(self._sample_rate * self._audio_interval_ms / 1000)
|
||||
bytes_per_frame = samples_per_frame * self._num_channels * 2
|
||||
silence_audio = bytes(bytes_per_frame)
|
||||
|
||||
while self._running:
|
||||
try:
|
||||
frame = InputAudioRawFrame(
|
||||
audio=silence_audio,
|
||||
sample_rate=self._sample_rate,
|
||||
num_channels=self._num_channels,
|
||||
)
|
||||
await self.push_frame(frame)
|
||||
await asyncio.sleep(self._audio_interval_ms / 1000)
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
|
||||
def _start_tasks(self):
|
||||
"""Start audio generation task."""
|
||||
if not self._running:
|
||||
self._running = True
|
||||
if self._generate_audio:
|
||||
self._audio_task = asyncio.create_task(self._generate_audio_frames())
|
||||
|
||||
def _stop_tasks(self):
|
||||
"""Stop all background tasks."""
|
||||
self._running = False
|
||||
if self._audio_task and not self._audio_task.done():
|
||||
self._audio_task.cancel()
|
||||
self._audio_task = None
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, StartFrame):
|
||||
self._start_tasks()
|
||||
elif isinstance(frame, (EndFrame, CancelFrame)):
|
||||
self._stop_tasks()
|
||||
elif isinstance(frame, FunctionCallResultFrame):
|
||||
# When we see FunctionCallResultFrame #1 flowing upstream,
|
||||
# inject user speaking frames downstream
|
||||
self._function_call_result_count += 1
|
||||
if self._function_call_result_count == 1:
|
||||
# Simulate first race condition to generate
|
||||
# LLM call close enough to the LLM call from
|
||||
# function call
|
||||
await asyncio.sleep(self._user_speech_initial_delay)
|
||||
await self.push_frame(UserStartedSpeakingFrame())
|
||||
|
||||
await asyncio.sleep(0)
|
||||
await self.push_frame(
|
||||
TranscriptionFrame("First User Speech", "abc", time_now_iso8601())
|
||||
)
|
||||
|
||||
await asyncio.sleep(0)
|
||||
await self.push_frame(UserStoppedSpeakingFrame())
|
||||
|
||||
# Generate second llm call
|
||||
await asyncio.sleep(0.1)
|
||||
await self.push_frame(UserStartedSpeakingFrame())
|
||||
|
||||
await asyncio.sleep(0)
|
||||
await self.push_frame(
|
||||
TranscriptionFrame("Second User Speech", "abc", time_now_iso8601())
|
||||
)
|
||||
|
||||
await asyncio.sleep(0)
|
||||
await self.push_frame(UserStoppedSpeakingFrame())
|
||||
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
async def cleanup(self):
|
||||
self._stop_tasks()
|
||||
await super().cleanup()
|
||||
|
||||
|
||||
class UserSpeechInjectingTransport(BaseTransport):
|
||||
"""Transport that injects user speaking frames on first FunctionCallResultFrame."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
params: Optional[TransportParams] = None,
|
||||
*,
|
||||
input_name: Optional[str] = None,
|
||||
output_name: Optional[str] = None,
|
||||
emit_bot_speaking: bool = True,
|
||||
generate_audio: bool = False,
|
||||
audio_interval_ms: int = 20,
|
||||
audio_sample_rate: int = 16000,
|
||||
audio_num_channels: int = 1,
|
||||
user_speech_initial_delay: float = 0.01,
|
||||
):
|
||||
super().__init__(input_name=input_name, output_name=output_name)
|
||||
self._params = params or TransportParams()
|
||||
self._input = UserSpeechInjectingInputTransport(
|
||||
self._params,
|
||||
name=self._input_name,
|
||||
generate_audio=generate_audio,
|
||||
audio_interval_ms=audio_interval_ms,
|
||||
sample_rate=audio_sample_rate,
|
||||
num_channels=audio_num_channels,
|
||||
user_speech_initial_delay=user_speech_initial_delay,
|
||||
)
|
||||
self._output = MockOutputTransport(
|
||||
self._params,
|
||||
emit_bot_speaking=emit_bot_speaking,
|
||||
name=self._output_name,
|
||||
)
|
||||
|
||||
def input(self) -> UserSpeechInjectingInputTransport:
|
||||
return self._input
|
||||
|
||||
def output(self) -> FrameProcessor:
|
||||
return self._output
|
||||
|
||||
|
||||
async def create_test_pipeline(
|
||||
workflow: WorkflowGraph,
|
||||
mock_llm: MockLLMService,
|
||||
generate_audio: bool = True,
|
||||
user_speech_initial_delay: float = 0.01,
|
||||
) -> tuple[PipecatEngine, UserSpeechInjectingTransport, PipelineTask]:
|
||||
"""Create a PipecatEngine with full pipeline for testing node switch scenarios.
|
||||
|
||||
The transport's input automatically injects UserStartedSpeakingFrame and
|
||||
UserStoppedSpeakingFrame when it sees the first FunctionCallResultFrame
|
||||
flowing upstream through the pipeline.
|
||||
|
||||
Args:
|
||||
workflow: The workflow graph to use.
|
||||
mock_llm: The mock LLM service.
|
||||
generate_audio: If True, the mock transport generates InputAudioRawFrame
|
||||
every 20ms to simulate real audio input.
|
||||
user_speech_initial_delay: Delay in seconds before injecting
|
||||
UserStartedSpeakingFrame after seeing FunctionCallResultFrame.
|
||||
|
||||
Returns:
|
||||
Tuple of (engine, transport, task)
|
||||
"""
|
||||
# Create MockTTSService
|
||||
tts = MockTTSService(mock_audio_duration_ms=10, frame_delay=0)
|
||||
|
||||
# Create custom transport that injects user speaking frames on FunctionCallResultFrame #1
|
||||
transport = UserSpeechInjectingTransport(
|
||||
generate_audio=generate_audio,
|
||||
audio_interval_ms=20,
|
||||
audio_sample_rate=16000,
|
||||
audio_num_channels=1,
|
||||
emit_bot_speaking=True,
|
||||
user_speech_initial_delay=user_speech_initial_delay,
|
||||
)
|
||||
|
||||
# Create LLM context
|
||||
context = LLMContext()
|
||||
|
||||
# Create PipecatEngine with the workflow
|
||||
engine = PipecatEngine(
|
||||
llm=mock_llm,
|
||||
context=context,
|
||||
workflow=workflow,
|
||||
call_context_vars={"customer_name": "Test User"},
|
||||
workflow_run_id=1,
|
||||
)
|
||||
|
||||
# Create user turn strategies matching run_pipeline.py
|
||||
user_turn_strategies = UserTurnStrategies(
|
||||
start=[TranscriptionUserTurnStartStrategy()],
|
||||
stop=[ExternalUserTurnStopStrategy()],
|
||||
)
|
||||
|
||||
# Create user mute strategies matching run_pipeline.py
|
||||
user_mute_strategies = [
|
||||
MuteUntilFirstBotCompleteUserMuteStrategy(),
|
||||
CallbackUserMuteStrategy(should_mute_callback=engine.should_mute_user),
|
||||
]
|
||||
|
||||
user_params = LLMUserAggregatorParams(
|
||||
user_turn_strategies=user_turn_strategies,
|
||||
user_mute_strategies=user_mute_strategies,
|
||||
)
|
||||
|
||||
# Create context aggregator with user and assistant params
|
||||
assistant_params = LLMAssistantAggregatorParams(expect_stripped_words=True)
|
||||
|
||||
context_aggregator = LLMContextAggregatorPair(
|
||||
context, assistant_params=assistant_params, user_params=user_params
|
||||
)
|
||||
user_context_aggregator = context_aggregator.user()
|
||||
assistant_context_aggregator = context_aggregator.assistant()
|
||||
|
||||
# Create the pipeline:
|
||||
# transport.input() -> user_aggregator -> LLM -> TTS -> transport.output() -> assistant_aggregator
|
||||
# The transport input watches for FunctionCallResultFrame flowing upstream
|
||||
# and injects user speaking frames when it sees the first one
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
transport.input(),
|
||||
user_context_aggregator,
|
||||
mock_llm,
|
||||
tts,
|
||||
transport.output(),
|
||||
assistant_context_aggregator,
|
||||
]
|
||||
)
|
||||
|
||||
# Create pipeline task
|
||||
task = PipelineTask(pipeline, params=PipelineParams(), enable_rtvi=False)
|
||||
|
||||
engine.set_task(task)
|
||||
|
||||
return engine, transport, task
|
||||
|
||||
|
||||
class TestNodeSwitchWithUserSpeech:
|
||||
"""Test scenarios where node switch and user speech happen simultaneously."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"user_speech_initial_delay,scenario_name",
|
||||
[
|
||||
(0.01, "delayed"),
|
||||
(0, "immediate"),
|
||||
],
|
||||
ids=["delayed_user_speech", "immediate_user_speech"],
|
||||
)
|
||||
async def test_node_switch_with_concurrent_user_speech(
|
||||
self,
|
||||
three_node_workflow_no_variable_extraction: WorkflowGraph,
|
||||
user_speech_initial_delay: float,
|
||||
scenario_name: str,
|
||||
):
|
||||
"""Test scenario: node transition happens while user is speaking.
|
||||
|
||||
This test creates the scenario where:
|
||||
1. LLM generates text and calls collect_info to transition from start to agent
|
||||
2. When FunctionCallResultFrame #1 is seen, UserStartedSpeakingFrame and
|
||||
UserStoppedSpeakingFrame are automatically injected from the pipeline source
|
||||
3. The pipeline processes both events concurrently
|
||||
|
||||
The FunctionCallResultObserver in the pipeline detects the first function call
|
||||
result and triggers the transport to inject user speaking frames.
|
||||
|
||||
This test is parameterized with two scenarios:
|
||||
- delayed_user_speech: 10ms delay before UserStartedSpeakingFrame (user_speech_initial_delay=0.01)
|
||||
- immediate_user_speech: No delay before UserStartedSpeakingFrame (user_speech_initial_delay=0)
|
||||
|
||||
This is a scenario creation test - no specific assertions yet.
|
||||
"""
|
||||
# Step 0 (Start node): greet user then call collect_info to transition to agent
|
||||
step_0_chunks = MockLLMService.create_mixed_chunks(
|
||||
text="Hello!",
|
||||
function_name="collect_info",
|
||||
arguments={},
|
||||
tool_call_id="call_transition_1",
|
||||
)
|
||||
|
||||
step_1_chunks = MockLLMService.create_text_chunks(
|
||||
text="Step 1 with some longer text that should cause multiple chunks to be created."
|
||||
)
|
||||
|
||||
step_2_chunks = MockLLMService.create_function_call_chunks(
|
||||
function_name="end_call",
|
||||
arguments={},
|
||||
tool_call_id="call_transition_2",
|
||||
)
|
||||
|
||||
mock_steps = [step_0_chunks, step_1_chunks, step_2_chunks]
|
||||
llm = MockLLMService(mock_steps=mock_steps, chunk_delay=0.001)
|
||||
|
||||
engine, _transport, task = await create_test_pipeline(
|
||||
three_node_workflow_no_variable_extraction,
|
||||
llm,
|
||||
user_speech_initial_delay=user_speech_initial_delay,
|
||||
)
|
||||
|
||||
# Patch DB calls
|
||||
with patch(
|
||||
"api.services.workflow.pipecat_engine.get_organization_id_from_workflow_run",
|
||||
new_callable=AsyncMock,
|
||||
return_value=1,
|
||||
):
|
||||
with patch(
|
||||
"api.services.workflow.pipecat_engine.apply_disposition_mapping",
|
||||
new_callable=AsyncMock,
|
||||
return_value="completed",
|
||||
):
|
||||
runner = PipelineRunner()
|
||||
|
||||
async def run_pipeline():
|
||||
await runner.run(task)
|
||||
|
||||
async def initialize_engine():
|
||||
await asyncio.sleep(0.01)
|
||||
await engine.initialize()
|
||||
# Start the LLM generation - user speech will be injected
|
||||
# automatically when FunctionCallResultFrame #1 is seen
|
||||
await engine.llm.queue_frame(LLMContextFrame(engine.context))
|
||||
|
||||
await asyncio.gather(run_pipeline(), initialize_engine())
|
||||
|
||||
# Total 4 generations out of which 1 was cancelled due to interruption
|
||||
assert llm.get_current_step() == 4
|
||||
|
|
@ -12,7 +12,7 @@ import pytest
|
|||
|
||||
from api.services.workflow.pipecat_engine import PipecatEngine
|
||||
from api.services.workflow.workflow import WorkflowGraph
|
||||
from api.tests.conftest import END_CALL_SYSTEM_PROMPT, MockTransportProcessor
|
||||
from api.tests.conftest import END_CALL_SYSTEM_PROMPT
|
||||
from pipecat.frames.frames import LLMContextFrame
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
|
|
@ -23,6 +23,7 @@ from pipecat.processors.aggregators.llm_response_universal import (
|
|||
LLMContextAggregatorPair,
|
||||
)
|
||||
from pipecat.tests import MockLLMService, MockTTSService
|
||||
from pipecat.tests.mock_transport import MockTransport
|
||||
|
||||
|
||||
async def run_pipeline_with_tool_calls(
|
||||
|
|
@ -65,7 +66,8 @@ async def run_pipeline_with_tool_calls(
|
|||
# Create MockTTSService to generate TTS frames
|
||||
tts = MockTTSService(mock_audio_duration_ms=10, frame_delay=0)
|
||||
|
||||
mock_transport_emulator = MockTransportProcessor(emit_bot_speaking=False)
|
||||
# Create MockTransport for simulating transport behavior
|
||||
mock_transport = MockTransport(emit_bot_speaking=False)
|
||||
|
||||
# Create LLM context
|
||||
context = LLMContext()
|
||||
|
|
@ -91,15 +93,13 @@ async def run_pipeline_with_tool_calls(
|
|||
[
|
||||
llm,
|
||||
tts,
|
||||
mock_transport_emulator,
|
||||
mock_transport.output(),
|
||||
assistant_context_aggregator,
|
||||
]
|
||||
)
|
||||
|
||||
# Create a real pipeline task
|
||||
task = PipelineTask(
|
||||
pipeline, params=PipelineParams(allow_interruptions=False), enable_rtvi=False
|
||||
)
|
||||
task = PipelineTask(pipeline, params=PipelineParams(), enable_rtvi=False)
|
||||
|
||||
engine.set_task(task)
|
||||
|
||||
|
|
|
|||
|
|
@ -17,23 +17,11 @@ from unittest.mock import AsyncMock, patch
|
|||
|
||||
import pytest
|
||||
|
||||
from api.services.workflow.dto import (
|
||||
EdgeDataDTO,
|
||||
ExtractionVariableDTO,
|
||||
NodeDataDTO,
|
||||
NodeType,
|
||||
Position,
|
||||
ReactFlowDTO,
|
||||
RFEdgeDTO,
|
||||
RFNodeDTO,
|
||||
VariableType,
|
||||
)
|
||||
from api.services.workflow.pipecat_engine import PipecatEngine
|
||||
from api.services.workflow.pipecat_engine_variable_extractor import (
|
||||
VariableExtractionManager,
|
||||
)
|
||||
from api.services.workflow.workflow import WorkflowGraph
|
||||
from api.tests.conftest import MockTransportProcessor
|
||||
from pipecat.frames.frames import LLMContextFrame
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
|
|
@ -44,96 +32,7 @@ from pipecat.processors.aggregators.llm_response_universal import (
|
|||
LLMContextAggregatorPair,
|
||||
)
|
||||
from pipecat.tests import MockLLMService, MockTTSService
|
||||
|
||||
# Define prompts for test nodes
|
||||
START_NODE_PROMPT = "Start Node System Prompt"
|
||||
AGENT_NODE_PROMPT = "Agent Node System Prompt"
|
||||
END_NODE_PROMPT = "End Node System Prompt"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def three_node_workflow_with_extraction_on_start() -> WorkflowGraph:
|
||||
"""Create a three-node workflow where only the start node has extraction enabled.
|
||||
|
||||
The workflow has:
|
||||
- Start node with extraction_enabled=True and extraction_variables set
|
||||
- Agent node with extraction_enabled=False (default)
|
||||
- End node with extraction_enabled=False (default)
|
||||
|
||||
This is used to test that variable extraction is triggered for the correct node
|
||||
during transitions.
|
||||
"""
|
||||
dto = ReactFlowDTO(
|
||||
nodes=[
|
||||
RFNodeDTO(
|
||||
id="start",
|
||||
type=NodeType.startNode,
|
||||
position=Position(x=0, y=0),
|
||||
data=NodeDataDTO(
|
||||
name="Start Call",
|
||||
prompt=START_NODE_PROMPT,
|
||||
is_start=True,
|
||||
allow_interrupt=False,
|
||||
add_global_prompt=False,
|
||||
extraction_enabled=True,
|
||||
extraction_prompt="Extract the user's name from the conversation.",
|
||||
extraction_variables=[
|
||||
ExtractionVariableDTO(
|
||||
name="user_name",
|
||||
type=VariableType.string,
|
||||
prompt="The name the user provided",
|
||||
),
|
||||
],
|
||||
),
|
||||
),
|
||||
RFNodeDTO(
|
||||
id="agent",
|
||||
type=NodeType.agentNode,
|
||||
position=Position(x=0, y=200),
|
||||
data=NodeDataDTO(
|
||||
name="Collect Info",
|
||||
prompt=AGENT_NODE_PROMPT,
|
||||
allow_interrupt=False,
|
||||
add_global_prompt=False,
|
||||
extraction_enabled=False, # Explicitly disabled
|
||||
),
|
||||
),
|
||||
RFNodeDTO(
|
||||
id="end",
|
||||
type=NodeType.endNode,
|
||||
position=Position(x=0, y=400),
|
||||
data=NodeDataDTO(
|
||||
name="End Call",
|
||||
prompt=END_NODE_PROMPT,
|
||||
is_end=True,
|
||||
allow_interrupt=False,
|
||||
add_global_prompt=False,
|
||||
extraction_enabled=False, # Explicitly disabled
|
||||
),
|
||||
),
|
||||
],
|
||||
edges=[
|
||||
RFEdgeDTO(
|
||||
id="start-agent",
|
||||
source="start",
|
||||
target="agent",
|
||||
data=EdgeDataDTO(
|
||||
label="Collect Info",
|
||||
condition="When user has been greeted, proceed to collect information",
|
||||
),
|
||||
),
|
||||
RFEdgeDTO(
|
||||
id="agent-end",
|
||||
source="agent",
|
||||
target="end",
|
||||
data=EdgeDataDTO(
|
||||
label="End Call",
|
||||
condition="When information collection is complete, end the call",
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
return WorkflowGraph(dto)
|
||||
from pipecat.tests.mock_transport import MockTransport
|
||||
|
||||
|
||||
class TestVariableExtractionDuringTransitions:
|
||||
|
|
@ -141,7 +40,7 @@ class TestVariableExtractionDuringTransitions:
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extraction_called_for_source_node_not_target_node(
|
||||
self, three_node_workflow_with_extraction_on_start: WorkflowGraph
|
||||
self, three_node_workflow_extraction_start_only: WorkflowGraph
|
||||
):
|
||||
"""Test that when transitioning from START to AGENT, extraction is called for START node.
|
||||
|
||||
|
|
@ -183,7 +82,8 @@ class TestVariableExtractionDuringTransitions:
|
|||
# Create MockTTSService
|
||||
tts = MockTTSService(mock_audio_duration_ms=10, frame_delay=0)
|
||||
|
||||
mock_transport_emulator = MockTransportProcessor(emit_bot_speaking=False)
|
||||
# Create MockTransport for simulating transport behavior
|
||||
mock_transport = MockTransport(emit_bot_speaking=False)
|
||||
|
||||
# Create LLM context
|
||||
context = LLMContext()
|
||||
|
|
@ -195,7 +95,7 @@ class TestVariableExtractionDuringTransitions:
|
|||
)
|
||||
assistant_context_aggregator = context_aggregator.assistant()
|
||||
|
||||
workflow = three_node_workflow_with_extraction_on_start
|
||||
workflow = three_node_workflow_extraction_start_only
|
||||
|
||||
# Create PipecatEngine with the workflow
|
||||
engine = PipecatEngine(
|
||||
|
|
@ -209,7 +109,7 @@ class TestVariableExtractionDuringTransitions:
|
|||
# Patch _perform_variable_extraction_if_needed to track calls
|
||||
original_perform_extraction = engine._perform_variable_extraction_if_needed
|
||||
|
||||
async def tracked_perform_extraction(node):
|
||||
async def tracked_perform_extraction(node, run_in_background=True):
|
||||
extraction_calls.append(
|
||||
{
|
||||
"node_id": node.id if node else None,
|
||||
|
|
@ -228,7 +128,7 @@ class TestVariableExtractionDuringTransitions:
|
|||
[
|
||||
llm,
|
||||
tts,
|
||||
mock_transport_emulator,
|
||||
mock_transport.output(),
|
||||
assistant_context_aggregator,
|
||||
]
|
||||
)
|
||||
|
|
@ -236,7 +136,7 @@ class TestVariableExtractionDuringTransitions:
|
|||
# Create pipeline task
|
||||
task = PipelineTask(
|
||||
pipeline,
|
||||
params=PipelineParams(allow_interruptions=False),
|
||||
params=PipelineParams(),
|
||||
enable_rtvi=False,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -14,7 +14,6 @@ import pytest
|
|||
|
||||
from api.services.workflow.pipecat_engine import PipecatEngine
|
||||
from api.services.workflow.workflow import WorkflowGraph
|
||||
from api.tests.conftest import MockTransportProcessor
|
||||
from pipecat.frames.frames import LLMContextFrame
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
|
|
@ -26,6 +25,7 @@ from pipecat.processors.aggregators.llm_response_universal import (
|
|||
LLMUserAggregatorParams,
|
||||
)
|
||||
from pipecat.tests import MockLLMService, MockTTSService
|
||||
from pipecat.tests.mock_transport import MockTransport
|
||||
|
||||
|
||||
async def run_pipeline_with_user_idle(
|
||||
|
|
@ -52,14 +52,15 @@ async def run_pipeline_with_user_idle(
|
|||
if mock_steps is None:
|
||||
mock_steps = MockLLMService.create_multi_step_responses(
|
||||
MockLLMService.create_text_chunks("Hello, how can I help you today?"),
|
||||
num_text_steps=3, # Initial + 2 idle responses
|
||||
num_text_steps=4, # Initial + 2 idle responses + 1 variable extraction
|
||||
step_prefix="Response",
|
||||
)
|
||||
|
||||
llm = MockLLMService(mock_steps=mock_steps, chunk_delay=0.001)
|
||||
tts = MockTTSService(mock_audio_duration_ms=10)
|
||||
|
||||
mock_transport = MockTransportProcessor()
|
||||
# Create MockTransport for simulating transport behavior
|
||||
mock_transport = MockTransport(emit_bot_speaking=True)
|
||||
|
||||
# Create LLM context
|
||||
context = LLMContext()
|
||||
|
|
@ -99,15 +100,13 @@ async def run_pipeline_with_user_idle(
|
|||
user_context_aggregator,
|
||||
llm,
|
||||
tts,
|
||||
mock_transport,
|
||||
mock_transport.output(),
|
||||
assistant_context_aggregator,
|
||||
]
|
||||
)
|
||||
|
||||
# Create pipeline task
|
||||
task = PipelineTask(
|
||||
pipeline, params=PipelineParams(allow_interruptions=False), enable_rtvi=False
|
||||
)
|
||||
task = PipelineTask(pipeline, params=PipelineParams(), enable_rtvi=False)
|
||||
|
||||
engine.set_task(task)
|
||||
|
||||
|
|
@ -145,8 +144,6 @@ async def run_pipeline_with_user_idle(
|
|||
async def wait_for_idle_to_trigger():
|
||||
# Wait long enough for idle timeouts to trigger
|
||||
await asyncio.sleep(total_wait_time)
|
||||
# Cancel the task if it's still running
|
||||
await task.cancel()
|
||||
|
||||
# Run all concurrently
|
||||
await asyncio.gather(
|
||||
|
|
|
|||
238
api/tests/test_voicemail_detector.py
Normal file
238
api/tests/test_voicemail_detector.py
Normal file
|
|
@ -0,0 +1,238 @@
|
|||
"""Tests for understanding voicemail detector behavior with user aggregator and LLM.
|
||||
|
||||
This module tests the interaction between the voicemail detector, user aggregator,
|
||||
and LLM in a pipeline. It demonstrates how the voicemail detector classifies
|
||||
incoming speech as CONVERSATION or VOICEMAIL and how the main LLM responds.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
|
||||
import pytest
|
||||
|
||||
from pipecat.extensions.voicemail.voicemail_detector import VoicemailDetector
|
||||
from pipecat.frames.frames import (
|
||||
EndTaskFrame,
|
||||
Frame,
|
||||
TranscriptionFrame,
|
||||
UserStartedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
)
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.llm_response import LLMAssistantAggregatorParams
|
||||
from pipecat.processors.aggregators.llm_response_universal import (
|
||||
LLMContextAggregatorPair,
|
||||
LLMUserAggregatorParams,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
from pipecat.tests import MockLLMService
|
||||
from pipecat.turns.user_start import (
|
||||
TranscriptionUserTurnStartStrategy,
|
||||
VADUserTurnStartStrategy,
|
||||
)
|
||||
from pipecat.turns.user_stop import (
|
||||
ExternalUserTurnStopStrategy,
|
||||
)
|
||||
from pipecat.turns.user_turn_strategies import UserTurnStrategies
|
||||
from pipecat.utils.time import time_now_iso8601
|
||||
|
||||
|
||||
class FrameInjector(FrameProcessor):
|
||||
"""Simple processor that can inject frames into the pipeline."""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._frames_to_inject: list[Frame] = []
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
async def inject_frame(
|
||||
self, frame: Frame, direction: FrameDirection = FrameDirection.DOWNSTREAM
|
||||
):
|
||||
"""Inject a frame into the pipeline."""
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
|
||||
class FrameCounter:
|
||||
"""Helper to count specific frame types seen by a processor."""
|
||||
|
||||
def __init__(self):
|
||||
self.user_stopped_speaking_count = 0
|
||||
self.user_started_speaking_count = 0
|
||||
|
||||
def wrap_process_frame(self, original_process_frame):
|
||||
"""Wrap a process_frame method to count UserStoppedSpeakingFrame."""
|
||||
|
||||
async def wrapped(frame: Frame, direction: FrameDirection):
|
||||
if isinstance(frame, UserStoppedSpeakingFrame):
|
||||
self.user_stopped_speaking_count += 1
|
||||
elif isinstance(frame, UserStartedSpeakingFrame):
|
||||
self.user_started_speaking_count += 1
|
||||
return await original_process_frame(frame, direction)
|
||||
|
||||
return wrapped
|
||||
|
||||
|
||||
class TestVoicemailDetectorWithUserAggregator:
|
||||
"""Test scenarios with voicemail detector and user aggregator."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_voicemail_detector_conversation_flow(self):
|
||||
"""Test: Voicemail detector classifies as CONVERSATION and main LLM responds.
|
||||
|
||||
This test bench shows the flow:
|
||||
1. User starts speaking, sends transcription, stops speaking
|
||||
2. Voicemail detector's internal LLM classifies as "CONVERSATION"
|
||||
3. Main LLM generates response text
|
||||
4. Second user turn with transcription
|
||||
5. Main LLM generates end_call function to end pipeline
|
||||
|
||||
Pipeline structure mirrors run_pipeline.py:
|
||||
injector -> voicemail_detector.detector() -> user_aggregator -> main_llm
|
||||
-> voicemail_detector.gate() -> assistant_aggregator
|
||||
"""
|
||||
context = LLMContext()
|
||||
|
||||
# Create user turn strategies
|
||||
user_turn_strategies = UserTurnStrategies(
|
||||
start=[
|
||||
VADUserTurnStartStrategy(),
|
||||
TranscriptionUserTurnStartStrategy(),
|
||||
],
|
||||
stop=[ExternalUserTurnStopStrategy()],
|
||||
)
|
||||
|
||||
user_params = LLMUserAggregatorParams(
|
||||
user_turn_strategies=user_turn_strategies,
|
||||
)
|
||||
|
||||
assistant_params = LLMAssistantAggregatorParams(expect_stripped_words=True)
|
||||
|
||||
context_aggregator = LLMContextAggregatorPair(
|
||||
context, assistant_params=assistant_params, user_params=user_params
|
||||
)
|
||||
user_context_aggregator = context_aggregator.user()
|
||||
assistant_context_aggregator = context_aggregator.assistant()
|
||||
|
||||
# Create mock LLM for main conversation
|
||||
# Step 0: First response after CONVERSATION classification
|
||||
# Step 1: Response to second user turn
|
||||
# Step 2: end_call function call to end pipeline
|
||||
main_llm_steps = [
|
||||
MockLLMService.create_text_chunks(text="Hello! I'm here to help you today.")
|
||||
]
|
||||
main_llm = MockLLMService(mock_steps=main_llm_steps, chunk_delay=0.001)
|
||||
|
||||
# Create mock LLM for voicemail classification
|
||||
# First response: "CONVERSATION" to close the voicemail gate
|
||||
voicemail_classification_steps = [
|
||||
MockLLMService.create_text_chunks(text="CONVERSATION"),
|
||||
]
|
||||
voicemail_llm = MockLLMService(
|
||||
mock_steps=voicemail_classification_steps, chunk_delay=0.001
|
||||
)
|
||||
|
||||
# Create voicemail detector with the classification LLM
|
||||
voicemail_detector = VoicemailDetector(
|
||||
llm=voicemail_llm,
|
||||
voicemail_response_delay=0,
|
||||
)
|
||||
|
||||
# Set up frame counter to track UserStoppedSpeakingFrame in voicemail detector's user aggregator
|
||||
voicemail_user_aggregator = voicemail_detector._context_aggregator.user()
|
||||
frame_counter = FrameCounter()
|
||||
original_process_frame = voicemail_user_aggregator.process_frame
|
||||
voicemail_user_aggregator.process_frame = frame_counter.wrap_process_frame(
|
||||
original_process_frame
|
||||
)
|
||||
|
||||
# Build pipeline similar to run_pipeline.py structure
|
||||
injector = FrameInjector()
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
injector,
|
||||
voicemail_detector.detector(), # Classification parallel pipeline
|
||||
user_context_aggregator,
|
||||
main_llm,
|
||||
assistant_context_aggregator,
|
||||
]
|
||||
)
|
||||
|
||||
task = PipelineTask(pipeline, params=PipelineParams(), enable_rtvi=False)
|
||||
runner = PipelineRunner()
|
||||
|
||||
async def run_pipeline():
|
||||
await runner.run(task)
|
||||
|
||||
async def inject_frames():
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
# === First user turn ===
|
||||
# This triggers voicemail classification AND main LLM response
|
||||
await injector.inject_frame(UserStartedSpeakingFrame())
|
||||
await asyncio.sleep(0)
|
||||
await injector.inject_frame(
|
||||
TranscriptionFrame("First User Speech", "user-123", time_now_iso8601())
|
||||
)
|
||||
await asyncio.sleep(0.05)
|
||||
await injector.inject_frame(UserStoppedSpeakingFrame())
|
||||
|
||||
# Wait for voicemail classification and main LLM response
|
||||
await asyncio.sleep(0.2)
|
||||
|
||||
# === Second user turn ===
|
||||
await injector.inject_frame(UserStartedSpeakingFrame())
|
||||
|
||||
await asyncio.sleep(0)
|
||||
await injector.inject_frame(
|
||||
TranscriptionFrame(
|
||||
"Second User Speech",
|
||||
"user-123",
|
||||
time_now_iso8601(),
|
||||
)
|
||||
)
|
||||
|
||||
await asyncio.sleep(0.05)
|
||||
await injector.inject_frame(UserStoppedSpeakingFrame())
|
||||
|
||||
await asyncio.sleep(0.05)
|
||||
await injector.inject_frame(
|
||||
EndTaskFrame(), direction=FrameDirection.UPSTREAM
|
||||
)
|
||||
|
||||
await asyncio.gather(run_pipeline(), inject_frames())
|
||||
|
||||
# Assert voicemail LLM was called once for classification
|
||||
assert voicemail_llm.get_current_step() == 1
|
||||
|
||||
# Assert main LLM was called twice (once per user turn)
|
||||
assert main_llm.get_current_step() == 2
|
||||
|
||||
# Assert voicemail detector's user aggregator saw UserStoppedSpeakingFrame only once
|
||||
# (because the classifier gate closes after CONVERSATION classification,
|
||||
# blocking subsequent frames from reaching the voicemail branch)
|
||||
assert frame_counter.user_stopped_speaking_count == 1, (
|
||||
f"Expected voicemail detector's user aggregator to see UserStoppedSpeakingFrame once, "
|
||||
f"but saw it {frame_counter.user_stopped_speaking_count} times"
|
||||
)
|
||||
|
||||
# We should see no more than 2 user started speaking frame. One from downstream FrameInjector
|
||||
# and one from upstream main pipeline's LLMUserAggregator
|
||||
assert frame_counter.user_started_speaking_count <= 2, (
|
||||
f"Expected voicemail detector's user aggregator to see UserStartedSpeakingFrame at most twice, "
|
||||
f"but saw it {frame_counter.user_started_speaking_count} times"
|
||||
)
|
||||
|
||||
# Assert the classifier gate is closed after classification
|
||||
assert voicemail_detector._classifier_gate._gate_opened is False, (
|
||||
"Expected classifier gate to be closed after CONVERSATION classification"
|
||||
)
|
||||
|
||||
# Assert the classifier gate is closed after classification
|
||||
assert voicemail_detector._classifier_upstream_gate._gate_open is False, (
|
||||
"Expected classifier upstream gate to be closed after CONVERSATION classification"
|
||||
)
|
||||
2
pipecat
2
pipecat
|
|
@ -1 +1 @@
|
|||
Subproject commit f11fad8f3e90e06b1625b9dc49c13e26f3c9e716
|
||||
Subproject commit df1432e168570661ae418500fb04e8c62ba1335b
|
||||
Loading…
Add table
Add a link
Reference in a new issue