Initial Commit 🚀 🚀

This commit is contained in:
Abhishek Kumar 2025-09-09 14:37:32 +05:30
commit 4f2a629340
444 changed files with 76863 additions and 0 deletions

View file

View file

@ -0,0 +1,77 @@
"""Utility module for applying disposition code mapping."""
from typing import Optional
from loguru import logger
from api.db import db_client
from api.enums import OrganizationConfigurationKey
async def apply_disposition_mapping(value: str, organization_id: Optional[int]) -> str:
"""Apply disposition code mapping if configured.
Args:
value: The original disposition value to map
organization_id: The organization ID
Returns:
The mapped value if found in configuration, otherwise the original value
"""
if not organization_id or not value:
return value
try:
disposition_mapping = await db_client.get_configuration_value(
organization_id,
OrganizationConfigurationKey.DISPOSITION_CODE_MAPPING.value,
default={},
)
if not disposition_mapping:
return value
# Return mapped value if exists, otherwise original
# DISPOSITION_CODE_MAPPING looks like {"user_idle_max_duration_exceeded": "DAIR"} etc.
mapped_value = disposition_mapping.get(value, value)
if mapped_value != value:
logger.debug(
f"Mapped disposition code from '{value}' to '{mapped_value}' "
f"for organization {organization_id}"
)
return mapped_value
except Exception as e:
logger.error(f"Error applying disposition mapping: {e}")
return value
async def get_organization_id_from_workflow_run(
workflow_run_id: Optional[int],
) -> Optional[int]:
"""Get organization_id from workflow_run_id through the model relationships.
Args:
workflow_run_id: The workflow run ID
Returns:
The organization ID if found, otherwise None
"""
if not workflow_run_id:
return None
try:
workflow_run = await db_client.get_workflow_run_by_id(workflow_run_id)
if not workflow_run or not workflow_run.workflow:
return None
workflow = workflow_run.workflow
if not workflow.user:
return None
return workflow.user.selected_organization_id
except Exception as e:
logger.error(f"Error getting organization_id from workflow_run: {e}")
return None

View file

@ -0,0 +1,96 @@
from enum import Enum
from typing import List, Optional
from pydantic import BaseModel, Field, ValidationError, model_validator
class NodeType(str, Enum):
startNode = "startCall"
endNode = "endCall"
agentNode = "agentNode"
globalNode = "globalNode"
class Position(BaseModel):
x: float
y: float
class VariableType(str, Enum):
string = "string"
number = "number"
boolean = "boolean"
class ExtractionVariableDTO(BaseModel):
name: str = Field(..., min_length=1)
type: VariableType
prompt: Optional[str] = None
class NodeDataDTO(BaseModel):
name: str = Field(..., min_length=1)
prompt: str = Field(..., min_length=1)
is_static: bool = False
is_start: bool = False
is_end: bool = False
allow_interrupt: bool = False
extraction_enabled: bool = False
extraction_prompt: Optional[str] = None
extraction_variables: Optional[list[ExtractionVariableDTO]] = None
add_global_prompt: bool = True
wait_for_user_response: bool = False
wait_for_user_response_timeout: Optional[float] = None
detect_voicemail: bool = True
delayed_start: bool = False
delayed_start_duration: Optional[float] = None
class RFNodeDTO(BaseModel):
id: str
type: NodeType = Field(default=NodeType.agentNode)
position: Position
data: NodeDataDTO
class EdgeDataDTO(BaseModel):
label: str = Field(..., min_length=1)
condition: str = Field(..., min_length=1)
class RFEdgeDTO(BaseModel):
id: str
source: str
target: str
data: EdgeDataDTO
class ReactFlowDTO(BaseModel):
nodes: List[RFNodeDTO]
edges: List[RFEdgeDTO]
@model_validator(mode="after")
def _referential_integrity(self):
node_ids = {n.id for n in self.nodes}
line_errors: list[dict[str, str]] = []
for idx, edge in enumerate(self.edges):
for endpoint in (edge.source, edge.target):
if endpoint not in node_ids:
line_errors.append(
dict(
loc=("edges", idx),
type="missing_node",
msg="Edge references missing node",
input=edge.model_dump(mode="python"),
ctx={"edge_id": edge.id, "endpoint": endpoint},
)
)
if line_errors:
raise ValidationError.from_exception_data(
title="ReactFlowDTO validation failed",
line_errors=line_errors,
)
return self

View file

@ -0,0 +1,16 @@
# api/services/workflow/errors.py
from enum import Enum
from typing import TypedDict
class ItemKind(str, Enum):
node = "node"
edge = "edge"
workflow = "workflow"
class WorkflowError(TypedDict):
kind: ItemKind # "node" | "edge"
id: str | None # nodeId or edgeId
field: str | None # “data.prompt”, “position.x”, … (optional)
message: str # human-readable text

View file

