mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-07 07:55:16 +02:00
* feat: add stt evals * add smart turn as provider * chore: remove deprecations * chore: format files * fix: remove deprecated UserIdleProcessor * fix: remove deprecated TranscriptProcessor * chore: update pipecat submodule * feat: add evals visualisation * fix: trigger llm generation on client connected and pipeline started * chore: update pipecat * chore: update pipecat submodule * Add tests * fix: slow loading of workflow page * chore: update pipecat submodule * Show version after release * Fixes #99 * fix: provider check for websocket connection * Fixes #107 * Fix #96 * chore: fix documentation * fix: cloudonix campaign call error --------- Co-authored-by: Sabiha Khan <sabihak89@gmail.com>
778 lines
32 KiB
Python
778 lines
32 KiB
Python
from typing import TYPE_CHECKING, Awaitable, Callable, Optional, Union
|
|
|
|
from api.services.workflow.disposition_mapper import (
|
|
apply_disposition_mapping,
|
|
get_organization_id_from_workflow_run,
|
|
)
|
|
from api.services.workflow.workflow import Node, WorkflowGraph
|
|
from pipecat.frames.frames import (
|
|
CancelFrame,
|
|
EndFrame,
|
|
FunctionCallResultProperties,
|
|
TTSSpeakFrame,
|
|
)
|
|
from pipecat.pipeline.task import PipelineTask
|
|
from pipecat.processors.aggregators.llm_context import LLMContext
|
|
from pipecat.services.llm_service import FunctionCallParams
|
|
from pipecat.transports.base_transport import BaseTransport
|
|
from pipecat.utils.enums import EndTaskReason
|
|
|
|
if TYPE_CHECKING:
|
|
from api.services.telephony.stasis_rtp_connection import StasisRTPConnection
|
|
from pipecat.services.anthropic.llm import AnthropicLLMService
|
|
from pipecat.services.google.llm import GoogleLLMService
|
|
from pipecat.services.openai.llm import OpenAILLMService
|
|
|
|
LLMService = Union[OpenAILLMService, AnthropicLLMService, GoogleLLMService]
|
|
|
|
import asyncio
|
|
|
|
from loguru import logger
|
|
|
|
from api.services.workflow import pipecat_engine_callbacks as engine_callbacks
|
|
from api.services.workflow.pipecat_engine_custom_tools import CustomToolManager
|
|
from api.services.workflow.pipecat_engine_utils import (
|
|
get_function_schema,
|
|
render_template,
|
|
update_llm_context,
|
|
)
|
|
from api.services.workflow.pipecat_engine_variable_extractor import (
|
|
VariableExtractionManager,
|
|
)
|
|
from api.services.workflow.tools.calculator import get_calculator_tools, safe_calculator
|
|
from api.services.workflow.tools.knowledge_base import (
|
|
get_knowledge_base_tool,
|
|
retrieve_from_knowledge_base,
|
|
)
|
|
from api.services.workflow.tools.timezone import (
|
|
convert_time,
|
|
get_current_time,
|
|
get_time_tools,
|
|
)
|
|
from pipecat.processors.filters.stt_mute_filter import STTMuteFilter
|
|
from pipecat.utils.tracing.context_registry import get_current_turn_context
|
|
|
|
|
|
class PipecatEngine:
|
|
def __init__(
|
|
self,
|
|
*,
|
|
task: Optional[PipelineTask] = None,
|
|
llm: Optional["LLMService"] = None,
|
|
context: Optional[LLMContext] = None,
|
|
transport: Optional[BaseTransport] = None,
|
|
workflow: WorkflowGraph,
|
|
call_context_vars: dict,
|
|
workflow_run_id: Optional[int] = None,
|
|
node_transition_callback: Optional[
|
|
Callable[[str, Optional[str]], Awaitable[None]]
|
|
] = None,
|
|
embeddings_api_key: Optional[str] = None,
|
|
embeddings_model: Optional[str] = None,
|
|
):
|
|
self.task = task
|
|
self.llm = llm
|
|
self.context = context
|
|
self.transport = transport
|
|
self.workflow = workflow
|
|
self._call_context_vars = call_context_vars
|
|
self._workflow_run_id = workflow_run_id
|
|
self._node_transition_callback = node_transition_callback
|
|
self._initialized = False
|
|
self._client_disconnected = False
|
|
self._call_disposed = False
|
|
self._current_node: Optional[Node] = None
|
|
self._gathered_context: dict = {}
|
|
self._user_response_timeout_task: Optional[asyncio.Task] = None
|
|
self._call_disposition: Optional[str] = None
|
|
|
|
# Stasis connection for immediate transfers
|
|
self._stasis_connection: Optional["StasisRTPConnection"] = None
|
|
|
|
# Will be set later in initialize() when we have
|
|
# access to _context
|
|
self._variable_extraction_manager = None
|
|
|
|
# Lazy loaded built-in function schemas
|
|
self._builtin_function_schemas: Optional[list[dict]] = None
|
|
|
|
# Track current LLM reference text for TTS aggregation correction
|
|
self._current_llm_generation_reference_text: str = ""
|
|
|
|
# Custom tool manager (initialized in initialize())
|
|
self._custom_tool_manager: Optional[CustomToolManager] = None
|
|
|
|
# Embeddings configuration (passed from run_pipeline.py)
|
|
self._embeddings_api_key: Optional[str] = embeddings_api_key
|
|
self._embeddings_model: Optional[str] = embeddings_model
|
|
|
|
async def _get_organization_id(self) -> Optional[int]:
|
|
"""Get and cache the organization ID from workflow run."""
|
|
if self._custom_tool_manager:
|
|
return await self._custom_tool_manager.get_organization_id()
|
|
# Fallback for when manager is not yet initialized
|
|
return await get_organization_id_from_workflow_run(self._workflow_run_id)
|
|
|
|
@property
|
|
def builtin_function_schemas(self) -> list[dict]:
|
|
"""Get built-in function schemas (calculator and timezone tools)."""
|
|
if self._builtin_function_schemas is None:
|
|
self._builtin_function_schemas = []
|
|
|
|
# Transform calculator tools to get_function_schema format
|
|
for tool in get_calculator_tools():
|
|
func = tool["function"]
|
|
schema = get_function_schema(
|
|
func["name"],
|
|
func["description"],
|
|
properties=func["parameters"]["properties"],
|
|
required=func["parameters"]["required"],
|
|
)
|
|
self._builtin_function_schemas.append(schema)
|
|
|
|
# Transform timezone tools to get_function_schema format
|
|
for tool in get_time_tools():
|
|
func = tool["function"]
|
|
schema = get_function_schema(
|
|
func["name"],
|
|
func["description"],
|
|
properties=func["parameters"]["properties"],
|
|
required=func["parameters"]["required"],
|
|
)
|
|
self._builtin_function_schemas.append(schema)
|
|
|
|
return self._builtin_function_schemas
|
|
|
|
async def initialize(self):
|
|
# TODO: May be set_node in a separate task so that we return from initialize immediately
|
|
if self._initialized:
|
|
logger.warning(f"{self.__class__.__name__} already initialized")
|
|
return
|
|
try:
|
|
self._initialized = True
|
|
|
|
# Helper that encapsulates variable extraction logic
|
|
self._variable_extraction_manager = VariableExtractionManager(self)
|
|
|
|
# Helper that encapsulates custom tool management
|
|
self._custom_tool_manager = CustomToolManager(self)
|
|
|
|
# Add current time in EST (America/New_York) to gathered context
|
|
try:
|
|
est_time_result = get_current_time("America/New_York")
|
|
# The get_current_time utility returns a dict with 'datetime' field
|
|
# Store the ISO formatted datetime string under the key 'time'
|
|
self._gathered_context["time"] = est_time_result.get("datetime")
|
|
except Exception as e:
|
|
logger.error(f"Failed to fetch current EST time: {e}")
|
|
|
|
# Register built-in functions with the LLM
|
|
await self._register_builtin_functions()
|
|
|
|
await self.set_node(self.workflow.start_node_id)
|
|
|
|
logger.debug(f"{self.__class__.__name__} initialized")
|
|
except Exception as e:
|
|
logger.error(f"Error initializing {self.__class__.__name__}: {e}")
|
|
raise
|
|
|
|
def _get_function_schema(self, function_name: str, description: str):
|
|
"""Thin wrapper around utils.get_function_schema for backwards compatibility."""
|
|
|
|
return get_function_schema(function_name, description)
|
|
|
|
async def _update_llm_context(self, system_message: dict, functions: list[dict]):
|
|
"""Delegate context update to the shared workflow.utils implementation."""
|
|
|
|
update_llm_context(self.context, system_message, functions)
|
|
|
|
def _format_prompt(self, prompt: str) -> str:
|
|
"""Delegate prompt formatting to the shared workflow.utils implementation."""
|
|
|
|
return render_template(prompt, self._call_context_vars)
|
|
|
|
async def _create_transition_func(self, name: str, transition_to_node: str):
|
|
async def transition_func(function_call_params: FunctionCallParams) -> None:
|
|
"""Inner function that handles the node change tool calls"""
|
|
logger.info(f"LLM Function Call EXECUTED: {name}")
|
|
logger.info(
|
|
f"Function: {name} -> transitioning to node: {transition_to_node}"
|
|
)
|
|
logger.info(f"Arguments: {function_call_params.arguments}")
|
|
await self.set_node(transition_to_node)
|
|
try:
|
|
|
|
async def on_context_updated() -> None:
|
|
"""
|
|
pipecat framework will run this function after the function call result has been updated in the context.
|
|
This way, when we do set_node from within this function, and go for LLM completion with updated
|
|
system prompts, the context is updated with function call result.
|
|
"""
|
|
# Perform variable extraction before transitioning to new node
|
|
await self._perform_variable_extraction_if_needed(
|
|
self._current_node
|
|
)
|
|
|
|
# Queue EndFrame if we just transitioned to EndNode
|
|
if self._current_node.is_end:
|
|
await self.send_end_task_frame(
|
|
EndTaskReason.USER_QUALIFIED.value
|
|
)
|
|
|
|
result = {"status": "done"}
|
|
|
|
properties = FunctionCallResultProperties(
|
|
on_context_updated=on_context_updated,
|
|
)
|
|
|
|
# Call results callback from the pipecat framework
|
|
# so that a new llm generation can be triggred if
|
|
# required
|
|
await function_call_params.result_callback(
|
|
result, properties=properties
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"Error in transition function {name}: {str(e)}")
|
|
error_result = {"status": "error", "error": str(e)}
|
|
await function_call_params.result_callback(error_result)
|
|
|
|
return transition_func
|
|
|
|
async def _register_transition_function_with_llm(
|
|
self, name: str, transition_to_node: str
|
|
):
|
|
logger.debug(
|
|
f"Registering function {name} to transition to node {transition_to_node} with LLM"
|
|
)
|
|
|
|
# Create transition function
|
|
transition_func = await self._create_transition_func(name, transition_to_node)
|
|
|
|
# Register function with LLM
|
|
self.llm.register_function(
|
|
name,
|
|
transition_func,
|
|
cancel_on_interruption=True,
|
|
)
|
|
|
|
async def _register_builtin_functions(self):
|
|
"""Register built-in functions (calculator and timezone) with the LLM."""
|
|
logger.debug("Registering built-in functions with LLM")
|
|
|
|
# Register calculator function
|
|
async def calculate_func(function_call_params: FunctionCallParams) -> None:
|
|
logger.info(f"LLM Function Call EXECUTED: safe_calculator")
|
|
logger.info(f"Arguments: {function_call_params.arguments}")
|
|
try:
|
|
expr = function_call_params.arguments.get("expression", "")
|
|
result = safe_calculator(expr)
|
|
await function_call_params.result_callback(
|
|
{"expression": expr, "result": result}
|
|
)
|
|
except Exception as e:
|
|
await function_call_params.result_callback({"error": str(e)})
|
|
|
|
# Register timezone functions
|
|
async def get_current_time_func(
|
|
function_call_params: FunctionCallParams,
|
|
) -> None:
|
|
logger.info(f"LLM Function Call EXECUTED: get_current_time")
|
|
logger.info(f"Arguments: {function_call_params.arguments}")
|
|
try:
|
|
timezone = function_call_params.arguments.get("timezone", "UTC")
|
|
result = get_current_time(timezone)
|
|
await function_call_params.result_callback(result)
|
|
except Exception as e:
|
|
await function_call_params.result_callback({"error": str(e)})
|
|
|
|
async def convert_time_func(function_call_params: FunctionCallParams) -> None:
|
|
logger.info(f"LLM Function Call EXECUTED: convert_time")
|
|
logger.info(f"Arguments: {function_call_params.arguments}")
|
|
try:
|
|
result = convert_time(
|
|
function_call_params.arguments.get("source_timezone"),
|
|
function_call_params.arguments.get("time"),
|
|
function_call_params.arguments.get("target_timezone"),
|
|
)
|
|
await function_call_params.result_callback(result)
|
|
except Exception as e:
|
|
await function_call_params.result_callback({"error": str(e)})
|
|
|
|
# Register all built-in functions
|
|
self.llm.register_function("safe_calculator", calculate_func)
|
|
self.llm.register_function("get_current_time", get_current_time_func)
|
|
self.llm.register_function("convert_time", convert_time_func)
|
|
|
|
async def _register_knowledge_base_function(
|
|
self, document_uuids: list[str]
|
|
) -> None:
|
|
"""Register knowledge base retrieval function with the LLM.
|
|
|
|
Args:
|
|
document_uuids: List of document UUIDs to filter the search by
|
|
"""
|
|
logger.debug(
|
|
f"Registering knowledge base retrieval function with {len(document_uuids)} document(s)"
|
|
)
|
|
|
|
async def retrieve_kb_func(function_call_params: FunctionCallParams) -> None:
|
|
logger.info("LLM Function Call EXECUTED: retrieve_from_knowledge_base")
|
|
logger.info(f"Arguments: {function_call_params.arguments}")
|
|
try:
|
|
query = function_call_params.arguments.get("query", "")
|
|
organization_id = await self._get_organization_id()
|
|
|
|
if not organization_id:
|
|
raise ValueError(
|
|
"Organization ID not available for knowledge base retrieval"
|
|
)
|
|
|
|
if not self._embeddings_api_key:
|
|
raise ValueError(
|
|
"Embeddings API key not configured. Please set your API key in "
|
|
"Model Configurations > Embedding."
|
|
)
|
|
|
|
result = await retrieve_from_knowledge_base(
|
|
query=query,
|
|
organization_id=organization_id,
|
|
document_uuids=document_uuids,
|
|
limit=3, # Return top 3 most relevant chunks
|
|
embeddings_api_key=self._embeddings_api_key,
|
|
embeddings_model=self._embeddings_model,
|
|
)
|
|
|
|
await function_call_params.result_callback(result)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Knowledge base retrieval failed: {e}")
|
|
await function_call_params.result_callback(
|
|
{"error": str(e), "chunks": [], "query": query, "total_results": 0}
|
|
)
|
|
|
|
# Register the function with the LLM
|
|
self.llm.register_function("retrieve_from_knowledge_base", retrieve_kb_func)
|
|
|
|
async def _perform_variable_extraction_if_needed(
|
|
self, previous_node: Optional[Node]
|
|
) -> None:
|
|
"""Perform variable extraction if the previous node had extraction enabled."""
|
|
if (
|
|
previous_node
|
|
and previous_node.extraction_enabled
|
|
and previous_node.extraction_variables
|
|
):
|
|
logger.debug(
|
|
f"Scheduling background variable extraction for node: {previous_node.name}"
|
|
)
|
|
|
|
# Capture the current turn context before creating the background task
|
|
parent_context = get_current_turn_context()
|
|
extraction_prompt = self._format_prompt(previous_node.extraction_prompt)
|
|
extraction_variables = previous_node.extraction_variables
|
|
|
|
async def _background_extraction():
|
|
try:
|
|
extracted_data = (
|
|
await self._variable_extraction_manager._perform_extraction(
|
|
extraction_variables, parent_context, extraction_prompt
|
|
)
|
|
)
|
|
self._gathered_context.update(extracted_data)
|
|
logger.debug(
|
|
f"Background variable extraction completed. Extracted: {extracted_data}"
|
|
)
|
|
except Exception as e:
|
|
logger.error(
|
|
f"Error during background variable extraction: {str(e)}"
|
|
)
|
|
|
|
# Fire and forget - extraction happens in background without blocking
|
|
asyncio.create_task(_background_extraction())
|
|
|
|
async def _setup_llm_context_and_start_generation(self, node: Node) -> None:
|
|
"""Common method to set up LLM context and queue context frame for non-static nodes."""
|
|
# Set node name for tracing
|
|
try:
|
|
self.context.set_node_name(node.name)
|
|
except AttributeError:
|
|
logger.warning(f"context has no set_node_name method")
|
|
|
|
# Register transition functions if not an end node
|
|
if not node.is_end:
|
|
for outgoing_edge in node.out_edges:
|
|
await self._register_transition_function_with_llm(
|
|
outgoing_edge.get_function_name(), outgoing_edge.target
|
|
)
|
|
|
|
# Register custom tool handlers for this node
|
|
if node.tool_uuids and self._custom_tool_manager:
|
|
await self._custom_tool_manager.register_handlers(node.tool_uuids)
|
|
|
|
# Register knowledge base retrieval handler if node has documents
|
|
if node.document_uuids:
|
|
await self._register_knowledge_base_function(node.document_uuids)
|
|
|
|
# Set up system message and functions
|
|
(
|
|
system_message,
|
|
functions,
|
|
) = await self._compose_system_message_functions_for_node(node)
|
|
await self._update_llm_context(system_message, functions)
|
|
|
|
async def set_node(self, node_id: str):
|
|
"""
|
|
Simplified set_node implementation according to v2 PRD.
|
|
"""
|
|
node = self.workflow.nodes[node_id]
|
|
|
|
logger.debug(
|
|
f"Executing node: name: {node.name} is_static: {node.is_static} allow_interrupt: {node.allow_interrupt} is_end: {node.is_end}"
|
|
)
|
|
|
|
# Track previous node for transition event
|
|
previous_node_name = self._current_node.name if self._current_node else None
|
|
|
|
# Set current node for all nodes (including static ones) so STT mute filter works
|
|
self._current_node = node
|
|
|
|
# Send node transition event if callback is provided
|
|
if self._node_transition_callback:
|
|
try:
|
|
await self._node_transition_callback(node.name, previous_node_name)
|
|
except Exception as e:
|
|
# Log but don't fail - feedback is non-critical
|
|
logger.debug(f"Failed to send node transition event: {e}")
|
|
|
|
# Handle start nodes
|
|
if node.is_start:
|
|
await self._handle_start_node(node)
|
|
# Handle end nodes
|
|
elif node.is_end:
|
|
await self._handle_end_node(node)
|
|
# Handle normal agent nodes
|
|
else:
|
|
await self._handle_agent_node(node)
|
|
|
|
async def _handle_start_node(self, node: Node) -> None:
|
|
"""Handle start node execution."""
|
|
# Check if delayed start is enabled
|
|
if node.delayed_start:
|
|
# Use configured duration or default to 3 seconds
|
|
delay_duration = node.delayed_start_duration or 2.0
|
|
logger.debug(
|
|
f"Delayed start enabled - waiting {delay_duration} seconds before speaking"
|
|
)
|
|
await asyncio.sleep(delay_duration)
|
|
|
|
if node.is_static:
|
|
raise ValueError("Static nodes are not supported!")
|
|
else:
|
|
# Start generation for non-static start node
|
|
await self._setup_llm_context_and_start_generation(node)
|
|
|
|
async def _handle_end_node(self, node: Node) -> None:
|
|
"""Handle end node execution."""
|
|
if node.is_static:
|
|
raise ValueError("Static nodes are not supported!")
|
|
else:
|
|
await self._setup_llm_context_and_start_generation(node)
|
|
|
|
# If this end node has extraction enabled, perform extraction immediately
|
|
if node.extraction_enabled and node.extraction_variables:
|
|
await self._perform_variable_extraction_if_needed(node)
|
|
|
|
async def _handle_agent_node(self, node: Node) -> None:
|
|
"""Handle agent node execution."""
|
|
if node.is_static:
|
|
raise ValueError("Static nodes are not supported!")
|
|
else:
|
|
# Set context and functions for non-static agent node
|
|
await self._setup_llm_context_and_start_generation(node)
|
|
|
|
async def send_end_task_frame(
|
|
self,
|
|
reason: str,
|
|
abort_immediately: bool = False,
|
|
):
|
|
"""
|
|
Centralized method to send EndTaskFrame with metadata including
|
|
call_transfer_context and call_context_vars
|
|
"""
|
|
if self._call_disposed or self._client_disconnected:
|
|
# Call is already disposed and client disconnected
|
|
logger.debug(
|
|
f"Not sending EndFrame since call is already disposed: Call Disposed: {self._call_disposed} Client Disconnected: {self._client_disconnected}"
|
|
)
|
|
return
|
|
|
|
self._call_disposed = True
|
|
|
|
frame_to_push = CancelFrame() if abort_immediately else EndFrame()
|
|
|
|
# Customer disposition code using their mapping
|
|
mapped_disposition = ""
|
|
|
|
# Apply disposition mapping - first try call_disposition if it is,
|
|
# extracted from the call conversation then fall back to reason
|
|
call_disposition = self._gathered_context.get("call_disposition", "")
|
|
organization_id = await self._get_organization_id()
|
|
|
|
# If client is disconnected before we get a chance to disconnect from
|
|
# the bot, lets consider that as final disposition
|
|
if self._client_disconnected:
|
|
call_disposition = EndTaskReason.USER_HANGUP.value
|
|
|
|
if call_disposition:
|
|
# If call_disposition exists, map it
|
|
mapped_disposition = await apply_disposition_mapping(
|
|
call_disposition, organization_id
|
|
)
|
|
# Store the original and mapped values
|
|
self._gathered_context["extracted_call_disposition"] = call_disposition
|
|
self._gathered_context["call_disposition"] = mapped_disposition
|
|
else:
|
|
# Otherwise, map the disconnect reason
|
|
mapped_disposition = await apply_disposition_mapping(
|
|
reason, organization_id
|
|
)
|
|
# Store the mapped disconnect reason
|
|
self._gathered_context["call_disposition"] = mapped_disposition
|
|
|
|
# TODO: Generalise this
|
|
self._gathered_context["address"] = ", ".join(
|
|
[
|
|
self._call_context_vars.get("address1", ""),
|
|
self._call_context_vars.get("address2", ""),
|
|
self._call_context_vars.get("address3", ""),
|
|
self._call_context_vars.get("city", ""),
|
|
self._call_context_vars.get("state", ""),
|
|
self._call_context_vars.get("province", ""),
|
|
self._call_context_vars.get("postal_code", ""),
|
|
]
|
|
)
|
|
self._gathered_context["full_name"] = " ".join(
|
|
[
|
|
self._call_context_vars.get("first_name", ""),
|
|
self._call_context_vars.get("middle_initial", ""),
|
|
self._call_context_vars.get("last_name", ""),
|
|
]
|
|
)
|
|
self._gathered_context["agent_name"] = "Alex"
|
|
self._gathered_context["customer_phone_number"] = self._call_context_vars.get(
|
|
"phone", ""
|
|
)
|
|
self._gathered_context["timezone"] = self._call_context_vars.get("province", "")
|
|
self._gathered_context["vendor_id"] = self._call_context_vars.get(
|
|
"vendor_lead_code", ""
|
|
)
|
|
|
|
decision_maker = self._gathered_context.get("primary_cardholder", False)
|
|
employment_status = self._gathered_context.get("employment_status", "N/A")
|
|
call_transfer_context = {
|
|
"first_name": self._call_context_vars.get("first_name", ""),
|
|
"full_name": self._gathered_context.get("full_name", ""),
|
|
"phone": self._call_context_vars.get("phone", ""),
|
|
"lead_id": self._call_context_vars.get("lead_id"),
|
|
"disposition": mapped_disposition,
|
|
"agent_name": self._gathered_context.get("agent_name", "Alex"),
|
|
"decision_maker": str(decision_maker),
|
|
"employment": employment_status.title() if employment_status else "N/A",
|
|
"debts": self._gathered_context.get("total_debt", "N/A"),
|
|
"number_of_credit_cards": self._gathered_context.get(
|
|
"number_of_credit_cards", "N/A"
|
|
),
|
|
"time": self._gathered_context.get("time"),
|
|
}
|
|
|
|
logger.debug(
|
|
f"gathered_context: {self._gathered_context} call_transfer_context: {call_transfer_context}"
|
|
)
|
|
|
|
# Initiate immediate transfer for Stasis connections when user is qualified
|
|
if (
|
|
reason == EndTaskReason.USER_QUALIFIED.value
|
|
and self._stasis_connection is not None
|
|
and not abort_immediately
|
|
):
|
|
try:
|
|
logger.info(
|
|
f"Initiating immediate Stasis transfer for channel {self._stasis_connection.channel_id}"
|
|
)
|
|
await self._stasis_connection.transfer(call_transfer_context)
|
|
logger.info("Immediate transfer initiated successfully")
|
|
except Exception as e:
|
|
logger.error(f"Failed to initiate immediate transfer: {e}")
|
|
# Continue with normal flow even if immediate transfer fails
|
|
|
|
if reason == EndTaskReason.CALL_DURATION_EXCEEDED.value:
|
|
await self.task.queue_frame(
|
|
TTSSpeakFrame(
|
|
"Sorry! It seems like our time has exceeded. Someone from our team will reach out to you soon. Thank you!"
|
|
)
|
|
)
|
|
|
|
# Store the original reason for later retrieval in event handler
|
|
self._call_disposition = mapped_disposition
|
|
|
|
logger.debug(
|
|
f"Finishing run with reason: {reason}, disposition: {mapped_disposition} queueing frame {frame_to_push}"
|
|
)
|
|
await self.task.queue_frame(frame_to_push)
|
|
|
|
async def _compose_system_message_functions_for_node(
|
|
self, node: "Node"
|
|
) -> tuple[list[dict], list[dict]]:
|
|
"""Generate the system messages and function schemas for the given node.
|
|
|
|
This performs the same formatting logic used when entering a node but
|
|
does **not** register the functions with the LLM; callers are
|
|
responsible for that.
|
|
"""
|
|
|
|
global_prompt = ""
|
|
if self.workflow.global_node_id and node.add_global_prompt:
|
|
global_node = self.workflow.nodes[self.workflow.global_node_id]
|
|
global_prompt = self._format_prompt(global_node.prompt)
|
|
|
|
functions: list[dict] = []
|
|
|
|
# Add built-in function schemas (calculator and timezone tools)
|
|
functions.extend(self.builtin_function_schemas)
|
|
|
|
# Add knowledge base retrieval tool if node has documents
|
|
if node.document_uuids:
|
|
kb_tool_def = get_knowledge_base_tool(node.document_uuids)
|
|
kb_schema = get_function_schema(
|
|
kb_tool_def["function"]["name"],
|
|
kb_tool_def["function"]["description"],
|
|
properties=kb_tool_def["function"]["parameters"].get("properties", {}),
|
|
required=kb_tool_def["function"]["parameters"].get("required", []),
|
|
)
|
|
functions.append(kb_schema)
|
|
|
|
# Add custom tools from node.tool_uuids
|
|
if node.tool_uuids and self._custom_tool_manager:
|
|
custom_tool_schemas = await self._custom_tool_manager.get_tool_schemas(
|
|
node.tool_uuids
|
|
)
|
|
functions.extend(custom_tool_schemas)
|
|
|
|
# Transition functions (schema only; registration handled elsewhere)
|
|
for outgoing_edge in node.out_edges:
|
|
function_schema = self._get_function_schema(
|
|
outgoing_edge.get_function_name(), outgoing_edge.condition
|
|
)
|
|
functions.append(function_schema)
|
|
|
|
formatted_node_prompt = self._format_prompt(node.prompt)
|
|
|
|
system_message = {
|
|
"role": "system",
|
|
"content": "\n\n".join(
|
|
p for p in (global_prompt, formatted_node_prompt) if p
|
|
),
|
|
}
|
|
|
|
return system_message, functions
|
|
|
|
def create_should_mute_callback(self) -> Callable[[STTMuteFilter], Awaitable[bool]]:
|
|
"""
|
|
This callback is called by STTMuteFilter to determine if the STT should be muted.
|
|
"""
|
|
return engine_callbacks.create_should_mute_callback(self)
|
|
|
|
def create_user_idle_handler(self):
|
|
"""
|
|
Returns a UserIdleHandler that manages user-idle timeouts with state.
|
|
The handler tracks retry count and handles escalating prompts.
|
|
"""
|
|
return engine_callbacks.create_user_idle_handler(self)
|
|
|
|
def create_max_duration_callback(self):
|
|
"""
|
|
This callback is called when the call duration exceeds the max duration.
|
|
We use this to send the EndTaskFrame.
|
|
"""
|
|
return engine_callbacks.create_max_duration_callback(self)
|
|
|
|
def create_generation_started_callback(self):
|
|
"""
|
|
This callback is called when a new generation starts.
|
|
This is used to reset the flags that control the flow of the engine.
|
|
"""
|
|
return engine_callbacks.create_generation_started_callback(self)
|
|
|
|
def create_aggregation_correction_callback(self) -> Callable[[str], str]:
|
|
"""Create a callback that corrects corrupted aggregation using reference text."""
|
|
return engine_callbacks.create_aggregation_correction_callback(self)
|
|
|
|
def set_context(self, context: LLMContext) -> None:
|
|
"""Set the LLM context.
|
|
|
|
This allows setting the context after the engine has been created,
|
|
which is useful when the context needs to be created after the engine.
|
|
"""
|
|
self.context = context
|
|
|
|
def set_task(self, task: PipelineTask) -> None:
|
|
"""Set the pipeline task.
|
|
|
|
This allows setting the task after the engine has been created,
|
|
which is useful when the task needs to be created after the engine.
|
|
"""
|
|
self.task = task
|
|
|
|
def set_stasis_connection(
|
|
self, connection: Optional["StasisRTPConnection"]
|
|
) -> None:
|
|
"""Set the Stasis RTP connection for immediate transfers.
|
|
|
|
This allows the engine to initiate transfers immediately when XFER
|
|
disposition is detected, without waiting for pipeline shutdown.
|
|
|
|
Args:
|
|
connection: The StasisRTPConnection instance, or None for non-Stasis transports
|
|
"""
|
|
self._stasis_connection = connection
|
|
if connection:
|
|
logger.debug(
|
|
f"Stasis connection set for immediate transfers: {connection.channel_id}"
|
|
)
|
|
|
|
async def handle_llm_text_frame(self, text: str):
|
|
"""Accumulate LLM text frames to build reference text."""
|
|
self._current_llm_generation_reference_text += text
|
|
|
|
def handle_client_disconnected(self):
|
|
"""Handle client disconnected event."""
|
|
self._client_disconnected = True
|
|
|
|
def is_call_disposed(self):
|
|
"""Check whether a call has been disposed by the engine"""
|
|
return self._call_disposed
|
|
|
|
async def get_call_disposition(self) -> Optional[str]:
|
|
"""Get the disconnect reason set by the engine."""
|
|
if self._call_disposition:
|
|
# We would have a _call_disposition variable set if we have initiated
|
|
# a disconnect from the bot, i.e we have called send_end_task_frame.
|
|
return self._call_disposition
|
|
|
|
if self._client_disconnected:
|
|
return EndTaskReason.USER_HANGUP.value
|
|
else:
|
|
return EndTaskReason.UNKNOWN.value
|
|
|
|
async def get_gathered_context(self) -> dict:
|
|
"""Get the gathered context including extracted variables."""
|
|
return self._gathered_context.copy()
|
|
|
|
async def cleanup(self):
|
|
"""Clean up engine resources on disconnect."""
|
|
# Cancel any pending timeout tasks
|
|
if (
|
|
self._user_response_timeout_task
|
|
and not self._user_response_timeout_task.done()
|
|
):
|
|
self._user_response_timeout_task.cancel()
|