From 033fde8946881fa81e90ffb37e7242f1c30421cb Mon Sep 17 00:00:00 2001 From: Abhishek Date: Tue, 27 Jan 2026 18:20:23 +0530 Subject: [PATCH] chore: refactor and add tests (#130) * chore: add tests for end call * Update pipecat module * fix: allow interruptions from deepgram flux * Add VadUserTurnStrategy * chore: add test for voicemail detection --- api/services/pipecat/event_handlers.py | 24 +- api/services/pipecat/pipeline_builder.py | 4 + api/services/pipecat/run_pipeline.py | 19 +- api/services/workflow/pipecat_engine.py | 241 ++-- .../workflow/pipecat_engine_callbacks.py | 33 +- .../workflow/pipecat_engine_custom_tools.py | 18 +- api/tests/conftest.py | 307 +++-- .../test_pipecat_engine_context_update.py | 129 +- api/tests/test_pipecat_engine_end_call.py | 1097 +++++++++++++++++ ...cat_engine_node_switch_with_user_speech.py | 393 ++++++ api/tests/test_pipecat_engine_tool_calls.py | 12 +- ...test_pipecat_engine_variable_extraction.py | 116 +- api/tests/test_user_idle_handler.py | 15 +- api/tests/test_voicemail_detector.py | 238 ++++ pipecat | 2 +- 15 files changed, 2106 insertions(+), 542 deletions(-) create mode 100644 api/tests/test_pipecat_engine_end_call.py create mode 100644 api/tests/test_pipecat_engine_node_switch_with_user_speech.py create mode 100644 api/tests/test_voicemail_detector.py diff --git a/api/services/pipecat/event_handlers.py b/api/services/pipecat/event_handlers.py index 1e68109..a898639 100644 --- a/api/services/pipecat/event_handlers.py +++ b/api/services/pipecat/event_handlers.py @@ -10,16 +10,13 @@ from api.services.pipecat.in_memory_buffers import ( InMemoryTranscriptBuffer, ) from api.services.pipecat.pipeline_metrics_aggregator import PipelineMetricsAggregator -from api.services.workflow.disposition_mapper import ( - apply_disposition_mapping, - get_organization_id_from_workflow_run, -) from api.services.workflow.pipecat_engine import PipecatEngine from api.tasks.arq import enqueue_job from api.tasks.function_names import FunctionNames from pipecat.frames.frames import Frame, LLMContextFrame from pipecat.pipeline.task import PipelineTask from pipecat.processors.audio.audio_buffer_processor import AudioBufferProcessor +from pipecat.utils.enums import EndTaskReason def register_event_handlers( @@ -83,18 +80,17 @@ def register_event_handlers( @transport.event_handler("on_client_disconnected") async def on_client_disconnected(_transport, _participant): call_disposed = engine.is_call_disposed() - logger.debug( f"In on_client_disconnected callback handler. Call disposed: {call_disposed}" ) - engine.handle_client_disconnected() # Stop recordings await audio_buffer.stop_recording() - # Only cancel the task if the call is not already disposed by the engine - if not call_disposed: - await task.cancel() + # End the call + await engine.end_call_with_reason( + EndTaskReason.USER_HANGUP.value, abort_immediately=True + ) @task.event_handler("on_pipeline_started") async def on_pipeline_started(_task: PipelineTask, _frame: Frame): @@ -114,9 +110,6 @@ def register_event_handlers( # Stop recordings await audio_buffer.stop_recording() - call_disposition = await engine.get_call_disposition() - logger.debug(f"call disposition in on_pipeline_finished: {call_disposition}") - gathered_context = await engine.get_gathered_context() # Add trace URL if available (must be done before conversation tracing ends) @@ -129,13 +122,6 @@ def register_event_handlers( # also consider existing gathered context in workflow_run gathered_context = {**gathered_context, **workflow_run.gathered_context} - organization_id = await get_organization_id_from_workflow_run(workflow_run_id) - mapped_call_disposition = await apply_disposition_mapping( - call_disposition, organization_id - ) - - gathered_context.update({"mapped_call_disposition": mapped_call_disposition}) - # Set user_speech call tag if in_memory_transcript_buffer: call_tags = gathered_context.get("call_tags", []) diff --git a/api/services/pipecat/pipeline_builder.py b/api/services/pipecat/pipeline_builder.py index 0f6de5d..1706a2e 100644 --- a/api/services/pipecat/pipeline_builder.py +++ b/api/services/pipecat/pipeline_builder.py @@ -57,6 +57,10 @@ def build_pipeline( # Insert voicemail detector after STT if enabled # Note: We intentionally do NOT use voicemail_detector.gate() to allow TTS # frames to continue flowing during classification (non-blocking detection) + + # Note: We must keep user_context_aggregator after voicemail_detector + # or else, LLMContextFrames generated from user_context_aggregator will + # start generating LLM Completion from Voicemail Classifier if voicemail_detector: logger.info("Adding native voicemail detector to pipeline") processors.append(voicemail_detector.detector()) diff --git a/api/services/pipecat/run_pipeline.py b/api/services/pipecat/run_pipeline.py index 91b1e18..cf5ea26 100644 --- a/api/services/pipecat/run_pipeline.py +++ b/api/services/pipecat/run_pipeline.py @@ -52,9 +52,11 @@ from pipecat.processors.aggregators.llm_response_universal import ( LLMUserAggregatorParams, ) from pipecat.transports.smallwebrtc.connection import SmallWebRTCConnection -from pipecat.turns.user_mute import MuteUntilFirstBotCompleteUserMuteStrategy +from pipecat.turns.user_mute import ( + CallbackUserMuteStrategy, + MuteUntilFirstBotCompleteUserMuteStrategy, +) from pipecat.turns.user_start import ( - ExternalUserTurnStartStrategy, TranscriptionUserTurnStartStrategy, ) from pipecat.turns.user_start.vad_user_turn_start_strategy import ( @@ -547,7 +549,7 @@ async def _run_pipeline( if is_deepgram_flux: user_turn_strategies = UserTurnStrategies( - start=[VADUserTurnStartStrategy(), ExternalUserTurnStartStrategy()], + start=[VADUserTurnStartStrategy(), TranscriptionUserTurnStartStrategy()], stop=[ExternalUserTurnStopStrategy()], ) else: @@ -556,9 +558,16 @@ async def _run_pipeline( stop=[TranscriptionUserTurnStopStrategy()], ) + # Create user mute strategies + # - CallbackUserMuteStrategy: mutes based on engine's _mute_pipeline state + user_mute_strategies = [ + MuteUntilFirstBotCompleteUserMuteStrategy(), + CallbackUserMuteStrategy(should_mute_callback=engine.should_mute_user), + ] + user_params = LLMUserAggregatorParams( user_turn_strategies=user_turn_strategies, - user_mute_strategies=[MuteUntilFirstBotCompleteUserMuteStrategy()], + user_mute_strategies=user_mute_strategies, user_idle_timeout=max_user_idle_timeout, ) context_aggregator = LLMContextAggregatorPair( @@ -606,7 +615,7 @@ async def _run_pipeline( @voicemail_detector.event_handler("on_voicemail_detected") async def _on_voicemail_detected(_processor): logger.info(f"Voicemail detected for workflow run {workflow_run_id}") - await engine.send_end_task_frame( + await engine.end_call_with_reason( reason=EndTaskReason.VOICEMAIL_DETECTED.value, abort_immediately=True, ) diff --git a/api/services/workflow/pipecat_engine.py b/api/services/workflow/pipecat_engine.py index 8a810dc..08c6baa 100644 --- a/api/services/workflow/pipecat_engine.py +++ b/api/services/workflow/pipecat_engine.py @@ -9,7 +9,6 @@ from pipecat.frames.frames import ( CancelFrame, EndFrame, FunctionCallResultProperties, - TTSSpeakFrame, ) from pipecat.pipeline.task import PipelineTask from pipecat.processors.aggregators.llm_context import LLMContext @@ -19,6 +18,7 @@ from pipecat.utils.enums import EndTaskReason if TYPE_CHECKING: from api.services.telephony.stasis_rtp_connection import StasisRTPConnection + from pipecat.frames.frames import Frame from pipecat.services.anthropic.llm import AnthropicLLMService from pipecat.services.google.llm import GoogleLLMService from pipecat.services.openai.llm import OpenAILLMService @@ -49,7 +49,6 @@ from api.services.workflow.tools.timezone import ( get_current_time, get_time_tools, ) -from pipecat.processors.filters.stt_mute_filter import STTMuteFilter from pipecat.utils.tracing.context_registry import get_current_turn_context @@ -79,12 +78,10 @@ class PipecatEngine: self._workflow_run_id = workflow_run_id self._node_transition_callback = node_transition_callback self._initialized = False - self._client_disconnected = False self._call_disposed = False self._current_node: Optional[Node] = None self._gathered_context: dict = {} self._user_response_timeout_task: Optional[asyncio.Task] = None - self._call_disposition: Optional[str] = None # Stasis connection for immediate transfers self._stasis_connection: Optional["StasisRTPConnection"] = None @@ -99,6 +96,9 @@ class PipecatEngine: # Track current LLM reference text for TTS aggregation correction self._current_llm_generation_reference_text: str = "" + # Controls whether user input should be muted + self._mute_pipeline: bool = False + # Custom tool manager (initialized in initialize()) self._custom_tool_manager: Optional[CustomToolManager] = None @@ -215,9 +215,14 @@ class PipecatEngine: This way, when we do set_node from within this function, and go for LLM completion with updated system prompts, the context is updated with function call result. """ + # FIXME: There is a potential race condition, when we generate LLM Completion from UserContextAggregator + # with FunctionCallResultFrame and we call end_call_with_reason where we queue EndFrame or CancelFrame. + # If EndFrame reaches the LLM Processor before the ContextFrame, we might never run generation which + # might be intended + # Queue EndFrame if we just transitioned to EndNode if self._current_node.is_end: - await self.send_end_task_frame( + await self.end_call_with_reason( EndTaskReason.USER_QUALIFIED.value ) @@ -356,44 +361,52 @@ class PipecatEngine: self.llm.register_function("retrieve_from_knowledge_base", retrieve_kb_func) async def _perform_variable_extraction_if_needed( - self, previous_node: Optional[Node] + self, node: Optional[Node], run_in_background: bool = True ) -> None: - """Perform variable extraction if the previous node had extraction enabled.""" - if ( - previous_node - and previous_node.extraction_enabled - and previous_node.extraction_variables - ): + """Perform variable extraction if the node has extraction enabled. + + Args: + node: The node to extract variables from. + run_in_background: If True, runs extraction as a fire-and-forget task. + If False, awaits the extraction synchronously. + """ + if not (node and node.extraction_enabled and node.extraction_variables): + return + + # Capture the current turn context for otel tracing + # before creating the background task + parent_context = get_current_turn_context() + + extraction_prompt = self._format_prompt(node.extraction_prompt) + extraction_variables = node.extraction_variables + + async def _do_extraction(): + try: + extracted_data = ( + await self._variable_extraction_manager._perform_extraction( + extraction_variables, parent_context, extraction_prompt + ) + ) + self._gathered_context.update(extracted_data) + logger.debug( + f"Variable extraction completed. Extracted: {extracted_data}" + ) + except Exception as e: + logger.error(f"Error during variable extraction: {str(e)}") + + if run_in_background: logger.debug( - f"Scheduling background variable extraction for node: {previous_node.name}" + f"Scheduling background variable extraction for node: {node.name}" ) + asyncio.create_task(_do_extraction()) + else: + logger.debug( + f"Performing synchronous variable extraction for node: {node.name}" + ) + await _do_extraction() - # Capture the current turn context before creating the background task - parent_context = get_current_turn_context() - extraction_prompt = self._format_prompt(previous_node.extraction_prompt) - extraction_variables = previous_node.extraction_variables - - async def _background_extraction(): - try: - extracted_data = ( - await self._variable_extraction_manager._perform_extraction( - extraction_variables, parent_context, extraction_prompt - ) - ) - self._gathered_context.update(extracted_data) - logger.debug( - f"Background variable extraction completed. Extracted: {extracted_data}" - ) - except Exception as e: - logger.error( - f"Error during background variable extraction: {str(e)}" - ) - - # Fire and forget - extraction happens in background without blocking - asyncio.create_task(_background_extraction()) - - async def _setup_llm_context_and_start_generation(self, node: Node) -> None: - """Common method to set up LLM context and queue context frame for non-static nodes.""" + async def _setup_llm_context(self, node: Node) -> None: + """Common method to set up LLM context""" # Set node name for tracing try: self.context.set_node_name(node.name) @@ -470,61 +483,54 @@ class PipecatEngine: if node.is_static: raise ValueError("Static nodes are not supported!") else: - # Start generation for non-static start node - await self._setup_llm_context_and_start_generation(node) + # Setup LLM Context with Prompts and Functions + await self._setup_llm_context(node) async def _handle_end_node(self, node: Node) -> None: """Handle end node execution.""" if node.is_static: raise ValueError("Static nodes are not supported!") else: - await self._setup_llm_context_and_start_generation(node) - - # If this end node has extraction enabled, perform extraction immediately - if node.extraction_enabled and node.extraction_variables: - await self._perform_variable_extraction_if_needed(node) + # Setup LLM Context with Prompts and Functions + await self._setup_llm_context(node) async def _handle_agent_node(self, node: Node) -> None: """Handle agent node execution.""" if node.is_static: raise ValueError("Static nodes are not supported!") else: - # Set context and functions for non-static agent node - await self._setup_llm_context_and_start_generation(node) + # Setup LLM Context with Prompts and Functions + await self._setup_llm_context(node) - async def send_end_task_frame( + async def end_call_with_reason( self, reason: str, abort_immediately: bool = False, ): """ - Centralized method to send EndTaskFrame with metadata including - call_transfer_context and call_context_vars + Centralized method to end the call with disposition mapping """ - if self._call_disposed or self._client_disconnected: - # Call is already disposed and client disconnected - logger.debug( - f"Not sending EndFrame since call is already disposed: Call Disposed: {self._call_disposed} Client Disconnected: {self._client_disconnected}" - ) + if self._call_disposed: + logger.debug(f"Call already Disposed: {self._call_disposed}") return self._call_disposed = True - frame_to_push = CancelFrame() if abort_immediately else EndFrame() + # Mute the pipeline + self._mute_pipeline = True - # Customer disposition code using their mapping - mapped_disposition = "" + # Perform final variable extraction synchronously before ending + await self._perform_variable_extraction_if_needed( + self._current_node, run_in_background=False + ) + + frame_to_push = CancelFrame() if abort_immediately else EndFrame() # Apply disposition mapping - first try call_disposition if it is, # extracted from the call conversation then fall back to reason call_disposition = self._gathered_context.get("call_disposition", "") organization_id = await self._get_organization_id() - # If client is disconnected before we get a chance to disconnect from - # the bot, lets consider that as final disposition - if self._client_disconnected: - call_disposition = EndTaskReason.USER_HANGUP.value - if call_disposition: # If call_disposition exists, map it mapped_disposition = await apply_disposition_mapping( @@ -532,90 +538,16 @@ class PipecatEngine: ) # Store the original and mapped values self._gathered_context["extracted_call_disposition"] = call_disposition - self._gathered_context["call_disposition"] = mapped_disposition + self._gathered_context["call_disposition"] = call_disposition + self._gathered_context["mapped_call_disposition"] = mapped_disposition else: # Otherwise, map the disconnect reason mapped_disposition = await apply_disposition_mapping( reason, organization_id ) # Store the mapped disconnect reason - self._gathered_context["call_disposition"] = mapped_disposition - - # TODO: Generalise this - self._gathered_context["address"] = ", ".join( - [ - self._call_context_vars.get("address1", ""), - self._call_context_vars.get("address2", ""), - self._call_context_vars.get("address3", ""), - self._call_context_vars.get("city", ""), - self._call_context_vars.get("state", ""), - self._call_context_vars.get("province", ""), - self._call_context_vars.get("postal_code", ""), - ] - ) - self._gathered_context["full_name"] = " ".join( - [ - self._call_context_vars.get("first_name", ""), - self._call_context_vars.get("middle_initial", ""), - self._call_context_vars.get("last_name", ""), - ] - ) - self._gathered_context["agent_name"] = "Alex" - self._gathered_context["customer_phone_number"] = self._call_context_vars.get( - "phone", "" - ) - self._gathered_context["timezone"] = self._call_context_vars.get("province", "") - self._gathered_context["vendor_id"] = self._call_context_vars.get( - "vendor_lead_code", "" - ) - - decision_maker = self._gathered_context.get("primary_cardholder", False) - employment_status = self._gathered_context.get("employment_status", "N/A") - call_transfer_context = { - "first_name": self._call_context_vars.get("first_name", ""), - "full_name": self._gathered_context.get("full_name", ""), - "phone": self._call_context_vars.get("phone", ""), - "lead_id": self._call_context_vars.get("lead_id"), - "disposition": mapped_disposition, - "agent_name": self._gathered_context.get("agent_name", "Alex"), - "decision_maker": str(decision_maker), - "employment": employment_status.title() if employment_status else "N/A", - "debts": self._gathered_context.get("total_debt", "N/A"), - "number_of_credit_cards": self._gathered_context.get( - "number_of_credit_cards", "N/A" - ), - "time": self._gathered_context.get("time"), - } - - logger.debug( - f"gathered_context: {self._gathered_context} call_transfer_context: {call_transfer_context}" - ) - - # Initiate immediate transfer for Stasis connections when user is qualified - if ( - reason == EndTaskReason.USER_QUALIFIED.value - and self._stasis_connection is not None - and not abort_immediately - ): - try: - logger.info( - f"Initiating immediate Stasis transfer for channel {self._stasis_connection.channel_id}" - ) - await self._stasis_connection.transfer(call_transfer_context) - logger.info("Immediate transfer initiated successfully") - except Exception as e: - logger.error(f"Failed to initiate immediate transfer: {e}") - # Continue with normal flow even if immediate transfer fails - - if reason == EndTaskReason.CALL_DURATION_EXCEEDED.value: - await self.task.queue_frame( - TTSSpeakFrame( - "Sorry! It seems like our time has exceeded. Someone from our team will reach out to you soon. Thank you!" - ) - ) - - # Store the original reason for later retrieval in event handler - self._call_disposition = mapped_disposition + self._gathered_context["call_disposition"] = reason + self._gathered_context["mapped_call_disposition"] = mapped_disposition logger.debug( f"Finishing run with reason: {reason}, disposition: {mapped_disposition} queueing frame {frame_to_push}" @@ -678,11 +610,14 @@ class PipecatEngine: return system_message, functions - def create_should_mute_callback(self) -> Callable[[STTMuteFilter], Awaitable[bool]]: + async def should_mute_user(self, frame: "Frame") -> bool: """ - This callback is called by STTMuteFilter to determine if the STT should be muted. + Callback for CallbackUserMuteStrategy to determine if the user should be muted. + + Returns: + True if the user should be muted, False otherwise. """ - return engine_callbacks.create_should_mute_callback(self) + return self._mute_pipeline def create_user_idle_handler(self): """ @@ -746,26 +681,10 @@ class PipecatEngine: """Accumulate LLM text frames to build reference text.""" self._current_llm_generation_reference_text += text - def handle_client_disconnected(self): - """Handle client disconnected event.""" - self._client_disconnected = True - def is_call_disposed(self): """Check whether a call has been disposed by the engine""" return self._call_disposed - async def get_call_disposition(self) -> Optional[str]: - """Get the disconnect reason set by the engine.""" - if self._call_disposition: - # We would have a _call_disposition variable set if we have initiated - # a disconnect from the bot, i.e we have called send_end_task_frame. - return self._call_disposition - - if self._client_disconnected: - return EndTaskReason.USER_HANGUP.value - else: - return EndTaskReason.UNKNOWN.value - async def get_gathered_context(self) -> dict: """Get the gathered context including extracted variables.""" return self._gathered_context.copy() diff --git a/api/services/workflow/pipecat_engine_callbacks.py b/api/services/workflow/pipecat_engine_callbacks.py index a6ac51c..a422734 100644 --- a/api/services/workflow/pipecat_engine_callbacks.py +++ b/api/services/workflow/pipecat_engine_callbacks.py @@ -11,46 +11,19 @@ unit-testing. """ import re -from typing import TYPE_CHECKING, Awaitable, Callable +from typing import TYPE_CHECKING from loguru import logger from pipecat.frames.frames import ( LLMMessagesAppendFrame, ) -from pipecat.processors.filters.stt_mute_filter import STTMuteFilter from pipecat.utils.enums import EndTaskReason if TYPE_CHECKING: from api.services.workflow.pipecat_engine import PipecatEngine -# --------------------------------------------------------------------------- -# STT mute handling -# --------------------------------------------------------------------------- - - -def create_should_mute_callback( - engine: "PipecatEngine", -) -> Callable[[STTMuteFilter], Awaitable[bool]]: - """Return a callback indicating whether STT should be muted. - - STT is muted when *interruptions are **not*** allowed on the current node. - """ - - async def callback(_: STTMuteFilter) -> bool: # noqa: D401 - if engine._current_node is None: - # Default to not muting if we have no active node yet. - return False - - logger.debug( - f"STT mute callback: allow_interrupt={engine._current_node.allow_interrupt}" - ) - return not engine._current_node.allow_interrupt - - return callback - - # --------------------------------------------------------------------------- # User-idle handling # --------------------------------------------------------------------------- @@ -85,7 +58,7 @@ class UserIdleHandler: "content": "The user has been quiet. We will be disconnecting the call now. Wish them a good day in the language that the user has been speaking so far.", } await aggregator.push_frame(LLMMessagesAppendFrame([message], run_llm=True)) - await self._engine.send_end_task_frame( + await self._engine.end_call_with_reason( EndTaskReason.USER_IDLE_MAX_DURATION_EXCEEDED.value ) @@ -105,7 +78,7 @@ def create_max_duration_callback(engine: "PipecatEngine"): async def handle_max_duration(): logger.debug("Max call duration exceeded. Terminating call") - await engine.send_end_task_frame(EndTaskReason.CALL_DURATION_EXCEEDED.value) + await engine.end_call_with_reason(EndTaskReason.CALL_DURATION_EXCEEDED.value) return handle_max_duration diff --git a/api/services/workflow/pipecat_engine_custom_tools.py b/api/services/workflow/pipecat_engine_custom_tools.py index af01510..b60ea79 100644 --- a/api/services/workflow/pipecat_engine_custom_tools.py +++ b/api/services/workflow/pipecat_engine_custom_tools.py @@ -23,15 +23,12 @@ from api.services.workflow.tools.custom_tool import ( from pipecat.adapters.schemas.function_schema import FunctionSchema from pipecat.frames.frames import FunctionCallResultProperties, TTSSpeakFrame from pipecat.services.llm_service import FunctionCallParams +from pipecat.utils.enums import EndTaskReason if TYPE_CHECKING: from api.services.workflow.pipecat_engine import PipecatEngine -# End task reason for end call tool -END_CALL_TOOL_REASON = "end_call_tool" - - class CustomToolManager: """Manager for custom tool registration and execution. @@ -214,21 +211,22 @@ class CustomToolManager: logger.info(f"Playing custom goodbye message: {custom_message}") await self._engine.task.queue_frame(TTSSpeakFrame(custom_message)) # End the call after the message (not immediately) - await self._engine.send_end_task_frame( - END_CALL_TOOL_REASON, abort_immediately=False + await self._engine.end_call_with_reason( + EndTaskReason.END_CALL_TOOL_REASON.value, + abort_immediately=False, ) else: # No message - end call immediately logger.info("Ending call immediately (no goodbye message)") - await self._engine.send_end_task_frame( - END_CALL_TOOL_REASON, abort_immediately=True + await self._engine.end_call_with_reason( + EndTaskReason.END_CALL_TOOL_REASON.value, abort_immediately=True ) except Exception as e: logger.error(f"End call tool '{function_name}' execution failed: {e}") # Still try to end the call even if there's an error - await self._engine.send_end_task_frame( - END_CALL_TOOL_REASON, abort_immediately=True + await self._engine.end_call_with_reason( + EndTaskReason.UNEXPECTED_ERROR.value, abort_immediately=True ) return end_call_handler diff --git a/api/tests/conftest.py b/api/tests/conftest.py index a7fe452..9b0484b 100644 --- a/api/tests/conftest.py +++ b/api/tests/conftest.py @@ -6,27 +6,20 @@ import pytest from api.services.workflow.dto import ( EdgeDataDTO, + ExtractionVariableDTO, NodeDataDTO, NodeType, Position, ReactFlowDTO, RFEdgeDTO, RFNodeDTO, + VariableType, ) from api.services.workflow.workflow import WorkflowGraph -from pipecat.frames.frames import ( - BotSpeakingFrame, - BotStartedSpeakingFrame, - BotStoppedSpeakingFrame, - Frame, - TTSAudioRawFrame, - TTSStartedFrame, - TTSStoppedFrame, -) -from pipecat.processors.frame_processor import FrameDirection, FrameProcessor -START_CALL_SYSTEM_PROMPT = "start_call_system_prompt" -END_CALL_SYSTEM_PROMPT = "end_call_system_prompt" +START_CALL_SYSTEM_PROMPT = "Start Call System Prompt" +AGENT_SYSTEM_PROMPT = "Agent Node System Prompt" +END_CALL_SYSTEM_PROMPT = "End Call System Prompt" # Default workflow definition for mocking database WorkflowModel DEFAULT_WORKFLOW_DEFINITION = { @@ -110,57 +103,6 @@ class MockUserConfig: embeddings: Optional[Any] = None -class MockTransportProcessor(FrameProcessor): - """ - Mocks the transport behavior by emitting Bot speaking frames - when it encounters TTS frames. - - This simulates what a real transport would do when the bot is speaking: - - TTSStartedFrame -> BotStartedSpeakingFrame - - TTSAudioRawFrame -> BotSpeakingFrame - - TTSStoppedFrame -> BotStoppedSpeakingFrame - - Args: - emit_bot_speaking: If True, also emits BotSpeakingFrame on TTSAudioRawFrame - which is needed for user idle tracking to start conversation tracking. Default True. - """ - - def __init__( - self, - *, - emit_bot_speaking: bool = True, - **kwargs, - ): - super().__init__(**kwargs) - self._emit_bot_speaking = emit_bot_speaking - - async def process_frame(self, frame: Frame, direction: FrameDirection): - await super().process_frame(frame, direction) - - if isinstance(frame, TTSStartedFrame): - # Emit BotStartedSpeakingFrame to indicate bot started speaking - await self.push_frame(BotStartedSpeakingFrame()) - await self.push_frame( - BotStartedSpeakingFrame(), direction=FrameDirection.UPSTREAM - ) - elif isinstance(frame, TTSAudioRawFrame): - # Emit BotSpeakingFrame - this is what triggers user idle tracking - # to start conversation tracking - if self._emit_bot_speaking: - await self.push_frame(BotSpeakingFrame()) - await self.push_frame( - BotSpeakingFrame(), direction=FrameDirection.UPSTREAM - ) - elif isinstance(frame, TTSStoppedFrame): - # Emit BotStoppedSpeakingFrame to indicate bot stopped speaking - await self.push_frame(BotStoppedSpeakingFrame()) - await self.push_frame( - BotStoppedSpeakingFrame(), direction=FrameDirection.UPSTREAM - ) - - await self.push_frame(frame, direction) - - @dataclass class MockToolModel: """Mock tool model for testing.""" @@ -299,14 +241,14 @@ def simple_workflow() -> WorkflowGraph: """Create a simple two-node workflow for testing. The workflow has: - - Start node with a prompt + - Start node with extraction enabled (extracts user_intent) - End node with a prompt - One edge connecting them with label "End Call" """ dto = ReactFlowDTO( nodes=[ RFNodeDTO( - id="1", + id="start", type=NodeType.startNode, position=Position(x=0, y=0), data=NodeDataDTO( @@ -315,10 +257,19 @@ def simple_workflow() -> WorkflowGraph: is_start=True, allow_interrupt=False, add_global_prompt=False, + extraction_enabled=True, + extraction_prompt="Extract user information from the conversation.", + extraction_variables=[ + ExtractionVariableDTO( + name="user_intent", + type=VariableType.string, + prompt="The user's intent or reason for calling", + ), + ], ), ), RFNodeDTO( - id="2", + id="end", type=NodeType.endNode, position=Position(x=0, y=200), data=NodeDataDTO( @@ -327,14 +278,15 @@ def simple_workflow() -> WorkflowGraph: is_end=True, allow_interrupt=False, add_global_prompt=False, + extraction_enabled=False, ), ), ], edges=[ RFEdgeDTO( - id="1-2", - source="1", - target="2", + id="start-end", + source="start", + target="end", data=EdgeDataDTO( label="End Call", condition="When the user says to end the call, end the call", @@ -350,37 +302,59 @@ def three_node_workflow() -> WorkflowGraph: """Create a three-node workflow for testing with an intermediate agent node. The workflow has: - - Start node - - Agent node (for collecting information) - - End node + - Start node with extraction enabled (extracts greeting_type) + - Agent node with extraction enabled (extracts user_name) + - End node (no extraction) + + Edges: + - Start -> Agent (label: "Collect Info") + - Agent -> End (label: "End Call") """ dto = ReactFlowDTO( nodes=[ RFNodeDTO( - id="1", + id="start", type=NodeType.startNode, position=Position(x=0, y=0), data=NodeDataDTO( name="Start Call", prompt=START_CALL_SYSTEM_PROMPT, is_start=True, - allow_interrupt=True, + allow_interrupt=False, add_global_prompt=False, + extraction_enabled=True, + extraction_prompt="Extract greeting information from the conversation.", + extraction_variables=[ + ExtractionVariableDTO( + name="greeting_type", + type=VariableType.string, + prompt="The type of greeting used", + ), + ], ), ), RFNodeDTO( - id="2", + id="agent", type=NodeType.agentNode, position=Position(x=0, y=200), data=NodeDataDTO( name="Collect Info", - prompt="Help the user with their request. Ask clarifying questions if needed.", - allow_interrupt=True, + prompt=AGENT_SYSTEM_PROMPT, + allow_interrupt=False, add_global_prompt=False, + extraction_enabled=True, + extraction_prompt="Extract user details from the conversation.", + extraction_variables=[ + ExtractionVariableDTO( + name="user_name", + type=VariableType.string, + prompt="The user's name", + ), + ], ), ), RFNodeDTO( - id="3", + id="end", type=NodeType.endNode, position=Position(x=0, y=400), data=NodeDataDTO( @@ -389,26 +363,187 @@ def three_node_workflow() -> WorkflowGraph: is_end=True, allow_interrupt=False, add_global_prompt=False, + extraction_enabled=False, ), ), ], edges=[ RFEdgeDTO( - id="1-2", - source="1", - target="2", + id="start-agent", + source="start", + target="agent", data=EdgeDataDTO( label="Collect Info", - condition="When the user wants help, collect their information", + condition="When user has been greeted, proceed to collect information", ), ), RFEdgeDTO( - id="2-3", - source="2", - target="3", + id="agent-end", + source="agent", + target="end", data=EdgeDataDTO( label="End Call", - condition="When the user is done or wants to end the call", + condition="When information collection is complete, end the call", + ), + ), + ], + ) + return WorkflowGraph(dto) + + +@pytest.fixture +def three_node_workflow_extraction_start_only() -> WorkflowGraph: + """Create a three-node workflow with extraction enabled ONLY on start node. + + This fixture is specifically for testing that variable extraction is triggered + for the correct node during transitions. The agent node has extraction disabled + to verify extraction happens for the SOURCE node, not the TARGET node. + + The workflow has: + - Start node with extraction enabled (extracts user_name) + - Agent node with extraction DISABLED + - End node (no extraction) + """ + dto = ReactFlowDTO( + nodes=[ + RFNodeDTO( + id="start", + type=NodeType.startNode, + position=Position(x=0, y=0), + data=NodeDataDTO( + name="Start Call", + prompt=START_CALL_SYSTEM_PROMPT, + is_start=True, + allow_interrupt=False, + add_global_prompt=False, + extraction_enabled=True, + extraction_prompt="Extract the user's name from the conversation.", + extraction_variables=[ + ExtractionVariableDTO( + name="user_name", + type=VariableType.string, + prompt="The name the user provided", + ), + ], + ), + ), + RFNodeDTO( + id="agent", + type=NodeType.agentNode, + position=Position(x=0, y=200), + data=NodeDataDTO( + name="Collect Info", + prompt=AGENT_SYSTEM_PROMPT, + allow_interrupt=False, + add_global_prompt=False, + extraction_enabled=False, # Explicitly disabled for testing + ), + ), + RFNodeDTO( + id="end", + type=NodeType.endNode, + position=Position(x=0, y=400), + data=NodeDataDTO( + name="End Call", + prompt=END_CALL_SYSTEM_PROMPT, + is_end=True, + allow_interrupt=False, + add_global_prompt=False, + extraction_enabled=False, + ), + ), + ], + edges=[ + RFEdgeDTO( + id="start-agent", + source="start", + target="agent", + data=EdgeDataDTO( + label="Collect Info", + condition="When user has been greeted, proceed to collect information", + ), + ), + RFEdgeDTO( + id="agent-end", + source="agent", + target="end", + data=EdgeDataDTO( + label="End Call", + condition="When information collection is complete, end the call", + ), + ), + ], + ) + return WorkflowGraph(dto) + + +@pytest.fixture +def three_node_workflow_no_variable_extraction() -> WorkflowGraph: + """Create a three-node workflow without variable extraction + + The workflow has: + - Start node with extraction DISABLED + - Agent node with extraction DISABLED + - End node (no extraction) + """ + dto = ReactFlowDTO( + nodes=[ + RFNodeDTO( + id="start", + type=NodeType.startNode, + position=Position(x=0, y=0), + data=NodeDataDTO( + name="Start Call", + prompt=START_CALL_SYSTEM_PROMPT, + is_start=True, + allow_interrupt=False, + add_global_prompt=False, + extraction_enabled=False, + ), + ), + RFNodeDTO( + id="agent", + type=NodeType.agentNode, + position=Position(x=0, y=200), + data=NodeDataDTO( + name="Collect Info", + prompt=AGENT_SYSTEM_PROMPT, + allow_interrupt=False, + add_global_prompt=False, + extraction_enabled=False, # Explicitly disabled for testing + ), + ), + RFNodeDTO( + id="end", + type=NodeType.endNode, + position=Position(x=0, y=400), + data=NodeDataDTO( + name="End Call", + prompt=END_CALL_SYSTEM_PROMPT, + is_end=True, + allow_interrupt=False, + add_global_prompt=False, + extraction_enabled=False, + ), + ), + ], + edges=[ + RFEdgeDTO( + id="start-agent", + source="start", + target="agent", + data=EdgeDataDTO( + label="Collect Info", + condition="When user has been greeted, proceed to collect information", + ), + ), + RFEdgeDTO( + id="agent-end", + source="agent", + target="end", + data=EdgeDataDTO( + label="End Call", + condition="When information collection is complete, end the call", ), ), ], diff --git a/api/tests/test_pipecat_engine_context_update.py b/api/tests/test_pipecat_engine_context_update.py index e20a7ab..17bc95f 100644 --- a/api/tests/test_pipecat_engine_context_update.py +++ b/api/tests/test_pipecat_engine_context_update.py @@ -19,18 +19,13 @@ from unittest.mock import AsyncMock, patch import pytest -from api.services.workflow.dto import ( - EdgeDataDTO, - NodeDataDTO, - NodeType, - Position, - ReactFlowDTO, - RFEdgeDTO, - RFNodeDTO, -) from api.services.workflow.pipecat_engine import PipecatEngine from api.services.workflow.workflow import WorkflowGraph -from api.tests.conftest import MockTransportProcessor +from api.tests.conftest import ( + AGENT_SYSTEM_PROMPT, + END_CALL_SYSTEM_PROMPT, + START_CALL_SYSTEM_PROMPT, +) from pipecat.frames.frames import LLMContextFrame from pipecat.pipeline.pipeline import Pipeline from pipecat.pipeline.runner import PipelineRunner @@ -41,86 +36,7 @@ from pipecat.processors.aggregators.llm_response_universal import ( LLMContextAggregatorPair, ) from pipecat.tests import MockLLMService, MockTTSService - -# Define prompts for test nodes -START_NODE_PROMPT = "Start Node System Prompt" -AGENT_NODE_PROMPT = "Agent Node System Prompt" -END_NODE_PROMPT = "End Node System Prompt" - - -@pytest.fixture -def three_node_workflow_for_context_test() -> WorkflowGraph: - """Create a three-node workflow for testing context updates during transitions. - - The workflow has: - - Start node with prompt to greet user - - Agent node with prompt to collect information - - End node with prompt to say goodbye - - Edges: - - Start -> Agent (label: "Collect Info") - - Agent -> End (label: "End Call") - """ - dto = ReactFlowDTO( - nodes=[ - RFNodeDTO( - id="start", - type=NodeType.startNode, - position=Position(x=0, y=0), - data=NodeDataDTO( - name="Start Call", - prompt=START_NODE_PROMPT, - is_start=True, - allow_interrupt=False, - add_global_prompt=False, - ), - ), - RFNodeDTO( - id="agent", - type=NodeType.agentNode, - position=Position(x=0, y=200), - data=NodeDataDTO( - name="Collect Info", - prompt=AGENT_NODE_PROMPT, - allow_interrupt=False, - add_global_prompt=False, - ), - ), - RFNodeDTO( - id="end", - type=NodeType.endNode, - position=Position(x=0, y=400), - data=NodeDataDTO( - name="End Call", - prompt=END_NODE_PROMPT, - is_end=True, - allow_interrupt=False, - add_global_prompt=False, - ), - ), - ], - edges=[ - RFEdgeDTO( - id="start-agent", - source="start", - target="agent", - data=EdgeDataDTO( - label="Collect Info", - condition="When user has been greeted, proceed to collect information", - ), - ), - RFEdgeDTO( - id="agent-end", - source="agent", - target="end", - data=EdgeDataDTO( - label="End Call", - condition="When information collection is complete, end the call", - ), - ), - ], - ) - return WorkflowGraph(dto) +from pipecat.tests.mock_transport import MockTransport class ContextCapturingMockLLM(MockLLMService): @@ -215,7 +131,8 @@ async def run_pipeline_and_capture_context( # Create MockTTSService tts = MockTTSService(mock_audio_duration_ms=10, frame_delay=0) - mock_transport_emulator = MockTransportProcessor(emit_bot_speaking=False) + # Create MockTransport for simulating transport behavior + mock_transport = MockTransport(emit_bot_speaking=False) # Create LLM context context = LLMContext() @@ -251,15 +168,13 @@ async def run_pipeline_and_capture_context( [ llm, tts, - mock_transport_emulator, + mock_transport.output(), assistant_context_aggregator, ] ) # Create pipeline task - task = PipelineTask( - pipeline, params=PipelineParams(allow_interruptions=False), enable_rtvi=False - ) + task = PipelineTask(pipeline, params=PipelineParams(), enable_rtvi=False) engine.set_task(task) @@ -294,7 +209,7 @@ class TestContextUpdateBeforeNextCompletion: @pytest.mark.asyncio async def test_single_transition_updates_context_before_next_completion( - self, three_node_workflow_for_context_test: WorkflowGraph + self, three_node_workflow: WorkflowGraph ): """Test that a single transition function call updates context before next LLM generation. @@ -329,7 +244,7 @@ class TestContextUpdateBeforeNextCompletion: mock_steps = [step_0_chunks, step_1_chunks, step_2_chunks] llm, _ = await run_pipeline_and_capture_context( - workflow=three_node_workflow_for_context_test, + workflow=three_node_workflow, mock_steps=mock_steps, set_node_delay=0.05, # Introduce 50ms delay in set_node ) @@ -341,17 +256,17 @@ class TestContextUpdateBeforeNextCompletion: # Verify step 0 (start node) had start node's system prompt step_0_prompt = llm.get_system_prompt_at_step(0) - assert START_NODE_PROMPT in step_0_prompt, ( + assert START_CALL_SYSTEM_PROMPT in step_0_prompt, ( f"Step 0 should have start node prompt, got: {step_0_prompt[:100]}" ) # Verify step 1 (agent node) had: # 1. The agent node's system prompt (not start node's) step_1_prompt = llm.get_system_prompt_at_step(1) - assert AGENT_NODE_PROMPT in step_1_prompt, ( + assert AGENT_SYSTEM_PROMPT in step_1_prompt, ( f"Step 1 should have agent node prompt, got: {step_1_prompt[:100]}" ) - assert START_NODE_PROMPT not in step_1_prompt, ( + assert START_CALL_SYSTEM_PROMPT not in step_1_prompt, ( "Step 1 should NOT have start node prompt anymore" ) @@ -371,7 +286,7 @@ class TestContextUpdateBeforeNextCompletion: @pytest.mark.asyncio async def test_sequential_transitions_maintain_correct_context( - self, three_node_workflow_for_context_test: WorkflowGraph + self, three_node_workflow: WorkflowGraph ): """Test that sequential transitions maintain correct context at each step. @@ -403,7 +318,7 @@ class TestContextUpdateBeforeNextCompletion: mock_steps = [step_0_chunks, step_1_chunks, step_2_chunks] llm, _ = await run_pipeline_and_capture_context( - workflow=three_node_workflow_for_context_test, + workflow=three_node_workflow, mock_steps=mock_steps, set_node_delay=0.05, # Introduce 50ms delay in set_node ) @@ -414,13 +329,13 @@ class TestContextUpdateBeforeNextCompletion: ) # Step 0: Start node - should have start prompt - assert START_NODE_PROMPT in llm.get_system_prompt_at_step(0) + assert START_CALL_SYSTEM_PROMPT in llm.get_system_prompt_at_step(0) # Step 1: Agent node - should have agent prompt - assert AGENT_NODE_PROMPT in llm.get_system_prompt_at_step(1) + assert AGENT_SYSTEM_PROMPT in llm.get_system_prompt_at_step(1) # Step 2: End node - should have end prompt - assert END_NODE_PROMPT in llm.get_system_prompt_at_step(2) + assert END_CALL_SYSTEM_PROMPT in llm.get_system_prompt_at_step(2) # Verify each subsequent step has the previous tool results step_1_ctx = llm.get_context_at_step(1) @@ -445,7 +360,7 @@ class TestContextUpdateBeforeNextCompletion: @pytest.mark.asyncio async def test_context_messages_preserve_conversation_history( - self, three_node_workflow_for_context_test: WorkflowGraph + self, three_node_workflow: WorkflowGraph ): """Test that conversation history is preserved across node transitions. @@ -474,7 +389,7 @@ class TestContextUpdateBeforeNextCompletion: mock_steps = [step_0_chunks, step_1_chunks, step_2_chunks] llm, _ = await run_pipeline_and_capture_context( - workflow=three_node_workflow_for_context_test, + workflow=three_node_workflow, mock_steps=mock_steps, ) diff --git a/api/tests/test_pipecat_engine_end_call.py b/api/tests/test_pipecat_engine_end_call.py new file mode 100644 index 0000000..389047b --- /dev/null +++ b/api/tests/test_pipecat_engine_end_call.py @@ -0,0 +1,1097 @@ +"""Tests for verifying end_call_with_reason behavior in different scenarios. + +This module tests the end call flow in PipecatEngine with multiple scenarios: +1. Normal end call when transitioning to end node +2. End call triggered by custom end_call tool +3. End call triggered by on_client_disconnected +4. Race condition between end_call tool and client disconnect + +For all scenarios, we verify: +- Pipeline muting (_mute_pipeline is set to True) +- Variable extraction is called (_perform_variable_extraction_if_needed) +- Call disposition flag is set (_call_disposed is True) +- User audio muting via CallbackUserMuteStrategy (should_mute_user returns True) + +The tests use MockTransport with audio generation to simulate real pipeline scenarios +where InputAudioRawFrame frames are continuously generated. The pipeline includes +LLMUserAggregatorParams with user mute strategies (MuteUntilFirstBotCompleteUserMuteStrategy +and CallbackUserMuteStrategy) matching the production run_pipeline.py configuration. +""" + +import asyncio +from typing import Any, Dict, List +from unittest.mock import AsyncMock, patch + +import pytest + +from api.enums import ToolCategory +from api.services.workflow.dto import ( + EdgeDataDTO, + NodeDataDTO, + NodeType, + Position, + ReactFlowDTO, + RFEdgeDTO, + RFNodeDTO, +) +from api.services.workflow.pipecat_engine import PipecatEngine +from api.services.workflow.pipecat_engine_custom_tools import CustomToolManager +from api.services.workflow.pipecat_engine_variable_extractor import ( + VariableExtractionManager, +) +from api.services.workflow.workflow import WorkflowGraph +from api.tests.conftest import END_CALL_SYSTEM_PROMPT, START_CALL_SYSTEM_PROMPT +from pipecat.frames.frames import Frame, LLMContextFrame +from pipecat.pipeline.pipeline import Pipeline +from pipecat.pipeline.runner import PipelineRunner +from pipecat.pipeline.task import PipelineParams, PipelineTask +from pipecat.processors.aggregators.llm_context import LLMContext +from pipecat.processors.aggregators.llm_response import LLMAssistantAggregatorParams +from pipecat.processors.aggregators.llm_response_universal import ( + LLMContextAggregatorPair, + LLMUserAggregatorParams, +) +from pipecat.tests import MockLLMService, MockTTSService +from pipecat.tests.mock_transport import MockTransport +from pipecat.turns.user_mute import ( + CallbackUserMuteStrategy, + MuteUntilFirstBotCompleteUserMuteStrategy, +) +from pipecat.utils.enums import EndTaskReason + + +class EndCallTestHelper: + """Helper class to track end call related state during tests.""" + + def __init__(self): + self.extraction_calls: List[Dict[str, Any]] = [] + self.mute_pipeline_state: List[bool] = [] + self.call_disposed_state: List[bool] = [] + self.end_call_reasons: List[str] = [] + self.frames_queued: List[Any] = [] + self.should_mute_user_calls: List[bool] = [] + + def reset(self): + """Reset all tracked state.""" + self.extraction_calls.clear() + self.mute_pipeline_state.clear() + self.call_disposed_state.clear() + self.end_call_reasons.clear() + self.frames_queued.clear() + self.should_mute_user_calls.clear() + + +class MockEndCallToolModel: + """Mock end call tool model for testing.""" + + def __init__( + self, + tool_uuid: str = "end-call-uuid", + name: str = "End Call", + description: str = "End the current call", + message_type: str = "none", + custom_message: str = "", + ): + self.tool_uuid = tool_uuid + self.name = name + self.description = description + self.category = ToolCategory.END_CALL.value + self.definition = { + "schema_version": 1, + "type": "end_call", + "config": { + "messageType": message_type, + "customMessage": custom_message, + }, + } + + +async def create_engine_with_tracking( + workflow: WorkflowGraph, + mock_llm: MockLLMService, + test_helper: EndCallTestHelper, + generate_audio: bool = True, +) -> tuple[PipecatEngine, MockTTSService, MockTransport, PipelineTask]: + """Create a PipecatEngine with tracking for end call behavior. + + Args: + workflow: The workflow graph to use. + mock_llm: The mock LLM service. + test_helper: Helper to track test state. + generate_audio: If True, the mock transport generates InputAudioRawFrame + every 20ms to simulate real audio input. + + Returns: + Tuple of (engine, tts, transport, task) + """ + # Create MockTTSService + tts = MockTTSService(mock_audio_duration_ms=10, frame_delay=0) + + # Create MockTransport with audio generation to simulate real pipeline + mock_transport = MockTransport( + generate_audio=generate_audio, + audio_interval_ms=20, + audio_sample_rate=16000, + audio_num_channels=1, + emit_bot_speaking=True, + ) + + # Create LLM context + context = LLMContext() + + # Create PipecatEngine with the workflow (before context aggregator so we can use its callback) + engine = PipecatEngine( + llm=mock_llm, + context=context, + workflow=workflow, + call_context_vars={"customer_name": "Test User"}, + workflow_run_id=1, + ) + + # Track variable extraction calls + original_perform_extraction = engine._perform_variable_extraction_if_needed + + async def tracked_perform_extraction(node, run_in_background: bool = True): + test_helper.extraction_calls.append( + { + "node_id": node.id if node else None, + "node_name": node.name if node else None, + "extraction_enabled": node.extraction_enabled if node else None, + "run_in_background": run_in_background, + } + ) + await original_perform_extraction(node, run_in_background=run_in_background) + + engine._perform_variable_extraction_if_needed = tracked_perform_extraction + + # Track end_call_with_reason calls + original_end_call = engine.end_call_with_reason + + async def tracked_end_call(reason: str, abort_immediately: bool = False): + # Record state before end_call_with_reason modifies it + test_helper.end_call_reasons.append(reason) + await original_end_call(reason, abort_immediately) + # Record state after end_call_with_reason + test_helper.mute_pipeline_state.append(engine._mute_pipeline) + test_helper.call_disposed_state.append(engine._call_disposed) + + engine.end_call_with_reason = tracked_end_call + + # Create context aggregator with user mute strategies (after engine so we can use its callback) + assistant_params = LLMAssistantAggregatorParams(expect_stripped_words=True) + + # Wrap should_mute_user to track calls + original_should_mute_user = engine.should_mute_user + + def tracked_should_mute_user(frame: Frame) -> bool: + result = original_should_mute_user(frame) + test_helper.should_mute_user_calls.append(result) + return result + + # Create user mute strategies matching run_pipeline.py + # - MuteUntilFirstBotCompleteUserMuteStrategy: mutes until first bot response completes + # - CallbackUserMuteStrategy: mutes based on engine's _mute_pipeline state + user_mute_strategies = [ + MuteUntilFirstBotCompleteUserMuteStrategy(), + CallbackUserMuteStrategy(should_mute_callback=tracked_should_mute_user), + ] + + user_params = LLMUserAggregatorParams( + user_mute_strategies=user_mute_strategies, + ) + + context_aggregator = LLMContextAggregatorPair( + context, assistant_params=assistant_params, user_params=user_params + ) + user_context_aggregator = context_aggregator.user() + assistant_context_aggregator = context_aggregator.assistant() + + # Create the pipeline with transport input -> user aggregator -> LLM -> TTS -> transport output -> assistant aggregator + pipeline = Pipeline( + [ + mock_transport.input(), + user_context_aggregator, + mock_llm, + tts, + mock_transport.output(), + assistant_context_aggregator, + ] + ) + + # Create pipeline task + task = PipelineTask(pipeline, params=PipelineParams(), enable_rtvi=False) + + engine.set_task(task) + + return engine, tts, mock_transport, task + + +class TestEndCallViaNodeTransition: + """Test end call behavior when transitioning to an end node.""" + + @pytest.mark.asyncio + async def test_end_call_via_transition_mutes_pipeline_and_extracts_variables( + self, simple_workflow: WorkflowGraph + ): + """Test that transitioning to end node mutes pipeline and extracts variables. + + Scenario: + 1. Start node has extraction_enabled=True + 2. LLM calls transition function to end node + 3. VERIFY: Pipeline is muted + 4. VERIFY: Variable extraction is called with run_in_background=False + 5. VERIFY: Call is disposed + """ + test_helper = EndCallTestHelper() + + # Step 0 (Start node): greet user then call end_call to transition to end + step_0_chunks = MockLLMService.create_mixed_chunks( + text="Hello! Thank you for calling. Goodbye!", + function_name="end_call", + arguments={}, + tool_call_id="call_end_1", + ) + + mock_steps = [step_0_chunks] + llm = MockLLMService(mock_steps=mock_steps, chunk_delay=0.001) + + engine, tts, transport, task = await create_engine_with_tracking( + simple_workflow, llm, test_helper + ) + + # Patch DB calls and extraction manager + with patch( + "api.services.workflow.pipecat_engine.get_organization_id_from_workflow_run", + new_callable=AsyncMock, + return_value=1, + ): + with patch( + "api.services.workflow.pipecat_engine.apply_disposition_mapping", + new_callable=AsyncMock, + return_value="completed", + ): + with patch.object( + VariableExtractionManager, + "_perform_extraction", + new_callable=AsyncMock, + return_value={"user_intent": "end call"}, + ): + runner = PipelineRunner() + + async def run_pipeline(): + await runner.run(task) + + async def initialize_engine(): + await asyncio.sleep(0.01) + await engine.initialize() + await engine.llm.queue_frame(LLMContextFrame(engine.context)) + + await asyncio.gather(run_pipeline(), initialize_engine()) + + # Verify end_call_with_reason was called + assert len(test_helper.end_call_reasons) >= 1, ( + "end_call_with_reason should have been called" + ) + assert EndTaskReason.USER_QUALIFIED.value in test_helper.end_call_reasons + + # Verify pipeline was muted + assert any(test_helper.mute_pipeline_state), "Pipeline should be muted" + + # Verify call was disposed + assert any(test_helper.call_disposed_state), "Call should be disposed" + + # Verify variable extraction was called + # Should have extraction calls - at least one for the transition + # and one synchronous call in end_call_with_reason + sync_extraction_calls = [ + call + for call in test_helper.extraction_calls + if not call["run_in_background"] + ] + assert len(sync_extraction_calls) >= 1, ( + f"Expected at least 1 synchronous extraction call, got {len(sync_extraction_calls)}. " + f"All calls: {test_helper.extraction_calls}" + ) + + # Verify user muting behavior via CallbackUserMuteStrategy + # After end_call_with_reason, should_mute_user should return True + # which causes CallbackUserMuteStrategy to mute user audio + assert len(test_helper.should_mute_user_calls) > 0, ( + "should_mute_user callback should have been called during pipeline execution" + ) + # The last calls should return True (after _mute_pipeline is set) + assert any(test_helper.should_mute_user_calls), ( + "should_mute_user should return True after end_call_with_reason sets _mute_pipeline" + ) + + @pytest.mark.asyncio + async def test_multi_node_transition_to_end_extracts_from_correct_nodes( + self, three_node_workflow: WorkflowGraph + ): + """Test that multi-node workflow extracts variables from correct nodes. + + Scenario: + 1. Start -> Agent -> End transitions + 2. Both start and agent nodes have extraction enabled + 3. VERIFY: Extraction is called for start node during first transition + 4. VERIFY: Extraction is called for agent node during second transition + 5. VERIFY: Final synchronous extraction is called in end_call_with_reason + """ + test_helper = EndCallTestHelper() + + # Step 0 (Start node): greet user then call collect_info to transition to agent + step_0_chunks = MockLLMService.create_mixed_chunks( + text="Hello! Welcome to our service. Let me collect some information.", + function_name="collect_info", + arguments={}, + tool_call_id="call_transition_1", + ) + + # Step 1 (Agent node): acknowledge then call end_call to transition to end + step_1_chunks = MockLLMService.create_mixed_chunks( + text="Thank you for providing that information. Goodbye!", + function_name="end_call", + arguments={}, + tool_call_id="call_transition_2", + ) + + mock_steps = [step_0_chunks, step_1_chunks] + llm = MockLLMService(mock_steps=mock_steps, chunk_delay=0.001) + + engine, tts, transport, task = await create_engine_with_tracking( + three_node_workflow, llm, test_helper + ) + + # Patch DB calls and extraction manager + with patch( + "api.services.workflow.pipecat_engine.get_organization_id_from_workflow_run", + new_callable=AsyncMock, + return_value=1, + ): + with patch( + "api.services.workflow.pipecat_engine.apply_disposition_mapping", + new_callable=AsyncMock, + return_value="completed", + ): + with patch.object( + VariableExtractionManager, + "_perform_extraction", + new_callable=AsyncMock, + return_value={"greeting_type": "formal", "user_name": "John"}, + ): + runner = PipelineRunner() + + async def run_pipeline(): + await runner.run(task) + + async def initialize_engine(): + await asyncio.sleep(0.01) + await engine.initialize() + await engine.llm.queue_frame(LLMContextFrame(engine.context)) + + await asyncio.gather(run_pipeline(), initialize_engine()) + + # Should have 3 LLM generations + assert llm.get_current_step() == 3 + + # Verify end_call_with_reason was called + assert len(test_helper.end_call_reasons) >= 1 + assert EndTaskReason.USER_QUALIFIED.value in test_helper.end_call_reasons + + # Verify pipeline was muted and call disposed + assert any(test_helper.mute_pipeline_state), "Pipeline should be muted" + assert any(test_helper.call_disposed_state), "Call should be disposed" + + # Verify extraction was called multiple times + # Background extractions during transitions + synchronous in end_call + assert len(test_helper.extraction_calls) >= 2, ( + f"Expected at least 2 extraction calls, got {len(test_helper.extraction_calls)}" + ) + + # Verify user muting is active after call ends + assert any(test_helper.should_mute_user_calls), ( + "should_mute_user should return True after end call" + ) + + +class TestEndCallViaCustomTool: + """Test end call behavior when using custom end_call tool.""" + + @pytest.mark.asyncio + async def test_end_call_tool_without_message_ends_immediately( + self, simple_workflow: WorkflowGraph + ): + """Test that end_call tool without custom message ends call immediately. + + Scenario: + 1. LLM calls a custom end_call tool (no message configured) + 2. VERIFY: Pipeline is muted + 3. VERIFY: Variable extraction is called + 4. VERIFY: Call ends with abort_immediately=True + """ + test_helper = EndCallTestHelper() + + # Step 0: call end_call tool + step_0_chunks = MockLLMService.create_function_call_chunks( + function_name="end_call_tool", + arguments={}, + tool_call_id="call_end_tool_1", + ) + + mock_steps = [step_0_chunks] + llm = MockLLMService(mock_steps=mock_steps, chunk_delay=0.001) + + engine, tts, transport, task = await create_engine_with_tracking( + simple_workflow, llm, test_helper + ) + + # Create end call tool + end_call_tool = MockEndCallToolModel( + message_type="none", # No message, immediate end + ) + + # Create CustomToolManager and register the end call handler + custom_tool_manager = CustomToolManager(engine) + engine._custom_tool_manager = custom_tool_manager + + # Manually register the end call handler + handler = custom_tool_manager._create_end_call_handler( + end_call_tool, "end_call_tool" + ) + llm.register_function("end_call_tool", handler) + + # Patch DB calls and extraction manager + with patch( + "api.services.workflow.pipecat_engine.get_organization_id_from_workflow_run", + new_callable=AsyncMock, + return_value=1, + ): + with patch( + "api.services.workflow.pipecat_engine.apply_disposition_mapping", + new_callable=AsyncMock, + return_value="end_call_tool", + ): + with patch.object( + VariableExtractionManager, + "_perform_extraction", + new_callable=AsyncMock, + return_value={"user_intent": "end"}, + ): + runner = PipelineRunner() + + async def run_pipeline(): + await runner.run(task) + + async def initialize_engine(): + await asyncio.sleep(0.01) + await engine.initialize() + await engine.llm.queue_frame(LLMContextFrame(engine.context)) + + await asyncio.gather(run_pipeline(), initialize_engine()) + + # Verify end_call_with_reason was called with END_CALL_TOOL_REASON + assert len(test_helper.end_call_reasons) >= 1, ( + "end_call_with_reason should have been called" + ) + assert EndTaskReason.END_CALL_TOOL_REASON.value in test_helper.end_call_reasons + + # Verify pipeline was muted + assert any(test_helper.mute_pipeline_state), "Pipeline should be muted" + + # Verify call was disposed + assert any(test_helper.call_disposed_state), "Call should be disposed" + + # Verify user muting is active via CallbackUserMuteStrategy + assert any(test_helper.should_mute_user_calls), ( + "should_mute_user should return True after end_call_tool" + ) + + @pytest.mark.asyncio + async def test_end_call_tool_with_custom_message_speaks_before_ending( + self, simple_workflow: WorkflowGraph + ): + """Test that end_call tool with custom message speaks before ending. + + Scenario: + 1. LLM calls a custom end_call tool with custom message + 2. VERIFY: TTS speaks the goodbye message + 3. VERIFY: Pipeline is muted + 4. VERIFY: Variable extraction is called + """ + test_helper = EndCallTestHelper() + + # Step 0: call end_call tool + step_0_chunks = MockLLMService.create_function_call_chunks( + function_name="end_call_with_message", + arguments={}, + tool_call_id="call_end_tool_1", + ) + + mock_steps = [step_0_chunks] + llm = MockLLMService(mock_steps=mock_steps, chunk_delay=0.001) + + engine, tts, transport, task = await create_engine_with_tracking( + simple_workflow, llm, test_helper + ) + + # Create end call tool with custom message + end_call_tool = MockEndCallToolModel( + name="End Call With Message", + message_type="custom", + custom_message="Thank you for calling. Goodbye!", + ) + + # Create CustomToolManager and register the end call handler + custom_tool_manager = CustomToolManager(engine) + engine._custom_tool_manager = custom_tool_manager + + # Manually register the end call handler + handler = custom_tool_manager._create_end_call_handler( + end_call_tool, "end_call_with_message" + ) + llm.register_function("end_call_with_message", handler) + + # Patch DB calls and extraction manager + with patch( + "api.services.workflow.pipecat_engine.get_organization_id_from_workflow_run", + new_callable=AsyncMock, + return_value=1, + ): + with patch( + "api.services.workflow.pipecat_engine.apply_disposition_mapping", + new_callable=AsyncMock, + return_value="end_call_tool", + ): + with patch.object( + VariableExtractionManager, + "_perform_extraction", + new_callable=AsyncMock, + return_value={"user_intent": "end"}, + ): + runner = PipelineRunner() + + async def run_pipeline(): + await runner.run(task) + + async def initialize_engine(): + await asyncio.sleep(0.01) + await engine.initialize() + await engine.llm.queue_frame(LLMContextFrame(engine.context)) + + await asyncio.gather(run_pipeline(), initialize_engine()) + + # Verify end_call_with_reason was called + assert len(test_helper.end_call_reasons) >= 1, ( + "end_call_with_reason should have been called" + ) + assert EndTaskReason.END_CALL_TOOL_REASON.value in test_helper.end_call_reasons + + # Verify pipeline was muted + assert any(test_helper.mute_pipeline_state), "Pipeline should be muted" + + # Verify call was disposed + assert any(test_helper.call_disposed_state), "Call should be disposed" + + # Verify user muting is active via CallbackUserMuteStrategy + assert any(test_helper.should_mute_user_calls), ( + "should_mute_user should return True after end_call_with_message" + ) + + +class TestEndCallViaClientDisconnect: + """Test end call behavior when client disconnects.""" + + @pytest.mark.asyncio + async def test_client_disconnect_ends_call_with_user_hangup( + self, simple_workflow: WorkflowGraph + ): + """Test that client disconnect triggers end_call_with_reason. + + Scenario: + 1. Pipeline is running + 2. Client disconnects (simulated via direct call to end_call_with_reason) + 3. VERIFY: Pipeline is muted + 4. VERIFY: Variable extraction is called + 5. VERIFY: Reason is USER_HANGUP + """ + test_helper = EndCallTestHelper() + + # Create a simple text response + step_0_chunks = MockLLMService.create_text_chunks( + "Hello! How can I help you today?" + ) + + mock_steps = [step_0_chunks] + llm = MockLLMService(mock_steps=mock_steps, chunk_delay=0.001) + + engine, tts, transport, task = await create_engine_with_tracking( + simple_workflow, llm, test_helper + ) + + # Patch DB calls and extraction manager + with patch( + "api.services.workflow.pipecat_engine.get_organization_id_from_workflow_run", + new_callable=AsyncMock, + return_value=1, + ): + with patch( + "api.services.workflow.pipecat_engine.apply_disposition_mapping", + new_callable=AsyncMock, + return_value="user_hangup", + ): + with patch.object( + VariableExtractionManager, + "_perform_extraction", + new_callable=AsyncMock, + return_value={"user_intent": "disconnected"}, + ): + runner = PipelineRunner() + + async def run_pipeline(): + await runner.run(task) + + async def initialize_and_disconnect(): + await asyncio.sleep(0.01) + await engine.initialize() + await engine.llm.queue_frame(LLMContextFrame(engine.context)) + + # Wait for initial generation to complete + await asyncio.sleep(0.1) + + # Simulate client disconnect by calling end_call_with_reason directly + # This is what on_client_disconnected does + await engine.end_call_with_reason( + EndTaskReason.USER_HANGUP.value, abort_immediately=True + ) + + await asyncio.gather(run_pipeline(), initialize_and_disconnect()) + + # Verify end_call_with_reason was called with USER_HANGUP + assert EndTaskReason.USER_HANGUP.value in test_helper.end_call_reasons, ( + f"Expected USER_HANGUP in reasons, got: {test_helper.end_call_reasons}" + ) + + # Verify pipeline was muted + assert any(test_helper.mute_pipeline_state), "Pipeline should be muted" + + # Verify call was disposed + assert any(test_helper.call_disposed_state), "Call should be disposed" + + # Verify synchronous extraction was called (run_in_background=False) + sync_extraction_calls = [ + call + for call in test_helper.extraction_calls + if not call["run_in_background"] + ] + assert len(sync_extraction_calls) >= 1, ( + f"Expected at least 1 synchronous extraction call during disconnect. " + f"All calls: {test_helper.extraction_calls}" + ) + + # Verify user muting is active via CallbackUserMuteStrategy + assert any(test_helper.should_mute_user_calls), ( + "should_mute_user should return True after client disconnect" + ) + + +class TestEndCallRaceConditions: + """Test race conditions between different end call triggers.""" + + @pytest.mark.asyncio + async def test_only_first_end_call_succeeds(self, simple_workflow: WorkflowGraph): + """Test that only the first end_call_with_reason call succeeds. + + Scenario: + 1. Multiple end_call_with_reason calls are made concurrently + 2. VERIFY: Only the first one sets _call_disposed + 3. VERIFY: Subsequent calls return early without doing work + """ + test_helper = EndCallTestHelper() + + # Create a simple text response + step_0_chunks = MockLLMService.create_text_chunks("Hello!") + + mock_steps = [step_0_chunks] + llm = MockLLMService(mock_steps=mock_steps, chunk_delay=0.001) + + engine, tts, transport, task = await create_engine_with_tracking( + simple_workflow, llm, test_helper + ) + + # Patch DB calls and extraction manager + with patch( + "api.services.workflow.pipecat_engine.get_organization_id_from_workflow_run", + new_callable=AsyncMock, + return_value=1, + ): + with patch( + "api.services.workflow.pipecat_engine.apply_disposition_mapping", + new_callable=AsyncMock, + return_value="first_reason", + ): + with patch.object( + VariableExtractionManager, + "_perform_extraction", + new_callable=AsyncMock, + return_value={"user_intent": "end"}, + ): + runner = PipelineRunner() + + async def run_pipeline(): + await runner.run(task) + + async def initialize_and_race(): + await asyncio.sleep(0.01) + await engine.initialize() + await engine.llm.queue_frame(LLMContextFrame(engine.context)) + + # Wait for initial generation + await asyncio.sleep(0.1) + + # Try to end call multiple times concurrently + await asyncio.gather( + engine.end_call_with_reason( + EndTaskReason.USER_HANGUP.value, abort_immediately=True + ), + engine.end_call_with_reason( + EndTaskReason.END_CALL_TOOL_REASON.value, + abort_immediately=True, + ), + engine.end_call_with_reason( + EndTaskReason.USER_QUALIFIED.value, + abort_immediately=False, + ), + ) + + await asyncio.gather(run_pipeline(), initialize_and_race()) + + # Due to the _call_disposed guard, only one end_call should fully execute + # The tracked end_call_reasons will show all attempted calls + # but only the first one should modify state + assert len(test_helper.end_call_reasons) == 3, ( + f"Expected 3 end_call attempts, got {len(test_helper.end_call_reasons)}" + ) + + # Only one should have actually set the mute_pipeline and call_disposed + # (the others return early due to _call_disposed check) + # Since we track state AFTER end_call_with_reason, we should see True values + # only from the first successful call + assert any(test_helper.mute_pipeline_state), "Pipeline should be muted" + assert any(test_helper.call_disposed_state), "Call should be disposed" + + # Verify user muting is active via CallbackUserMuteStrategy + assert any(test_helper.should_mute_user_calls), ( + "should_mute_user should return True after race condition end call" + ) + + @pytest.mark.asyncio + async def test_end_call_tool_and_disconnect_race( + self, simple_workflow: WorkflowGraph + ): + """Test race between end_call tool and client disconnect. + + Scenario: + 1. LLM calls end_call tool + 2. Client disconnects at nearly the same time + 3. VERIFY: Only one end call succeeds + 4. VERIFY: Call is properly disposed + """ + test_helper = EndCallTestHelper() + + # Step 0: Text response + step_0_chunks = MockLLMService.create_text_chunks("Hello!") + + # Step 1: call end_call tool + step_1_chunks = MockLLMService.create_function_call_chunks( + function_name="end_call_tool", + arguments={}, + tool_call_id="call_end_tool_1", + ) + + mock_steps = [step_0_chunks, step_1_chunks] + llm = MockLLMService(mock_steps=mock_steps, chunk_delay=0.001) + + engine, tts, transport, task = await create_engine_with_tracking( + simple_workflow, llm, test_helper + ) + + # Create end call tool + end_call_tool = MockEndCallToolModel(message_type="none") + + # Create CustomToolManager and register the end call handler + custom_tool_manager = CustomToolManager(engine) + engine._custom_tool_manager = custom_tool_manager + + handler = custom_tool_manager._create_end_call_handler( + end_call_tool, "end_call_tool" + ) + llm.register_function("end_call_tool", handler) + + disconnect_called = False + + # Patch DB calls and extraction manager + with patch( + "api.services.workflow.pipecat_engine.get_organization_id_from_workflow_run", + new_callable=AsyncMock, + return_value=1, + ): + with patch( + "api.services.workflow.pipecat_engine.apply_disposition_mapping", + new_callable=AsyncMock, + return_value="end_reason", + ): + with patch.object( + VariableExtractionManager, + "_perform_extraction", + new_callable=AsyncMock, + return_value={"user_intent": "end"}, + ): + runner = PipelineRunner() + + async def run_pipeline(): + await runner.run(task) + + async def initialize_and_race_disconnect(): + nonlocal disconnect_called + await asyncio.sleep(0.01) + await engine.initialize() + await engine.llm.queue_frame(LLMContextFrame(engine.context)) + + # Wait for the end_call tool to be called + await asyncio.sleep(0.15) + + # Simulate client disconnect racing with end_call tool + disconnect_called = True + await engine.end_call_with_reason( + EndTaskReason.USER_HANGUP.value, abort_immediately=True + ) + + await asyncio.gather( + run_pipeline(), initialize_and_race_disconnect() + ) + + # Verify disconnect was attempted + assert disconnect_called, "Disconnect should have been called" + + # Verify at least one end call reason was recorded + assert len(test_helper.end_call_reasons) >= 1, ( + "At least one end_call should have been attempted" + ) + + # Verify call was properly disposed + assert engine._call_disposed, "Call should be disposed" + + # Verify pipeline was muted + assert engine._mute_pipeline, "Pipeline should be muted" + + # Verify user muting is active via CallbackUserMuteStrategy + assert any(test_helper.should_mute_user_calls), ( + "should_mute_user should return True after end call" + ) + + +class TestEndCallExtractionBehavior: + """Test variable extraction behavior during end call.""" + + @pytest.mark.asyncio + async def test_synchronous_extraction_in_end_call( + self, simple_workflow: WorkflowGraph + ): + """Test that end_call_with_reason performs synchronous extraction. + + Scenario: + 1. End call is triggered + 2. VERIFY: Variable extraction is called with run_in_background=False + 3. VERIFY: Extraction completes before call ends + """ + test_helper = EndCallTestHelper() + extraction_completed = asyncio.Event() + + # Create a simple text response + step_0_chunks = MockLLMService.create_text_chunks("Hello!") + + mock_steps = [step_0_chunks] + llm = MockLLMService(mock_steps=mock_steps, chunk_delay=0.001) + + engine, tts, transport, task = await create_engine_with_tracking( + simple_workflow, llm, test_helper + ) + + # Create a custom extraction mock that signals when called + async def mock_extraction(*args, **kwargs): + # Simulate some extraction work + await asyncio.sleep(0.05) + extraction_completed.set() + return {"user_intent": "extracted"} + + # Patch DB calls and extraction manager + with patch( + "api.services.workflow.pipecat_engine.get_organization_id_from_workflow_run", + new_callable=AsyncMock, + return_value=1, + ): + with patch( + "api.services.workflow.pipecat_engine.apply_disposition_mapping", + new_callable=AsyncMock, + return_value="completed", + ): + with patch.object( + VariableExtractionManager, + "_perform_extraction", + side_effect=mock_extraction, + ): + runner = PipelineRunner() + + async def run_pipeline(): + await runner.run(task) + + async def initialize_and_end(): + await asyncio.sleep(0.01) + await engine.initialize() + await engine.llm.queue_frame(LLMContextFrame(engine.context)) + + # Wait for initial generation + await asyncio.sleep(0.1) + + # End the call + await engine.end_call_with_reason( + EndTaskReason.USER_HANGUP.value, abort_immediately=True + ) + + # Verify extraction was awaited (synchronous) + assert extraction_completed.is_set(), ( + "Extraction should have completed before end_call returned" + ) + + await asyncio.gather(run_pipeline(), initialize_and_end()) + + # Verify synchronous extraction was used + sync_extractions = [ + call + for call in test_helper.extraction_calls + if not call["run_in_background"] + ] + assert len(sync_extractions) >= 1, ( + f"Expected synchronous extraction, got: {test_helper.extraction_calls}" + ) + + # Verify user muting is active via CallbackUserMuteStrategy + assert any(test_helper.should_mute_user_calls), ( + "should_mute_user should return True after end call" + ) + + @pytest.mark.asyncio + async def test_extraction_skipped_for_node_without_extraction( + self, simple_workflow: WorkflowGraph + ): + """Test that extraction is skipped when current node has no extraction. + + Scenario: + 1. Engine is on a node with extraction_enabled=False + 2. End call is triggered + 3. VERIFY: Extraction is attempted but skips due to node config + """ + test_helper = EndCallTestHelper() + + # Create a simple text response + step_0_chunks = MockLLMService.create_text_chunks("Hello!") + + mock_steps = [step_0_chunks] + llm = MockLLMService(mock_steps=mock_steps, chunk_delay=0.001) + + # Create a workflow where start node has NO extraction + dto = ReactFlowDTO( + nodes=[ + RFNodeDTO( + id="start", + type=NodeType.startNode, + position=Position(x=0, y=0), + data=NodeDataDTO( + name="Start Call", + prompt=START_CALL_SYSTEM_PROMPT, + is_start=True, + allow_interrupt=False, + add_global_prompt=False, + extraction_enabled=False, # No extraction + ), + ), + RFNodeDTO( + id="end", + type=NodeType.endNode, + position=Position(x=0, y=200), + data=NodeDataDTO( + name="End Call", + prompt=END_CALL_SYSTEM_PROMPT, + is_end=True, + allow_interrupt=False, + add_global_prompt=False, + extraction_enabled=False, + ), + ), + ], + edges=[ + RFEdgeDTO( + id="start-end", + source="start", + target="end", + data=EdgeDataDTO( + label="End Call", + condition="When ready to end the call", + ), + ), + ], + ) + workflow_no_extraction = WorkflowGraph(dto) + + engine, tts, transport, task = await create_engine_with_tracking( + workflow_no_extraction, llm, test_helper + ) + + extraction_mock = AsyncMock(return_value={}) + + # Patch DB calls and extraction manager + with patch( + "api.services.workflow.pipecat_engine.get_organization_id_from_workflow_run", + new_callable=AsyncMock, + return_value=1, + ): + with patch( + "api.services.workflow.pipecat_engine.apply_disposition_mapping", + new_callable=AsyncMock, + return_value="completed", + ): + with patch.object( + VariableExtractionManager, + "_perform_extraction", + extraction_mock, + ): + runner = PipelineRunner() + + async def run_pipeline(): + await runner.run(task) + + async def initialize_and_end(): + await asyncio.sleep(0.01) + await engine.initialize() + await engine.llm.queue_frame(LLMContextFrame(engine.context)) + + # Wait for initial generation + await asyncio.sleep(0.1) + + # End the call + await engine.end_call_with_reason( + EndTaskReason.USER_HANGUP.value, abort_immediately=True + ) + + await asyncio.gather(run_pipeline(), initialize_and_end()) + + # Extraction should have been called but the inner _perform_extraction + # should not have been called because extraction_enabled=False + # Our tracked_perform_extraction still records the call attempt + # but VariableExtractionManager._perform_extraction should not be called + extraction_mock.assert_not_called() + + # Even without extraction, user muting should still be active + assert any(test_helper.should_mute_user_calls), ( + "should_mute_user should return True after end call (even without extraction)" + ) diff --git a/api/tests/test_pipecat_engine_node_switch_with_user_speech.py b/api/tests/test_pipecat_engine_node_switch_with_user_speech.py new file mode 100644 index 0000000..55e6ac7 --- /dev/null +++ b/api/tests/test_pipecat_engine_node_switch_with_user_speech.py @@ -0,0 +1,393 @@ +"""Tests for verifying behavior when node switch and user speech happen simultaneously. + +This module tests the interaction between node transitions and user speaking events +in the PipecatEngine. The key scenario being tested: + +1. LLM calls a transition function to move from one node to another +2. At the same time, user starts and stops speaking (triggered by FunctionCallResultFrame) +3. The pipeline should handle both events correctly + +The tests use a custom input transport that injects UserStartedSpeakingFrame and +UserStoppedSpeakingFrame when triggered by a FunctionCallResultFrame observer. +""" + +import asyncio +from typing import Optional +from unittest.mock import AsyncMock, patch + +import pytest + +from api.services.workflow.pipecat_engine import PipecatEngine +from api.services.workflow.workflow import WorkflowGraph +from pipecat.frames.frames import ( + CancelFrame, + EndFrame, + Frame, + FunctionCallResultFrame, + InputAudioRawFrame, + LLMContextFrame, + StartFrame, + TranscriptionFrame, + UserStartedSpeakingFrame, + UserStoppedSpeakingFrame, +) +from pipecat.pipeline.pipeline import Pipeline +from pipecat.pipeline.runner import PipelineRunner +from pipecat.pipeline.task import PipelineParams, PipelineTask +from pipecat.processors.aggregators.llm_context import LLMContext +from pipecat.processors.aggregators.llm_response import LLMAssistantAggregatorParams +from pipecat.processors.aggregators.llm_response_universal import ( + LLMContextAggregatorPair, + LLMUserAggregatorParams, +) +from pipecat.processors.frame_processor import FrameDirection, FrameProcessor +from pipecat.tests import MockLLMService, MockTTSService +from pipecat.tests.mock_transport import MockOutputTransport +from pipecat.transports.base_transport import BaseTransport, TransportParams +from pipecat.turns.user_mute import ( + CallbackUserMuteStrategy, + MuteUntilFirstBotCompleteUserMuteStrategy, +) +from pipecat.turns.user_start import ( + TranscriptionUserTurnStartStrategy, +) +from pipecat.turns.user_stop import ( + ExternalUserTurnStopStrategy, +) +from pipecat.turns.user_turn_strategies import UserTurnStrategies +from pipecat.utils.time import time_now_iso8601 + + +class UserSpeechInjectingInputTransport(FrameProcessor): + """Mock input transport that injects user speaking frames on FunctionCallResultFrame. + + This transport generates audio frames and automatically injects UserStartedSpeakingFrame + and UserStoppedSpeakingFrame when it sees the first FunctionCallResultFrame flowing + upstream through the pipeline. + """ + + def __init__( + self, + params: Optional[TransportParams] = None, + *, + generate_audio: bool = False, + audio_interval_ms: int = 20, + sample_rate: int = 16000, + num_channels: int = 1, + user_speech_initial_delay: float = 0.01, + **kwargs, + ): + super().__init__(**kwargs) + self._params = params or TransportParams() + self._generate_audio = generate_audio + self._audio_interval_ms = audio_interval_ms + self._sample_rate = sample_rate + self._num_channels = num_channels + self._user_speech_initial_delay = user_speech_initial_delay + self._audio_task: Optional[asyncio.Task] = None + self._running = False + self._function_call_result_count = 0 + + async def _generate_audio_frames(self): + """Generate audio frames at regular intervals.""" + samples_per_frame = int(self._sample_rate * self._audio_interval_ms / 1000) + bytes_per_frame = samples_per_frame * self._num_channels * 2 + silence_audio = bytes(bytes_per_frame) + + while self._running: + try: + frame = InputAudioRawFrame( + audio=silence_audio, + sample_rate=self._sample_rate, + num_channels=self._num_channels, + ) + await self.push_frame(frame) + await asyncio.sleep(self._audio_interval_ms / 1000) + except asyncio.CancelledError: + break + + def _start_tasks(self): + """Start audio generation task.""" + if not self._running: + self._running = True + if self._generate_audio: + self._audio_task = asyncio.create_task(self._generate_audio_frames()) + + def _stop_tasks(self): + """Stop all background tasks.""" + self._running = False + if self._audio_task and not self._audio_task.done(): + self._audio_task.cancel() + self._audio_task = None + + async def process_frame(self, frame: Frame, direction: FrameDirection): + await super().process_frame(frame, direction) + + if isinstance(frame, StartFrame): + self._start_tasks() + elif isinstance(frame, (EndFrame, CancelFrame)): + self._stop_tasks() + elif isinstance(frame, FunctionCallResultFrame): + # When we see FunctionCallResultFrame #1 flowing upstream, + # inject user speaking frames downstream + self._function_call_result_count += 1 + if self._function_call_result_count == 1: + # Simulate first race condition to generate + # LLM call close enough to the LLM call from + # function call + await asyncio.sleep(self._user_speech_initial_delay) + await self.push_frame(UserStartedSpeakingFrame()) + + await asyncio.sleep(0) + await self.push_frame( + TranscriptionFrame("First User Speech", "abc", time_now_iso8601()) + ) + + await asyncio.sleep(0) + await self.push_frame(UserStoppedSpeakingFrame()) + + # Generate second llm call + await asyncio.sleep(0.1) + await self.push_frame(UserStartedSpeakingFrame()) + + await asyncio.sleep(0) + await self.push_frame( + TranscriptionFrame("Second User Speech", "abc", time_now_iso8601()) + ) + + await asyncio.sleep(0) + await self.push_frame(UserStoppedSpeakingFrame()) + + await self.push_frame(frame, direction) + + async def cleanup(self): + self._stop_tasks() + await super().cleanup() + + +class UserSpeechInjectingTransport(BaseTransport): + """Transport that injects user speaking frames on first FunctionCallResultFrame.""" + + def __init__( + self, + params: Optional[TransportParams] = None, + *, + input_name: Optional[str] = None, + output_name: Optional[str] = None, + emit_bot_speaking: bool = True, + generate_audio: bool = False, + audio_interval_ms: int = 20, + audio_sample_rate: int = 16000, + audio_num_channels: int = 1, + user_speech_initial_delay: float = 0.01, + ): + super().__init__(input_name=input_name, output_name=output_name) + self._params = params or TransportParams() + self._input = UserSpeechInjectingInputTransport( + self._params, + name=self._input_name, + generate_audio=generate_audio, + audio_interval_ms=audio_interval_ms, + sample_rate=audio_sample_rate, + num_channels=audio_num_channels, + user_speech_initial_delay=user_speech_initial_delay, + ) + self._output = MockOutputTransport( + self._params, + emit_bot_speaking=emit_bot_speaking, + name=self._output_name, + ) + + def input(self) -> UserSpeechInjectingInputTransport: + return self._input + + def output(self) -> FrameProcessor: + return self._output + + +async def create_test_pipeline( + workflow: WorkflowGraph, + mock_llm: MockLLMService, + generate_audio: bool = True, + user_speech_initial_delay: float = 0.01, +) -> tuple[PipecatEngine, UserSpeechInjectingTransport, PipelineTask]: + """Create a PipecatEngine with full pipeline for testing node switch scenarios. + + The transport's input automatically injects UserStartedSpeakingFrame and + UserStoppedSpeakingFrame when it sees the first FunctionCallResultFrame + flowing upstream through the pipeline. + + Args: + workflow: The workflow graph to use. + mock_llm: The mock LLM service. + generate_audio: If True, the mock transport generates InputAudioRawFrame + every 20ms to simulate real audio input. + user_speech_initial_delay: Delay in seconds before injecting + UserStartedSpeakingFrame after seeing FunctionCallResultFrame. + + Returns: + Tuple of (engine, transport, task) + """ + # Create MockTTSService + tts = MockTTSService(mock_audio_duration_ms=10, frame_delay=0) + + # Create custom transport that injects user speaking frames on FunctionCallResultFrame #1 + transport = UserSpeechInjectingTransport( + generate_audio=generate_audio, + audio_interval_ms=20, + audio_sample_rate=16000, + audio_num_channels=1, + emit_bot_speaking=True, + user_speech_initial_delay=user_speech_initial_delay, + ) + + # Create LLM context + context = LLMContext() + + # Create PipecatEngine with the workflow + engine = PipecatEngine( + llm=mock_llm, + context=context, + workflow=workflow, + call_context_vars={"customer_name": "Test User"}, + workflow_run_id=1, + ) + + # Create user turn strategies matching run_pipeline.py + user_turn_strategies = UserTurnStrategies( + start=[TranscriptionUserTurnStartStrategy()], + stop=[ExternalUserTurnStopStrategy()], + ) + + # Create user mute strategies matching run_pipeline.py + user_mute_strategies = [ + MuteUntilFirstBotCompleteUserMuteStrategy(), + CallbackUserMuteStrategy(should_mute_callback=engine.should_mute_user), + ] + + user_params = LLMUserAggregatorParams( + user_turn_strategies=user_turn_strategies, + user_mute_strategies=user_mute_strategies, + ) + + # Create context aggregator with user and assistant params + assistant_params = LLMAssistantAggregatorParams(expect_stripped_words=True) + + context_aggregator = LLMContextAggregatorPair( + context, assistant_params=assistant_params, user_params=user_params + ) + user_context_aggregator = context_aggregator.user() + assistant_context_aggregator = context_aggregator.assistant() + + # Create the pipeline: + # transport.input() -> user_aggregator -> LLM -> TTS -> transport.output() -> assistant_aggregator + # The transport input watches for FunctionCallResultFrame flowing upstream + # and injects user speaking frames when it sees the first one + pipeline = Pipeline( + [ + transport.input(), + user_context_aggregator, + mock_llm, + tts, + transport.output(), + assistant_context_aggregator, + ] + ) + + # Create pipeline task + task = PipelineTask(pipeline, params=PipelineParams(), enable_rtvi=False) + + engine.set_task(task) + + return engine, transport, task + + +class TestNodeSwitchWithUserSpeech: + """Test scenarios where node switch and user speech happen simultaneously.""" + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "user_speech_initial_delay,scenario_name", + [ + (0.01, "delayed"), + (0, "immediate"), + ], + ids=["delayed_user_speech", "immediate_user_speech"], + ) + async def test_node_switch_with_concurrent_user_speech( + self, + three_node_workflow_no_variable_extraction: WorkflowGraph, + user_speech_initial_delay: float, + scenario_name: str, + ): + """Test scenario: node transition happens while user is speaking. + + This test creates the scenario where: + 1. LLM generates text and calls collect_info to transition from start to agent + 2. When FunctionCallResultFrame #1 is seen, UserStartedSpeakingFrame and + UserStoppedSpeakingFrame are automatically injected from the pipeline source + 3. The pipeline processes both events concurrently + + The FunctionCallResultObserver in the pipeline detects the first function call + result and triggers the transport to inject user speaking frames. + + This test is parameterized with two scenarios: + - delayed_user_speech: 10ms delay before UserStartedSpeakingFrame (user_speech_initial_delay=0.01) + - immediate_user_speech: No delay before UserStartedSpeakingFrame (user_speech_initial_delay=0) + + This is a scenario creation test - no specific assertions yet. + """ + # Step 0 (Start node): greet user then call collect_info to transition to agent + step_0_chunks = MockLLMService.create_mixed_chunks( + text="Hello!", + function_name="collect_info", + arguments={}, + tool_call_id="call_transition_1", + ) + + step_1_chunks = MockLLMService.create_text_chunks( + text="Step 1 with some longer text that should cause multiple chunks to be created." + ) + + step_2_chunks = MockLLMService.create_function_call_chunks( + function_name="end_call", + arguments={}, + tool_call_id="call_transition_2", + ) + + mock_steps = [step_0_chunks, step_1_chunks, step_2_chunks] + llm = MockLLMService(mock_steps=mock_steps, chunk_delay=0.001) + + engine, _transport, task = await create_test_pipeline( + three_node_workflow_no_variable_extraction, + llm, + user_speech_initial_delay=user_speech_initial_delay, + ) + + # Patch DB calls + with patch( + "api.services.workflow.pipecat_engine.get_organization_id_from_workflow_run", + new_callable=AsyncMock, + return_value=1, + ): + with patch( + "api.services.workflow.pipecat_engine.apply_disposition_mapping", + new_callable=AsyncMock, + return_value="completed", + ): + runner = PipelineRunner() + + async def run_pipeline(): + await runner.run(task) + + async def initialize_engine(): + await asyncio.sleep(0.01) + await engine.initialize() + # Start the LLM generation - user speech will be injected + # automatically when FunctionCallResultFrame #1 is seen + await engine.llm.queue_frame(LLMContextFrame(engine.context)) + + await asyncio.gather(run_pipeline(), initialize_engine()) + + # Total 4 generations out of which 1 was cancelled due to interruption + assert llm.get_current_step() == 4 diff --git a/api/tests/test_pipecat_engine_tool_calls.py b/api/tests/test_pipecat_engine_tool_calls.py index 4fe97b6..b8487b6 100644 --- a/api/tests/test_pipecat_engine_tool_calls.py +++ b/api/tests/test_pipecat_engine_tool_calls.py @@ -12,7 +12,7 @@ import pytest from api.services.workflow.pipecat_engine import PipecatEngine from api.services.workflow.workflow import WorkflowGraph -from api.tests.conftest import END_CALL_SYSTEM_PROMPT, MockTransportProcessor +from api.tests.conftest import END_CALL_SYSTEM_PROMPT from pipecat.frames.frames import LLMContextFrame from pipecat.pipeline.pipeline import Pipeline from pipecat.pipeline.runner import PipelineRunner @@ -23,6 +23,7 @@ from pipecat.processors.aggregators.llm_response_universal import ( LLMContextAggregatorPair, ) from pipecat.tests import MockLLMService, MockTTSService +from pipecat.tests.mock_transport import MockTransport async def run_pipeline_with_tool_calls( @@ -65,7 +66,8 @@ async def run_pipeline_with_tool_calls( # Create MockTTSService to generate TTS frames tts = MockTTSService(mock_audio_duration_ms=10, frame_delay=0) - mock_transport_emulator = MockTransportProcessor(emit_bot_speaking=False) + # Create MockTransport for simulating transport behavior + mock_transport = MockTransport(emit_bot_speaking=False) # Create LLM context context = LLMContext() @@ -91,15 +93,13 @@ async def run_pipeline_with_tool_calls( [ llm, tts, - mock_transport_emulator, + mock_transport.output(), assistant_context_aggregator, ] ) # Create a real pipeline task - task = PipelineTask( - pipeline, params=PipelineParams(allow_interruptions=False), enable_rtvi=False - ) + task = PipelineTask(pipeline, params=PipelineParams(), enable_rtvi=False) engine.set_task(task) diff --git a/api/tests/test_pipecat_engine_variable_extraction.py b/api/tests/test_pipecat_engine_variable_extraction.py index 88b5e22..1ec7104 100644 --- a/api/tests/test_pipecat_engine_variable_extraction.py +++ b/api/tests/test_pipecat_engine_variable_extraction.py @@ -17,23 +17,11 @@ from unittest.mock import AsyncMock, patch import pytest -from api.services.workflow.dto import ( - EdgeDataDTO, - ExtractionVariableDTO, - NodeDataDTO, - NodeType, - Position, - ReactFlowDTO, - RFEdgeDTO, - RFNodeDTO, - VariableType, -) from api.services.workflow.pipecat_engine import PipecatEngine from api.services.workflow.pipecat_engine_variable_extractor import ( VariableExtractionManager, ) from api.services.workflow.workflow import WorkflowGraph -from api.tests.conftest import MockTransportProcessor from pipecat.frames.frames import LLMContextFrame from pipecat.pipeline.pipeline import Pipeline from pipecat.pipeline.runner import PipelineRunner @@ -44,96 +32,7 @@ from pipecat.processors.aggregators.llm_response_universal import ( LLMContextAggregatorPair, ) from pipecat.tests import MockLLMService, MockTTSService - -# Define prompts for test nodes -START_NODE_PROMPT = "Start Node System Prompt" -AGENT_NODE_PROMPT = "Agent Node System Prompt" -END_NODE_PROMPT = "End Node System Prompt" - - -@pytest.fixture -def three_node_workflow_with_extraction_on_start() -> WorkflowGraph: - """Create a three-node workflow where only the start node has extraction enabled. - - The workflow has: - - Start node with extraction_enabled=True and extraction_variables set - - Agent node with extraction_enabled=False (default) - - End node with extraction_enabled=False (default) - - This is used to test that variable extraction is triggered for the correct node - during transitions. - """ - dto = ReactFlowDTO( - nodes=[ - RFNodeDTO( - id="start", - type=NodeType.startNode, - position=Position(x=0, y=0), - data=NodeDataDTO( - name="Start Call", - prompt=START_NODE_PROMPT, - is_start=True, - allow_interrupt=False, - add_global_prompt=False, - extraction_enabled=True, - extraction_prompt="Extract the user's name from the conversation.", - extraction_variables=[ - ExtractionVariableDTO( - name="user_name", - type=VariableType.string, - prompt="The name the user provided", - ), - ], - ), - ), - RFNodeDTO( - id="agent", - type=NodeType.agentNode, - position=Position(x=0, y=200), - data=NodeDataDTO( - name="Collect Info", - prompt=AGENT_NODE_PROMPT, - allow_interrupt=False, - add_global_prompt=False, - extraction_enabled=False, # Explicitly disabled - ), - ), - RFNodeDTO( - id="end", - type=NodeType.endNode, - position=Position(x=0, y=400), - data=NodeDataDTO( - name="End Call", - prompt=END_NODE_PROMPT, - is_end=True, - allow_interrupt=False, - add_global_prompt=False, - extraction_enabled=False, # Explicitly disabled - ), - ), - ], - edges=[ - RFEdgeDTO( - id="start-agent", - source="start", - target="agent", - data=EdgeDataDTO( - label="Collect Info", - condition="When user has been greeted, proceed to collect information", - ), - ), - RFEdgeDTO( - id="agent-end", - source="agent", - target="end", - data=EdgeDataDTO( - label="End Call", - condition="When information collection is complete, end the call", - ), - ), - ], - ) - return WorkflowGraph(dto) +from pipecat.tests.mock_transport import MockTransport class TestVariableExtractionDuringTransitions: @@ -141,7 +40,7 @@ class TestVariableExtractionDuringTransitions: @pytest.mark.asyncio async def test_extraction_called_for_source_node_not_target_node( - self, three_node_workflow_with_extraction_on_start: WorkflowGraph + self, three_node_workflow_extraction_start_only: WorkflowGraph ): """Test that when transitioning from START to AGENT, extraction is called for START node. @@ -183,7 +82,8 @@ class TestVariableExtractionDuringTransitions: # Create MockTTSService tts = MockTTSService(mock_audio_duration_ms=10, frame_delay=0) - mock_transport_emulator = MockTransportProcessor(emit_bot_speaking=False) + # Create MockTransport for simulating transport behavior + mock_transport = MockTransport(emit_bot_speaking=False) # Create LLM context context = LLMContext() @@ -195,7 +95,7 @@ class TestVariableExtractionDuringTransitions: ) assistant_context_aggregator = context_aggregator.assistant() - workflow = three_node_workflow_with_extraction_on_start + workflow = three_node_workflow_extraction_start_only # Create PipecatEngine with the workflow engine = PipecatEngine( @@ -209,7 +109,7 @@ class TestVariableExtractionDuringTransitions: # Patch _perform_variable_extraction_if_needed to track calls original_perform_extraction = engine._perform_variable_extraction_if_needed - async def tracked_perform_extraction(node): + async def tracked_perform_extraction(node, run_in_background=True): extraction_calls.append( { "node_id": node.id if node else None, @@ -228,7 +128,7 @@ class TestVariableExtractionDuringTransitions: [ llm, tts, - mock_transport_emulator, + mock_transport.output(), assistant_context_aggregator, ] ) @@ -236,7 +136,7 @@ class TestVariableExtractionDuringTransitions: # Create pipeline task task = PipelineTask( pipeline, - params=PipelineParams(allow_interruptions=False), + params=PipelineParams(), enable_rtvi=False, ) diff --git a/api/tests/test_user_idle_handler.py b/api/tests/test_user_idle_handler.py index 73e9cba..2f2a35c 100644 --- a/api/tests/test_user_idle_handler.py +++ b/api/tests/test_user_idle_handler.py @@ -14,7 +14,6 @@ import pytest from api.services.workflow.pipecat_engine import PipecatEngine from api.services.workflow.workflow import WorkflowGraph -from api.tests.conftest import MockTransportProcessor from pipecat.frames.frames import LLMContextFrame from pipecat.pipeline.pipeline import Pipeline from pipecat.pipeline.runner import PipelineRunner @@ -26,6 +25,7 @@ from pipecat.processors.aggregators.llm_response_universal import ( LLMUserAggregatorParams, ) from pipecat.tests import MockLLMService, MockTTSService +from pipecat.tests.mock_transport import MockTransport async def run_pipeline_with_user_idle( @@ -52,14 +52,15 @@ async def run_pipeline_with_user_idle( if mock_steps is None: mock_steps = MockLLMService.create_multi_step_responses( MockLLMService.create_text_chunks("Hello, how can I help you today?"), - num_text_steps=3, # Initial + 2 idle responses + num_text_steps=4, # Initial + 2 idle responses + 1 variable extraction step_prefix="Response", ) llm = MockLLMService(mock_steps=mock_steps, chunk_delay=0.001) tts = MockTTSService(mock_audio_duration_ms=10) - mock_transport = MockTransportProcessor() + # Create MockTransport for simulating transport behavior + mock_transport = MockTransport(emit_bot_speaking=True) # Create LLM context context = LLMContext() @@ -99,15 +100,13 @@ async def run_pipeline_with_user_idle( user_context_aggregator, llm, tts, - mock_transport, + mock_transport.output(), assistant_context_aggregator, ] ) # Create pipeline task - task = PipelineTask( - pipeline, params=PipelineParams(allow_interruptions=False), enable_rtvi=False - ) + task = PipelineTask(pipeline, params=PipelineParams(), enable_rtvi=False) engine.set_task(task) @@ -145,8 +144,6 @@ async def run_pipeline_with_user_idle( async def wait_for_idle_to_trigger(): # Wait long enough for idle timeouts to trigger await asyncio.sleep(total_wait_time) - # Cancel the task if it's still running - await task.cancel() # Run all concurrently await asyncio.gather( diff --git a/api/tests/test_voicemail_detector.py b/api/tests/test_voicemail_detector.py new file mode 100644 index 0000000..d0681c2 --- /dev/null +++ b/api/tests/test_voicemail_detector.py @@ -0,0 +1,238 @@ +"""Tests for understanding voicemail detector behavior with user aggregator and LLM. + +This module tests the interaction between the voicemail detector, user aggregator, +and LLM in a pipeline. It demonstrates how the voicemail detector classifies +incoming speech as CONVERSATION or VOICEMAIL and how the main LLM responds. +""" + +import asyncio + +import pytest + +from pipecat.extensions.voicemail.voicemail_detector import VoicemailDetector +from pipecat.frames.frames import ( + EndTaskFrame, + Frame, + TranscriptionFrame, + UserStartedSpeakingFrame, + UserStoppedSpeakingFrame, +) +from pipecat.pipeline.pipeline import Pipeline +from pipecat.pipeline.runner import PipelineRunner +from pipecat.pipeline.task import PipelineParams, PipelineTask +from pipecat.processors.aggregators.llm_context import LLMContext +from pipecat.processors.aggregators.llm_response import LLMAssistantAggregatorParams +from pipecat.processors.aggregators.llm_response_universal import ( + LLMContextAggregatorPair, + LLMUserAggregatorParams, +) +from pipecat.processors.frame_processor import FrameDirection, FrameProcessor +from pipecat.tests import MockLLMService +from pipecat.turns.user_start import ( + TranscriptionUserTurnStartStrategy, + VADUserTurnStartStrategy, +) +from pipecat.turns.user_stop import ( + ExternalUserTurnStopStrategy, +) +from pipecat.turns.user_turn_strategies import UserTurnStrategies +from pipecat.utils.time import time_now_iso8601 + + +class FrameInjector(FrameProcessor): + """Simple processor that can inject frames into the pipeline.""" + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self._frames_to_inject: list[Frame] = [] + + async def process_frame(self, frame: Frame, direction: FrameDirection): + await super().process_frame(frame, direction) + await self.push_frame(frame, direction) + + async def inject_frame( + self, frame: Frame, direction: FrameDirection = FrameDirection.DOWNSTREAM + ): + """Inject a frame into the pipeline.""" + await self.push_frame(frame, direction) + + +class FrameCounter: + """Helper to count specific frame types seen by a processor.""" + + def __init__(self): + self.user_stopped_speaking_count = 0 + self.user_started_speaking_count = 0 + + def wrap_process_frame(self, original_process_frame): + """Wrap a process_frame method to count UserStoppedSpeakingFrame.""" + + async def wrapped(frame: Frame, direction: FrameDirection): + if isinstance(frame, UserStoppedSpeakingFrame): + self.user_stopped_speaking_count += 1 + elif isinstance(frame, UserStartedSpeakingFrame): + self.user_started_speaking_count += 1 + return await original_process_frame(frame, direction) + + return wrapped + + +class TestVoicemailDetectorWithUserAggregator: + """Test scenarios with voicemail detector and user aggregator.""" + + @pytest.mark.asyncio + async def test_voicemail_detector_conversation_flow(self): + """Test: Voicemail detector classifies as CONVERSATION and main LLM responds. + + This test bench shows the flow: + 1. User starts speaking, sends transcription, stops speaking + 2. Voicemail detector's internal LLM classifies as "CONVERSATION" + 3. Main LLM generates response text + 4. Second user turn with transcription + 5. Main LLM generates end_call function to end pipeline + + Pipeline structure mirrors run_pipeline.py: + injector -> voicemail_detector.detector() -> user_aggregator -> main_llm + -> voicemail_detector.gate() -> assistant_aggregator + """ + context = LLMContext() + + # Create user turn strategies + user_turn_strategies = UserTurnStrategies( + start=[ + VADUserTurnStartStrategy(), + TranscriptionUserTurnStartStrategy(), + ], + stop=[ExternalUserTurnStopStrategy()], + ) + + user_params = LLMUserAggregatorParams( + user_turn_strategies=user_turn_strategies, + ) + + assistant_params = LLMAssistantAggregatorParams(expect_stripped_words=True) + + context_aggregator = LLMContextAggregatorPair( + context, assistant_params=assistant_params, user_params=user_params + ) + user_context_aggregator = context_aggregator.user() + assistant_context_aggregator = context_aggregator.assistant() + + # Create mock LLM for main conversation + # Step 0: First response after CONVERSATION classification + # Step 1: Response to second user turn + # Step 2: end_call function call to end pipeline + main_llm_steps = [ + MockLLMService.create_text_chunks(text="Hello! I'm here to help you today.") + ] + main_llm = MockLLMService(mock_steps=main_llm_steps, chunk_delay=0.001) + + # Create mock LLM for voicemail classification + # First response: "CONVERSATION" to close the voicemail gate + voicemail_classification_steps = [ + MockLLMService.create_text_chunks(text="CONVERSATION"), + ] + voicemail_llm = MockLLMService( + mock_steps=voicemail_classification_steps, chunk_delay=0.001 + ) + + # Create voicemail detector with the classification LLM + voicemail_detector = VoicemailDetector( + llm=voicemail_llm, + voicemail_response_delay=0, + ) + + # Set up frame counter to track UserStoppedSpeakingFrame in voicemail detector's user aggregator + voicemail_user_aggregator = voicemail_detector._context_aggregator.user() + frame_counter = FrameCounter() + original_process_frame = voicemail_user_aggregator.process_frame + voicemail_user_aggregator.process_frame = frame_counter.wrap_process_frame( + original_process_frame + ) + + # Build pipeline similar to run_pipeline.py structure + injector = FrameInjector() + pipeline = Pipeline( + [ + injector, + voicemail_detector.detector(), # Classification parallel pipeline + user_context_aggregator, + main_llm, + assistant_context_aggregator, + ] + ) + + task = PipelineTask(pipeline, params=PipelineParams(), enable_rtvi=False) + runner = PipelineRunner() + + async def run_pipeline(): + await runner.run(task) + + async def inject_frames(): + await asyncio.sleep(0.05) + + # === First user turn === + # This triggers voicemail classification AND main LLM response + await injector.inject_frame(UserStartedSpeakingFrame()) + await asyncio.sleep(0) + await injector.inject_frame( + TranscriptionFrame("First User Speech", "user-123", time_now_iso8601()) + ) + await asyncio.sleep(0.05) + await injector.inject_frame(UserStoppedSpeakingFrame()) + + # Wait for voicemail classification and main LLM response + await asyncio.sleep(0.2) + + # === Second user turn === + await injector.inject_frame(UserStartedSpeakingFrame()) + + await asyncio.sleep(0) + await injector.inject_frame( + TranscriptionFrame( + "Second User Speech", + "user-123", + time_now_iso8601(), + ) + ) + + await asyncio.sleep(0.05) + await injector.inject_frame(UserStoppedSpeakingFrame()) + + await asyncio.sleep(0.05) + await injector.inject_frame( + EndTaskFrame(), direction=FrameDirection.UPSTREAM + ) + + await asyncio.gather(run_pipeline(), inject_frames()) + + # Assert voicemail LLM was called once for classification + assert voicemail_llm.get_current_step() == 1 + + # Assert main LLM was called twice (once per user turn) + assert main_llm.get_current_step() == 2 + + # Assert voicemail detector's user aggregator saw UserStoppedSpeakingFrame only once + # (because the classifier gate closes after CONVERSATION classification, + # blocking subsequent frames from reaching the voicemail branch) + assert frame_counter.user_stopped_speaking_count == 1, ( + f"Expected voicemail detector's user aggregator to see UserStoppedSpeakingFrame once, " + f"but saw it {frame_counter.user_stopped_speaking_count} times" + ) + + # We should see no more than 2 user started speaking frame. One from downstream FrameInjector + # and one from upstream main pipeline's LLMUserAggregator + assert frame_counter.user_started_speaking_count <= 2, ( + f"Expected voicemail detector's user aggregator to see UserStartedSpeakingFrame at most twice, " + f"but saw it {frame_counter.user_started_speaking_count} times" + ) + + # Assert the classifier gate is closed after classification + assert voicemail_detector._classifier_gate._gate_opened is False, ( + "Expected classifier gate to be closed after CONVERSATION classification" + ) + + # Assert the classifier gate is closed after classification + assert voicemail_detector._classifier_upstream_gate._gate_open is False, ( + "Expected classifier upstream gate to be closed after CONVERSATION classification" + ) diff --git a/pipecat b/pipecat index f11fad8..df1432e 160000 --- a/pipecat +++ b/pipecat @@ -1 +1 @@ -Subproject commit f11fad8f3e90e06b1625b9dc49c13e26f3c9e716 +Subproject commit df1432e168570661ae418500fb04e8c62ba1335b