@ -0,0 +1,939 @@
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Optional, Union
from pipecat.frames.frames import (
CancelFrame,
EndFrame,
FunctionCallResultProperties,
LLMFullResponseEndFrame,
LLMFullResponseStartFrame,
TTSSpeakFrame,
)
from pipecat.pipeline.task import PipelineTask
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContextFrame
from pipecat.services.llm_service import FunctionCallParams
from pipecat.services.openai.llm import OpenAILLMContext
from pipecat.transports.base_transport import BaseTransport
from pipecat.utils.enums import EndTaskReason
from api.constants import VOICEMAIL_RECORDING_DURATION
from api.services.gender.gender_service import GenderService
from api.services.workflow.disposition_mapper import (
apply_disposition_mapping,
get_organization_id_from_workflow_run,
)
from api.services.workflow.pipecat_engine_voicemail_detector import (
VoicemailDetector,
)
from api.services.workflow.workflow import Node, WorkflowGraph
if TYPE_CHECKING:
from pipecat.processors.audio.audio_buffer_processor import AudioBuffer
from pipecat.services.anthropic.llm import AnthropicLLMService
from pipecat.services.google.llm import GoogleLLMService
from pipecat.services.openai.llm import OpenAILLMService
from api.services.telephony.stasis_rtp_connection import StasisRTPConnection
LLMService = Union[OpenAILLMService, AnthropicLLMService, GoogleLLMService]
import asyncio
from loguru import logger
from pipecat.processors.filters.stt_mute_filter import STTMuteFilter
from pipecat.utils.tracing.context_registry import get_current_turn_context
from api.services.workflow import pipecat_engine_callbacks as engine_callbacks
from api.services.workflow.pipecat_engine_utils import (
get_function_schema,
render_template,
update_llm_context,
)
from api.services.workflow.pipecat_engine_variable_extractor import (
VariableExtractionManager,
)
from api.services.workflow.tools.calculator import get_calculator_tools, safe_calculator
from api.services.workflow.tools.timezone import (
convert_time,
get_current_time,
get_time_tools,
)
class PipecatEngine:
def __init__(
self,
*,
task: Optional[PipelineTask] = None,
llm: Optional["LLMService"] = None,
context: Optional[OpenAILLMContext] = None,
tts: Optional[Any] = None,
transport: Optional[BaseTransport] = None,
workflow: WorkflowGraph,
call_context_vars: dict,
audio_buffer: Optional["AudioBuffer"] = None,
workflow_run_id: Optional[int] = None,
):
self.task = task
self.llm = llm
self.context = context
self.tts = tts
self.transport = transport
self.workflow = workflow
self._call_context_vars = call_context_vars
self._audio_buffer = audio_buffer
self._workflow_run_id = workflow_run_id
self._initialized = False
self._pending_function_calls = 0
self._current_node: Optional[Node] = None
self._gathered_context: dict = {}
self._user_response_timeout_task: Optional[asyncio.Task] = None
self._call_disposition: Optional[str] = None
# Stasis connection for immediate transfers
self._stasis_connection: Optional["StasisRTPConnection"] = None
# Will be set later in initialize() when we have
# access to _context
self._variable_extraction_manager = None
self._gender_service = GenderService(confidence_threshold=0.5)
# Voicemail detection state
self._detect_voicemail = False
self._voicemail_detector = None
self._voicemail_detection_task: Optional[asyncio.Task] = None
# This transition is generated by the llm as part of tool call. This can
# also be accompanied with some content which can be played using TTS. If the
# bot is interrupted, we would cancel this transition (we do cancel this currently when
# the next generation starts in handle_generation_started callback handler.)
self._pending_generated_transition_after_context_push: Optional[
Callable[[], Awaitable[None]]
] = None
# This is the transtion which is typically programmatic transition, and not goes as
# tool call to LLM. This is not interrupted by the user and is done on context push
self._pending_control_transition_after_context_push: Optional[
Callable[[], Awaitable[None]]
] = None
# Flag to determine if the current llm generation has a text completion
self._defer_context_push: bool = False
# Lazy loaded built-in function schemas
self._builtin_function_schemas: Optional[list[dict]] = None
# Flag to control whether to queue context frame
self._queue_context_frame: bool = True
# Track current LLM reference text for TTS aggregation correction
self._current_llm_reference_text: str = ""
@property
def builtin_function_schemas(self) -> list[dict]:
"""Get built-in function schemas (calculator and timezone tools)."""
if self._builtin_function_schemas is None:
self._builtin_function_schemas = []
# Transform calculator tools to get_function_schema format
for tool in get_calculator_tools():
func = tool["function"]
schema = get_function_schema(
func["name"],
func["description"],
properties=func["parameters"]["properties"],
required=func["parameters"]["required"],
)
self._builtin_function_schemas.append(schema)
# Transform timezone tools to get_function_schema format
for tool in get_time_tools():
func = tool["function"]
schema = get_function_schema(
func["name"],
func["description"],
properties=func["parameters"]["properties"],
required=func["parameters"]["required"],
)
self._builtin_function_schemas.append(schema)
return self._builtin_function_schemas
async def initialize(self):
# TODO: May be set_node in a separate task so that we return from initialize immediately
if self._initialized:
logger.warning(f"{self.__class__.__name__} already initialized")
return
try:
self._initialized = True
# Helper that encapsulates variable extraction logic
self._variable_extraction_manager = VariableExtractionManager(self)
# Add current time in EST (America/New_York) to gathered context
try:
est_time_result = get_current_time("America/New_York")
# The get_current_time utility returns a dict with 'datetime' field
# Store the ISO formatted datetime string under the key 'time'
self._gathered_context["time"] = est_time_result.get("datetime")
except Exception as e:
logger.error(f"Failed to fetch current EST time: {e}")
# Register built-in functions with the LLM
await self._register_builtin_functions()
# Set gender in initial context predicted from first name
if "first_name" in self._call_context_vars:
salutation = await self._gender_service.get_salutation(
self._call_context_vars["first_name"]
)
self._call_context_vars["salutation"] = salutation
await self.set_node(self.workflow.start_node_id)
logger.debug(f"{self.__class__.__name__} initialized")
except Exception as e:
logger.error(f"Error initializing {self.__class__.__name__}: {e}")
raise
def _get_function_schema(self, function_name: str, description: str):
"""Thin wrapper around utils.get_function_schema for backwards compatibility."""
return get_function_schema(function_name, description)
async def _update_llm_context(self, system_message: dict, functions: list[dict]):
"""Delegate context update to the shared workflow.utils implementation."""
update_llm_context(self.context, system_message, functions)
def _format_prompt(self, prompt: str) -> str:
"""Delegate prompt formatting to the shared workflow.utils implementation."""
return render_template(prompt, self._call_context_vars)
async def _create_transition_func(self, name: str, transition_to_node: str):
async def transition_func(function_call_params: FunctionCallParams) -> None:
"""Inner function that handles the actual tool invocation."""
try:
# Track pending function call
self._pending_function_calls += 1
logger.debug(
f"Function call pending: {function_call_params.function_name} (total: {self._pending_function_calls})"
)
# For edge functions, prevent LLM completion until transition (run_llm=False)
# For node functions, allow immediate completion (run_llm=True)
async def on_context_updated() -> None:
"""
Framework will run this function after the function call result has been updated in the context.
This way, when we do set_node from within this function, and go for LLM completion with updated
system prompts, the context is updated with function call result.
"""
self._pending_function_calls -= 1
# Perform variable extraction before transitioning to new node
await self._perform_variable_extraction_if_needed(
self._current_node
)
await self.set_node(transition_to_node)
result = {"status": "done"}
properties = FunctionCallResultProperties(
run_llm=False,
on_context_updated=on_context_updated,
)
async def _invoke_result_callback():
"""
Functions are executed immediately when they come from LLM as part of text completion.
But, if the LLM completion also has some text, we would want to not call the function if the user interrupts the speech.
We would also not want the function to be added to context, so that the LLM can call the function again. Hence, we
defer the function invocation until we receive on_context_updated callback, i.e the bot has finished speaking
the text that was generated.
"""
await function_call_params.result_callback(
result, properties=properties
)
if self._defer_context_push:
"""
We set the flag to _defer_context_push when we receive text in the current generation from LLM.
This is set in the handle_llm_generated_text callback handler.
"""
logger.debug(
"Deferring transition function result until context push"
)
# Only one deferred transition should exist at any time.
# Overwrite if one is somehow already set (unexpected).
self._pending_generated_transition_after_context_push = (
_invoke_result_callback
)
else:
"""
If there was no text in the current generation, and we only had function call,
lets invoke the result callback, so that framework can call on_context_updated and
we can do switch node.
"""
await _invoke_result_callback()
except Exception as e:
logger.error(f"Error in transition function {name}: {str(e)}")
self._pending_function_calls = 0
error_result = {"status": "error", "error": str(e)}
await function_call_params.result_callback(error_result)
return transition_func
async def _register_transition_function_with_llm(
self, name: str, transition_to_node: str
):
logger.debug(
f"Registering function {name} to transition to node {transition_to_node} with LLM"
)
# Create transition function
transition_func = await self._create_transition_func(name, transition_to_node)
# Register function with LLM
self.llm.register_function(
name,
transition_func,
cancel_on_interruption=True,
)
async def _register_builtin_functions(self):
"""Register built-in functions (calculator and timezone) with the LLM."""
logger.debug("Registering built-in functions with LLM")
properties = FunctionCallResultProperties(run_llm=True)
# Register calculator function
async def calculate_func(function_call_params: FunctionCallParams) -> None:
try:
expr = function_call_params.arguments.get("expression", "")
result = safe_calculator(expr)
await function_call_params.result_callback(
{"expression": expr, "result": result}, properties=properties
)
except Exception as e:
await function_call_params.result_callback(
{"error": str(e)}, properties=properties
)
# Register timezone functions
async def get_current_time_func(
function_call_params: FunctionCallParams,
) -> None:
try:
timezone = function_call_params.arguments.get("timezone", "UTC")
result = get_current_time(timezone)
await function_call_params.result_callback(
result, properties=properties
)
except Exception as e:
await function_call_params.result_callback(
{"error": str(e)}, properties=properties
)
async def convert_time_func(function_call_params: FunctionCallParams) -> None:
try:
result = convert_time(
function_call_params.arguments.get("source_timezone"),
function_call_params.arguments.get("time"),
function_call_params.arguments.get("target_timezone"),
)
await function_call_params.result_callback(
result, properties=properties
)
except Exception as e:
await function_call_params.result_callback(
{"error": str(e)}, properties=properties
)
# Register all built-in functions
self.llm.register_function("safe_calculator", calculate_func)
self.llm.register_function("get_current_time", get_current_time_func)
self.llm.register_function("convert_time", convert_time_func)
async def _queue_tts_response(self, text: str) -> None:
"""Queue TTS frames for static text response."""
await self.task.queue_frames(
[
LLMFullResponseStartFrame(),
TTSSpeakFrame(text=text),
LLMFullResponseEndFrame(),
]
)
async def _setup_static_start_node_transition(self, node: Node) -> None:
"""Set up the deferred transition for static start nodes."""
if not node.out_edges:
return
next_node_id = node.out_edges[0].target
if not node.wait_for_user_response:
# Normal static start node - transition immediately after context push
async def _deferred_static_transition():
try:
await self.set_node(next_node_id)
except Exception as exc:
logger.error(
f"Error executing deferred static node transition to {next_node_id}: {exc}"
)
self._pending_control_transition_after_context_push = (
_deferred_static_transition
)
async def _perform_variable_extraction_if_needed(
self, previous_node: Optional[Node]
) -> None:
"""Perform variable extraction if the previous node had extraction enabled."""
if (
previous_node
and previous_node.extraction_enabled
and previous_node.extraction_variables
):
logger.debug(
f"Scheduling background variable extraction for node: {previous_node.name}"
)
# Capture the current turn context before creating the background task
parent_context = get_current_turn_context()
extraction_prompt = self._format_prompt(previous_node.extraction_prompt)
extraction_variables = previous_node.extraction_variables
async def _background_extraction():
try:
extracted_data = (
await self._variable_extraction_manager._perform_extraction(
extraction_variables, parent_context, extraction_prompt
)
)
self._gathered_context.update(extracted_data)
logger.debug(
f"Background variable extraction completed. Extracted: {extracted_data}"
)
except Exception as e:
logger.error(
f"Error during background variable extraction: {str(e)}"
)
# Fire and forget - extraction happens in background without blocking
asyncio.create_task(_background_extraction())
async def _setup_llm_context_and_start_generation(self, node: Node) -> None:
"""Common method to set up LLM context and queue context frame for non-static nodes."""
# Set node name for tracing
try:
self.context.set_node_name(node.name)
except AttributeError:
logger.warning(f"context has no set_node_name method")
# Register transition functions if not an end node
if not node.is_end:
for outgoing_edge in node.out_edges:
await self._register_transition_function_with_llm(
outgoing_edge.get_function_name(), outgoing_edge.target
)
# Set up system message and functions
(
system_message,
functions,
) = await self._compose_system_message_functions_for_node(node)
await self._update_llm_context(system_message, functions)
# Queue context frame if needed
if self._queue_context_frame:
await self.task.queue_frame(OpenAILLMContextFrame(self.context))
else:
logger.debug(
f"Not queueing context frame for node: {node.name} as _queue_context_frame is False"
)
# Reset _queue_context_frame as default behavior
self._queue_context_frame = True
async def set_node(self, node_id: str):
"""
Simplified set_node implementation according to v2 PRD.
"""
node = self.workflow.nodes[node_id]
logger.debug(
f"Executing node: name: {node.name} is_static: {node.is_static} allow_interrupt: {node.allow_interrupt} is_end: {node.is_end}"
)
# Set current node for all nodes (including static ones) so STT mute filter works
self._current_node = node
# Handle start nodes
if node.is_start:
await self._handle_start_node(node)
# Handle end nodes
elif node.is_end:
await self._handle_end_node(node)
# Handle normal agent nodes
else:
await self._handle_agent_node(node)
async def _handle_start_node(self, node: Node) -> None:
"""Handle start node execution."""
# Handle voicemail detection setup (before any returns)
if node.detect_voicemail:
if not self._audio_buffer:
logger.warning(
"Voicemail detection enabled but no audio buffer available - skipping detection"
)
else:
logger.debug(
"Start node has detect_voicemail enabled - setting up audio-based detector"
)
self._detect_voicemail = True
self._voicemail_detector = VoicemailDetector(
detection_duration=VOICEMAIL_RECORDING_DURATION,
workflow_run_id=self._workflow_run_id,
)
# Register audio handler on the audio buffer input processor
audio_input = self._audio_buffer.input()
@audio_input.event_handler("on_input_audio_data")
async def handle_voicemail_audio(
processor, pcm, sample_rate, num_channels
):
if (
self._voicemail_detector
and self._voicemail_detector.is_detecting
):
await self._voicemail_detector.handle_audio_data(
processor, pcm, sample_rate, num_channels
)
# Start detection
await self._voicemail_detector.start_detection(self)
# Check if delayed start is enabled
if node.delayed_start:
# Use configured duration or default to 3 seconds
delay_duration = node.delayed_start_duration or 2.0
logger.debug(
f"Delayed start enabled - waiting {delay_duration} seconds before speaking"
)
await asyncio.sleep(delay_duration)
if node.is_static:
# Queue TTS for static start node
formatted_prompt = self._format_prompt(node.prompt)
await self._queue_tts_response(formatted_prompt)
# Set up deferred transition for static start nodes
await self._setup_static_start_node_transition(node)
else:
# Start generation for non-static start node
await self._setup_llm_context_and_start_generation(node)
async def _handle_end_node(self, node: Node) -> None:
"""Handle end node execution."""
if node.is_static:
# Queue TTS for static end node
formatted_prompt = self._format_prompt(node.prompt)
await self._queue_tts_response(formatted_prompt)
else:
# Start generation for non-static end node
await self._setup_llm_context_and_start_generation(node)
# If this end node has extraction enabled, perform extraction immediately
if node.extraction_enabled and node.extraction_variables:
await self._perform_variable_extraction_if_needed(node)
# TODO: Extract disposition code from extracted variables
# Defer send_end_task_frame using _pending_control_transition_after_context_push
# Decide the end-task reason dynamically depending on call_disposition.
async def _deferred_end_task():
# call_disposition is the disposition which is generated from
# llm call based on the conversation so far.
# TODO: Make this more generic based on configuration or llm prompting
disposition = self._gathered_context.get("call_disposition")
if disposition == "XFER":
reason = EndTaskReason.USER_QUALIFIED.value
else:
reason = EndTaskReason.USER_DISQUALIFIED.value
await self.send_end_task_frame(reason)
self._pending_control_transition_after_context_push = _deferred_end_task
async def _handle_agent_node(self, node: Node) -> None:
"""Handle agent node execution."""
if node.is_static:
# Queue TTS for static agent node
formatted_prompt = self._format_prompt(node.prompt)
await self._queue_tts_response(formatted_prompt)
# Set up deferred transition for static agent nodes
await self._setup_agent_node_transition(node)
else:
# Set context and functions for non-static agent node
await self._setup_llm_context_and_start_generation(node)
async def _setup_agent_node_transition(self, node: Node) -> None:
"""Set up the deferred transition for static agent nodes."""
if not node.out_edges:
return
next_node_id = node.out_edges[0].target
async def _deferred_static_transition():
try:
await self.set_node(next_node_id)
except Exception as exc:
logger.error(
f"Error executing deferred static node transition to {next_node_id}: {exc}"
)
self._pending_control_transition_after_context_push = (
_deferred_static_transition
)
async def send_end_task_frame(
self,
reason: str,
additional_metadata: dict = None,
abort_immediately: bool = False,
):
"""
Centralized method to send EndTaskFrame with metadata including
call_transfer_context and call_context_vars
"""
frame_to_push = CancelFrame() if abort_immediately else EndFrame()
# Customer disposition code using their mapping
mapped_disposition = ""
# Apply disposition mapping - first try call_disposition if it is,
# extracted from the call conversation then fall back to reason
call_disposition = self._gathered_context.get("call_disposition", "")
organization_id = await get_organization_id_from_workflow_run(
self._workflow_run_id
)
if call_disposition:
# If call_disposition exists, map it
mapped_disposition = await apply_disposition_mapping(
call_disposition, organization_id
)
# Store the original and mapped values
self._gathered_context["extracted_call_disposition"] = call_disposition
self._gathered_context["call_disposition"] = mapped_disposition
else:
# Otherwise, map the disconnect reason
mapped_disposition = await apply_disposition_mapping(
reason, organization_id
)
# Store the mapped disconnect reason
self._gathered_context["call_disposition"] = mapped_disposition
# TODO: Generalise this, currently tailored to Kapil's use case
self._gathered_context["address"] = ", ".join(
[
self._call_context_vars.get("address1", ""),
self._call_context_vars.get("address2", ""),
self._call_context_vars.get("address3", ""),
self._call_context_vars.get("city", ""),
self._call_context_vars.get("state", ""),
self._call_context_vars.get("province", ""),
self._call_context_vars.get("postal_code", ""),
]
)
self._gathered_context["full_name"] = " ".join(
[
self._call_context_vars.get("first_name", ""),
self._call_context_vars.get("middle_initial", ""),
self._call_context_vars.get("last_name", ""),
]
)
self._gathered_context["agent_name"] = "Alex"
self._gathered_context["customer_phone_number"] = self._call_context_vars.get(
"phone", ""
)
self._gathered_context["timezone"] = self._call_context_vars.get("province", "")
self._gathered_context["vendor_id"] = self._call_context_vars.get(
"vendor_lead_code", ""
)
decision_maker = self._gathered_context.get("primary_cardholder", False)
employment_status = self._gathered_context.get("employment_status", "N/A")
call_transfer_context = {
"first_name": self._call_context_vars.get("first_name", ""),
"full_name": self._gathered_context.get("full_name", ""),
"phone": self._call_context_vars.get("phone", ""),
"lead_id": self._call_context_vars.get("lead_id"),
"disposition": mapped_disposition,
"agent_name": self._gathered_context.get("agent_name", "Alex"),
"decision_maker": str(decision_maker),
"employment": employment_status.title() if employment_status else "N/A",
"debts": self._gathered_context.get("total_debt", "N/A"),
"number_of_credit_cards": self._gathered_context.get(
"number_of_credit_cards", "N/A"
),
"time": self._gathered_context.get("time"),
}
logger.debug(
f"gathered_context: {self._gathered_context} call_transfer_context: {call_transfer_context}"
)
# Initiate immediate transfer for Stasis connections when user is qualified
if (
reason == EndTaskReason.USER_QUALIFIED.value
and self._stasis_connection is not None
and not abort_immediately
):
try:
logger.info(
f"Initiating immediate Stasis transfer for channel {self._stasis_connection.channel_id}"
)
await self._stasis_connection.transfer(call_transfer_context)
logger.info("Immediate transfer initiated successfully")
except Exception as e:
logger.error(f"Failed to initiate immediate transfer: {e}")
# Continue with normal flow even if immediate transfer fails
if reason == EndTaskReason.CALL_DURATION_EXCEEDED.value:
await self.task.queue_frame(
TTSSpeakFrame(
"Sorry! It seems like our time has exceeded. Someone from our team will reach out to you soon. Thank you!"
)
)
metadata = {
# Keep original reason in metadata, which would be used to decide
# whether to disconnect or to transfer the call in the transport
"reason": reason,
"call_transfer_context": call_transfer_context,
}
# Add any additional metadata
if additional_metadata:
metadata.update(additional_metadata)
frame_to_push.metadata = metadata
# Store the original reason for later retrieval in event handler
self._call_disposition = mapped_disposition
logger.debug(
f"Finishing run with reason: {reason}, disposition: {mapped_disposition} queueing frame {frame_to_push}"
)
await self.task.queue_frame(frame_to_push)
async def _compose_system_message_functions_for_node(
self, node: "Node"
) -> tuple[list[dict], list[dict]]:
"""Generate the system messages and function schemas for the given node.
This performs the same formatting logic used when entering a node but
does **not** register the functions with the LLM; callers are
responsible for that.
"""
global_prompt = ""
if self.workflow.global_node_id and node.add_global_prompt:
global_node = self.workflow.nodes[self.workflow.global_node_id]
global_prompt = self._format_prompt(global_node.prompt)
functions: list[dict] = []
# Add built-in function schemas (calculator and timezone tools)
functions.extend(self.builtin_function_schemas)
# Transition functions (schema only; registration handled elsewhere)
for outgoing_edge in node.out_edges:
function_schema = self._get_function_schema(
outgoing_edge.get_function_name(), outgoing_edge.condition
)
functions.append(function_schema)
formatted_node_prompt = self._format_prompt(node.prompt)
system_message = {
"role": "system",
"content": "\n\n".join(
p for p in (global_prompt, formatted_node_prompt) if p
),
}
return system_message, functions
# ------------------------------------------------------------------
# Pending transition handling
# ------------------------------------------------------------------
async def flush_pending_transitions(self, *, source: str = "context_push"):
"""Execute and clear any pending transitions.
Args:
source: Indicates the trigger that caused this flush:
- "context_push": the assistant context aggregator completed a push.
"""
if source != "context_push":
raise ValueError("Invalid flush source expected 'context_push'")
len_pending_functions = 0
if self._pending_generated_transition_after_context_push is not None:
len_pending_functions += 1
if self._pending_control_transition_after_context_push is not None:
len_pending_functions += 1
# Nothing to do
if len_pending_functions == 0:
return
logger.debug(
f"Flushing {len_pending_functions} pending transition(s) after {source.replace('_', ' ')}"
)
# Generated transition
if self._pending_generated_transition_after_context_push is not None:
pending_cb = self._pending_generated_transition_after_context_push
self._pending_generated_transition_after_context_push = None
try:
await pending_cb()
except Exception as exc: # pragma: no cover
logger.error(f"Error executing deferred transition: {exc}")
# Control transition (context push)
if self._pending_control_transition_after_context_push is not None:
logger.debug("Executing control transition after context push")
static_cb = self._pending_control_transition_after_context_push
self._pending_control_transition_after_context_push = None
try:
await static_cb()
except Exception as exc: # pragma: no cover
logger.error(f"Error executing deferred static node transition: {exc}")
def create_should_mute_callback(self) -> Callable[[STTMuteFilter], Awaitable[bool]]:
"""
This callback is called by STTMuteFilter to determine if the STT should be muted.
"""
return engine_callbacks.create_should_mute_callback(self)
def create_user_idle_callback(self):
"""
This callback is called when the user is idle for a certain duration.
We use this to either play the static text or end the call
"""
return engine_callbacks.create_user_idle_callback(self)
def create_max_duration_callback(self):
"""
This callback is called when the call duration exceeds the max duration.
We use this to send the EndTaskFrame.
"""
return engine_callbacks.create_max_duration_callback(self)
def create_llm_generated_text_callback(self):
"""
This callback is called when some text is generated by the LLM.
We use this to defer the result_callback of the node transition functions if
there is set_node called along with some text generated. This way, we will
have the context sent in the next generation from new node.
"""
return engine_callbacks.create_llm_generated_text_callback(self)
def create_generation_started_callback(self):
"""
This callback is called when a new generation starts.
This is used to reset the flags that control the flow of the engine.
"""
return engine_callbacks.create_generation_started_callback(self)
def create_user_stopped_speaking_callback(self):
"""
This callback is called when the user stops speaking.
We use this to handle transitions when wait_for_user_response is enabled.
"""
return engine_callbacks.create_user_stopped_speaking_callback(self)
def create_user_started_speaking_callback(self):
"""
This callback is called when the user starts speaking.
We use this to handle wait_for_user_greeting functionality.
"""
return engine_callbacks.create_user_started_speaking_callback(self)
def create_aggregation_correction_callback(self) -> Callable[[str], str]:
"""Create a callback that corrects corrupted aggregation using reference text."""
return engine_callbacks.create_aggregation_correction_callback(self)
def get_call_disposition(self) -> Optional[str]:
"""Get the disconnect reason set by the engine."""
return self._call_disposition
def get_gathered_context(self) -> dict:
"""Get the gathered context including extracted variables."""
return self._gathered_context.copy()
def set_context(self, context: OpenAILLMContext) -> None:
"""Set the OpenAI LLM context.
This allows setting the context after the engine has been created,
which is useful when the context needs to be created after the engine.
"""
self.context = context
def set_task(self, task: PipelineTask) -> None:
"""Set the pipeline task.
This allows setting the task after the engine has been created,
which is useful when the task needs to be created after the engine.
"""
self.task = task
def set_audio_buffer(self, audio_buffer: "AudioBuffer") -> None:
"""Set the audio buffer.
This allows setting the audio buffer after the engine has been created,
which is useful when the audio buffer needs to be created after the engine.
"""
self._audio_buffer = audio_buffer
def set_stasis_connection(
self, connection: Optional["StasisRTPConnection"]
) -> None:
"""Set the Stasis RTP connection for immediate transfers.
This allows the engine to initiate transfers immediately when XFER
disposition is detected, without waiting for pipeline shutdown.
Args:
connection: The StasisRTPConnection instance, or None for non-Stasis transports
"""
self._stasis_connection = connection
if connection:
logger.debug(
f"Stasis connection set for immediate transfers: {connection.channel_id}"
)
async def handle_llm_text_frame(self, text: str):
"""Accumulate LLM text frames to build reference text."""
self._current_llm_reference_text += text
async def cleanup(self):
"""Clean up engine resources on disconnect."""
# Cancel any pending timeout tasks
if (
self._user_response_timeout_task
and not self._user_response_timeout_task.done()
):
self._user_response_timeout_task.cancel()
# Stop voicemail detection if active
if self._voicemail_detector and hasattr(
self._voicemail_detector, "stop_detection"
):
await self._voicemail_detector.stop_detection()

