mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-07 07:55:16 +02:00
chore: refactor and add tests (#130)
* chore: add tests for end call * Update pipecat module * fix: allow interruptions from deepgram flux * Add VadUserTurnStrategy * chore: add test for voicemail detection
This commit is contained in:
parent
2aedb839ff
commit
033fde8946
15 changed files with 2106 additions and 542 deletions
|
|
@ -9,7 +9,6 @@ from pipecat.frames.frames import (
|
|||
CancelFrame,
|
||||
EndFrame,
|
||||
FunctionCallResultProperties,
|
||||
TTSSpeakFrame,
|
||||
)
|
||||
from pipecat.pipeline.task import PipelineTask
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
|
|
@ -19,6 +18,7 @@ from pipecat.utils.enums import EndTaskReason
|
|||
|
||||
if TYPE_CHECKING:
|
||||
from api.services.telephony.stasis_rtp_connection import StasisRTPConnection
|
||||
from pipecat.frames.frames import Frame
|
||||
from pipecat.services.anthropic.llm import AnthropicLLMService
|
||||
from pipecat.services.google.llm import GoogleLLMService
|
||||
from pipecat.services.openai.llm import OpenAILLMService
|
||||
|
|
@ -49,7 +49,6 @@ from api.services.workflow.tools.timezone import (
|
|||
get_current_time,
|
||||
get_time_tools,
|
||||
)
|
||||
from pipecat.processors.filters.stt_mute_filter import STTMuteFilter
|
||||
from pipecat.utils.tracing.context_registry import get_current_turn_context
|
||||
|
||||
|
||||
|
|
@ -79,12 +78,10 @@ class PipecatEngine:
|
|||
self._workflow_run_id = workflow_run_id
|
||||
self._node_transition_callback = node_transition_callback
|
||||
self._initialized = False
|
||||
self._client_disconnected = False
|
||||
self._call_disposed = False
|
||||
self._current_node: Optional[Node] = None
|
||||
self._gathered_context: dict = {}
|
||||
self._user_response_timeout_task: Optional[asyncio.Task] = None
|
||||
self._call_disposition: Optional[str] = None
|
||||
|
||||
# Stasis connection for immediate transfers
|
||||
self._stasis_connection: Optional["StasisRTPConnection"] = None
|
||||
|
|
@ -99,6 +96,9 @@ class PipecatEngine:
|
|||
# Track current LLM reference text for TTS aggregation correction
|
||||
self._current_llm_generation_reference_text: str = ""
|
||||
|
||||
# Controls whether user input should be muted
|
||||
self._mute_pipeline: bool = False
|
||||
|
||||
# Custom tool manager (initialized in initialize())
|
||||
self._custom_tool_manager: Optional[CustomToolManager] = None
|
||||
|
||||
|
|
@ -215,9 +215,14 @@ class PipecatEngine:
|
|||
This way, when we do set_node from within this function, and go for LLM completion with updated
|
||||
system prompts, the context is updated with function call result.
|
||||
"""
|
||||
# FIXME: There is a potential race condition, when we generate LLM Completion from UserContextAggregator
|
||||
# with FunctionCallResultFrame and we call end_call_with_reason where we queue EndFrame or CancelFrame.
|
||||
# If EndFrame reaches the LLM Processor before the ContextFrame, we might never run generation which
|
||||
# might be intended
|
||||
|
||||
# Queue EndFrame if we just transitioned to EndNode
|
||||
if self._current_node.is_end:
|
||||
await self.send_end_task_frame(
|
||||
await self.end_call_with_reason(
|
||||
EndTaskReason.USER_QUALIFIED.value
|
||||
)
|
||||
|
||||
|
|
@ -356,44 +361,52 @@ class PipecatEngine:
|
|||
self.llm.register_function("retrieve_from_knowledge_base", retrieve_kb_func)
|
||||
|
||||
async def _perform_variable_extraction_if_needed(
|
||||
self, previous_node: Optional[Node]
|
||||
self, node: Optional[Node], run_in_background: bool = True
|
||||
) -> None:
|
||||
"""Perform variable extraction if the previous node had extraction enabled."""
|
||||
if (
|
||||
previous_node
|
||||
and previous_node.extraction_enabled
|
||||
and previous_node.extraction_variables
|
||||
):
|
||||
"""Perform variable extraction if the node has extraction enabled.
|
||||
|
||||
Args:
|
||||
node: The node to extract variables from.
|
||||
run_in_background: If True, runs extraction as a fire-and-forget task.
|
||||
If False, awaits the extraction synchronously.
|
||||
"""
|
||||
if not (node and node.extraction_enabled and node.extraction_variables):
|
||||
return
|
||||
|
||||
# Capture the current turn context for otel tracing
|
||||
# before creating the background task
|
||||
parent_context = get_current_turn_context()
|
||||
|
||||
extraction_prompt = self._format_prompt(node.extraction_prompt)
|
||||
extraction_variables = node.extraction_variables
|
||||
|
||||
async def _do_extraction():
|
||||
try:
|
||||
extracted_data = (
|
||||
await self._variable_extraction_manager._perform_extraction(
|
||||
extraction_variables, parent_context, extraction_prompt
|
||||
)
|
||||
)
|
||||
self._gathered_context.update(extracted_data)
|
||||
logger.debug(
|
||||
f"Variable extraction completed. Extracted: {extracted_data}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error during variable extraction: {str(e)}")
|
||||
|
||||
if run_in_background:
|
||||
logger.debug(
|
||||
f"Scheduling background variable extraction for node: {previous_node.name}"
|
||||
f"Scheduling background variable extraction for node: {node.name}"
|
||||
)
|
||||
asyncio.create_task(_do_extraction())
|
||||
else:
|
||||
logger.debug(
|
||||
f"Performing synchronous variable extraction for node: {node.name}"
|
||||
)
|
||||
await _do_extraction()
|
||||
|
||||
# Capture the current turn context before creating the background task
|
||||
parent_context = get_current_turn_context()
|
||||
extraction_prompt = self._format_prompt(previous_node.extraction_prompt)
|
||||
extraction_variables = previous_node.extraction_variables
|
||||
|
||||
async def _background_extraction():
|
||||
try:
|
||||
extracted_data = (
|
||||
await self._variable_extraction_manager._perform_extraction(
|
||||
extraction_variables, parent_context, extraction_prompt
|
||||
)
|
||||
)
|
||||
self._gathered_context.update(extracted_data)
|
||||
logger.debug(
|
||||
f"Background variable extraction completed. Extracted: {extracted_data}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error during background variable extraction: {str(e)}"
|
||||
)
|
||||
|
||||
# Fire and forget - extraction happens in background without blocking
|
||||
asyncio.create_task(_background_extraction())
|
||||
|
||||
async def _setup_llm_context_and_start_generation(self, node: Node) -> None:
|
||||
"""Common method to set up LLM context and queue context frame for non-static nodes."""
|
||||
async def _setup_llm_context(self, node: Node) -> None:
|
||||
"""Common method to set up LLM context"""
|
||||
# Set node name for tracing
|
||||
try:
|
||||
self.context.set_node_name(node.name)
|
||||
|
|
@ -470,61 +483,54 @@ class PipecatEngine:
|
|||
if node.is_static:
|
||||
raise ValueError("Static nodes are not supported!")
|
||||
else:
|
||||
# Start generation for non-static start node
|
||||
await self._setup_llm_context_and_start_generation(node)
|
||||
# Setup LLM Context with Prompts and Functions
|
||||
await self._setup_llm_context(node)
|
||||
|
||||
async def _handle_end_node(self, node: Node) -> None:
|
||||
"""Handle end node execution."""
|
||||
if node.is_static:
|
||||
raise ValueError("Static nodes are not supported!")
|
||||
else:
|
||||
await self._setup_llm_context_and_start_generation(node)
|
||||
|
||||
# If this end node has extraction enabled, perform extraction immediately
|
||||
if node.extraction_enabled and node.extraction_variables:
|
||||
await self._perform_variable_extraction_if_needed(node)
|
||||
# Setup LLM Context with Prompts and Functions
|
||||
await self._setup_llm_context(node)
|
||||
|
||||
async def _handle_agent_node(self, node: Node) -> None:
|
||||
"""Handle agent node execution."""
|
||||
if node.is_static:
|
||||
raise ValueError("Static nodes are not supported!")
|
||||
else:
|
||||
# Set context and functions for non-static agent node
|
||||
await self._setup_llm_context_and_start_generation(node)
|
||||
# Setup LLM Context with Prompts and Functions
|
||||
await self._setup_llm_context(node)
|
||||
|
||||
async def send_end_task_frame(
|
||||
async def end_call_with_reason(
|
||||
self,
|
||||
reason: str,
|
||||
abort_immediately: bool = False,
|
||||
):
|
||||
"""
|
||||
Centralized method to send EndTaskFrame with metadata including
|
||||
call_transfer_context and call_context_vars
|
||||
Centralized method to end the call with disposition mapping
|
||||
"""
|
||||
if self._call_disposed or self._client_disconnected:
|
||||
# Call is already disposed and client disconnected
|
||||
logger.debug(
|
||||
f"Not sending EndFrame since call is already disposed: Call Disposed: {self._call_disposed} Client Disconnected: {self._client_disconnected}"
|
||||
)
|
||||
if self._call_disposed:
|
||||
logger.debug(f"Call already Disposed: {self._call_disposed}")
|
||||
return
|
||||
|
||||
self._call_disposed = True
|
||||
|
||||
frame_to_push = CancelFrame() if abort_immediately else EndFrame()
|
||||
# Mute the pipeline
|
||||
self._mute_pipeline = True
|
||||
|
||||
# Customer disposition code using their mapping
|
||||
mapped_disposition = ""
|
||||
# Perform final variable extraction synchronously before ending
|
||||
await self._perform_variable_extraction_if_needed(
|
||||
self._current_node, run_in_background=False
|
||||
)
|
||||
|
||||
frame_to_push = CancelFrame() if abort_immediately else EndFrame()
|
||||
|
||||
# Apply disposition mapping - first try call_disposition if it is,
|
||||
# extracted from the call conversation then fall back to reason
|
||||
call_disposition = self._gathered_context.get("call_disposition", "")
|
||||
organization_id = await self._get_organization_id()
|
||||
|
||||
# If client is disconnected before we get a chance to disconnect from
|
||||
# the bot, lets consider that as final disposition
|
||||
if self._client_disconnected:
|
||||
call_disposition = EndTaskReason.USER_HANGUP.value
|
||||
|
||||
if call_disposition:
|
||||
# If call_disposition exists, map it
|
||||
mapped_disposition = await apply_disposition_mapping(
|
||||
|
|
@ -532,90 +538,16 @@ class PipecatEngine:
|
|||
)
|
||||
# Store the original and mapped values
|
||||
self._gathered_context["extracted_call_disposition"] = call_disposition
|
||||
self._gathered_context["call_disposition"] = mapped_disposition
|
||||
self._gathered_context["call_disposition"] = call_disposition
|
||||
self._gathered_context["mapped_call_disposition"] = mapped_disposition
|
||||
else:
|
||||
# Otherwise, map the disconnect reason
|
||||
mapped_disposition = await apply_disposition_mapping(
|
||||
reason, organization_id
|
||||
)
|
||||
# Store the mapped disconnect reason
|
||||
self._gathered_context["call_disposition"] = mapped_disposition
|
||||
|
||||
# TODO: Generalise this
|
||||
self._gathered_context["address"] = ", ".join(
|
||||
[
|
||||
self._call_context_vars.get("address1", ""),
|
||||
self._call_context_vars.get("address2", ""),
|
||||
self._call_context_vars.get("address3", ""),
|
||||
self._call_context_vars.get("city", ""),
|
||||
self._call_context_vars.get("state", ""),
|
||||
self._call_context_vars.get("province", ""),
|
||||
self._call_context_vars.get("postal_code", ""),
|
||||
]
|
||||
)
|
||||
self._gathered_context["full_name"] = " ".join(
|
||||
[
|
||||
self._call_context_vars.get("first_name", ""),
|
||||
self._call_context_vars.get("middle_initial", ""),
|
||||
self._call_context_vars.get("last_name", ""),
|
||||
]
|
||||
)
|
||||
self._gathered_context["agent_name"] = "Alex"
|
||||
self._gathered_context["customer_phone_number"] = self._call_context_vars.get(
|
||||
"phone", ""
|
||||
)
|
||||
self._gathered_context["timezone"] = self._call_context_vars.get("province", "")
|
||||
self._gathered_context["vendor_id"] = self._call_context_vars.get(
|
||||
"vendor_lead_code", ""
|
||||
)
|
||||
|
||||
decision_maker = self._gathered_context.get("primary_cardholder", False)
|
||||
employment_status = self._gathered_context.get("employment_status", "N/A")
|
||||
call_transfer_context = {
|
||||
"first_name": self._call_context_vars.get("first_name", ""),
|
||||
"full_name": self._gathered_context.get("full_name", ""),
|
||||
"phone": self._call_context_vars.get("phone", ""),
|
||||
"lead_id": self._call_context_vars.get("lead_id"),
|
||||
"disposition": mapped_disposition,
|
||||
"agent_name": self._gathered_context.get("agent_name", "Alex"),
|
||||
"decision_maker": str(decision_maker),
|
||||
"employment": employment_status.title() if employment_status else "N/A",
|
||||
"debts": self._gathered_context.get("total_debt", "N/A"),
|
||||
"number_of_credit_cards": self._gathered_context.get(
|
||||
"number_of_credit_cards", "N/A"
|
||||
),
|
||||
"time": self._gathered_context.get("time"),
|
||||
}
|
||||
|
||||
logger.debug(
|
||||
f"gathered_context: {self._gathered_context} call_transfer_context: {call_transfer_context}"
|
||||
)
|
||||
|
||||
# Initiate immediate transfer for Stasis connections when user is qualified
|
||||
if (
|
||||
reason == EndTaskReason.USER_QUALIFIED.value
|
||||
and self._stasis_connection is not None
|
||||
and not abort_immediately
|
||||
):
|
||||
try:
|
||||
logger.info(
|
||||
f"Initiating immediate Stasis transfer for channel {self._stasis_connection.channel_id}"
|
||||
)
|
||||
await self._stasis_connection.transfer(call_transfer_context)
|
||||
logger.info("Immediate transfer initiated successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initiate immediate transfer: {e}")
|
||||
# Continue with normal flow even if immediate transfer fails
|
||||
|
||||
if reason == EndTaskReason.CALL_DURATION_EXCEEDED.value:
|
||||
await self.task.queue_frame(
|
||||
TTSSpeakFrame(
|
||||
"Sorry! It seems like our time has exceeded. Someone from our team will reach out to you soon. Thank you!"
|
||||
)
|
||||
)
|
||||
|
||||
# Store the original reason for later retrieval in event handler
|
||||
self._call_disposition = mapped_disposition
|
||||
self._gathered_context["call_disposition"] = reason
|
||||
self._gathered_context["mapped_call_disposition"] = mapped_disposition
|
||||
|
||||
logger.debug(
|
||||
f"Finishing run with reason: {reason}, disposition: {mapped_disposition} queueing frame {frame_to_push}"
|
||||
|
|
@ -678,11 +610,14 @@ class PipecatEngine:
|
|||
|
||||
return system_message, functions
|
||||
|
||||
def create_should_mute_callback(self) -> Callable[[STTMuteFilter], Awaitable[bool]]:
|
||||
async def should_mute_user(self, frame: "Frame") -> bool:
|
||||
"""
|
||||
This callback is called by STTMuteFilter to determine if the STT should be muted.
|
||||
Callback for CallbackUserMuteStrategy to determine if the user should be muted.
|
||||
|
||||
Returns:
|
||||
True if the user should be muted, False otherwise.
|
||||
"""
|
||||
return engine_callbacks.create_should_mute_callback(self)
|
||||
return self._mute_pipeline
|
||||
|
||||
def create_user_idle_handler(self):
|
||||
"""
|
||||
|
|
@ -746,26 +681,10 @@ class PipecatEngine:
|
|||
"""Accumulate LLM text frames to build reference text."""
|
||||
self._current_llm_generation_reference_text += text
|
||||
|
||||
def handle_client_disconnected(self):
|
||||
"""Handle client disconnected event."""
|
||||
self._client_disconnected = True
|
||||
|
||||
def is_call_disposed(self):
|
||||
"""Check whether a call has been disposed by the engine"""
|
||||
return self._call_disposed
|
||||
|
||||
async def get_call_disposition(self) -> Optional[str]:
|
||||
"""Get the disconnect reason set by the engine."""
|
||||
if self._call_disposition:
|
||||
# We would have a _call_disposition variable set if we have initiated
|
||||
# a disconnect from the bot, i.e we have called send_end_task_frame.
|
||||
return self._call_disposition
|
||||
|
||||
if self._client_disconnected:
|
||||
return EndTaskReason.USER_HANGUP.value
|
||||
else:
|
||||
return EndTaskReason.UNKNOWN.value
|
||||
|
||||
async def get_gathered_context(self) -> dict:
|
||||
"""Get the gathered context including extracted variables."""
|
||||
return self._gathered_context.copy()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue