chore: refactor and add tests (#130)

* chore: add tests for end call

* Update pipecat module

* fix: allow interruptions from deepgram flux

* Add VadUserTurnStrategy

* chore: add test for voicemail detection
This commit is contained in:
Abhishek 2026-01-27 18:20:23 +05:30 committed by GitHub
parent 2aedb839ff
commit 033fde8946
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
15 changed files with 2106 additions and 542 deletions

View file

@ -9,7 +9,6 @@ from pipecat.frames.frames import (
CancelFrame,
EndFrame,
FunctionCallResultProperties,
TTSSpeakFrame,
)
from pipecat.pipeline.task import PipelineTask
from pipecat.processors.aggregators.llm_context import LLMContext
@ -19,6 +18,7 @@ from pipecat.utils.enums import EndTaskReason
if TYPE_CHECKING:
from api.services.telephony.stasis_rtp_connection import StasisRTPConnection
from pipecat.frames.frames import Frame
from pipecat.services.anthropic.llm import AnthropicLLMService
from pipecat.services.google.llm import GoogleLLMService
from pipecat.services.openai.llm import OpenAILLMService
@ -49,7 +49,6 @@ from api.services.workflow.tools.timezone import (
get_current_time,
get_time_tools,
)
from pipecat.processors.filters.stt_mute_filter import STTMuteFilter
from pipecat.utils.tracing.context_registry import get_current_turn_context
@ -79,12 +78,10 @@ class PipecatEngine:
self._workflow_run_id = workflow_run_id
self._node_transition_callback = node_transition_callback
self._initialized = False
self._client_disconnected = False
self._call_disposed = False
self._current_node: Optional[Node] = None
self._gathered_context: dict = {}
self._user_response_timeout_task: Optional[asyncio.Task] = None
self._call_disposition: Optional[str] = None
# Stasis connection for immediate transfers
self._stasis_connection: Optional["StasisRTPConnection"] = None
@ -99,6 +96,9 @@ class PipecatEngine:
# Track current LLM reference text for TTS aggregation correction
self._current_llm_generation_reference_text: str = ""
# Controls whether user input should be muted
self._mute_pipeline: bool = False
# Custom tool manager (initialized in initialize())
self._custom_tool_manager: Optional[CustomToolManager] = None
@ -215,9 +215,14 @@ class PipecatEngine:
This way, when we do set_node from within this function, and go for LLM completion with updated
system prompts, the context is updated with function call result.
"""
# FIXME: There is a potential race condition, when we generate LLM Completion from UserContextAggregator
# with FunctionCallResultFrame and we call end_call_with_reason where we queue EndFrame or CancelFrame.
# If EndFrame reaches the LLM Processor before the ContextFrame, we might never run generation which
# might be intended
# Queue EndFrame if we just transitioned to EndNode
if self._current_node.is_end:
await self.send_end_task_frame(
await self.end_call_with_reason(
EndTaskReason.USER_QUALIFIED.value
)
@ -356,44 +361,52 @@ class PipecatEngine:
self.llm.register_function("retrieve_from_knowledge_base", retrieve_kb_func)
async def _perform_variable_extraction_if_needed(
self, previous_node: Optional[Node]
self, node: Optional[Node], run_in_background: bool = True
) -> None:
"""Perform variable extraction if the previous node had extraction enabled."""
if (
previous_node
and previous_node.extraction_enabled
and previous_node.extraction_variables
):
"""Perform variable extraction if the node has extraction enabled.
Args:
node: The node to extract variables from.
run_in_background: If True, runs extraction as a fire-and-forget task.
If False, awaits the extraction synchronously.
"""
if not (node and node.extraction_enabled and node.extraction_variables):
return
# Capture the current turn context for otel tracing
# before creating the background task
parent_context = get_current_turn_context()
extraction_prompt = self._format_prompt(node.extraction_prompt)
extraction_variables = node.extraction_variables
async def _do_extraction():
try:
extracted_data = (
await self._variable_extraction_manager._perform_extraction(
extraction_variables, parent_context, extraction_prompt
)
)
self._gathered_context.update(extracted_data)
logger.debug(
f"Variable extraction completed. Extracted: {extracted_data}"
)
except Exception as e:
logger.error(f"Error during variable extraction: {str(e)}")
if run_in_background:
logger.debug(
f"Scheduling background variable extraction for node: {previous_node.name}"
f"Scheduling background variable extraction for node: {node.name}"
)
asyncio.create_task(_do_extraction())
else:
logger.debug(
f"Performing synchronous variable extraction for node: {node.name}"
)
await _do_extraction()
# Capture the current turn context before creating the background task
parent_context = get_current_turn_context()
extraction_prompt = self._format_prompt(previous_node.extraction_prompt)
extraction_variables = previous_node.extraction_variables
async def _background_extraction():
try:
extracted_data = (
await self._variable_extraction_manager._perform_extraction(
extraction_variables, parent_context, extraction_prompt
)
)
self._gathered_context.update(extracted_data)
logger.debug(
f"Background variable extraction completed. Extracted: {extracted_data}"
)
except Exception as e:
logger.error(
f"Error during background variable extraction: {str(e)}"
)
# Fire and forget - extraction happens in background without blocking
asyncio.create_task(_background_extraction())
async def _setup_llm_context_and_start_generation(self, node: Node) -> None:
"""Common method to set up LLM context and queue context frame for non-static nodes."""
async def _setup_llm_context(self, node: Node) -> None:
"""Common method to set up LLM context"""
# Set node name for tracing
try:
self.context.set_node_name(node.name)
@ -470,61 +483,54 @@ class PipecatEngine:
if node.is_static:
raise ValueError("Static nodes are not supported!")
else:
# Start generation for non-static start node
await self._setup_llm_context_and_start_generation(node)
# Setup LLM Context with Prompts and Functions
await self._setup_llm_context(node)
async def _handle_end_node(self, node: Node) -> None:
"""Handle end node execution."""
if node.is_static:
raise ValueError("Static nodes are not supported!")
else:
await self._setup_llm_context_and_start_generation(node)
# If this end node has extraction enabled, perform extraction immediately
if node.extraction_enabled and node.extraction_variables:
await self._perform_variable_extraction_if_needed(node)
# Setup LLM Context with Prompts and Functions
await self._setup_llm_context(node)
async def _handle_agent_node(self, node: Node) -> None:
"""Handle agent node execution."""
if node.is_static:
raise ValueError("Static nodes are not supported!")
else:
# Set context and functions for non-static agent node
await self._setup_llm_context_and_start_generation(node)
# Setup LLM Context with Prompts and Functions
await self._setup_llm_context(node)
async def send_end_task_frame(
async def end_call_with_reason(
self,
reason: str,
abort_immediately: bool = False,
):
"""
Centralized method to send EndTaskFrame with metadata including
call_transfer_context and call_context_vars
Centralized method to end the call with disposition mapping
"""
if self._call_disposed or self._client_disconnected:
# Call is already disposed and client disconnected
logger.debug(
f"Not sending EndFrame since call is already disposed: Call Disposed: {self._call_disposed} Client Disconnected: {self._client_disconnected}"
)
if self._call_disposed:
logger.debug(f"Call already Disposed: {self._call_disposed}")
return
self._call_disposed = True
frame_to_push = CancelFrame() if abort_immediately else EndFrame()
# Mute the pipeline
self._mute_pipeline = True
# Customer disposition code using their mapping
mapped_disposition = ""
# Perform final variable extraction synchronously before ending
await self._perform_variable_extraction_if_needed(
self._current_node, run_in_background=False
)
frame_to_push = CancelFrame() if abort_immediately else EndFrame()
# Apply disposition mapping - first try call_disposition if it is,
# extracted from the call conversation then fall back to reason
call_disposition = self._gathered_context.get("call_disposition", "")
organization_id = await self._get_organization_id()
# If client is disconnected before we get a chance to disconnect from
# the bot, lets consider that as final disposition
if self._client_disconnected:
call_disposition = EndTaskReason.USER_HANGUP.value
if call_disposition:
# If call_disposition exists, map it
mapped_disposition = await apply_disposition_mapping(
@ -532,90 +538,16 @@ class PipecatEngine:
)
# Store the original and mapped values
self._gathered_context["extracted_call_disposition"] = call_disposition
self._gathered_context["call_disposition"] = mapped_disposition
self._gathered_context["call_disposition"] = call_disposition
self._gathered_context["mapped_call_disposition"] = mapped_disposition
else:
# Otherwise, map the disconnect reason
mapped_disposition = await apply_disposition_mapping(
reason, organization_id
)
# Store the mapped disconnect reason
self._gathered_context["call_disposition"] = mapped_disposition
# TODO: Generalise this
self._gathered_context["address"] = ", ".join(
[
self._call_context_vars.get("address1", ""),
self._call_context_vars.get("address2", ""),
self._call_context_vars.get("address3", ""),
self._call_context_vars.get("city", ""),
self._call_context_vars.get("state", ""),
self._call_context_vars.get("province", ""),
self._call_context_vars.get("postal_code", ""),
]
)
self._gathered_context["full_name"] = " ".join(
[
self._call_context_vars.get("first_name", ""),
self._call_context_vars.get("middle_initial", ""),
self._call_context_vars.get("last_name", ""),
]
)
self._gathered_context["agent_name"] = "Alex"
self._gathered_context["customer_phone_number"] = self._call_context_vars.get(
"phone", ""
)
self._gathered_context["timezone"] = self._call_context_vars.get("province", "")
self._gathered_context["vendor_id"] = self._call_context_vars.get(
"vendor_lead_code", ""
)
decision_maker = self._gathered_context.get("primary_cardholder", False)
employment_status = self._gathered_context.get("employment_status", "N/A")
call_transfer_context = {
"first_name": self._call_context_vars.get("first_name", ""),
"full_name": self._gathered_context.get("full_name", ""),
"phone": self._call_context_vars.get("phone", ""),
"lead_id": self._call_context_vars.get("lead_id"),
"disposition": mapped_disposition,
"agent_name": self._gathered_context.get("agent_name", "Alex"),
"decision_maker": str(decision_maker),
"employment": employment_status.title() if employment_status else "N/A",
"debts": self._gathered_context.get("total_debt", "N/A"),
"number_of_credit_cards": self._gathered_context.get(
"number_of_credit_cards", "N/A"
),
"time": self._gathered_context.get("time"),
}
logger.debug(
f"gathered_context: {self._gathered_context} call_transfer_context: {call_transfer_context}"
)
# Initiate immediate transfer for Stasis connections when user is qualified
if (
reason == EndTaskReason.USER_QUALIFIED.value
and self._stasis_connection is not None
and not abort_immediately
):
try:
logger.info(
f"Initiating immediate Stasis transfer for channel {self._stasis_connection.channel_id}"
)
await self._stasis_connection.transfer(call_transfer_context)
logger.info("Immediate transfer initiated successfully")
except Exception as e:
logger.error(f"Failed to initiate immediate transfer: {e}")
# Continue with normal flow even if immediate transfer fails
if reason == EndTaskReason.CALL_DURATION_EXCEEDED.value:
await self.task.queue_frame(
TTSSpeakFrame(
"Sorry! It seems like our time has exceeded. Someone from our team will reach out to you soon. Thank you!"
)
)
# Store the original reason for later retrieval in event handler
self._call_disposition = mapped_disposition
self._gathered_context["call_disposition"] = reason
self._gathered_context["mapped_call_disposition"] = mapped_disposition
logger.debug(
f"Finishing run with reason: {reason}, disposition: {mapped_disposition} queueing frame {frame_to_push}"
@ -678,11 +610,14 @@ class PipecatEngine:
return system_message, functions
def create_should_mute_callback(self) -> Callable[[STTMuteFilter], Awaitable[bool]]:
async def should_mute_user(self, frame: "Frame") -> bool:
"""
This callback is called by STTMuteFilter to determine if the STT should be muted.
Callback for CallbackUserMuteStrategy to determine if the user should be muted.
Returns:
True if the user should be muted, False otherwise.
"""
return engine_callbacks.create_should_mute_callback(self)
return self._mute_pipeline
def create_user_idle_handler(self):
"""
@ -746,26 +681,10 @@ class PipecatEngine:
"""Accumulate LLM text frames to build reference text."""
self._current_llm_generation_reference_text += text
def handle_client_disconnected(self):
"""Handle client disconnected event."""
self._client_disconnected = True
def is_call_disposed(self):
"""Check whether a call has been disposed by the engine"""
return self._call_disposed
async def get_call_disposition(self) -> Optional[str]:
"""Get the disconnect reason set by the engine."""
if self._call_disposition:
# We would have a _call_disposition variable set if we have initiated
# a disconnect from the bot, i.e we have called send_end_task_frame.
return self._call_disposition
if self._client_disconnected:
return EndTaskReason.USER_HANGUP.value
else:
return EndTaskReason.UNKNOWN.value
async def get_gathered_context(self) -> dict:
"""Get the gathered context including extracted variables."""
return self._gathered_context.copy()