View file

@ -0,0 +1,305 @@
from __future__ import annotations
"""Callback factory helpers for :pyclass:`~api.services.workflow.pipecat_engine.PipecatEngine`.
Each helper takes a :class:`PipecatEngine` instance and returns an async
callback function suitable for passing to the various pipeline processors.
Separating these helpers into their own module keeps
``pipecat_engine.py`` focused on high-level engine orchestration logic while
encapsulating the callback implementations here for easier maintenance and
unit-testing.
"""
import re
from typing import TYPE_CHECKING, Awaitable, Callable
from loguru import logger
from pipecat.frames.frames import (
LLMFullResponseEndFrame,
LLMFullResponseStartFrame,
TTSSpeakFrame,
)
from pipecat.processors.filters.stt_mute_filter import STTMuteFilter
from pipecat.utils.enums import EndTaskReason
if TYPE_CHECKING:
from pipecat.processors.user_idle_processor import UserIdleProcessor
from api.services.workflow.pipecat_engine import PipecatEngine
# ---------------------------------------------------------------------------
# STT mute handling
# ---------------------------------------------------------------------------
def create_should_mute_callback(
engine: "PipecatEngine",
) -> Callable[[STTMuteFilter], Awaitable[bool]]:
"""Return a callback indicating whether STT should be muted.
STT is muted when *interruptions are **not*** allowed on the current node.
"""
async def callback(_: STTMuteFilter) -> bool: # noqa: D401
if engine._current_node is None:
# Default to not muting if we have no active node yet.
return False
logger.debug(
f"STT mute callback: allow_interrupt={engine._current_node.allow_interrupt}"
)
return not engine._current_node.allow_interrupt
return callback
# ---------------------------------------------------------------------------
# User-idle handling
# ---------------------------------------------------------------------------
def create_user_idle_callback(engine: "PipecatEngine"):
"""Return a callback that handles user-idle timeouts."""
async def handle_user_idle(
user_idle: "UserIdleProcessor", retry_count: int
) -> bool:
logger.debug(f"Handling user_idle, attempt: {retry_count}")
# Check if we're on a StartNode - if yes, directly disconnect
if engine._current_node and engine._current_node.is_start:
logger.debug("User idle on StartNode - disconnecting immediately")
await engine.send_end_task_frame(
EndTaskReason.USER_IDLE_MAX_DURATION_EXCEEDED.value
)
return False
if retry_count == 1:
# Simulate an LLM generation, so that we can have the LLM context
# updated with the new message
await engine.task.queue_frames(
[
LLMFullResponseStartFrame(),
TTSSpeakFrame("Just checking in to see if you're still there."),
LLMFullResponseEndFrame(),
]
)
return True
# Second attempt: terminate the call due to inactivity.
await user_idle.push_frame(
TTSSpeakFrame("It seems like you're busy right now. Have a nice day!")
)
await engine.send_end_task_frame(
EndTaskReason.USER_IDLE_MAX_DURATION_EXCEEDED.value
)
return False
return handle_user_idle
# ---------------------------------------------------------------------------
# Max-duration handling
# ---------------------------------------------------------------------------
def create_max_duration_callback(engine: "PipecatEngine"):
"""Return a callback that ends the task when the max call duration is exceeded."""
async def handle_max_duration():
logger.debug("Max call duration exceeded. Terminating call")
await engine.send_end_task_frame(EndTaskReason.CALL_DURATION_EXCEEDED.value)
return handle_max_duration
# ---------------------------------------------------------------------------
# LLM-generated-text handling
# ---------------------------------------------------------------------------
def create_llm_generated_text_callback(engine: "PipecatEngine"):
"""Return a callback invoked when the LLM emits text (not only tool calls)."""
async def handle_llm_generated_text(): # noqa: D401
logger.debug(
"Generation has text content in current response - deferring context push from set_node"
)
engine._defer_context_push = True
return handle_llm_generated_text
# ---------------------------------------------------------------------------
# Generation-started handling
# ---------------------------------------------------------------------------
def create_generation_started_callback(engine: "PipecatEngine"):
"""Return a callback that resets flags at the start of each LLM generation."""
async def handle_generation_started(): # noqa: D401
logger.debug("LLM generation started - resetting defer flags and tool counters")
engine._defer_context_push = False
engine._pending_function_calls = 0
engine._pending_generated_transition_after_context_push = None
# Clear reference text from previous generation
engine._current_llm_reference_text = ""
return handle_generation_started
# ---------------------------------------------------------------------------
# User-stopped-speaking handling
# ---------------------------------------------------------------------------
def create_user_stopped_speaking_callback(engine: "PipecatEngine"):
"""Return a callback that handles when the user stops speaking.
According to simplified flow:
- For start nodes with wait_for_user_response=True:
- Cancel timeout task if still active
- Transition to next node with _queue_context_frame=False
"""
async def handle_user_stopped_speaking():
# Only handle if current node is a start node with wait_for_user_response
if (
engine._current_node
and engine._current_node.is_start
and engine._current_node.wait_for_user_response
and engine._current_node.out_edges
):
# Cancel timeout task if it's still active
if (
engine._user_response_timeout_task
and not engine._user_response_timeout_task.done()
):
logger.debug("Cancelling user response timeout - user responded")
engine._user_response_timeout_task.cancel()
engine._user_response_timeout_task = None
# Transition to next node
next_node_id = engine._current_node.out_edges[0].target
logger.debug(
f"User stopped speaking after wait_for_user_response - transitioning to: {next_node_id}"
)
# Set flag to not queue context frame since
# it will be pushed by user context aggregator
# we are just setting the context with next node's
# functions and prompts
engine._queue_context_frame = False
# Transition to next node
await engine.set_node(next_node_id)
return handle_user_stopped_speaking
# ---------------------------------------------------------------------------
# User-started-speaking handling
# ---------------------------------------------------------------------------
def create_user_started_speaking_callback(engine: "PipecatEngine"):
"""Return a callback that handles when the user starts speaking.
According to simplified flow:
- For start nodes with wait_for_user_response=True:
- Cancel the timeout timer if it exists (but don't set to None)
"""
async def handle_user_started_speaking():
# Only handle if current node is a start node with wait_for_user_response
if (
engine._current_node
and engine._current_node.is_start
and engine._current_node.wait_for_user_response
and engine._user_response_timeout_task
and not engine._user_response_timeout_task.done()
):
logger.debug(
"User started speaking during wait_for_user_response - cancelling timeout timer"
)
engine._user_response_timeout_task.cancel()
# Don't set to None here - let user_stopped_speaking handle the transition
return handle_user_started_speaking
def create_aggregation_correction_callback(engine: "PipecatEngine"):
"""Create a callback that uses engine's reference text to correct corrupted aggregation."""
def correct_corrupted_aggregation(ref: str, corrupted: str) -> str:
"""Correct corrupted text by aligning it with reference text.
This is a pure function that doesn't depend on engine instance.
"""
# 1) Safety check: if ref (minus spaces) is shorter than corrupted, bail out
# also if corrupted is less than 10 characters, lets also return that since most likely
# Elevenlabs returned the right alignment
alnum_corr = "".join(ch for ch in corrupted if ch.isalnum())
alnum_ref = "".join(ch for ch in ref if ch.isalnum())
if corrupted in ref or len(alnum_ref) < len(alnum_corr) or len(alnum_corr) < 10:
return corrupted
# 2) Find where in `ref` we should start aligning.
# We take the first N (N=10) characters of `corrupted`
# and look for all their occurrences in `ref`.
# We pick the *last* one
prefix = corrupted[:10]
# find all startindices of that prefix in ref
starts = [m.start() for m in re.finditer(re.escape(prefix), ref)]
start_idx = starts[-1] if starts else 0
# 3) Now run the same twopointer scan from start_idx
i, j = start_idx, 0
out_chars = []
while i < len(ref) and j < len(corrupted):
r_ch, c_ch = ref[i], corrupted[j]
if r_ch == c_ch:
out_chars.append(r_ch)
i += 1
j += 1
elif c_ch == " ":
# extra space in corrupted → skip it
j += 1
elif r_ch == " " or r_ch in ".,;:!?":
# missing structural char in corrupted → emit from ref
out_chars.append(r_ch)
i += 1
else:
# letter mismatch → besteffort copy from ref
out_chars.append(r_ch)
i += 1
j += 1
# 4) A final check - the final created output should be exactly
# as corrupted sentence sans whitespace.
alnum_out = "".join([ch for ch in out_chars if ch.isalnum()])
if alnum_out != alnum_corr:
return corrupted
# 5) Join and return exactly what we built
return "".join(out_chars)
def correct_aggregation(corrupted: str) -> str:
reference = engine._current_llm_reference_text
if not reference:
logger.warning("No reference text available for aggregation correction")
return corrupted
# Apply the correction algorithm
corrected = correct_corrupted_aggregation(reference, corrupted)
return corrected
return correct_aggregation

