dograh/api/services/workflow/pipecat_engine.py
Abhishek Kumar 4f2a629340 Initial Commit 🚀 🚀
2025-09-09 14:37:32 +05:30

939 lines
39 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from typing import TYPE_CHECKING, Any, Awaitable, Callable, Optional, Union
from pipecat.frames.frames import (
CancelFrame,
EndFrame,
FunctionCallResultProperties,
LLMFullResponseEndFrame,
LLMFullResponseStartFrame,
TTSSpeakFrame,
)
from pipecat.pipeline.task import PipelineTask
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContextFrame
from pipecat.services.llm_service import FunctionCallParams
from pipecat.services.openai.llm import OpenAILLMContext
from pipecat.transports.base_transport import BaseTransport
from pipecat.utils.enums import EndTaskReason
from api.constants import VOICEMAIL_RECORDING_DURATION
from api.services.gender.gender_service import GenderService
from api.services.workflow.disposition_mapper import (
apply_disposition_mapping,
get_organization_id_from_workflow_run,
)
from api.services.workflow.pipecat_engine_voicemail_detector import (
VoicemailDetector,
)
from api.services.workflow.workflow import Node, WorkflowGraph
if TYPE_CHECKING:
from pipecat.processors.audio.audio_buffer_processor import AudioBuffer
from pipecat.services.anthropic.llm import AnthropicLLMService
from pipecat.services.google.llm import GoogleLLMService
from pipecat.services.openai.llm import OpenAILLMService
from api.services.telephony.stasis_rtp_connection import StasisRTPConnection
LLMService = Union[OpenAILLMService, AnthropicLLMService, GoogleLLMService]
import asyncio
from loguru import logger
from pipecat.processors.filters.stt_mute_filter import STTMuteFilter
from pipecat.utils.tracing.context_registry import get_current_turn_context
from api.services.workflow import pipecat_engine_callbacks as engine_callbacks
from api.services.workflow.pipecat_engine_utils import (
get_function_schema,
render_template,
update_llm_context,
)
from api.services.workflow.pipecat_engine_variable_extractor import (
VariableExtractionManager,
)
from api.services.workflow.tools.calculator import get_calculator_tools, safe_calculator
from api.services.workflow.tools.timezone import (
convert_time,
get_current_time,
get_time_tools,
)
class PipecatEngine:
def __init__(
self,
*,
task: Optional[PipelineTask] = None,
llm: Optional["LLMService"] = None,
context: Optional[OpenAILLMContext] = None,
tts: Optional[Any] = None,
transport: Optional[BaseTransport] = None,
workflow: WorkflowGraph,
call_context_vars: dict,
audio_buffer: Optional["AudioBuffer"] = None,
workflow_run_id: Optional[int] = None,
):
self.task = task
self.llm = llm
self.context = context
self.tts = tts
self.transport = transport
self.workflow = workflow
self._call_context_vars = call_context_vars
self._audio_buffer = audio_buffer
self._workflow_run_id = workflow_run_id
self._initialized = False
self._pending_function_calls = 0
self._current_node: Optional[Node] = None
self._gathered_context: dict = {}
self._user_response_timeout_task: Optional[asyncio.Task] = None
self._call_disposition: Optional[str] = None
# Stasis connection for immediate transfers
self._stasis_connection: Optional["StasisRTPConnection"] = None
# Will be set later in initialize() when we have
# access to _context
self._variable_extraction_manager = None
self._gender_service = GenderService(confidence_threshold=0.5)
# Voicemail detection state
self._detect_voicemail = False
self._voicemail_detector = None
self._voicemail_detection_task: Optional[asyncio.Task] = None
# This transition is generated by the llm as part of tool call. This can
# also be accompanied with some content which can be played using TTS. If the
# bot is interrupted, we would cancel this transition (we do cancel this currently when
# the next generation starts in handle_generation_started callback handler.)
self._pending_generated_transition_after_context_push: Optional[
Callable[[], Awaitable[None]]
] = None
# This is the transtion which is typically programmatic transition, and not goes as
# tool call to LLM. This is not interrupted by the user and is done on context push
self._pending_control_transition_after_context_push: Optional[
Callable[[], Awaitable[None]]
] = None
# Flag to determine if the current llm generation has a text completion
self._defer_context_push: bool = False
# Lazy loaded built-in function schemas
self._builtin_function_schemas: Optional[list[dict]] = None
# Flag to control whether to queue context frame
self._queue_context_frame: bool = True
# Track current LLM reference text for TTS aggregation correction
self._current_llm_reference_text: str = ""
@property
def builtin_function_schemas(self) -> list[dict]:
"""Get built-in function schemas (calculator and timezone tools)."""
if self._builtin_function_schemas is None:
self._builtin_function_schemas = []
# Transform calculator tools to get_function_schema format
for tool in get_calculator_tools():
func = tool["function"]
schema = get_function_schema(
func["name"],
func["description"],
properties=func["parameters"]["properties"],
required=func["parameters"]["required"],
)
self._builtin_function_schemas.append(schema)
# Transform timezone tools to get_function_schema format
for tool in get_time_tools():
func = tool["function"]
schema = get_function_schema(
func["name"],
func["description"],
properties=func["parameters"]["properties"],
required=func["parameters"]["required"],
)
self._builtin_function_schemas.append(schema)
return self._builtin_function_schemas
async def initialize(self):
# TODO: May be set_node in a separate task so that we return from initialize immediately
if self._initialized:
logger.warning(f"{self.__class__.__name__} already initialized")
return
try:
self._initialized = True
# Helper that encapsulates variable extraction logic
self._variable_extraction_manager = VariableExtractionManager(self)
# Add current time in EST (America/New_York) to gathered context
try:
est_time_result = get_current_time("America/New_York")
# The get_current_time utility returns a dict with 'datetime' field
# Store the ISO formatted datetime string under the key 'time'
self._gathered_context["time"] = est_time_result.get("datetime")
except Exception as e:
logger.error(f"Failed to fetch current EST time: {e}")
# Register built-in functions with the LLM
await self._register_builtin_functions()
# Set gender in initial context predicted from first name
if "first_name" in self._call_context_vars:
salutation = await self._gender_service.get_salutation(
self._call_context_vars["first_name"]
)
self._call_context_vars["salutation"] = salutation
await self.set_node(self.workflow.start_node_id)
logger.debug(f"{self.__class__.__name__} initialized")
except Exception as e:
logger.error(f"Error initializing {self.__class__.__name__}: {e}")
raise
def _get_function_schema(self, function_name: str, description: str):
"""Thin wrapper around utils.get_function_schema for backwards compatibility."""
return get_function_schema(function_name, description)
async def _update_llm_context(self, system_message: dict, functions: list[dict]):
"""Delegate context update to the shared workflow.utils implementation."""
update_llm_context(self.context, system_message, functions)
def _format_prompt(self, prompt: str) -> str:
"""Delegate prompt formatting to the shared workflow.utils implementation."""
return render_template(prompt, self._call_context_vars)
async def _create_transition_func(self, name: str, transition_to_node: str):
async def transition_func(function_call_params: FunctionCallParams) -> None:
"""Inner function that handles the actual tool invocation."""
try:
# Track pending function call
self._pending_function_calls += 1
logger.debug(
f"Function call pending: {function_call_params.function_name} (total: {self._pending_function_calls})"
)
# For edge functions, prevent LLM completion until transition (run_llm=False)
# For node functions, allow immediate completion (run_llm=True)
async def on_context_updated() -> None:
"""
Framework will run this function after the function call result has been updated in the context.
This way, when we do set_node from within this function, and go for LLM completion with updated
system prompts, the context is updated with function call result.
"""
self._pending_function_calls -= 1
# Perform variable extraction before transitioning to new node
await self._perform_variable_extraction_if_needed(
self._current_node
)
await self.set_node(transition_to_node)
result = {"status": "done"}
properties = FunctionCallResultProperties(
run_llm=False,
on_context_updated=on_context_updated,
)
async def _invoke_result_callback():
"""
Functions are executed immediately when they come from LLM as part of text completion.
But, if the LLM completion also has some text, we would want to not call the function if the user interrupts the speech.
We would also not want the function to be added to context, so that the LLM can call the function again. Hence, we
defer the function invocation until we receive on_context_updated callback, i.e the bot has finished speaking
the text that was generated.
"""
await function_call_params.result_callback(
result, properties=properties
)
if self._defer_context_push:
"""
We set the flag to _defer_context_push when we receive text in the current generation from LLM.
This is set in the handle_llm_generated_text callback handler.
"""
logger.debug(
"Deferring transition function result until context push"
)
# Only one deferred transition should exist at any time.
# Overwrite if one is somehow already set (unexpected).
self._pending_generated_transition_after_context_push = (
_invoke_result_callback
)
else:
"""
If there was no text in the current generation, and we only had function call,
lets invoke the result callback, so that framework can call on_context_updated and
we can do switch node.
"""
await _invoke_result_callback()
except Exception as e:
logger.error(f"Error in transition function {name}: {str(e)}")
self._pending_function_calls = 0
error_result = {"status": "error", "error": str(e)}
await function_call_params.result_callback(error_result)
return transition_func
async def _register_transition_function_with_llm(
self, name: str, transition_to_node: str
):
logger.debug(
f"Registering function {name} to transition to node {transition_to_node} with LLM"
)
# Create transition function
transition_func = await self._create_transition_func(name, transition_to_node)
# Register function with LLM
self.llm.register_function(
name,
transition_func,
cancel_on_interruption=True,
)
async def _register_builtin_functions(self):
"""Register built-in functions (calculator and timezone) with the LLM."""
logger.debug("Registering built-in functions with LLM")
properties = FunctionCallResultProperties(run_llm=True)
# Register calculator function
async def calculate_func(function_call_params: FunctionCallParams) -> None:
try:
expr = function_call_params.arguments.get("expression", "")
result = safe_calculator(expr)
await function_call_params.result_callback(
{"expression": expr, "result": result}, properties=properties
)
except Exception as e:
await function_call_params.result_callback(
{"error": str(e)}, properties=properties
)
# Register timezone functions
async def get_current_time_func(
function_call_params: FunctionCallParams,
) -> None:
try:
timezone = function_call_params.arguments.get("timezone", "UTC")
result = get_current_time(timezone)
await function_call_params.result_callback(
result, properties=properties
)
except Exception as e:
await function_call_params.result_callback(
{"error": str(e)}, properties=properties
)
async def convert_time_func(function_call_params: FunctionCallParams) -> None:
try:
result = convert_time(
function_call_params.arguments.get("source_timezone"),
function_call_params.arguments.get("time"),
function_call_params.arguments.get("target_timezone"),
)
await function_call_params.result_callback(
result, properties=properties
)
except Exception as e:
await function_call_params.result_callback(
{"error": str(e)}, properties=properties
)
# Register all built-in functions
self.llm.register_function("safe_calculator", calculate_func)
self.llm.register_function("get_current_time", get_current_time_func)
self.llm.register_function("convert_time", convert_time_func)
async def _queue_tts_response(self, text: str) -> None:
"""Queue TTS frames for static text response."""
await self.task.queue_frames(
[
LLMFullResponseStartFrame(),
TTSSpeakFrame(text=text),
LLMFullResponseEndFrame(),
]
)
async def _setup_static_start_node_transition(self, node: Node) -> None:
"""Set up the deferred transition for static start nodes."""
if not node.out_edges:
return
next_node_id = node.out_edges[0].target
if not node.wait_for_user_response:
# Normal static start node - transition immediately after context push
async def _deferred_static_transition():
try:
await self.set_node(next_node_id)
except Exception as exc:
logger.error(
f"Error executing deferred static node transition to {next_node_id}: {exc}"
)
self._pending_control_transition_after_context_push = (
_deferred_static_transition
)
async def _perform_variable_extraction_if_needed(
self, previous_node: Optional[Node]
) -> None:
"""Perform variable extraction if the previous node had extraction enabled."""
if (
previous_node
and previous_node.extraction_enabled
and previous_node.extraction_variables
):
logger.debug(
f"Scheduling background variable extraction for node: {previous_node.name}"
)
# Capture the current turn context before creating the background task
parent_context = get_current_turn_context()
extraction_prompt = self._format_prompt(previous_node.extraction_prompt)
extraction_variables = previous_node.extraction_variables
async def _background_extraction():
try:
extracted_data = (
await self._variable_extraction_manager._perform_extraction(
extraction_variables, parent_context, extraction_prompt
)
)
self._gathered_context.update(extracted_data)
logger.debug(
f"Background variable extraction completed. Extracted: {extracted_data}"
)
except Exception as e:
logger.error(
f"Error during background variable extraction: {str(e)}"
)
# Fire and forget - extraction happens in background without blocking
asyncio.create_task(_background_extraction())
async def _setup_llm_context_and_start_generation(self, node: Node) -> None:
"""Common method to set up LLM context and queue context frame for non-static nodes."""
# Set node name for tracing
try:
self.context.set_node_name(node.name)
except AttributeError:
logger.warning(f"context has no set_node_name method")
# Register transition functions if not an end node
if not node.is_end:
for outgoing_edge in node.out_edges:
await self._register_transition_function_with_llm(
outgoing_edge.get_function_name(), outgoing_edge.target
)
# Set up system message and functions
(
system_message,
functions,
) = await self._compose_system_message_functions_for_node(node)
await self._update_llm_context(system_message, functions)
# Queue context frame if needed
if self._queue_context_frame:
await self.task.queue_frame(OpenAILLMContextFrame(self.context))
else:
logger.debug(
f"Not queueing context frame for node: {node.name} as _queue_context_frame is False"
)
# Reset _queue_context_frame as default behavior
self._queue_context_frame = True
async def set_node(self, node_id: str):
"""
Simplified set_node implementation according to v2 PRD.
"""
node = self.workflow.nodes[node_id]
logger.debug(
f"Executing node: name: {node.name} is_static: {node.is_static} allow_interrupt: {node.allow_interrupt} is_end: {node.is_end}"
)
# Set current node for all nodes (including static ones) so STT mute filter works
self._current_node = node
# Handle start nodes
if node.is_start:
await self._handle_start_node(node)
# Handle end nodes
elif node.is_end:
await self._handle_end_node(node)
# Handle normal agent nodes
else:
await self._handle_agent_node(node)
async def _handle_start_node(self, node: Node) -> None:
"""Handle start node execution."""
# Handle voicemail detection setup (before any returns)
if node.detect_voicemail:
if not self._audio_buffer:
logger.warning(
"Voicemail detection enabled but no audio buffer available - skipping detection"
)
else:
logger.debug(
"Start node has detect_voicemail enabled - setting up audio-based detector"
)
self._detect_voicemail = True
self._voicemail_detector = VoicemailDetector(
detection_duration=VOICEMAIL_RECORDING_DURATION,
workflow_run_id=self._workflow_run_id,
)
# Register audio handler on the audio buffer input processor
audio_input = self._audio_buffer.input()
@audio_input.event_handler("on_input_audio_data")
async def handle_voicemail_audio(
processor, pcm, sample_rate, num_channels
):
if (
self._voicemail_detector
and self._voicemail_detector.is_detecting
):
await self._voicemail_detector.handle_audio_data(
processor, pcm, sample_rate, num_channels
)
# Start detection
await self._voicemail_detector.start_detection(self)
# Check if delayed start is enabled
if node.delayed_start:
# Use configured duration or default to 3 seconds
delay_duration = node.delayed_start_duration or 2.0
logger.debug(
f"Delayed start enabled - waiting {delay_duration} seconds before speaking"
)
await asyncio.sleep(delay_duration)
if node.is_static:
# Queue TTS for static start node
formatted_prompt = self._format_prompt(node.prompt)
await self._queue_tts_response(formatted_prompt)
# Set up deferred transition for static start nodes
await self._setup_static_start_node_transition(node)
else:
# Start generation for non-static start node
await self._setup_llm_context_and_start_generation(node)
async def _handle_end_node(self, node: Node) -> None:
"""Handle end node execution."""
if node.is_static:
# Queue TTS for static end node
formatted_prompt = self._format_prompt(node.prompt)
await self._queue_tts_response(formatted_prompt)
else:
# Start generation for non-static end node
await self._setup_llm_context_and_start_generation(node)
# If this end node has extraction enabled, perform extraction immediately
if node.extraction_enabled and node.extraction_variables:
await self._perform_variable_extraction_if_needed(node)
# TODO: Extract disposition code from extracted variables
# Defer send_end_task_frame using _pending_control_transition_after_context_push
# Decide the end-task reason dynamically depending on call_disposition.
async def _deferred_end_task():
# call_disposition is the disposition which is generated from
# llm call based on the conversation so far.
# TODO: Make this more generic based on configuration or llm prompting
disposition = self._gathered_context.get("call_disposition")
if disposition == "XFER":
reason = EndTaskReason.USER_QUALIFIED.value
else:
reason = EndTaskReason.USER_DISQUALIFIED.value
await self.send_end_task_frame(reason)
self._pending_control_transition_after_context_push = _deferred_end_task
async def _handle_agent_node(self, node: Node) -> None:
"""Handle agent node execution."""
if node.is_static:
# Queue TTS for static agent node
formatted_prompt = self._format_prompt(node.prompt)
await self._queue_tts_response(formatted_prompt)
# Set up deferred transition for static agent nodes
await self._setup_agent_node_transition(node)
else:
# Set context and functions for non-static agent node
await self._setup_llm_context_and_start_generation(node)
async def _setup_agent_node_transition(self, node: Node) -> None:
"""Set up the deferred transition for static agent nodes."""
if not node.out_edges:
return
next_node_id = node.out_edges[0].target
async def _deferred_static_transition():
try:
await self.set_node(next_node_id)
except Exception as exc:
logger.error(
f"Error executing deferred static node transition to {next_node_id}: {exc}"
)
self._pending_control_transition_after_context_push = (
_deferred_static_transition
)
async def send_end_task_frame(
self,
reason: str,
additional_metadata: dict = None,
abort_immediately: bool = False,
):
"""
Centralized method to send EndTaskFrame with metadata including
call_transfer_context and call_context_vars
"""
frame_to_push = CancelFrame() if abort_immediately else EndFrame()
# Customer disposition code using their mapping
mapped_disposition = ""
# Apply disposition mapping - first try call_disposition if it is,
# extracted from the call conversation then fall back to reason
call_disposition = self._gathered_context.get("call_disposition", "")
organization_id = await get_organization_id_from_workflow_run(
self._workflow_run_id
)
if call_disposition:
# If call_disposition exists, map it
mapped_disposition = await apply_disposition_mapping(
call_disposition, organization_id
)
# Store the original and mapped values
self._gathered_context["extracted_call_disposition"] = call_disposition
self._gathered_context["call_disposition"] = mapped_disposition
else:
# Otherwise, map the disconnect reason
mapped_disposition = await apply_disposition_mapping(
reason, organization_id
)
# Store the mapped disconnect reason
self._gathered_context["call_disposition"] = mapped_disposition
# TODO: Generalise this, currently tailored to Kapil's use case
self._gathered_context["address"] = ", ".join(
[
self._call_context_vars.get("address1", ""),
self._call_context_vars.get("address2", ""),
self._call_context_vars.get("address3", ""),
self._call_context_vars.get("city", ""),
self._call_context_vars.get("state", ""),
self._call_context_vars.get("province", ""),
self._call_context_vars.get("postal_code", ""),
]
)
self._gathered_context["full_name"] = " ".join(
[
self._call_context_vars.get("first_name", ""),
self._call_context_vars.get("middle_initial", ""),
self._call_context_vars.get("last_name", ""),
]
)
self._gathered_context["agent_name"] = "Alex"
self._gathered_context["customer_phone_number"] = self._call_context_vars.get(
"phone", ""
)
self._gathered_context["timezone"] = self._call_context_vars.get("province", "")
self._gathered_context["vendor_id"] = self._call_context_vars.get(
"vendor_lead_code", ""
)
decision_maker = self._gathered_context.get("primary_cardholder", False)
employment_status = self._gathered_context.get("employment_status", "N/A")
call_transfer_context = {
"first_name": self._call_context_vars.get("first_name", ""),
"full_name": self._gathered_context.get("full_name", ""),
"phone": self._call_context_vars.get("phone", ""),
"lead_id": self._call_context_vars.get("lead_id"),
"disposition": mapped_disposition,
"agent_name": self._gathered_context.get("agent_name", "Alex"),
"decision_maker": str(decision_maker),
"employment": employment_status.title() if employment_status else "N/A",
"debts": self._gathered_context.get("total_debt", "N/A"),
"number_of_credit_cards": self._gathered_context.get(
"number_of_credit_cards", "N/A"
),
"time": self._gathered_context.get("time"),
}
logger.debug(
f"gathered_context: {self._gathered_context} call_transfer_context: {call_transfer_context}"
)
# Initiate immediate transfer for Stasis connections when user is qualified
if (
reason == EndTaskReason.USER_QUALIFIED.value
and self._stasis_connection is not None
and not abort_immediately
):
try:
logger.info(
f"Initiating immediate Stasis transfer for channel {self._stasis_connection.channel_id}"
)
await self._stasis_connection.transfer(call_transfer_context)
logger.info("Immediate transfer initiated successfully")
except Exception as e:
logger.error(f"Failed to initiate immediate transfer: {e}")
# Continue with normal flow even if immediate transfer fails
if reason == EndTaskReason.CALL_DURATION_EXCEEDED.value:
await self.task.queue_frame(
TTSSpeakFrame(
"Sorry! It seems like our time has exceeded. Someone from our team will reach out to you soon. Thank you!"
)
)
metadata = {
# Keep original reason in metadata, which would be used to decide
# whether to disconnect or to transfer the call in the transport
"reason": reason,
"call_transfer_context": call_transfer_context,
}
# Add any additional metadata
if additional_metadata:
metadata.update(additional_metadata)
frame_to_push.metadata = metadata
# Store the original reason for later retrieval in event handler
self._call_disposition = mapped_disposition
logger.debug(
f"Finishing run with reason: {reason}, disposition: {mapped_disposition} queueing frame {frame_to_push}"
)
await self.task.queue_frame(frame_to_push)
async def _compose_system_message_functions_for_node(
self, node: "Node"
) -> tuple[list[dict], list[dict]]:
"""Generate the system messages and function schemas for the given node.
This performs the same formatting logic used when entering a node but
does **not** register the functions with the LLM; callers are
responsible for that.
"""
global_prompt = ""
if self.workflow.global_node_id and node.add_global_prompt:
global_node = self.workflow.nodes[self.workflow.global_node_id]
global_prompt = self._format_prompt(global_node.prompt)
functions: list[dict] = []
# Add built-in function schemas (calculator and timezone tools)
functions.extend(self.builtin_function_schemas)
# Transition functions (schema only; registration handled elsewhere)
for outgoing_edge in node.out_edges:
function_schema = self._get_function_schema(
outgoing_edge.get_function_name(), outgoing_edge.condition
)
functions.append(function_schema)
formatted_node_prompt = self._format_prompt(node.prompt)
system_message = {
"role": "system",
"content": "\n\n".join(
p for p in (global_prompt, formatted_node_prompt) if p
),
}
return system_message, functions
# ------------------------------------------------------------------
# Pending transition handling
# ------------------------------------------------------------------
async def flush_pending_transitions(self, *, source: str = "context_push"):
"""Execute and clear any pending transitions.
Args:
source: Indicates the trigger that caused this flush:
- "context_push": the assistant context aggregator completed a push.
"""
if source != "context_push":
raise ValueError("Invalid flush source expected 'context_push'")
len_pending_functions = 0
if self._pending_generated_transition_after_context_push is not None:
len_pending_functions += 1
if self._pending_control_transition_after_context_push is not None:
len_pending_functions += 1
# Nothing to do
if len_pending_functions == 0:
return
logger.debug(
f"Flushing {len_pending_functions} pending transition(s) after {source.replace('_', ' ')}"
)
# Generated transition
if self._pending_generated_transition_after_context_push is not None:
pending_cb = self._pending_generated_transition_after_context_push
self._pending_generated_transition_after_context_push = None
try:
await pending_cb()
except Exception as exc: # pragma: no cover
logger.error(f"Error executing deferred transition: {exc}")
# Control transition (context push)
if self._pending_control_transition_after_context_push is not None:
logger.debug("Executing control transition after context push")
static_cb = self._pending_control_transition_after_context_push
self._pending_control_transition_after_context_push = None
try:
await static_cb()
except Exception as exc: # pragma: no cover
logger.error(f"Error executing deferred static node transition: {exc}")
def create_should_mute_callback(self) -> Callable[[STTMuteFilter], Awaitable[bool]]:
"""
This callback is called by STTMuteFilter to determine if the STT should be muted.
"""
return engine_callbacks.create_should_mute_callback(self)
def create_user_idle_callback(self):
"""
This callback is called when the user is idle for a certain duration.
We use this to either play the static text or end the call
"""
return engine_callbacks.create_user_idle_callback(self)
def create_max_duration_callback(self):
"""
This callback is called when the call duration exceeds the max duration.
We use this to send the EndTaskFrame.
"""
return engine_callbacks.create_max_duration_callback(self)
def create_llm_generated_text_callback(self):
"""
This callback is called when some text is generated by the LLM.
We use this to defer the result_callback of the node transition functions if
there is set_node called along with some text generated. This way, we will
have the context sent in the next generation from new node.
"""
return engine_callbacks.create_llm_generated_text_callback(self)
def create_generation_started_callback(self):
"""
This callback is called when a new generation starts.
This is used to reset the flags that control the flow of the engine.
"""
return engine_callbacks.create_generation_started_callback(self)
def create_user_stopped_speaking_callback(self):
"""
This callback is called when the user stops speaking.
We use this to handle transitions when wait_for_user_response is enabled.
"""
return engine_callbacks.create_user_stopped_speaking_callback(self)
def create_user_started_speaking_callback(self):
"""
This callback is called when the user starts speaking.
We use this to handle wait_for_user_greeting functionality.
"""
return engine_callbacks.create_user_started_speaking_callback(self)
def create_aggregation_correction_callback(self) -> Callable[[str], str]:
"""Create a callback that corrects corrupted aggregation using reference text."""
return engine_callbacks.create_aggregation_correction_callback(self)
def get_call_disposition(self) -> Optional[str]:
"""Get the disconnect reason set by the engine."""
return self._call_disposition
def get_gathered_context(self) -> dict:
"""Get the gathered context including extracted variables."""
return self._gathered_context.copy()
def set_context(self, context: OpenAILLMContext) -> None:
"""Set the OpenAI LLM context.
This allows setting the context after the engine has been created,
which is useful when the context needs to be created after the engine.
"""
self.context = context
def set_task(self, task: PipelineTask) -> None:
"""Set the pipeline task.
This allows setting the task after the engine has been created,
which is useful when the task needs to be created after the engine.
"""
self.task = task
def set_audio_buffer(self, audio_buffer: "AudioBuffer") -> None:
"""Set the audio buffer.
This allows setting the audio buffer after the engine has been created,
which is useful when the audio buffer needs to be created after the engine.
"""
self._audio_buffer = audio_buffer
def set_stasis_connection(
self, connection: Optional["StasisRTPConnection"]
) -> None:
"""Set the Stasis RTP connection for immediate transfers.
This allows the engine to initiate transfers immediately when XFER
disposition is detected, without waiting for pipeline shutdown.
Args:
connection: The StasisRTPConnection instance, or None for non-Stasis transports
"""
self._stasis_connection = connection
if connection:
logger.debug(
f"Stasis connection set for immediate transfers: {connection.channel_id}"
)
async def handle_llm_text_frame(self, text: str):
"""Accumulate LLM text frames to build reference text."""
self._current_llm_reference_text += text
async def cleanup(self):
"""Clean up engine resources on disconnect."""
# Cancel any pending timeout tasks
if (
self._user_response_timeout_task
and not self._user_response_timeout_task.done()
):
self._user_response_timeout_task.cancel()
# Stop voicemail detection if active
if self._voicemail_detector and hasattr(
self._voicemail_detector, "stop_detection"
):
await self._voicemail_detector.stop_detection()