dograh/api/services/workflow/pipecat_engine.py

781 lines
32 KiB
Python
Raw Normal View History

2026-01-13 14:55:48 +05:30
from typing import TYPE_CHECKING, Awaitable, Callable, Optional, Union
2025-09-09 14:37:32 +05:30
from api.services.workflow.disposition_mapper import (
apply_disposition_mapping,
get_organization_id_from_workflow_run,
)
from api.services.workflow.workflow import Node, WorkflowGraph
2025-09-09 14:37:32 +05:30
from pipecat.frames.frames import (
CancelFrame,
EndFrame,
FunctionCallResultProperties,
TTSSpeakFrame,
)
from pipecat.pipeline.task import PipelineTask
from pipecat.processors.aggregators.llm_context import LLMContext
2025-09-09 14:37:32 +05:30
from pipecat.services.llm_service import FunctionCallParams
from pipecat.transports.base_transport import BaseTransport
from pipecat.utils.enums import EndTaskReason
if TYPE_CHECKING:
from api.services.telephony.stasis_rtp_connection import StasisRTPConnection
2025-09-09 14:37:32 +05:30
from pipecat.services.anthropic.llm import AnthropicLLMService
from pipecat.services.google.llm import GoogleLLMService
from pipecat.services.openai.llm import OpenAILLMService
LLMService = Union[OpenAILLMService, AnthropicLLMService, GoogleLLMService]
import asyncio
from loguru import logger
from api.services.workflow import pipecat_engine_callbacks as engine_callbacks
from api.services.workflow.pipecat_engine_custom_tools import CustomToolManager
2025-09-09 14:37:32 +05:30
from api.services.workflow.pipecat_engine_utils import (
get_function_schema,
render_template,
update_llm_context,
)
from api.services.workflow.pipecat_engine_variable_extractor import (
VariableExtractionManager,
)
from api.services.workflow.tools.calculator import get_calculator_tools, safe_calculator
from api.services.workflow.tools.knowledge_base import (
get_knowledge_base_tool,
retrieve_from_knowledge_base,
)
2025-09-09 14:37:32 +05:30
from api.services.workflow.tools.timezone import (
convert_time,
get_current_time,
get_time_tools,
)
from pipecat.processors.filters.stt_mute_filter import STTMuteFilter
from pipecat.utils.tracing.context_registry import get_current_turn_context
2025-09-09 14:37:32 +05:30
class PipecatEngine:
def __init__(
self,
*,
task: Optional[PipelineTask] = None,
llm: Optional["LLMService"] = None,
context: Optional[LLMContext] = None,
2025-09-09 14:37:32 +05:30
transport: Optional[BaseTransport] = None,
workflow: WorkflowGraph,
call_context_vars: dict,
workflow_run_id: Optional[int] = None,
node_transition_callback: Optional[
Callable[[str, Optional[str]], Awaitable[None]]
] = None,
embeddings_api_key: Optional[str] = None,
embeddings_model: Optional[str] = None,
2025-09-09 14:37:32 +05:30
):
self.task = task
self.llm = llm
self.context = context
self.transport = transport
self.workflow = workflow
self._call_context_vars = call_context_vars
self._workflow_run_id = workflow_run_id
self._node_transition_callback = node_transition_callback
2025-09-09 14:37:32 +05:30
self._initialized = False
self._client_disconnected = False
self._call_disposed = False
2025-09-09 14:37:32 +05:30
self._current_node: Optional[Node] = None
self._gathered_context: dict = {}
self._user_response_timeout_task: Optional[asyncio.Task] = None
self._call_disposition: Optional[str] = None
# Stasis connection for immediate transfers
self._stasis_connection: Optional["StasisRTPConnection"] = None
# Will be set later in initialize() when we have
# access to _context
self._variable_extraction_manager = None
# Lazy loaded built-in function schemas
self._builtin_function_schemas: Optional[list[dict]] = None
# Track current LLM reference text for TTS aggregation correction
self._current_llm_generation_reference_text: str = ""
2025-09-09 14:37:32 +05:30
# Custom tool manager (initialized in initialize())
self._custom_tool_manager: Optional[CustomToolManager] = None
# Embeddings configuration (passed from run_pipeline.py)
self._embeddings_api_key: Optional[str] = embeddings_api_key
self._embeddings_model: Optional[str] = embeddings_model
async def _get_organization_id(self) -> Optional[int]:
"""Get and cache the organization ID from workflow run."""
if self._custom_tool_manager:
return await self._custom_tool_manager.get_organization_id()
# Fallback for when manager is not yet initialized
return await get_organization_id_from_workflow_run(self._workflow_run_id)
2025-09-09 14:37:32 +05:30
@property
def builtin_function_schemas(self) -> list[dict]:
"""Get built-in function schemas (calculator and timezone tools)."""
if self._builtin_function_schemas is None:
self._builtin_function_schemas = []
# Transform calculator tools to get_function_schema format
for tool in get_calculator_tools():
func = tool["function"]
schema = get_function_schema(
func["name"],
func["description"],
properties=func["parameters"]["properties"],
required=func["parameters"]["required"],
)
self._builtin_function_schemas.append(schema)
# Transform timezone tools to get_function_schema format
for tool in get_time_tools():
func = tool["function"]
schema = get_function_schema(
func["name"],
func["description"],
properties=func["parameters"]["properties"],
required=func["parameters"]["required"],
)
self._builtin_function_schemas.append(schema)
return self._builtin_function_schemas
async def initialize(self):
# TODO: May be set_node in a separate task so that we return from initialize immediately
if self._initialized:
logger.warning(f"{self.__class__.__name__} already initialized")
return
try:
self._initialized = True
# Helper that encapsulates variable extraction logic
self._variable_extraction_manager = VariableExtractionManager(self)
# Helper that encapsulates custom tool management
self._custom_tool_manager = CustomToolManager(self)
2025-09-09 14:37:32 +05:30
# Add current time in EST (America/New_York) to gathered context
try:
est_time_result = get_current_time("America/New_York")
# The get_current_time utility returns a dict with 'datetime' field
# Store the ISO formatted datetime string under the key 'time'
self._gathered_context["time"] = est_time_result.get("datetime")
except Exception as e:
logger.error(f"Failed to fetch current EST time: {e}")
# Register built-in functions with the LLM
await self._register_builtin_functions()
await self.set_node(self.workflow.start_node_id)
2025-09-09 14:37:32 +05:30
logger.debug(f"{self.__class__.__name__} initialized")
except Exception as e:
logger.error(f"Error initializing {self.__class__.__name__}: {e}")
raise
def _get_function_schema(self, function_name: str, description: str):
"""Thin wrapper around utils.get_function_schema for backwards compatibility."""
return get_function_schema(function_name, description)
async def _update_llm_context(self, system_message: dict, functions: list[dict]):
"""Delegate context update to the shared workflow.utils implementation."""
update_llm_context(self.context, system_message, functions)
def _format_prompt(self, prompt: str) -> str:
"""Delegate prompt formatting to the shared workflow.utils implementation."""
return render_template(prompt, self._call_context_vars)
async def _create_transition_func(self, name: str, transition_to_node: str):
async def transition_func(function_call_params: FunctionCallParams) -> None:
"""Inner function that handles the node change tool calls"""
logger.info(f"LLM Function Call EXECUTED: {name}")
logger.info(
f"Function: {name} -> transitioning to node: {transition_to_node}"
)
logger.info(f"Arguments: {function_call_params.arguments}")
# Perform variable extraction before transitioning to new node
await self._perform_variable_extraction_if_needed(self._current_node)
# Set context for the new node, so that when the function call result
# frame is received by LLMContextAggregator and an LLM generation
# is done, we have updated context and functions
2026-01-22 20:28:13 +05:30
await self.set_node(transition_to_node)
2025-09-09 14:37:32 +05:30
try:
2025-09-09 14:37:32 +05:30
async def on_context_updated() -> None:
"""
pipecat framework will run this function after the function call result has been updated in the context.
2025-09-09 14:37:32 +05:30
This way, when we do set_node from within this function, and go for LLM completion with updated
system prompts, the context is updated with function call result.
"""
# Queue EndFrame if we just transitioned to EndNode
if self._current_node.is_end:
await self.send_end_task_frame(
EndTaskReason.USER_QUALIFIED.value
)
2025-09-09 14:37:32 +05:30
result = {"status": "done"}
properties = FunctionCallResultProperties(
on_context_updated=on_context_updated,
)
# Call results callback from the pipecat framework
# so that a new llm generation can be triggred if
# required
await function_call_params.result_callback(
result, properties=properties
)
2025-09-09 14:37:32 +05:30
except Exception as e:
logger.error(f"Error in transition function {name}: {str(e)}")
error_result = {"status": "error", "error": str(e)}
await function_call_params.result_callback(error_result)
return transition_func
async def _register_transition_function_with_llm(
self, name: str, transition_to_node: str
):
logger.debug(
f"Registering function {name} to transition to node {transition_to_node} with LLM"
)
# Create transition function
transition_func = await self._create_transition_func(name, transition_to_node)
# Register function with LLM
self.llm.register_function(
name,
transition_func,
cancel_on_interruption=True,
)
async def _register_builtin_functions(self):
"""Register built-in functions (calculator and timezone) with the LLM."""
logger.debug("Registering built-in functions with LLM")
# Register calculator function
async def calculate_func(function_call_params: FunctionCallParams) -> None:
logger.info(f"LLM Function Call EXECUTED: safe_calculator")
logger.info(f"Arguments: {function_call_params.arguments}")
2025-09-09 14:37:32 +05:30
try:
expr = function_call_params.arguments.get("expression", "")
result = safe_calculator(expr)
await function_call_params.result_callback(
{"expression": expr, "result": result}
2025-09-09 14:37:32 +05:30
)
except Exception as e:
await function_call_params.result_callback({"error": str(e)})
2025-09-09 14:37:32 +05:30
# Register timezone functions
async def get_current_time_func(
function_call_params: FunctionCallParams,
) -> None:
logger.info(f"LLM Function Call EXECUTED: get_current_time")
logger.info(f"Arguments: {function_call_params.arguments}")
2025-09-09 14:37:32 +05:30
try:
timezone = function_call_params.arguments.get("timezone", "UTC")
result = get_current_time(timezone)
await function_call_params.result_callback(result)
2025-09-09 14:37:32 +05:30
except Exception as e:
await function_call_params.result_callback({"error": str(e)})
2025-09-09 14:37:32 +05:30
async def convert_time_func(function_call_params: FunctionCallParams) -> None:
logger.info(f"LLM Function Call EXECUTED: convert_time")
logger.info(f"Arguments: {function_call_params.arguments}")
2025-09-09 14:37:32 +05:30
try:
result = convert_time(
function_call_params.arguments.get("source_timezone"),
function_call_params.arguments.get("time"),
function_call_params.arguments.get("target_timezone"),
)
await function_call_params.result_callback(result)
2025-09-09 14:37:32 +05:30
except Exception as e:
await function_call_params.result_callback({"error": str(e)})
2025-09-09 14:37:32 +05:30
# Register all built-in functions
self.llm.register_function("safe_calculator", calculate_func)
self.llm.register_function("get_current_time", get_current_time_func)
self.llm.register_function("convert_time", convert_time_func)
async def _register_knowledge_base_function(
self, document_uuids: list[str]
) -> None:
"""Register knowledge base retrieval function with the LLM.
Args:
document_uuids: List of document UUIDs to filter the search by
"""
logger.debug(
f"Registering knowledge base retrieval function with {len(document_uuids)} document(s)"
)
async def retrieve_kb_func(function_call_params: FunctionCallParams) -> None:
logger.info("LLM Function Call EXECUTED: retrieve_from_knowledge_base")
logger.info(f"Arguments: {function_call_params.arguments}")
try:
query = function_call_params.arguments.get("query", "")
organization_id = await self._get_organization_id()
if not organization_id:
raise ValueError(
"Organization ID not available for knowledge base retrieval"
)
if not self._embeddings_api_key:
raise ValueError(
"Embeddings API key not configured. Please set your API key in "
"Model Configurations > Embedding."
)
result = await retrieve_from_knowledge_base(
query=query,
organization_id=organization_id,
document_uuids=document_uuids,
limit=3, # Return top 3 most relevant chunks
embeddings_api_key=self._embeddings_api_key,
embeddings_model=self._embeddings_model,
)
await function_call_params.result_callback(result)
except Exception as e:
logger.error(f"Knowledge base retrieval failed: {e}")
await function_call_params.result_callback(
{"error": str(e), "chunks": [], "query": query, "total_results": 0}
)
# Register the function with the LLM
self.llm.register_function("retrieve_from_knowledge_base", retrieve_kb_func)
2025-09-09 14:37:32 +05:30
async def _perform_variable_extraction_if_needed(
self, previous_node: Optional[Node]
) -> None:
"""Perform variable extraction if the previous node had extraction enabled."""
if (
previous_node
and previous_node.extraction_enabled
and previous_node.extraction_variables
):
logger.debug(
f"Scheduling background variable extraction for node: {previous_node.name}"
)
# Capture the current turn context before creating the background task
parent_context = get_current_turn_context()
extraction_prompt = self._format_prompt(previous_node.extraction_prompt)
extraction_variables = previous_node.extraction_variables
async def _background_extraction():
try:
extracted_data = (
await self._variable_extraction_manager._perform_extraction(
extraction_variables, parent_context, extraction_prompt
)
)
self._gathered_context.update(extracted_data)
logger.debug(
f"Background variable extraction completed. Extracted: {extracted_data}"
)
except Exception as e:
logger.error(
f"Error during background variable extraction: {str(e)}"
)
# Fire and forget - extraction happens in background without blocking
asyncio.create_task(_background_extraction())
async def _setup_llm_context_and_start_generation(self, node: Node) -> None:
"""Common method to set up LLM context and queue context frame for non-static nodes."""
# Set node name for tracing
try:
self.context.set_node_name(node.name)
except AttributeError:
logger.warning(f"context has no set_node_name method")
# Register transition functions if not an end node
if not node.is_end:
for outgoing_edge in node.out_edges:
await self._register_transition_function_with_llm(
outgoing_edge.get_function_name(), outgoing_edge.target
)
# Register custom tool handlers for this node
if node.tool_uuids and self._custom_tool_manager:
await self._custom_tool_manager.register_handlers(node.tool_uuids)
# Register knowledge base retrieval handler if node has documents
if node.document_uuids:
await self._register_knowledge_base_function(node.document_uuids)
2025-09-09 14:37:32 +05:30
# Set up system message and functions
(
system_message,
functions,
) = await self._compose_system_message_functions_for_node(node)
await self._update_llm_context(system_message, functions)
async def set_node(self, node_id: str):
"""
Simplified set_node implementation according to v2 PRD.
"""
node = self.workflow.nodes[node_id]
logger.debug(
f"Executing node: name: {node.name} is_static: {node.is_static} allow_interrupt: {node.allow_interrupt} is_end: {node.is_end}"
)
# Track previous node for transition event
previous_node_name = self._current_node.name if self._current_node else None
2025-09-09 14:37:32 +05:30
# Set current node for all nodes (including static ones) so STT mute filter works
self._current_node = node
# Send node transition event if callback is provided
if self._node_transition_callback:
try:
await self._node_transition_callback(node.name, previous_node_name)
except Exception as e:
# Log but don't fail - feedback is non-critical
logger.debug(f"Failed to send node transition event: {e}")
2025-09-09 14:37:32 +05:30
# Handle start nodes
if node.is_start:
await self._handle_start_node(node)
# Handle end nodes
elif node.is_end:
await self._handle_end_node(node)
# Handle normal agent nodes
else:
await self._handle_agent_node(node)
async def _handle_start_node(self, node: Node) -> None:
"""Handle start node execution."""
# Check if delayed start is enabled
if node.delayed_start:
# Use configured duration or default to 3 seconds
delay_duration = node.delayed_start_duration or 2.0
logger.debug(
f"Delayed start enabled - waiting {delay_duration} seconds before speaking"
)
await asyncio.sleep(delay_duration)
if node.is_static:
raise ValueError("Static nodes are not supported!")
2025-09-09 14:37:32 +05:30
else:
# Start generation for non-static start node
await self._setup_llm_context_and_start_generation(node)
async def _handle_end_node(self, node: Node) -> None:
"""Handle end node execution."""
if node.is_static:
raise ValueError("Static nodes are not supported!")
2025-09-09 14:37:32 +05:30
else:
await self._setup_llm_context_and_start_generation(node)
# If this end node has extraction enabled, perform extraction immediately
if node.extraction_enabled and node.extraction_variables:
await self._perform_variable_extraction_if_needed(node)
async def _handle_agent_node(self, node: Node) -> None:
"""Handle agent node execution."""
if node.is_static:
raise ValueError("Static nodes are not supported!")
2025-09-09 14:37:32 +05:30
else:
# Set context and functions for non-static agent node
await self._setup_llm_context_and_start_generation(node)
async def send_end_task_frame(
self,
reason: str,
abort_immediately: bool = False,
):
"""
Centralized method to send EndTaskFrame with metadata including
call_transfer_context and call_context_vars
"""
if self._call_disposed or self._client_disconnected:
# Call is already disposed and client disconnected
logger.debug(
f"Not sending EndFrame since call is already disposed: Call Disposed: {self._call_disposed} Client Disconnected: {self._client_disconnected}"
)
return
self._call_disposed = True
2025-09-09 14:37:32 +05:30
frame_to_push = CancelFrame() if abort_immediately else EndFrame()
# Customer disposition code using their mapping
mapped_disposition = ""
# Apply disposition mapping - first try call_disposition if it is,
# extracted from the call conversation then fall back to reason
call_disposition = self._gathered_context.get("call_disposition", "")
organization_id = await self._get_organization_id()
2025-09-09 14:37:32 +05:30
# If client is disconnected before we get a chance to disconnect from
# the bot, lets consider that as final disposition
if self._client_disconnected:
call_disposition = EndTaskReason.USER_HANGUP.value
2025-09-09 14:37:32 +05:30
if call_disposition:
# If call_disposition exists, map it
mapped_disposition = await apply_disposition_mapping(
call_disposition, organization_id
)
# Store the original and mapped values
self._gathered_context["extracted_call_disposition"] = call_disposition
self._gathered_context["call_disposition"] = mapped_disposition
else:
# Otherwise, map the disconnect reason
mapped_disposition = await apply_disposition_mapping(
reason, organization_id
)
# Store the mapped disconnect reason
self._gathered_context["call_disposition"] = mapped_disposition
# TODO: Generalise this
2025-09-09 14:37:32 +05:30
self._gathered_context["address"] = ", ".join(
[
self._call_context_vars.get("address1", ""),
self._call_context_vars.get("address2", ""),
self._call_context_vars.get("address3", ""),
self._call_context_vars.get("city", ""),
self._call_context_vars.get("state", ""),
self._call_context_vars.get("province", ""),
self._call_context_vars.get("postal_code", ""),
]
)
self._gathered_context["full_name"] = " ".join(
[
self._call_context_vars.get("first_name", ""),
self._call_context_vars.get("middle_initial", ""),
self._call_context_vars.get("last_name", ""),
]
)
self._gathered_context["agent_name"] = "Alex"
self._gathered_context["customer_phone_number"] = self._call_context_vars.get(
"phone", ""
)
self._gathered_context["timezone"] = self._call_context_vars.get("province", "")
self._gathered_context["vendor_id"] = self._call_context_vars.get(
"vendor_lead_code", ""
)
decision_maker = self._gathered_context.get("primary_cardholder", False)
employment_status = self._gathered_context.get("employment_status", "N/A")
call_transfer_context = {
"first_name": self._call_context_vars.get("first_name", ""),
"full_name": self._gathered_context.get("full_name", ""),
"phone": self._call_context_vars.get("phone", ""),
"lead_id": self._call_context_vars.get("lead_id"),
"disposition": mapped_disposition,
"agent_name": self._gathered_context.get("agent_name", "Alex"),
"decision_maker": str(decision_maker),
"employment": employment_status.title() if employment_status else "N/A",
"debts": self._gathered_context.get("total_debt", "N/A"),
"number_of_credit_cards": self._gathered_context.get(
"number_of_credit_cards", "N/A"
),
"time": self._gathered_context.get("time"),
}
logger.debug(
f"gathered_context: {self._gathered_context} call_transfer_context: {call_transfer_context}"
)
# Initiate immediate transfer for Stasis connections when user is qualified
if (
reason == EndTaskReason.USER_QUALIFIED.value
and self._stasis_connection is not None
and not abort_immediately
):
try:
logger.info(
f"Initiating immediate Stasis transfer for channel {self._stasis_connection.channel_id}"
)
await self._stasis_connection.transfer(call_transfer_context)
logger.info("Immediate transfer initiated successfully")
except Exception as e:
logger.error(f"Failed to initiate immediate transfer: {e}")
# Continue with normal flow even if immediate transfer fails
if reason == EndTaskReason.CALL_DURATION_EXCEEDED.value:
await self.task.queue_frame(
TTSSpeakFrame(
"Sorry! It seems like our time has exceeded. Someone from our team will reach out to you soon. Thank you!"
)
)
# Store the original reason for later retrieval in event handler
self._call_disposition = mapped_disposition
logger.debug(
f"Finishing run with reason: {reason}, disposition: {mapped_disposition} queueing frame {frame_to_push}"
)
await self.task.queue_frame(frame_to_push)
async def _compose_system_message_functions_for_node(
self, node: "Node"
) -> tuple[list[dict], list[dict]]:
"""Generate the system messages and function schemas for the given node.
This performs the same formatting logic used when entering a node but
does **not** register the functions with the LLM; callers are
responsible for that.
"""
global_prompt = ""
if self.workflow.global_node_id and node.add_global_prompt:
global_node = self.workflow.nodes[self.workflow.global_node_id]
global_prompt = self._format_prompt(global_node.prompt)
functions: list[dict] = []
# Add built-in function schemas (calculator and timezone tools)
functions.extend(self.builtin_function_schemas)
# Add knowledge base retrieval tool if node has documents
if node.document_uuids:
kb_tool_def = get_knowledge_base_tool(node.document_uuids)
kb_schema = get_function_schema(
kb_tool_def["function"]["name"],
kb_tool_def["function"]["description"],
properties=kb_tool_def["function"]["parameters"].get("properties", {}),
required=kb_tool_def["function"]["parameters"].get("required", []),
)
functions.append(kb_schema)
# Add custom tools from node.tool_uuids
if node.tool_uuids and self._custom_tool_manager:
custom_tool_schemas = await self._custom_tool_manager.get_tool_schemas(
node.tool_uuids
)
functions.extend(custom_tool_schemas)
2025-09-09 14:37:32 +05:30
# Transition functions (schema only; registration handled elsewhere)
for outgoing_edge in node.out_edges:
function_schema = self._get_function_schema(
outgoing_edge.get_function_name(), outgoing_edge.condition
)
functions.append(function_schema)
formatted_node_prompt = self._format_prompt(node.prompt)
system_message = {
"role": "system",
"content": "\n\n".join(
p for p in (global_prompt, formatted_node_prompt) if p
),
}
return system_message, functions
def create_should_mute_callback(self) -> Callable[[STTMuteFilter], Awaitable[bool]]:
"""
This callback is called by STTMuteFilter to determine if the STT should be muted.
"""
return engine_callbacks.create_should_mute_callback(self)
def create_user_idle_handler(self):
2025-09-09 14:37:32 +05:30
"""
Returns a UserIdleHandler that manages user-idle timeouts with state.
The handler tracks retry count and handles escalating prompts.
2025-09-09 14:37:32 +05:30
"""
return engine_callbacks.create_user_idle_handler(self)
2025-09-09 14:37:32 +05:30
def create_max_duration_callback(self):
"""
This callback is called when the call duration exceeds the max duration.
We use this to send the EndTaskFrame.
"""
return engine_callbacks.create_max_duration_callback(self)
def create_generation_started_callback(self):
"""
This callback is called when a new generation starts.
This is used to reset the flags that control the flow of the engine.
"""
return engine_callbacks.create_generation_started_callback(self)
def create_aggregation_correction_callback(self) -> Callable[[str], str]:
"""Create a callback that corrects corrupted aggregation using reference text."""
return engine_callbacks.create_aggregation_correction_callback(self)
def set_context(self, context: LLMContext) -> None:
"""Set the LLM context.
2025-09-09 14:37:32 +05:30
This allows setting the context after the engine has been created,
which is useful when the context needs to be created after the engine.
"""
self.context = context
def set_task(self, task: PipelineTask) -> None:
"""Set the pipeline task.
This allows setting the task after the engine has been created,
which is useful when the task needs to be created after the engine.
"""
self.task = task
def set_stasis_connection(
self, connection: Optional["StasisRTPConnection"]
) -> None:
"""Set the Stasis RTP connection for immediate transfers.
This allows the engine to initiate transfers immediately when XFER
disposition is detected, without waiting for pipeline shutdown.
Args:
connection: The StasisRTPConnection instance, or None for non-Stasis transports
"""
self._stasis_connection = connection
if connection:
logger.debug(
f"Stasis connection set for immediate transfers: {connection.channel_id}"
)
async def handle_llm_text_frame(self, text: str):
"""Accumulate LLM text frames to build reference text."""
self._current_llm_generation_reference_text += text
2025-09-09 14:37:32 +05:30
def handle_client_disconnected(self):
"""Handle client disconnected event."""
self._client_disconnected = True
def is_call_disposed(self):
"""Check whether a call has been disposed by the engine"""
return self._call_disposed
async def get_call_disposition(self) -> Optional[str]:
"""Get the disconnect reason set by the engine."""
if self._call_disposition:
# We would have a _call_disposition variable set if we have initiated
# a disconnect from the bot, i.e we have called send_end_task_frame.
return self._call_disposition
if self._client_disconnected:
return EndTaskReason.USER_HANGUP.value
else:
return EndTaskReason.UNKNOWN.value
async def get_gathered_context(self) -> dict:
"""Get the gathered context including extracted variables."""
return self._gathered_context.copy()
2025-09-09 14:37:32 +05:30
async def cleanup(self):
"""Clean up engine resources on disconnect."""
# Cancel any pending timeout tasks
if (
self._user_response_timeout_task
and not self._user_response_timeout_task.done()
):
self._user_response_timeout_task.cancel()