View file

@ -0,0 +1,90 @@
from __future__ import annotations
from typing import Any, Dict, List
from google.genai.types import (
Content,
Part,
)
from pipecat.adapters.schemas.function_schema import FunctionSchema
from pipecat.adapters.schemas.tools_schema import ToolsSchema
from pipecat.services.google.llm import GoogleLLMContext
from pipecat.services.openai.llm import OpenAILLMContext
from api.utils.template_renderer import render_template
__all__ = [
"get_function_schema",
"update_llm_context",
"render_template",
]
def get_function_schema(
function_name: str,
description: str,
*,
properties: Dict[str, Any] | None = None,
required: List[str] | None = None,
) -> FunctionSchema:
"""Create a FunctionSchema definition that can later be transformed into
the provider-specific format (OpenAI, Gemini, etc.).
The helper keeps the public signature backward-compatible callers that
only pass ``function_name`` and ``description`` continue to work and will
define a parameter-less function.
"""
return FunctionSchema(
name=function_name,
description=description,
properties=properties or {},
required=required or [],
)
def update_llm_context(
context: OpenAILLMContext,
system_message: Dict[str, Any],
functions: List[FunctionSchema],
) -> None:
"""Update *context* with an up-to-date system message and tool list.
This helper removes any previous system messages before inserting the new
*system_message* at the top of the conversation history and then instructs
the LLM which *functions* (a.k.a. tools) are currently available.
"""
# Wrap the provided function schemas in a ToolsSchema so that the adapter
# associated with the current LLM service can convert them to the correct
# provider-specific representation when required.
tools_schema = ToolsSchema(standard_tools=functions)
if isinstance(context, GoogleLLMContext):
context.system_message = system_message["content"]
if functions:
# Lets only call set_tools if we have functions, else Gemini will
# throw an exception
context.set_tools(tools_schema)
if context.messages[-1].role != "user":
# Google expects the last message should end with user message
context.add_message(Content(role="user", parts=[Part(text="...")]))
return
# In case of OpenAILLMContext, replace the system message with incoming system message
previous_interactions = context.messages
# Filter out old system messages but keep user/assistant/function content.
messages: List[Dict[str, Any]] = [system_message]
messages.extend(
interaction
for interaction in previous_interactions
if interaction["role"] != "system"
)
context.set_messages(messages)
if functions:
context.set_tools(tools_schema)

View file

@ -0,0 +1,192 @@
from __future__ import annotations
import json
import os
from typing import TYPE_CHECKING, Any, List
from loguru import logger
from openai import AsyncOpenAI
from opentelemetry import trace
from pipecat.services.openai.llm import OpenAILLMContext
from pipecat.utils.tracing.service_attributes import add_llm_span_attributes
from api.services.pipecat.tracing_config import is_tracing_enabled
from api.services.workflow.dto import ExtractionVariableDTO
if TYPE_CHECKING:
from api.services.workflow.pipecat_engine import PipecatEngine
class VariableExtractionManager:
"""Helper that registers and executes the \"extract_variables\" tool.
The manager is responsible for two things:
1. Registering a callable with the LLM service so that the tool can be
invoked from within the model.
2. Executing the extraction in a background task while maintaining
correct bookkeeping and optional OpenTelemetry tracing.
"""
def __init__(self, engine: "PipecatEngine") -> None: # noqa: F821
# We keep a reference to the engine so we can reuse its context
# and update internal counters / extracted variable state.
self._engine = engine
self._context = engine.context
self._model = "gpt-4o"
# ------------------------------------------------------------------
# Internal helpers
# ------------------------------------------------------------------
async def _perform_extraction(
self,
extraction_variables: List[ExtractionVariableDTO],
parent_ctx: Any,
extraction_prompt: str = "",
) -> dict:
"""Run the actual extraction chat completion and post-process the result."""
# ------------------------------------------------------------------
# Build the prompt that instructs the model to extract the variables.
# ------------------------------------------------------------------
vars_description = "\n".join(
f"- {v.name} ({v.type}): {v.prompt}" for v in extraction_variables
)
# ------------------------------------------------------------------
# Build a normalized representation of the existing conversation so the
# extractor works with both OpenAI-style (dict) messages and Google
# Gemini `Content` objects.
# ------------------------------------------------------------------
def _get_role_and_content(msg: Any) -> tuple[str | None, str | None]:
"""Return a pair of (role, content) for the given message.
The logic supports both OpenAI-style dict messages and Google
`Content` objects that expose ``role`` and ``parts`` attributes.
Only plain textual content is extracted image parts, tool call
placeholders, etc. are ignored for the purpose of variable
extraction.
"""
# --------------------------------------------------------------
# OpenAI format → simple dict with ``role`` and ``content`` keys
# --------------------------------------------------------------
if isinstance(msg, dict):
role = msg.get("role")
content_field = msg.get("content")
# Content can be a str, list of segments, or None.
if isinstance(content_field, str):
content = content_field
elif isinstance(content_field, list):
# Collapse all text parts into a single string.
texts = [
segment.get("text", "")
for segment in content_field
if isinstance(segment, dict) and segment.get("type") == "text"
]
content = " ".join(texts) if texts else None
else:
content = None
return role, content
# --------------------------------------------------------------
# Google Gemini format → ``Content`` object with ``parts`` list
# --------------------------------------------------------------
role_attr = getattr(msg, "role", None)
parts_attr = getattr(msg, "parts", None)
if role_attr is None or parts_attr is None:
return None, None # Unrecognised message format
role = (
"assistant" if role_attr == "model" else role_attr
) # Normalise role name
# Collect textual parts only (ignore images, function calls, etc.)
texts: list[str] = []
for part in parts_attr:
text_val = getattr(part, "text", None)
if text_val:
texts.append(text_val)
content = " ".join(texts) if texts else None
return role, content
conversation_lines: list[str] = []
for msg in self._context.messages:
role, content = _get_role_and_content(msg)
if role in ("assistant", "user") and content:
conversation_lines.append(f"{role}: {content}")
conversation_history = "\n".join(conversation_lines)
system_prompt = (
"You are an assistant tasked with extracting structured data from the conversation. "
"Return ONLY a valid JSON object with the requested variables as top-level keys. Do not wrap the JSON in markdown." # noqa: E501
)
# Use provided extraction_prompt as system prompt, or default
system_prompt = (
system_prompt + "\n\n" + extraction_prompt
if extraction_prompt
else system_prompt
)
user_prompt = (
"\n\nVariables to extract:\n"
f"{vars_description}"
"\n\nConversation history:\n"
f"{conversation_history}"
)
extraction_context = OpenAILLMContext()
extraction_messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
]
extraction_context.set_messages(extraction_messages)
# ------------------------------------------------------------------
# Use independent OpenAI client for LLM call
# ------------------------------------------------------------------
client = AsyncOpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
# Direct API call - no pipeline involvement
response = await client.chat.completions.create(
model=self._model,
messages=extraction_messages,
temperature=0.0,
response_format={"type": "json_object"},
)
llm_response = response.choices[0].message.content
if is_tracing_enabled():
tracer = trace.get_tracer("pipecat")
with tracer.start_as_current_span(
"variable_extraction", context=parent_ctx
) as span:
add_llm_span_attributes(
span,
service_name="OpenAILLMService",
model=self._model,
operation_name="variable_extraction",
messages=json.dumps(extraction_messages),
output=llm_response,
stream=False,
parameters={"temperature": 0.0, "response_format": "json_object"},
)
# ------------------------------------------------------------------
# Parse the assistant output fall back to raw text if it is not valid JSON.
# ------------------------------------------------------------------
try:
extracted = json.loads(llm_response)
except json.JSONDecodeError:
logger.warning(
"Extractor returned invalid JSON; storing raw content instead."
)
extracted = {"raw": llm_response}
logger.debug(f"Extracted variables: {extracted}")
return extracted

View file

@ -0,0 +1,448 @@
from __future__ import annotations
import asyncio
import io
import json
import os
import tempfile
import wave
from typing import TYPE_CHECKING, Optional
from langfuse import get_client
from loguru import logger
from openai import AsyncOpenAI
from opentelemetry import context as otel_context
from pipecat.utils.enums import EndTaskReason
from pipecat.utils.tracing.context_registry import get_current_turn_context
from api.db import db_client
from api.services.pipecat.tracing_config import is_tracing_enabled
from api.tasks.arq import enqueue_job
from api.tasks.function_names import FunctionNames
if TYPE_CHECKING:
from api.services.workflow.pipecat_engine import PipecatEngine
DEFAULT_VOICEMAIL_PROMPT = """
You are analyzing the beginning of a phone call to determine if it's a voicemail greeting.
Common voicemail indicators:
- "You've reached the voicemail of..."
- "Please leave a message after the beep"
- "I'm not available right now"
- "Press 1 to leave a message"
- Robotic or pre-recorded voice quality mentioned
- Background music or hold music references
Transcript: {transcript}
Respond with a JSON object:
{
"is_voicemail": true/false,
"confidence": 0.0-1.0,
"reasoning": "Brief explanation"
}
"""
class VoicemailDetector:
"""
Autonomous voicemail detection system that operates independently of the main pipeline.
"""
def __init__(self, detection_duration: float = 15.0, workflow_run_id: int = None):
self.detection_duration = detection_duration
self.audio_buffer = bytearray()
self.is_detecting = False
self.workflow_run_id = workflow_run_id
self._langfuse_client = get_client()
# We will set the sample rate when we receive the audio packet
self._sample_rate = None
# Task management
self._detection_task: Optional[asyncio.Task] = None
self._is_cancelled = False
self._engine: Optional[PipecatEngine] = None
# Event for audio collection completion
self._audio_collected_event = asyncio.Event()
# ------------------------------------------------------------------
# Utility helpers
# ------------------------------------------------------------------
def _current_duration_seconds(self) -> float:
"""Return the duration (in seconds) of the audio currently in the buffer."""
if self._sample_rate:
return len(self.audio_buffer) / (self._sample_rate * 2)
return 0.0
async def handle_audio_data(
self, processor, pcm: bytes, sample_rate: int, num_channels: int
):
"""Handle incoming audio data without affecting pipeline."""
if not self.is_detecting or self._is_cancelled:
return
# Store the actual sample rate from the first audio packet
if self._sample_rate is None:
self._sample_rate = sample_rate
logger.debug(f"Voicemail detector using sample rate: {sample_rate}")
# Add to buffer without resampling
self.audio_buffer.extend(pcm)
# Check if we've collected enough audio
current_duration = self._current_duration_seconds()
if current_duration >= self.detection_duration:
self._audio_collected_event.set()
async def start_detection(self, engine: PipecatEngine):
"""Start voicemail detection process."""
logger.info("Starting voicemail detection")
self.is_detecting = True
self._is_cancelled = False
self._engine = engine
self._audio_collected_event.clear()
# Start detection in background
self._detection_task = asyncio.create_task(self._run_detection_with_timeout())
async def stop_detection(self):
"""Stop detection immediately (called on disconnect)."""
logger.info("Stopping voicemail detection due to disconnect")
self._is_cancelled = True
self.is_detecting = False
# Set the event to unblock any waiting tasks
self._audio_collected_event.set()
# Cancel ongoing detection task
if self._detection_task and not self._detection_task.done():
self._detection_task.cancel()
# Clear audio buffer
self.audio_buffer.clear()
# Wait for tasks to complete cancellation
if self._detection_task:
try:
await self._detection_task
except asyncio.CancelledError:
pass
async def _run_detection_with_timeout(self):
"""Run detection with proper timeout and cancellation handling."""
try:
# Wait for audio collection or cancellation directly
await self._wait_for_audio_collection()
# Check if cancelled during collection
if self._is_cancelled:
logger.info("Detection cancelled during audio collection")
return
# Process detection
await self._process_detection()
except asyncio.CancelledError:
logger.info("Voicemail detection task cancelled")
except Exception as e:
logger.error(f"Error in voicemail detection: {e}")
finally:
self.is_detecting = False
async def _wait_for_audio_collection(self):
"""Wait for audio buffer to fill or timeout."""
try:
# Wait for either audio collection completion or timeout
await asyncio.wait_for(
self._audio_collected_event.wait(),
timeout=self.detection_duration + 2.0,
)
if not self._is_cancelled:
current_duration = self._current_duration_seconds()
logger.info(
f"Collected {current_duration:.1f}s of audio for voicemail detection (sample rate: {self._sample_rate}Hz)"
)
except asyncio.TimeoutError:
if not self._is_cancelled:
current_duration = self._current_duration_seconds()
logger.warning("Audio collection timeout exceeded")
logger.info(
f"Proceeding with {current_duration:.1f}s of audio (sample rate: {self._sample_rate}Hz)"
)
async def _process_detection(self):
"""Process the collected audio to detect voicemail."""
if not self.audio_buffer or not self._engine:
logger.warning("No audio buffer or engine available for detection")
return
try:
# Convert PCM to WAV once for both transcription and storage
wav_data = self._create_wav_from_pcm(bytes(self.audio_buffer))
# Transcribe audio
logger.info("Transcribing audio for voicemail detection")
transcript = await self._transcribe_audio(wav_data)
if not transcript:
logger.warning("No transcript obtained from audio")
# Still upload the raw recording so data pipeline has it
if self.workflow_run_id:
await self._save_voicemail_audio(wav_data, 0.0, False)
return
logger.info(
f"Voicemail detection transcript obtained: {transcript[:100]}..."
)
# Analyze transcript
result = await self._analyze_transcript(transcript)
# Extract common fields
confidence = result.get("confidence", 0.0)
reasoning = result.get("reasoning", "No reasoning provided")
# Save voicemail audio to S3 once for data pipeline (include duration in filename)
s3_path = None
if self.workflow_run_id:
s3_path = await self._save_voicemail_audio(
wav_data, confidence, result.get("is_voicemail")
)
# Take action based on result
if result.get("is_voicemail", False):
logger.info(
f"Voicemail detected with confidence {confidence}: {reasoning}"
)
# Update workflow run with voicemail tags
if self.workflow_run_id:
# Fetch the workflow run from database
workflow_run = await db_client.get_workflow_run_by_id(
self.workflow_run_id
)
if workflow_run:
call_tags = workflow_run.gathered_context.get("call_tags", [])
call_tags.extend(["voicemail_detected", "not_connected"])
await db_client.update_workflow_run(
run_id=workflow_run.id,
gathered_context={
"call_tags": call_tags,
"voicemail_transcript": transcript,
"voicemail_confidence": confidence,
},
)
# Send end task frame with metadata (including optional S3 path)
await self._engine.send_end_task_frame(
reason=EndTaskReason.VOICEMAIL_DETECTED.value,
additional_metadata={
"voicemail_transcript": transcript,
"voicemail_confidence": confidence,
"voicemail_reasoning": reasoning,
"voicemail_detection_duration": self.detection_duration,
"voicemail_audio_s3_path": s3_path,
},
abort_immediately=True,
)
else:
logger.info("No voicemail detected, continuing normal conversation")
except Exception as e:
logger.error(f"Error processing voicemail detection: {e}")
async def _transcribe_audio(self, wav_data: bytes) -> str:
"""Transcribe audio using OpenAI API directly.
Args:
wav_data: WAV formatted audio data
"""
client = AsyncOpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
# Direct API call - no pipeline involvement
response = await client.audio.transcriptions.create(
file=("audio.wav", wav_data, "audio/wav"),
model="whisper-1", # Using whisper-1 as it's more stable for transcription
language="en",
temperature=0.0,
)
return response.text.strip()
def _create_wav_from_pcm(self, pcm_data: bytes) -> bytes:
"""Convert raw PCM data to WAV format."""
wav_buffer = io.BytesIO()
with wave.open(wav_buffer, "wb") as wav_file:
wav_file.setnchannels(1) # Mono
wav_file.setsampwidth(2) # 16-bit
wav_file.setframerate(self._sample_rate)
wav_file.writeframes(pcm_data)
wav_buffer.seek(0)
return wav_buffer.read()
async def _analyze_transcript(self, transcript: str) -> dict:
"""Analyze transcript using independent OpenAI client."""
# Capture the current turn context for proper span nesting
parent_context = get_current_turn_context()
client = AsyncOpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
langfuse_prompt = None
try:
langfuse_prompt = self._langfuse_client.get_prompt(
"production/voicemail_detection"
)
prompt = langfuse_prompt.compile(transcript=transcript)
except Exception as e:
logger.warning(f"Error getting prompt from Langfuse: {e}")
prompt = DEFAULT_VOICEMAIL_PROMPT.replace("{transcript}", transcript)
messages = [
{
"role": "system",
"content": prompt,
}
]
# When we have a parent OpenTelemetry context, we need to activate it
# so that Langfuse's OTEL tracer will automatically pick it up
if parent_context and is_tracing_enabled():
# Activate the parent context for this scope
token = otel_context.attach(parent_context)
try:
# Start Langfuse generation - it will automatically use the active OTEL context
langfuse_generation = None
try:
langfuse_generation = self._langfuse_client.start_generation(
name="voicemail_detection",
model="gpt-4o",
input=messages,
metadata={
"temperature": 0.0,
"detection_duration": self.detection_duration,
"transcript_length": len(transcript),
},
prompt=langfuse_prompt,
)
except Exception as e:
logger.warning(f"Error starting Langfuse generation: {e}")
# Direct API call
response = await client.chat.completions.create(
model="gpt-4o",
messages=messages,
temperature=0.0,
response_format={"type": "json_object"},
)
llm_response = response.choices[0].message.content
# Update and end Langfuse generation
if langfuse_generation:
try:
langfuse_generation.update(
output=llm_response,
usage_details={
"prompt_tokens": response.usage.prompt_tokens
if response.usage
else 0,
"completion_tokens": response.usage.completion_tokens
if response.usage
else 0,
"total_tokens": response.usage.total_tokens
if response.usage
else 0,
},
)
langfuse_generation.end()
except Exception as e:
logger.warning(f"Error updating Langfuse generation: {e}")
finally:
# Detach the context
otel_context.detach(token)
else:
# No parent context or tracing disabled - just make the API call
response = await client.chat.completions.create(
model="gpt-4o",
messages=messages,
temperature=0.0,
response_format={"type": "json_object"},
)
llm_response = response.choices[0].message.content
# Parse response
try:
return json.loads(llm_response)
except json.JSONDecodeError:
logger.warning("Invalid JSON response from voicemail detection")
return {
"is_voicemail": False,
"confidence": 0.0,
"reasoning": "Invalid response",
}
async def _save_voicemail_audio(
self, wav_data: bytes, confidence: float, is_voicemail: bool
) -> Optional[str]:
"""Save voicemail audio to temp file and enqueue task to upload to S3.
Args:
wav_data: WAV formatted audio data
confidence: Detection confidence score
is_voicemail: Whether it was detected as voicemail
Returns:
The expected S3 object key (bucket path). The actual upload happens asynchronously.
"""
try:
# Create filename with prediction, confidence and duration
duration_seconds = self._current_duration_seconds()
prediction = "voicemail" if is_voicemail else "not_voicemail"
confidence_int = int(confidence * 100)
duration_int = int(duration_seconds)
s3_key = f"voicemail_detections/{self.workflow_run_id}_{prediction}_{confidence_int}_{duration_int}.wav"
# Write WAV data to temp file - DO NOT delete it here, the async task will handle cleanup
with tempfile.NamedTemporaryFile(
suffix=".wav",
delete=False, # Important: don't delete immediately
prefix=f"voicemail_{self.workflow_run_id}_",
) as tmp_file:
tmp_file.write(wav_data)
tmp_file.flush()
temp_file_path = tmp_file.name
logger.info(f"Saved voicemail audio to temp file: {temp_file_path}")
# Enqueue async task to upload to S3
await enqueue_job(
FunctionNames.UPLOAD_VOICEMAIL_AUDIO_TO_S3,
self.workflow_run_id,
temp_file_path,
s3_key,
)
logger.info(f"Enqueued voicemail audio upload task for: {s3_key}")
return s3_key
except Exception as e:
logger.error(f"Failed to save voicemail audio: {e}")
# Clean up temp file if task enqueue failed
if "temp_file_path" in locals() and os.path.exists(temp_file_path):
try:
os.remove(temp_file_path)
except Exception as cleanup_error:
logger.warning(
f"Failed to cleanup temp file after error: {cleanup_error}"
)
return None

View file

View file

@ -0,0 +1,164 @@
{
"nodes": [
{
"id": "915",
"type": "agentNode",
"position": {
"x": 633,
"y": 324
},
"data": {
"prompt": "You are a voice agent whose mode of speaking is voice. Ask the user whether they want to talk to a sales guy or a customer service agent",
"name": "Agent"
},
"measured": {
"width": 300,
"height": 100
},
"selected": false,
"dragging": false
},
{
"id": "7598",
"type": "agentNode",
"position": {
"x": 460.1247806640531,
"y": 610.3714977079578
},
"data": {
"prompt": "You are a customer service agent whose mode of communication with the user is voice. Tell them that someone from our team will reach out to them soon",
"name": "Agent"
},
"measured": {
"width": 300,
"height": 100
},
"selected": false,
"dragging": false
},
{
"id": "6919",
"type": "agentNode",
"position": {
"x": 914.666735413607,
"y": 642.9800281289787
},
"data": {
"prompt": "You are a sales representative whose mode of communication with the user is voice. Tell the user that someone from our team will reach out to you soon",
"name": "Agent"
},
"measured": {
"width": 300,
"height": 100
},
"selected": false,
"dragging": false
},
{
"id": "6581",
"type": "startCall",
"position": {
"x": 648,
"y": 35
},
"data": {
"prompt": "Hello, I am Abhishek from Dograh. ",
"is_static": true,
"name": "Start Call",
"is_start": true
},
"measured": {
"width": 300,
"height": 100
},
"selected": false,
"dragging": false
},
{
"id": "1802",
"type": "endCall",
"position": {
"x": 666.7733431033548,
"y": 987.4345801025363
},
"data": {
"prompt": "Thank you for calling Dograh. Have a great day!",
"is_static": true,
"name": "End Call"
},
"measured": {
"width": 300,
"height": 100
},
"selected": false,
"dragging": false
}
],
"edges": [
{
"animated": true,
"type": "custom",
"source": "915",
"target": "7598",
"id": "xy-edge__915-7598",
"selected": false,
"data": {
"condition": "The customer wants to talk to a customer service agent",
"label": "customer service agent"
}
},
{
"animated": true,
"type": "custom",
"source": "915",
"target": "6919",
"id": "xy-edge__915-6919",
"selected": false,
"data": {
"condition": "customer wants to talk to a sales representative",
"label": "sales representative"
}
},
{
"animated": true,
"type": "custom",
"source": "6581",
"target": "915",
"id": "xy-edge__6581-915",
"selected": false,
"data": {
"condition": "Always take this route",
"label": "Always take this route"
}
},
{
"animated": true,
"type": "custom",
"source": "7598",
"target": "1802",
"id": "xy-edge__7598-1802",
"selected": false,
"data": {
"condition": "end call",
"label": "end call"
}
},
{
"animated": true,
"type": "custom",
"source": "6919",
"target": "1802",
"id": "xy-edge__6919-1802",
"selected": false,
"data": {
"condition": "end call",
"label": "end call"
}
}
],
"viewport": {
"x": 0,
"y": 0,
"zoom": 1
}
}

View file

@ -0,0 +1,192 @@
from unittest.mock import Mock
from api.services.workflow.pipecat_engine_callbacks import (
create_aggregation_correction_callback,
)
def test_aggregation_fixer():
"""Validate the aggregation correction algorithm using a helper that
creates a fresh callback for every (reference, corrupted) pair.
The production callback now needs a PipecatEngine instance with the
`_current_llm_reference_text` set. For test-friendliness we mock a bare
object providing just that attribute for each assertion so the original
two-argument test cases remain unchanged.
"""
def fixer(reference: str, corrupted: str) -> str: # noqa: D401
mock_engine = Mock()
mock_engine._current_llm_reference_text = reference
return create_aggregation_correction_callback(mock_engine)(corrupted)
##### Trailing extra Chars #####
assert (
fixer(
"Good Morning Mr NARGES , My name is Alex and I am calling you from Consumer Services. The reason of my call today, as I can see in our records that you are making your monthly credit card payments on time, but you STILL carry a balance of over 7 thousand dollars, right?",
"My name is Alex and I am calling you from Cons umer Services. The reason of my call today, as I can see in our records that you are making your monthly credit card payments on time, but you STILL carry a balance of over 7 thousand dollars, right?",
)
== "My name is Alex and I am calling you from Consumer Services. The reason of my call today, as I can see in our records that you are making your monthly credit card payments on time, but you STILL carry a balance of over 7 thousand dollars, right?"
), "leading_whole_sentence"
# Whole sentences
assert (
fixer(
"Good Morning Mr NARGES , My name is Alex and I am calling you from Consumer Services. The reason of my call today, as I can see in our records that you are making your monthly credit card payments on time, but you STILL carry a balance of over 7 thousand dollars, right?",
"Good Morning Mr NAR GES , My name is Alex and I am calling you from Cons umer Services. The reason of my call today, as I can see in our records that you are making your monthly credit card payments on time, but you STILL carry a balance of over 7 thousand dollars, right?",
)
== "Good Morning Mr NARGES , My name is Alex and I am calling you from Consumer Services. The reason of my call today, as I can see in our records that you are making your monthly credit card payments on time, but you STILL carry a balance of over 7 thousand dollars, right?"
), "whole_sentences"
# With a period in the end
assert (
fixer(
"Good Morning Mr NARGES , My name is Alex and I am calling you from Consumer Services. The reason of my call today, as I can see in our records that you are making your monthly credit card payments on time, but you STILL carry a balance of over 7 thousand dollars, right?",
"Good Morning Mr NAR GES , My name is Alex and I am calling you from Cons umer Services.",
)
== "Good Morning Mr NARGES , My name is Alex and I am calling you from Consumer Services."
), "period_end"
# without a period in the end
assert (
fixer(
"Good Morning Mr NARGES , My name is Alex and I am calling you from Consumer Services. The reason of my call today, as I can see in our records that you are making your monthly credit card payments on time, but you STILL carry a balance of over 7 thousand dollars, right?",
"Good Morning Mr NAR GES , My name is Alex and I am calling you from Cons umer Services",
)
== "Good Morning Mr NARGES , My name is Alex and I am calling you from Consumer Services"
), "without_period_end"
# Extra space in the end
assert (
fixer(
"Good Morning Mr NARGES , My name is Alex and I am calling you from Consumer Services. The reason of my call today, as I can see in our records that you are making your monthly credit card payments on time, but you STILL carry a balance of over 7 thousand dollars, right?",
"Good Morning Mr NAR GES , My name is Alex and I am calling you from Cons umer Services ",
)
== "Good Morning Mr NARGES , My name is Alex and I am calling you from Consumer Services"
), "extra_space"
# Multiple spaces in corruption
assert (
fixer(
"Good Morning Mr NARGES , My name is Alex and I am calling you from Consumer Services. The reason of my call today, as I can see in our records that you are making your monthly credit card payments on time, but you STILL carry a balance of over 7 thousand dollars, right?",
"Good Morning Mr NAR GES , My name is Alex and I am calling you from Cons umer Servi ces ",
)
== "Good Morning Mr NARGES , My name is Alex and I am calling you from Consumer Services"
), "multiple_space"
# Multiple spaces in corruption ending in a whitespace
assert (
fixer(
"Good Morning Mr NARGES , My name is Alex and I am calling you from Consumer Services. The reason of my call today, as I can see in our records that you are making your monthly credit card payments on time, but you STILL carry a balance of over 7 thousand dollars, right?",
"Good Morning Mr NAR GES , My name is Alex and I am calling you from Cons umer Servi ces. ",
)
== "Good Morning Mr NARGES , My name is Alex and I am calling you from Consumer Services. "
), "multiple_space_end_ws"
##### Leading extra Chars #####
# Whole sentences
assert (
fixer(
"Good Morning Mr NARGES , My name is Alex and I am calling you from Consumer Services. The reason of my call today, as I can see in our records that you are making your monthly credit card payments on time, but you STILL carry a balance of over 7 thousand dollars, right?",
"My name is Alex and I am calling you from Cons umer Services. The reason of my call today, as I can see in our records that you are making your monthly credit card payments on time, but you STILL carry a balance of over 7 thousand dollars, right?",
)
== "My name is Alex and I am calling you from Consumer Services. The reason of my call today, as I can see in our records that you are making your monthly credit card payments on time, but you STILL carry a balance of over 7 thousand dollars, right?"
), "leading_whole_sentence"
# With a period in the end
assert (
fixer(
"Good Morning Mr NARGES , My name is Alex and I am calling you from Consumer Services. The reason of my call today, as I can see in our records that you are making your monthly credit card payments on time, but you STILL carry a balance of over 7 thousand dollars, right?",
"My name is Alex and I am calling you from Cons umer Services.",
)
== "My name is Alex and I am calling you from Consumer Services."
), "leading_period_end"
# without a period in the end
assert (
fixer(
"Good Morning Mr NARGES , My name is Alex and I am calling you from Consumer Services. The reason of my call today, as I can see in our records that you are making your monthly credit card payments on time, but you STILL carry a balance of over 7 thousand dollars, right?",
"My name is Alex and I am calling you from Cons umer Services",
)
== "My name is Alex and I am calling you from Consumer Services"
), "leading_without_period_end"
# Extra space in the end
assert (
fixer(
"Good Morning Mr NARGES , My name is Alex and I am calling you from Consumer Services. The reason of my call today, as I can see in our records that you are making your monthly credit card payments on time, but you STILL carry a balance of over 7 thousand dollars, right?",
"My name is Alex and I am calling you from Cons umer Services ",
)
== "My name is Alex and I am calling you from Consumer Services"
), "leading_extra_space"
# Multiple spaces in corruption
assert (
fixer(
"Good Morning Mr NARGES , My name is Alex and I am calling you from Consumer Services. The reason of my call today, as I can see in our records that you are making your monthly credit card payments on time, but you STILL carry a balance of over 7 thousand dollars, right?",
"My name is Alex and I am calling you from Cons umer Servi ces ",
)
== "My name is Alex and I am calling you from Consumer Services"
), "leading_multiple_space"
# Multiple spaces in corruption ending in a whitespace
assert (
fixer(
"Good Morning Mr NARGES , My name is Alex and I am calling you from Consumer Services. The reason of my call today, as I can see in our records that you are making your monthly credit card payments on time, but you STILL carry a balance of over 7 thousand dollars, right?",
"My name is Alex and I am calling you from Cons umer Servi ces. ",
)
== "My name is Alex and I am calling you from Consumer Services. "
), "leading_multiple_space_end_ws"
# Whitespace
assert fixer("", "") == ""
# Missing reference
assert (
fixer("", "My name is Alex and I am calling you from Cons umer Servi ces.")
== "My name is Alex and I am calling you from Cons umer Servi ces."
), "missing_reference"
# Smaller reference
assert (
fixer(
"My name is Alex",
"My name is Alex and I am calling you from Cons umer Servi ces.",
)
== "My name is Alex and I am calling you from Cons umer Servi ces."
), "smaller_reference"
# Unrelated reference
assert (
fixer(
"Hello Hello",
"My name is Alex and I am calling you from Cons umer Servi ces.",
)
== "My name is Alex and I am calling you from Cons umer Servi ces."
), "unrelated_reference"
def test_create_aggregation_correction_callback():
"""Test the new aggregation correction callback creator."""
# Mock engine with reference text
mock_engine = Mock()
mock_engine._current_llm_reference_text = "Good Morning Mr NARGES, My name is Alex and I am calling you from Consumer Services."
# Create callback
callback = create_aggregation_correction_callback(mock_engine)
# Test correction
corrected = callback(
"Good Morning Mr NAR GES, My name is Alex and I am calling you from Cons umer Services."
)
assert (
corrected
== "Good Morning Mr NARGES, My name is Alex and I am calling you from Consumer Services."
)
# Test with no reference text
mock_engine._current_llm_reference_text = ""
corrected = callback("Some corrupted text")
assert corrected == "Some corrupted text" # Should return as-is when no reference

View file

@ -0,0 +1,128 @@
from unittest.mock import Mock
import pytest
from pipecat.processors.aggregators.llm_response import LLMAssistantAggregatorParams
from pipecat.services.openai.llm import OpenAILLMContext
from api.services.workflow.pipecat_engine import PipecatEngine
from api.services.workflow.pipecat_engine_callbacks import (
create_generation_started_callback,
)
class TestAggregationIntegration:
"""Integration tests for the TTS aggregation correction flow."""
@pytest.mark.asyncio
async def test_engine_reference_text_tracking(self):
"""Test that the engine properly tracks LLM reference text."""
# Create mock dependencies
mock_task = Mock()
mock_llm = Mock()
mock_context = Mock(spec=OpenAILLMContext)
mock_tts = Mock()
mock_workflow = Mock()
mock_workflow.start_node_id = "start"
mock_workflow.nodes = {
"start": Mock(is_start=True, is_static=True, is_end=False, out_edges=[])
}
# Create engine
engine = PipecatEngine(
task=mock_task,
llm=mock_llm,
context=mock_context,
tts=mock_tts,
workflow=mock_workflow,
call_context_vars={},
workflow_run_id=1,
)
# Test initial state
assert engine._current_llm_reference_text == ""
# Test accumulating LLM text
await engine.handle_llm_text_frame("Hello ")
assert engine._current_llm_reference_text == "Hello "
await engine.handle_llm_text_frame("world!")
assert engine._current_llm_reference_text == "Hello world!"
# Test generation started callback clears reference text
callback = create_generation_started_callback(engine)
await callback()
assert engine._current_llm_reference_text == ""
@pytest.mark.asyncio
async def test_aggregation_correction_callback_creation(self):
"""Test creating the aggregation correction callback."""
# Create mock engine
mock_task = Mock()
mock_llm = Mock()
mock_context = Mock(spec=OpenAILLMContext)
mock_workflow = Mock()
engine = PipecatEngine(
task=mock_task,
llm=mock_llm,
context=mock_context,
workflow=mock_workflow,
call_context_vars={},
workflow_run_id=1,
)
# Set reference text
engine._current_llm_reference_text = "Hello, world! How are you?"
# Create correction callback
callback = engine.create_aggregation_correction_callback()
# Test correction - note that trailing punctuation might be stripped if not in corrupted text
corrected = callback("Hello world How are you")
assert corrected == "Hello, world! How are you"
def test_llm_assistant_aggregator_params_with_callback(self):
"""Test that LLMAssistantAggregatorParams accepts correction callback."""
def mock_callback(text: str) -> str:
return text.upper()
params = LLMAssistantAggregatorParams(
expect_stripped_words=True, correct_aggregation_callback=mock_callback
)
assert params.expect_stripped_words is True
assert params.correct_aggregation_callback is not None
assert params.correct_aggregation_callback("hello") == "HELLO"
@pytest.mark.asyncio
async def test_pipeline_callbacks_processor_llm_text_frame(self):
"""Test that PipelineEngineCallbacksProcessor handles LLMTextFrame."""
from pipecat.frames.frames import LLMTextFrame
from pipecat.processors.frame_processor import FrameDirection
from api.services.pipecat.pipeline_engine_callbacks_processor import (
PipelineEngineCallbacksProcessor,
)
# Track callback invocations
callback_invoked = False
callback_text = None
async def mock_llm_text_callback(text: str):
nonlocal callback_invoked, callback_text
callback_invoked = True
callback_text = text
# Create processor with callback
processor = PipelineEngineCallbacksProcessor(
llm_text_frame_callback=mock_llm_text_callback
)
# Process LLMTextFrame
frame = LLMTextFrame(text="Hello world")
await processor.process_frame(frame, FrameDirection.DOWNSTREAM)
# Verify callback was invoked
assert callback_invoked is True
assert callback_text == "Hello world"

View file

@ -0,0 +1,31 @@
from api.services.pricing.cost_calculator import cost_calculator
def test_cost_calculator():
"""Test function to verify cost calculation works"""
sample_usage = {
"llm": {
"OpenAILLMService#0|||gpt-4.1-mini": {
"prompt_tokens": 45380,
"completion_tokens": 496,
"total_tokens": 45876,
"cache_read_input_tokens": 0,
"cache_creation_input_tokens": 0,
}
},
"tts": {"ElevenLabsTTSService#0|||eleven_flash_v2_5": 2399},
"stt": {"DeepgramSTTService#0|||nova-3-general": 177.21536946296692},
"call_duration_seconds": 179,
}
result = cost_calculator.calculate_total_cost(sample_usage)
assert result["llm_cost"] == 45380 * 0.40 / 1_000_000 + 496 * 1.60 / 1_000_000
assert result["tts_cost"] == 2399 * 0.0256 / 1_000
assert result["stt_cost"] == 177.21536946296692 / 60 * 0.0077
assert (
abs(
result["total"]
- (result["llm_cost"] + result["tts_cost"] + result["stt_cost"])
)
< 1e-10
)

View file

@ -0,0 +1,11 @@
import pytest
from api.services.workflow.dto import ReactFlowDTO
@pytest.mark.asyncio
async def test_dto():
# assert no exceptions are raised
with open("services/workflow/test/definitions/rf-1.json", "r") as f:
dto = ReactFlowDTO.model_validate_json(f.read())
assert dto is not None

View file

@ -0,0 +1,159 @@
from unittest.mock import AsyncMock, Mock
import pytest
from pipecat.frames.frames import StartInterruptionFrame
from pipecat.processors.aggregators.llm_response import LLMAssistantAggregatorParams
from pipecat.services.openai.llm import (
OpenAIAssistantContextAggregator,
OpenAILLMContext,
)
class TestInterruptionCorrection:
"""Test that TTS aggregation correction works during interruptions."""
@pytest.mark.asyncio
async def test_openai_interruption_with_correction(self):
"""Test OpenAI assistant context aggregator applies correction during interruption."""
# Create mock context
mock_context = Mock(spec=OpenAILLMContext)
mock_context.get_messages.return_value = []
mock_context.add_message = Mock()
# Create correction callback
def correction_callback(text: str) -> str:
# Simulate fixing corrupted text
if text == "Hello world how are you":
return "Hello world, how are you"
return text
# Create aggregator with correction callback
params = LLMAssistantAggregatorParams(
expect_stripped_words=True, correct_aggregation_callback=correction_callback
)
aggregator = OpenAIAssistantContextAggregator(
context=mock_context, params=params
)
# Set up aggregation state
aggregator._aggregation = "Hello world how are you"
aggregator._current_llm_response_id = "test-id"
aggregator._response_function_messages = {}
aggregator._function_calls_in_progress = {}
aggregator._started = 1
# Mock push_context_frame and reset methods
aggregator.push_context_frame = AsyncMock()
aggregator.reset = AsyncMock()
# Process interruption
interruption_frame = StartInterruptionFrame()
await aggregator._handle_interruptions(interruption_frame)
# Verify the corrected text was added to context
mock_context.add_message.assert_called_once()
added_message = mock_context.add_message.call_args[0][0]
assert added_message["role"] == "assistant"
assert (
added_message["content"]
== "Hello world, how are you <<interrupted_by_user>>"
)
@pytest.mark.asyncio
async def test_google_interruption_with_correction(self):
"""Test Google assistant context aggregator applies correction during interruption."""
from pipecat.services.google.llm import (
Content,
GoogleAssistantContextAggregator,
)
# Create mock context
mock_context = Mock(spec=OpenAILLMContext)
mock_context.get_messages.return_value = []
mock_context.add_message = Mock()
# Create correction callback
def correction_callback(text: str) -> str:
# Simulate fixing corrupted text
if text == "I am here to help":
return "I am here to help"
return text
# Create aggregator with correction callback
params = LLMAssistantAggregatorParams(
expect_stripped_words=True, correct_aggregation_callback=correction_callback
)
aggregator = GoogleAssistantContextAggregator(
context=mock_context, params=params
)
# Set up aggregation state
aggregator._aggregation = "I am here to help"
aggregator._current_llm_response_id = "test-id"
aggregator._response_function_messages = {}
aggregator._function_calls_in_progress = {}
aggregator._started = 1
# Mock push_context_frame and reset methods
aggregator.push_context_frame = AsyncMock()
aggregator.reset = AsyncMock()
# Process interruption
interruption_frame = StartInterruptionFrame()
await aggregator._handle_interruptions(interruption_frame)
# Verify the corrected text was added to context
mock_context.add_message.assert_called_once()
added_content = mock_context.add_message.call_args[0][0]
# Google uses Content objects
assert isinstance(added_content, Content)
assert added_content.role == "model"
assert len(added_content.parts) == 1
assert (
added_content.parts[0].text == "I am here to help <<interrupted_by_user>>"
)
@pytest.mark.asyncio
async def test_interruption_correction_error_handling(self):
"""Test that interruption handling continues even if correction callback fails."""
# Create mock context
mock_context = Mock(spec=OpenAILLMContext)
mock_context.get_messages.return_value = []
mock_context.add_message = Mock()
# Create correction callback that raises error
def failing_callback(text: str) -> str:
raise ValueError("Correction failed")
# Create aggregator with failing callback
params = LLMAssistantAggregatorParams(
expect_stripped_words=True, correct_aggregation_callback=failing_callback
)
aggregator = OpenAIAssistantContextAggregator(
context=mock_context, params=params
)
# Set up aggregation state
aggregator._aggregation = "Some text"
aggregator._current_llm_response_id = "test-id"
aggregator._response_function_messages = {}
aggregator._function_calls_in_progress = {}
aggregator._started = 1
# Mock push_context_frame and reset methods
aggregator.push_context_frame = AsyncMock()
aggregator.reset = AsyncMock()
# Process interruption - should not raise
interruption_frame = StartInterruptionFrame()
await aggregator._handle_interruptions(interruption_frame)
# Verify the original text was still added (fallback behavior)
mock_context.add_message.assert_called_once()
added_message = mock_context.add_message.call_args[0][0]
assert added_message["role"] == "assistant"
assert added_message["content"] == "Some text <<interrupted_by_user>>"

View file

View file

@ -0,0 +1,51 @@
import ast
from typing import Any, Dict
def safe_calculator(expr: str) -> float:
"""
Parse arithmetic expressions using ast and support + - * / ** and parentheses.
"""
allowed_nodes = {
ast.Expression,
ast.BinOp,
ast.UnaryOp,
ast.Add,
ast.Sub,
ast.Mult,
ast.Div,
ast.Pow,
ast.USub,
ast.UAdd,
ast.Constant,
ast.Load,
ast.Mod,
}
node = ast.parse(expr, mode="eval")
if not all(isinstance(n, tuple(allowed_nodes)) for n in ast.walk(node)):
raise ValueError("Unsupported expression")
return eval(compile(node, "<safe_calculator>", mode="eval"))
def get_calculator_tools() -> list[Dict[str, Any]]:
"""Get calculator tool definitions for LLM function calling."""
return [
{
"type": "function",
"function": {
"name": "safe_calculator",
"description": "Perform simple arithmetic calculations",
"parameters": {
"type": "object",
"properties": {
"expression": {
"type": "string",
"description": "Arithmetic expression to evaluate (supports +, -, *, /, **, %, and parentheses). Example: 2000 + 5000",
}
},
"required": ["expression"],
},
},
}
]

View file

@ -0,0 +1,196 @@
"""Time tools for LLM function calling - timezone and time conversion utilities."""
from datetime import datetime, timedelta
from typing import Any, Dict, Optional
from zoneinfo import ZoneInfo
from pydantic import BaseModel
class TimeResult(BaseModel):
"""Result model for time queries."""
timezone: str
datetime: str
is_dst: bool
class TimeConversionResult(BaseModel):
"""Result model for time conversions."""
source: TimeResult
target: TimeResult
time_difference: str
def get_local_timezone(local_tz_override: Optional[str] = None) -> str:
"""
Get the local timezone name using system timezone.
Falls back to UTC if cannot determine.
"""
if local_tz_override:
return local_tz_override
try:
# Try to get timezone from datetime
local_tz = datetime.now().astimezone().tzinfo
if hasattr(local_tz, "key"):
return local_tz.key
# Try to parse from string representation
tz_str = str(local_tz)
if tz_str and not tz_str.startswith("UTC"):
return tz_str
# Default to UTC
return "UTC"
except:
return "UTC"
def get_current_time(timezone: str) -> Dict[str, Any]:
"""
Get current time in specified timezone.
Args:
timezone: IANA timezone name (e.g., 'America/New_York', 'Europe/London')
Returns:
Dict containing timezone, datetime, and DST status
"""
try:
tz = ZoneInfo(timezone)
current_time = datetime.now(tz)
result = TimeResult(
timezone=timezone,
datetime=current_time.isoformat(timespec="seconds"),
is_dst=bool(current_time.dst()),
)
return result.model_dump()
except Exception as e:
raise ValueError(f"Invalid timezone '{timezone}': {str(e)}")
def convert_time(
source_timezone: str, time: str, target_timezone: str
) -> Dict[str, Any]:
"""
Convert time between timezones.
Args:
source_timezone: Source IANA timezone name
time: Time to convert in 24-hour format (HH:MM)
target_timezone: Target IANA timezone name
Returns:
Dict containing source time, target time, and time difference
"""
try:
source_tz = ZoneInfo(source_timezone)
target_tz = ZoneInfo(target_timezone)
except Exception as e:
raise ValueError(f"Invalid timezone: {str(e)}")
# Parse time
try:
parsed_time = datetime.strptime(time, "%H:%M").time()
except ValueError:
raise ValueError("Invalid time format. Expected HH:MM in 24-hour format")
# Create datetime objects
now = datetime.now(source_tz)
source_time = datetime(
now.year,
now.month,
now.day,
parsed_time.hour,
parsed_time.minute,
tzinfo=source_tz,
)
# Convert to target timezone
target_time = source_time.astimezone(target_tz)
# Calculate time difference
source_offset = source_time.utcoffset() or timedelta()
target_offset = target_time.utcoffset() or timedelta()
hours_difference = (target_offset - source_offset).total_seconds() / 3600
# Format time difference
if hours_difference.is_integer():
time_diff_str = f"{int(hours_difference):+d}h"
else:
# For fractional hours like Nepal's UTC+5:45
hours = int(hours_difference)
minutes = int(abs(hours_difference - hours) * 60)
if hours_difference >= 0:
time_diff_str = f"+{hours}h{minutes:02d}m"
else:
time_diff_str = f"{hours}h{minutes:02d}m"
result = TimeConversionResult(
source=TimeResult(
timezone=source_timezone,
datetime=source_time.isoformat(timespec="seconds"),
is_dst=bool(source_time.dst()),
),
target=TimeResult(
timezone=target_timezone,
datetime=target_time.isoformat(timespec="seconds"),
is_dst=bool(target_time.dst()),
),
time_difference=time_diff_str,
)
return result.model_dump()
# Tool definitions for LLM function calling
def get_time_tools(local_tz_override: Optional[str] = None) -> list[Dict[str, Any]]:
"""Get tool definitions with dynamic local timezone."""
local_tz = local_tz_override or get_local_timezone()
return [
{
"type": "function",
"function": {
"name": "get_current_time",
"description": "Get current time in a specific timezone",
"parameters": {
"type": "object",
"properties": {
"timezone": {
"type": "string",
"description": f"IANA timezone name (e.g., 'America/New_York', 'Europe/London'). Use '{local_tz}' as local timezone if no timezone provided by the user.",
}
},
"required": ["timezone"],
},
},
},
{
"type": "function",
"function": {
"name": "convert_time",
"description": "Convert time between timezones",
"parameters": {
"type": "object",
"properties": {
"source_timezone": {
"type": "string",
"description": f"Source IANA timezone name (e.g., 'America/New_York', 'Europe/London'). Use '{local_tz}' as local timezone if no source timezone provided by the user.",
},
"time": {
"type": "string",
"description": "Time to convert in 24-hour format (HH:MM)",
},
"target_timezone": {
"type": "string",
"description": f"Target IANA timezone name (e.g., 'Asia/Tokyo', 'America/San_Francisco'). Use '{local_tz}' as local timezone if no target timezone provided by the user.",
},
},
"required": ["source_timezone", "time", "target_timezone"],
},
},
},
]

View file

@ -0,0 +1,261 @@
import re
from collections import Counter
from typing import Dict, List
from api.services.workflow.dto import EdgeDataDTO, NodeDataDTO, NodeType, ReactFlowDTO
from api.services.workflow.errors import ItemKind, WorkflowError
class Edge:
def __init__(self, source: str, target: str, data: EdgeDataDTO):
self.source = source
self.target = target
self.label = data.label
self.condition = data.condition
self.data = data
def get_function_name(self):
return re.sub(r"[^a-z0-9]", "_", self.label.lower())
def __eq__(self, other):
if not isinstance(other, Edge):
return False
return self.source == other.source and self.target == other.target
def __hash__(self):
return hash((self.source, self.target))
class Node:
def __init__(self, id: str, node_type: NodeType, data: NodeDataDTO):
self.id, self.node_type, self.data = id, node_type, data
self.out: Dict[str, "Node"] = {} # forward nodes
self.out_edges: List[Edge] = [] # forward edges with properties
self.name = data.name
self.prompt = data.prompt
self.is_static = data.is_static
self.is_start = data.is_start
self.is_end = data.is_end
self.allow_interrupt = data.allow_interrupt
self.extraction_enabled = data.extraction_enabled
self.extraction_prompt = data.extraction_prompt
self.extraction_variables = data.extraction_variables
self.add_global_prompt = data.add_global_prompt
self.wait_for_user_response = data.wait_for_user_response
self.wait_for_user_response_timeout = data.wait_for_user_response_timeout
self.detect_voicemail = data.detect_voicemail
self.delayed_start = data.delayed_start
self.delayed_start_duration = data.delayed_start_duration
self.data = data
class WorkflowGraph:
"""
*All* business invariants (acyclic, cardinality, etc.) are verified here.
The constructor accepts a validated ReactFlowDTO.
"""
def __init__(self, dto: ReactFlowDTO):
# build adjacency list
self.nodes: Dict[str, Node] = {
n.id: Node(n.id, n.type, n.data) for n in dto.nodes
}
# Store all edges
self.edges: List[Edge] = []
for e in dto.edges:
source_node = self.nodes[e.source]
target_node = self.nodes[e.target]
# Create the edge with properties from dto
edge = Edge(source=e.source, target=e.target, data=e.data)
# Add to the edge list
self.edges.append(edge)
# Add to the source node's outgoing edges
source_node.out_edges.append(edge)
# Set up the node references for backward compatibility
source_node.out[target_node.id] = target_node
self._validate_graph()
# Get a reference to the start node
self.start_node_id = [n.id for n in dto.nodes if n.data.is_start][0]
# Get a reference to the global node
try:
self.global_node_id = [
n.id for n in dto.nodes if n.type == NodeType.globalNode
][0]
except IndexError:
self.global_node_id = None
# -----------------------------------------------------------
# validators
# -----------------------------------------------------------
def _validate_graph(self) -> None:
errors: list[WorkflowError] = []
# TODO: Figure out what kind of cyclic contraints can be applied, since there can be a cycle in the graph
# try:
# self._assert_acyclic()
# except ValueError as e:
# errors.append(
# WorkflowError(
# kind=ItemKind.workflow, id=None, field=None, message=str(e)
# )
# )
errors.extend(self._assert_start_node())
errors.extend(self._assert_end_node())
errors.extend(self._assert_connection_counts())
errors.extend(self._assert_global_node())
errors.extend(self._assert_node_configs())
if errors:
raise ValueError(errors)
def _assert_acyclic(self):
color: Dict[str, str] = {} # white / gray / black
def dfs(n: Node):
if color.get(n.id) == "gray": # back-edge
raise ValueError("workflow contains a cycle")
if color.get(n.id) != "black":
color[n.id] = "gray"
for m in n.out.values():
dfs(m)
color[n.id] = "black"
for n in self.nodes.values():
dfs(n)
def _assert_start_node(self):
errors: list[WorkflowError] = []
start_node = [n for n in self.nodes.values() if n.data.is_start]
if not start_node:
errors.append(
WorkflowError(
kind=ItemKind.workflow,
id=None,
field=None,
message="Workflow must have exactly one start node",
)
)
elif len(start_node) > 1:
errors.append(
WorkflowError(
kind=ItemKind.workflow,
id=None,
field=None,
message="Workflow must have exactly one start node",
)
)
return errors
def _assert_global_node(self):
errors: list[WorkflowError] = []
global_node = [
n for n in self.nodes.values() if n.node_type == NodeType.globalNode
]
if not len(global_node) <= 1:
errors.append(
WorkflowError(
kind=ItemKind.workflow,
id=None,
field=None,
message="Workflow must have at most one global node",
)
)
return errors
def _assert_end_node(self):
errors: list[WorkflowError] = []
end_node = [n for n in self.nodes.values() if n.data.is_end]
if not end_node:
errors.append(
WorkflowError(
kind=ItemKind.workflow,
id=None,
field=None,
message="Workflow must have exactly one end node",
)
)
elif len(end_node) > 1:
errors.append(
WorkflowError(
kind=ItemKind.workflow,
id=None,
field=None,
message="Workflow must have exactly one end node",
)
)
return errors
def _assert_connection_counts(self):
errors: list[WorkflowError] = []
out_deg = Counter()
in_deg = Counter()
for n in self.nodes.values(): # init counters
out_deg[n.id] = in_deg[n.id] = 0
for src, n in self.nodes.items(): # compute degrees
for m in n.out.values():
out_deg[src] += 1
in_deg[m.id] += 1
for n in self.nodes.values():
in_d, out_d = in_deg[n.id], out_deg[n.id]
match n.node_type:
case NodeType.startNode:
if in_d != 0 or out_d < 1:
errors.append(
WorkflowError(
kind=ItemKind.node,
id=n.id,
field=None,
message=f"StartNode must have at least 1 outgoing edge",
)
)
case NodeType.endNode:
if in_d < 1 or out_d != 0:
errors.append(
WorkflowError(
kind=ItemKind.node,
id=n.id,
field=None,
message=f"EndNode must have at least 1 incoming edge",
)
)
case NodeType.agentNode:
if in_d < 1 or out_d < 1:
errors.append(
WorkflowError(
kind=ItemKind.node,
id=n.id,
field=None,
message=f"Worker must have at least 1 incoming and 1 outgoing edge",
)
)
return errors
def _assert_node_configs(self):
"""Validate node-specific configuration constraints."""
errors: list[WorkflowError] = []
for node in self.nodes.values():
# Validate StartNode constraints
if node.node_type == NodeType.startNode:
# No specific validations for start node at this time
pass
return errors