mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-22 08:38:13 +02:00
Initial Commit 🚀 🚀
This commit is contained in:
commit
4f2a629340
444 changed files with 76863 additions and 0 deletions
0
api/services/workflow/__init__.py
Normal file
0
api/services/workflow/__init__.py
Normal file
77
api/services/workflow/disposition_mapper.py
Normal file
77
api/services/workflow/disposition_mapper.py
Normal file
|
|
@ -0,0 +1,77 @@
|
|||
"""Utility module for applying disposition code mapping."""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from api.db import db_client
|
||||
from api.enums import OrganizationConfigurationKey
|
||||
|
||||
|
||||
async def apply_disposition_mapping(value: str, organization_id: Optional[int]) -> str:
|
||||
"""Apply disposition code mapping if configured.
|
||||
|
||||
Args:
|
||||
value: The original disposition value to map
|
||||
organization_id: The organization ID
|
||||
|
||||
Returns:
|
||||
The mapped value if found in configuration, otherwise the original value
|
||||
"""
|
||||
if not organization_id or not value:
|
||||
return value
|
||||
|
||||
try:
|
||||
disposition_mapping = await db_client.get_configuration_value(
|
||||
organization_id,
|
||||
OrganizationConfigurationKey.DISPOSITION_CODE_MAPPING.value,
|
||||
default={},
|
||||
)
|
||||
|
||||
if not disposition_mapping:
|
||||
return value
|
||||
|
||||
# Return mapped value if exists, otherwise original
|
||||
# DISPOSITION_CODE_MAPPING looks like {"user_idle_max_duration_exceeded": "DAIR"} etc.
|
||||
mapped_value = disposition_mapping.get(value, value)
|
||||
|
||||
if mapped_value != value:
|
||||
logger.debug(
|
||||
f"Mapped disposition code from '{value}' to '{mapped_value}' "
|
||||
f"for organization {organization_id}"
|
||||
)
|
||||
|
||||
return mapped_value
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error applying disposition mapping: {e}")
|
||||
return value
|
||||
|
||||
|
||||
async def get_organization_id_from_workflow_run(
|
||||
workflow_run_id: Optional[int],
|
||||
) -> Optional[int]:
|
||||
"""Get organization_id from workflow_run_id through the model relationships.
|
||||
|
||||
Args:
|
||||
workflow_run_id: The workflow run ID
|
||||
|
||||
Returns:
|
||||
The organization ID if found, otherwise None
|
||||
"""
|
||||
if not workflow_run_id:
|
||||
return None
|
||||
|
||||
try:
|
||||
workflow_run = await db_client.get_workflow_run_by_id(workflow_run_id)
|
||||
if not workflow_run or not workflow_run.workflow:
|
||||
return None
|
||||
|
||||
workflow = workflow_run.workflow
|
||||
if not workflow.user:
|
||||
return None
|
||||
|
||||
return workflow.user.selected_organization_id
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting organization_id from workflow_run: {e}")
|
||||
return None
|
||||
96
api/services/workflow/dto.py
Normal file
96
api/services/workflow/dto.py
Normal file
|
|
@ -0,0 +1,96 @@
|
|||
from enum import Enum
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field, ValidationError, model_validator
|
||||
|
||||
|
||||
class NodeType(str, Enum):
|
||||
startNode = "startCall"
|
||||
endNode = "endCall"
|
||||
agentNode = "agentNode"
|
||||
globalNode = "globalNode"
|
||||
|
||||
|
||||
class Position(BaseModel):
|
||||
x: float
|
||||
y: float
|
||||
|
||||
|
||||
class VariableType(str, Enum):
|
||||
string = "string"
|
||||
number = "number"
|
||||
boolean = "boolean"
|
||||
|
||||
|
||||
class ExtractionVariableDTO(BaseModel):
|
||||
name: str = Field(..., min_length=1)
|
||||
type: VariableType
|
||||
prompt: Optional[str] = None
|
||||
|
||||
|
||||
class NodeDataDTO(BaseModel):
|
||||
name: str = Field(..., min_length=1)
|
||||
prompt: str = Field(..., min_length=1)
|
||||
is_static: bool = False
|
||||
is_start: bool = False
|
||||
is_end: bool = False
|
||||
allow_interrupt: bool = False
|
||||
extraction_enabled: bool = False
|
||||
extraction_prompt: Optional[str] = None
|
||||
extraction_variables: Optional[list[ExtractionVariableDTO]] = None
|
||||
add_global_prompt: bool = True
|
||||
wait_for_user_response: bool = False
|
||||
wait_for_user_response_timeout: Optional[float] = None
|
||||
detect_voicemail: bool = True
|
||||
delayed_start: bool = False
|
||||
delayed_start_duration: Optional[float] = None
|
||||
|
||||
|
||||
class RFNodeDTO(BaseModel):
|
||||
id: str
|
||||
type: NodeType = Field(default=NodeType.agentNode)
|
||||
position: Position
|
||||
data: NodeDataDTO
|
||||
|
||||
|
||||
class EdgeDataDTO(BaseModel):
|
||||
label: str = Field(..., min_length=1)
|
||||
condition: str = Field(..., min_length=1)
|
||||
|
||||
|
||||
class RFEdgeDTO(BaseModel):
|
||||
id: str
|
||||
source: str
|
||||
target: str
|
||||
data: EdgeDataDTO
|
||||
|
||||
|
||||
class ReactFlowDTO(BaseModel):
|
||||
nodes: List[RFNodeDTO]
|
||||
edges: List[RFEdgeDTO]
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _referential_integrity(self):
|
||||
node_ids = {n.id for n in self.nodes}
|
||||
line_errors: list[dict[str, str]] = []
|
||||
|
||||
for idx, edge in enumerate(self.edges):
|
||||
for endpoint in (edge.source, edge.target):
|
||||
if endpoint not in node_ids:
|
||||
line_errors.append(
|
||||
dict(
|
||||
loc=("edges", idx),
|
||||
type="missing_node",
|
||||
msg="Edge references missing node",
|
||||
input=edge.model_dump(mode="python"),
|
||||
ctx={"edge_id": edge.id, "endpoint": endpoint},
|
||||
)
|
||||
)
|
||||
|
||||
if line_errors:
|
||||
raise ValidationError.from_exception_data(
|
||||
title="ReactFlowDTO validation failed",
|
||||
line_errors=line_errors,
|
||||
)
|
||||
|
||||
return self
|
||||
16
api/services/workflow/errors.py
Normal file
16
api/services/workflow/errors.py
Normal file
|
|
@ -0,0 +1,16 @@
|
|||
# api/services/workflow/errors.py
|
||||
from enum import Enum
|
||||
from typing import TypedDict
|
||||
|
||||
|
||||
class ItemKind(str, Enum):
|
||||
node = "node"
|
||||
edge = "edge"
|
||||
workflow = "workflow"
|
||||
|
||||
|
||||
class WorkflowError(TypedDict):
|
||||
kind: ItemKind # "node" | "edge"
|
||||
id: str | None # nodeId or edgeId
|
||||
field: str | None # “data.prompt”, “position.x”, … (optional)
|
||||
message: str # human-readable text
|
||||
939
api/services/workflow/pipecat_engine.py
Normal file
939
api/services/workflow/pipecat_engine.py
Normal file
|
|
@ -0,0 +1,939 @@
|
|||
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Optional, Union
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
CancelFrame,
|
||||
EndFrame,
|
||||
FunctionCallResultProperties,
|
||||
LLMFullResponseEndFrame,
|
||||
LLMFullResponseStartFrame,
|
||||
TTSSpeakFrame,
|
||||
)
|
||||
from pipecat.pipeline.task import PipelineTask
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContextFrame
|
||||
from pipecat.services.llm_service import FunctionCallParams
|
||||
from pipecat.services.openai.llm import OpenAILLMContext
|
||||
from pipecat.transports.base_transport import BaseTransport
|
||||
from pipecat.utils.enums import EndTaskReason
|
||||
|
||||
from api.constants import VOICEMAIL_RECORDING_DURATION
|
||||
from api.services.gender.gender_service import GenderService
|
||||
from api.services.workflow.disposition_mapper import (
|
||||
apply_disposition_mapping,
|
||||
get_organization_id_from_workflow_run,
|
||||
)
|
||||
from api.services.workflow.pipecat_engine_voicemail_detector import (
|
||||
VoicemailDetector,
|
||||
)
|
||||
from api.services.workflow.workflow import Node, WorkflowGraph
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pipecat.processors.audio.audio_buffer_processor import AudioBuffer
|
||||
from pipecat.services.anthropic.llm import AnthropicLLMService
|
||||
from pipecat.services.google.llm import GoogleLLMService
|
||||
from pipecat.services.openai.llm import OpenAILLMService
|
||||
|
||||
from api.services.telephony.stasis_rtp_connection import StasisRTPConnection
|
||||
|
||||
LLMService = Union[OpenAILLMService, AnthropicLLMService, GoogleLLMService]
|
||||
|
||||
import asyncio
|
||||
|
||||
from loguru import logger
|
||||
from pipecat.processors.filters.stt_mute_filter import STTMuteFilter
|
||||
from pipecat.utils.tracing.context_registry import get_current_turn_context
|
||||
|
||||
from api.services.workflow import pipecat_engine_callbacks as engine_callbacks
|
||||
from api.services.workflow.pipecat_engine_utils import (
|
||||
get_function_schema,
|
||||
render_template,
|
||||
update_llm_context,
|
||||
)
|
||||
from api.services.workflow.pipecat_engine_variable_extractor import (
|
||||
VariableExtractionManager,
|
||||
)
|
||||
from api.services.workflow.tools.calculator import get_calculator_tools, safe_calculator
|
||||
from api.services.workflow.tools.timezone import (
|
||||
convert_time,
|
||||
get_current_time,
|
||||
get_time_tools,
|
||||
)
|
||||
|
||||
|
||||
class PipecatEngine:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
task: Optional[PipelineTask] = None,
|
||||
llm: Optional["LLMService"] = None,
|
||||
context: Optional[OpenAILLMContext] = None,
|
||||
tts: Optional[Any] = None,
|
||||
transport: Optional[BaseTransport] = None,
|
||||
workflow: WorkflowGraph,
|
||||
call_context_vars: dict,
|
||||
audio_buffer: Optional["AudioBuffer"] = None,
|
||||
workflow_run_id: Optional[int] = None,
|
||||
):
|
||||
self.task = task
|
||||
self.llm = llm
|
||||
self.context = context
|
||||
self.tts = tts
|
||||
self.transport = transport
|
||||
self.workflow = workflow
|
||||
self._call_context_vars = call_context_vars
|
||||
self._audio_buffer = audio_buffer
|
||||
self._workflow_run_id = workflow_run_id
|
||||
self._initialized = False
|
||||
self._pending_function_calls = 0
|
||||
self._current_node: Optional[Node] = None
|
||||
self._gathered_context: dict = {}
|
||||
self._user_response_timeout_task: Optional[asyncio.Task] = None
|
||||
self._call_disposition: Optional[str] = None
|
||||
|
||||
# Stasis connection for immediate transfers
|
||||
self._stasis_connection: Optional["StasisRTPConnection"] = None
|
||||
|
||||
# Will be set later in initialize() when we have
|
||||
# access to _context
|
||||
self._variable_extraction_manager = None
|
||||
|
||||
self._gender_service = GenderService(confidence_threshold=0.5)
|
||||
|
||||
# Voicemail detection state
|
||||
self._detect_voicemail = False
|
||||
self._voicemail_detector = None
|
||||
self._voicemail_detection_task: Optional[asyncio.Task] = None
|
||||
|
||||
# This transition is generated by the llm as part of tool call. This can
|
||||
# also be accompanied with some content which can be played using TTS. If the
|
||||
# bot is interrupted, we would cancel this transition (we do cancel this currently when
|
||||
# the next generation starts in handle_generation_started callback handler.)
|
||||
self._pending_generated_transition_after_context_push: Optional[
|
||||
Callable[[], Awaitable[None]]
|
||||
] = None
|
||||
|
||||
# This is the transtion which is typically programmatic transition, and not goes as
|
||||
# tool call to LLM. This is not interrupted by the user and is done on context push
|
||||
self._pending_control_transition_after_context_push: Optional[
|
||||
Callable[[], Awaitable[None]]
|
||||
] = None
|
||||
|
||||
# Flag to determine if the current llm generation has a text completion
|
||||
self._defer_context_push: bool = False
|
||||
|
||||
# Lazy loaded built-in function schemas
|
||||
self._builtin_function_schemas: Optional[list[dict]] = None
|
||||
|
||||
# Flag to control whether to queue context frame
|
||||
self._queue_context_frame: bool = True
|
||||
|
||||
# Track current LLM reference text for TTS aggregation correction
|
||||
self._current_llm_reference_text: str = ""
|
||||
|
||||
@property
|
||||
def builtin_function_schemas(self) -> list[dict]:
|
||||
"""Get built-in function schemas (calculator and timezone tools)."""
|
||||
if self._builtin_function_schemas is None:
|
||||
self._builtin_function_schemas = []
|
||||
|
||||
# Transform calculator tools to get_function_schema format
|
||||
for tool in get_calculator_tools():
|
||||
func = tool["function"]
|
||||
schema = get_function_schema(
|
||||
func["name"],
|
||||
func["description"],
|
||||
properties=func["parameters"]["properties"],
|
||||
required=func["parameters"]["required"],
|
||||
)
|
||||
self._builtin_function_schemas.append(schema)
|
||||
|
||||
# Transform timezone tools to get_function_schema format
|
||||
for tool in get_time_tools():
|
||||
func = tool["function"]
|
||||
schema = get_function_schema(
|
||||
func["name"],
|
||||
func["description"],
|
||||
properties=func["parameters"]["properties"],
|
||||
required=func["parameters"]["required"],
|
||||
)
|
||||
self._builtin_function_schemas.append(schema)
|
||||
|
||||
return self._builtin_function_schemas
|
||||
|
||||
async def initialize(self):
|
||||
# TODO: May be set_node in a separate task so that we return from initialize immediately
|
||||
if self._initialized:
|
||||
logger.warning(f"{self.__class__.__name__} already initialized")
|
||||
return
|
||||
try:
|
||||
self._initialized = True
|
||||
|
||||
# Helper that encapsulates variable extraction logic
|
||||
self._variable_extraction_manager = VariableExtractionManager(self)
|
||||
|
||||
# Add current time in EST (America/New_York) to gathered context
|
||||
try:
|
||||
est_time_result = get_current_time("America/New_York")
|
||||
# The get_current_time utility returns a dict with 'datetime' field
|
||||
# Store the ISO formatted datetime string under the key 'time'
|
||||
self._gathered_context["time"] = est_time_result.get("datetime")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to fetch current EST time: {e}")
|
||||
|
||||
# Register built-in functions with the LLM
|
||||
await self._register_builtin_functions()
|
||||
|
||||
# Set gender in initial context predicted from first name
|
||||
if "first_name" in self._call_context_vars:
|
||||
salutation = await self._gender_service.get_salutation(
|
||||
self._call_context_vars["first_name"]
|
||||
)
|
||||
self._call_context_vars["salutation"] = salutation
|
||||
|
||||
await self.set_node(self.workflow.start_node_id)
|
||||
logger.debug(f"{self.__class__.__name__} initialized")
|
||||
except Exception as e:
|
||||
logger.error(f"Error initializing {self.__class__.__name__}: {e}")
|
||||
raise
|
||||
|
||||
def _get_function_schema(self, function_name: str, description: str):
|
||||
"""Thin wrapper around utils.get_function_schema for backwards compatibility."""
|
||||
|
||||
return get_function_schema(function_name, description)
|
||||
|
||||
async def _update_llm_context(self, system_message: dict, functions: list[dict]):
|
||||
"""Delegate context update to the shared workflow.utils implementation."""
|
||||
|
||||
update_llm_context(self.context, system_message, functions)
|
||||
|
||||
def _format_prompt(self, prompt: str) -> str:
|
||||
"""Delegate prompt formatting to the shared workflow.utils implementation."""
|
||||
|
||||
return render_template(prompt, self._call_context_vars)
|
||||
|
||||
async def _create_transition_func(self, name: str, transition_to_node: str):
|
||||
async def transition_func(function_call_params: FunctionCallParams) -> None:
|
||||
"""Inner function that handles the actual tool invocation."""
|
||||
try:
|
||||
# Track pending function call
|
||||
self._pending_function_calls += 1
|
||||
logger.debug(
|
||||
f"Function call pending: {function_call_params.function_name} (total: {self._pending_function_calls})"
|
||||
)
|
||||
|
||||
# For edge functions, prevent LLM completion until transition (run_llm=False)
|
||||
# For node functions, allow immediate completion (run_llm=True)
|
||||
async def on_context_updated() -> None:
|
||||
"""
|
||||
Framework will run this function after the function call result has been updated in the context.
|
||||
This way, when we do set_node from within this function, and go for LLM completion with updated
|
||||
system prompts, the context is updated with function call result.
|
||||
"""
|
||||
self._pending_function_calls -= 1
|
||||
# Perform variable extraction before transitioning to new node
|
||||
await self._perform_variable_extraction_if_needed(
|
||||
self._current_node
|
||||
)
|
||||
await self.set_node(transition_to_node)
|
||||
|
||||
result = {"status": "done"}
|
||||
|
||||
properties = FunctionCallResultProperties(
|
||||
run_llm=False,
|
||||
on_context_updated=on_context_updated,
|
||||
)
|
||||
|
||||
async def _invoke_result_callback():
|
||||
"""
|
||||
Functions are executed immediately when they come from LLM as part of text completion.
|
||||
But, if the LLM completion also has some text, we would want to not call the function if the user interrupts the speech.
|
||||
We would also not want the function to be added to context, so that the LLM can call the function again. Hence, we
|
||||
defer the function invocation until we receive on_context_updated callback, i.e the bot has finished speaking
|
||||
the text that was generated.
|
||||
"""
|
||||
await function_call_params.result_callback(
|
||||
result, properties=properties
|
||||
)
|
||||
|
||||
if self._defer_context_push:
|
||||
"""
|
||||
We set the flag to _defer_context_push when we receive text in the current generation from LLM.
|
||||
This is set in the handle_llm_generated_text callback handler.
|
||||
"""
|
||||
logger.debug(
|
||||
"Deferring transition function result until context push"
|
||||
)
|
||||
# Only one deferred transition should exist at any time.
|
||||
# Overwrite if one is somehow already set (unexpected).
|
||||
self._pending_generated_transition_after_context_push = (
|
||||
_invoke_result_callback
|
||||
)
|
||||
else:
|
||||
"""
|
||||
If there was no text in the current generation, and we only had function call,
|
||||
lets invoke the result callback, so that framework can call on_context_updated and
|
||||
we can do switch node.
|
||||
"""
|
||||
await _invoke_result_callback()
|
||||
except Exception as e:
|
||||
logger.error(f"Error in transition function {name}: {str(e)}")
|
||||
self._pending_function_calls = 0
|
||||
error_result = {"status": "error", "error": str(e)}
|
||||
await function_call_params.result_callback(error_result)
|
||||
|
||||
return transition_func
|
||||
|
||||
async def _register_transition_function_with_llm(
|
||||
self, name: str, transition_to_node: str
|
||||
):
|
||||
logger.debug(
|
||||
f"Registering function {name} to transition to node {transition_to_node} with LLM"
|
||||
)
|
||||
|
||||
# Create transition function
|
||||
transition_func = await self._create_transition_func(name, transition_to_node)
|
||||
|
||||
# Register function with LLM
|
||||
self.llm.register_function(
|
||||
name,
|
||||
transition_func,
|
||||
cancel_on_interruption=True,
|
||||
)
|
||||
|
||||
async def _register_builtin_functions(self):
|
||||
"""Register built-in functions (calculator and timezone) with the LLM."""
|
||||
logger.debug("Registering built-in functions with LLM")
|
||||
|
||||
properties = FunctionCallResultProperties(run_llm=True)
|
||||
|
||||
# Register calculator function
|
||||
async def calculate_func(function_call_params: FunctionCallParams) -> None:
|
||||
try:
|
||||
expr = function_call_params.arguments.get("expression", "")
|
||||
result = safe_calculator(expr)
|
||||
await function_call_params.result_callback(
|
||||
{"expression": expr, "result": result}, properties=properties
|
||||
)
|
||||
except Exception as e:
|
||||
await function_call_params.result_callback(
|
||||
{"error": str(e)}, properties=properties
|
||||
)
|
||||
|
||||
# Register timezone functions
|
||||
async def get_current_time_func(
|
||||
function_call_params: FunctionCallParams,
|
||||
) -> None:
|
||||
try:
|
||||
timezone = function_call_params.arguments.get("timezone", "UTC")
|
||||
result = get_current_time(timezone)
|
||||
await function_call_params.result_callback(
|
||||
result, properties=properties
|
||||
)
|
||||
except Exception as e:
|
||||
await function_call_params.result_callback(
|
||||
{"error": str(e)}, properties=properties
|
||||
)
|
||||
|
||||
async def convert_time_func(function_call_params: FunctionCallParams) -> None:
|
||||
try:
|
||||
result = convert_time(
|
||||
function_call_params.arguments.get("source_timezone"),
|
||||
function_call_params.arguments.get("time"),
|
||||
function_call_params.arguments.get("target_timezone"),
|
||||
)
|
||||
await function_call_params.result_callback(
|
||||
result, properties=properties
|
||||
)
|
||||
except Exception as e:
|
||||
await function_call_params.result_callback(
|
||||
{"error": str(e)}, properties=properties
|
||||
)
|
||||
|
||||
# Register all built-in functions
|
||||
self.llm.register_function("safe_calculator", calculate_func)
|
||||
self.llm.register_function("get_current_time", get_current_time_func)
|
||||
self.llm.register_function("convert_time", convert_time_func)
|
||||
|
||||
async def _queue_tts_response(self, text: str) -> None:
|
||||
"""Queue TTS frames for static text response."""
|
||||
await self.task.queue_frames(
|
||||
[
|
||||
LLMFullResponseStartFrame(),
|
||||
TTSSpeakFrame(text=text),
|
||||
LLMFullResponseEndFrame(),
|
||||
]
|
||||
)
|
||||
|
||||
async def _setup_static_start_node_transition(self, node: Node) -> None:
|
||||
"""Set up the deferred transition for static start nodes."""
|
||||
if not node.out_edges:
|
||||
return
|
||||
|
||||
next_node_id = node.out_edges[0].target
|
||||
|
||||
if not node.wait_for_user_response:
|
||||
# Normal static start node - transition immediately after context push
|
||||
async def _deferred_static_transition():
|
||||
try:
|
||||
await self.set_node(next_node_id)
|
||||
except Exception as exc:
|
||||
logger.error(
|
||||
f"Error executing deferred static node transition to {next_node_id}: {exc}"
|
||||
)
|
||||
|
||||
self._pending_control_transition_after_context_push = (
|
||||
_deferred_static_transition
|
||||
)
|
||||
|
||||
async def _perform_variable_extraction_if_needed(
|
||||
self, previous_node: Optional[Node]
|
||||
) -> None:
|
||||
"""Perform variable extraction if the previous node had extraction enabled."""
|
||||
if (
|
||||
previous_node
|
||||
and previous_node.extraction_enabled
|
||||
and previous_node.extraction_variables
|
||||
):
|
||||
logger.debug(
|
||||
f"Scheduling background variable extraction for node: {previous_node.name}"
|
||||
)
|
||||
|
||||
# Capture the current turn context before creating the background task
|
||||
parent_context = get_current_turn_context()
|
||||
extraction_prompt = self._format_prompt(previous_node.extraction_prompt)
|
||||
extraction_variables = previous_node.extraction_variables
|
||||
|
||||
async def _background_extraction():
|
||||
try:
|
||||
extracted_data = (
|
||||
await self._variable_extraction_manager._perform_extraction(
|
||||
extraction_variables, parent_context, extraction_prompt
|
||||
)
|
||||
)
|
||||
self._gathered_context.update(extracted_data)
|
||||
logger.debug(
|
||||
f"Background variable extraction completed. Extracted: {extracted_data}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error during background variable extraction: {str(e)}"
|
||||
)
|
||||
|
||||
# Fire and forget - extraction happens in background without blocking
|
||||
asyncio.create_task(_background_extraction())
|
||||
|
||||
async def _setup_llm_context_and_start_generation(self, node: Node) -> None:
|
||||
"""Common method to set up LLM context and queue context frame for non-static nodes."""
|
||||
# Set node name for tracing
|
||||
try:
|
||||
self.context.set_node_name(node.name)
|
||||
except AttributeError:
|
||||
logger.warning(f"context has no set_node_name method")
|
||||
|
||||
# Register transition functions if not an end node
|
||||
if not node.is_end:
|
||||
for outgoing_edge in node.out_edges:
|
||||
await self._register_transition_function_with_llm(
|
||||
outgoing_edge.get_function_name(), outgoing_edge.target
|
||||
)
|
||||
|
||||
# Set up system message and functions
|
||||
(
|
||||
system_message,
|
||||
functions,
|
||||
) = await self._compose_system_message_functions_for_node(node)
|
||||
await self._update_llm_context(system_message, functions)
|
||||
|
||||
# Queue context frame if needed
|
||||
if self._queue_context_frame:
|
||||
await self.task.queue_frame(OpenAILLMContextFrame(self.context))
|
||||
else:
|
||||
logger.debug(
|
||||
f"Not queueing context frame for node: {node.name} as _queue_context_frame is False"
|
||||
)
|
||||
|
||||
# Reset _queue_context_frame as default behavior
|
||||
self._queue_context_frame = True
|
||||
|
||||
async def set_node(self, node_id: str):
|
||||
"""
|
||||
Simplified set_node implementation according to v2 PRD.
|
||||
"""
|
||||
node = self.workflow.nodes[node_id]
|
||||
|
||||
logger.debug(
|
||||
f"Executing node: name: {node.name} is_static: {node.is_static} allow_interrupt: {node.allow_interrupt} is_end: {node.is_end}"
|
||||
)
|
||||
|
||||
# Set current node for all nodes (including static ones) so STT mute filter works
|
||||
self._current_node = node
|
||||
|
||||
# Handle start nodes
|
||||
if node.is_start:
|
||||
await self._handle_start_node(node)
|
||||
# Handle end nodes
|
||||
elif node.is_end:
|
||||
await self._handle_end_node(node)
|
||||
# Handle normal agent nodes
|
||||
else:
|
||||
await self._handle_agent_node(node)
|
||||
|
||||
async def _handle_start_node(self, node: Node) -> None:
|
||||
"""Handle start node execution."""
|
||||
# Handle voicemail detection setup (before any returns)
|
||||
if node.detect_voicemail:
|
||||
if not self._audio_buffer:
|
||||
logger.warning(
|
||||
"Voicemail detection enabled but no audio buffer available - skipping detection"
|
||||
)
|
||||
else:
|
||||
logger.debug(
|
||||
"Start node has detect_voicemail enabled - setting up audio-based detector"
|
||||
)
|
||||
self._detect_voicemail = True
|
||||
|
||||
self._voicemail_detector = VoicemailDetector(
|
||||
detection_duration=VOICEMAIL_RECORDING_DURATION,
|
||||
workflow_run_id=self._workflow_run_id,
|
||||
)
|
||||
|
||||
# Register audio handler on the audio buffer input processor
|
||||
audio_input = self._audio_buffer.input()
|
||||
|
||||
@audio_input.event_handler("on_input_audio_data")
|
||||
async def handle_voicemail_audio(
|
||||
processor, pcm, sample_rate, num_channels
|
||||
):
|
||||
if (
|
||||
self._voicemail_detector
|
||||
and self._voicemail_detector.is_detecting
|
||||
):
|
||||
await self._voicemail_detector.handle_audio_data(
|
||||
processor, pcm, sample_rate, num_channels
|
||||
)
|
||||
|
||||
# Start detection
|
||||
await self._voicemail_detector.start_detection(self)
|
||||
|
||||
# Check if delayed start is enabled
|
||||
if node.delayed_start:
|
||||
# Use configured duration or default to 3 seconds
|
||||
delay_duration = node.delayed_start_duration or 2.0
|
||||
logger.debug(
|
||||
f"Delayed start enabled - waiting {delay_duration} seconds before speaking"
|
||||
)
|
||||
await asyncio.sleep(delay_duration)
|
||||
|
||||
if node.is_static:
|
||||
# Queue TTS for static start node
|
||||
formatted_prompt = self._format_prompt(node.prompt)
|
||||
await self._queue_tts_response(formatted_prompt)
|
||||
|
||||
# Set up deferred transition for static start nodes
|
||||
await self._setup_static_start_node_transition(node)
|
||||
else:
|
||||
# Start generation for non-static start node
|
||||
await self._setup_llm_context_and_start_generation(node)
|
||||
|
||||
async def _handle_end_node(self, node: Node) -> None:
|
||||
"""Handle end node execution."""
|
||||
if node.is_static:
|
||||
# Queue TTS for static end node
|
||||
formatted_prompt = self._format_prompt(node.prompt)
|
||||
await self._queue_tts_response(formatted_prompt)
|
||||
else:
|
||||
# Start generation for non-static end node
|
||||
await self._setup_llm_context_and_start_generation(node)
|
||||
|
||||
# If this end node has extraction enabled, perform extraction immediately
|
||||
if node.extraction_enabled and node.extraction_variables:
|
||||
await self._perform_variable_extraction_if_needed(node)
|
||||
|
||||
# TODO: Extract disposition code from extracted variables
|
||||
# Defer send_end_task_frame using _pending_control_transition_after_context_push
|
||||
|
||||
# Decide the end-task reason dynamically depending on call_disposition.
|
||||
async def _deferred_end_task():
|
||||
# call_disposition is the disposition which is generated from
|
||||
# llm call based on the conversation so far.
|
||||
# TODO: Make this more generic based on configuration or llm prompting
|
||||
disposition = self._gathered_context.get("call_disposition")
|
||||
if disposition == "XFER":
|
||||
reason = EndTaskReason.USER_QUALIFIED.value
|
||||
else:
|
||||
reason = EndTaskReason.USER_DISQUALIFIED.value
|
||||
await self.send_end_task_frame(reason)
|
||||
|
||||
self._pending_control_transition_after_context_push = _deferred_end_task
|
||||
|
||||
async def _handle_agent_node(self, node: Node) -> None:
|
||||
"""Handle agent node execution."""
|
||||
if node.is_static:
|
||||
# Queue TTS for static agent node
|
||||
formatted_prompt = self._format_prompt(node.prompt)
|
||||
await self._queue_tts_response(formatted_prompt)
|
||||
|
||||
# Set up deferred transition for static agent nodes
|
||||
await self._setup_agent_node_transition(node)
|
||||
else:
|
||||
# Set context and functions for non-static agent node
|
||||
await self._setup_llm_context_and_start_generation(node)
|
||||
|
||||
async def _setup_agent_node_transition(self, node: Node) -> None:
|
||||
"""Set up the deferred transition for static agent nodes."""
|
||||
if not node.out_edges:
|
||||
return
|
||||
|
||||
next_node_id = node.out_edges[0].target
|
||||
|
||||
async def _deferred_static_transition():
|
||||
try:
|
||||
await self.set_node(next_node_id)
|
||||
except Exception as exc:
|
||||
logger.error(
|
||||
f"Error executing deferred static node transition to {next_node_id}: {exc}"
|
||||
)
|
||||
|
||||
self._pending_control_transition_after_context_push = (
|
||||
_deferred_static_transition
|
||||
)
|
||||
|
||||
async def send_end_task_frame(
|
||||
self,
|
||||
reason: str,
|
||||
additional_metadata: dict = None,
|
||||
abort_immediately: bool = False,
|
||||
):
|
||||
"""
|
||||
Centralized method to send EndTaskFrame with metadata including
|
||||
call_transfer_context and call_context_vars
|
||||
"""
|
||||
frame_to_push = CancelFrame() if abort_immediately else EndFrame()
|
||||
|
||||
# Customer disposition code using their mapping
|
||||
mapped_disposition = ""
|
||||
|
||||
# Apply disposition mapping - first try call_disposition if it is,
|
||||
# extracted from the call conversation then fall back to reason
|
||||
call_disposition = self._gathered_context.get("call_disposition", "")
|
||||
organization_id = await get_organization_id_from_workflow_run(
|
||||
self._workflow_run_id
|
||||
)
|
||||
|
||||
if call_disposition:
|
||||
# If call_disposition exists, map it
|
||||
mapped_disposition = await apply_disposition_mapping(
|
||||
call_disposition, organization_id
|
||||
)
|
||||
# Store the original and mapped values
|
||||
self._gathered_context["extracted_call_disposition"] = call_disposition
|
||||
self._gathered_context["call_disposition"] = mapped_disposition
|
||||
else:
|
||||
# Otherwise, map the disconnect reason
|
||||
mapped_disposition = await apply_disposition_mapping(
|
||||
reason, organization_id
|
||||
)
|
||||
# Store the mapped disconnect reason
|
||||
self._gathered_context["call_disposition"] = mapped_disposition
|
||||
|
||||
# TODO: Generalise this, currently tailored to Kapil's use case
|
||||
self._gathered_context["address"] = ", ".join(
|
||||
[
|
||||
self._call_context_vars.get("address1", ""),
|
||||
self._call_context_vars.get("address2", ""),
|
||||
self._call_context_vars.get("address3", ""),
|
||||
self._call_context_vars.get("city", ""),
|
||||
self._call_context_vars.get("state", ""),
|
||||
self._call_context_vars.get("province", ""),
|
||||
self._call_context_vars.get("postal_code", ""),
|
||||
]
|
||||
)
|
||||
self._gathered_context["full_name"] = " ".join(
|
||||
[
|
||||
self._call_context_vars.get("first_name", ""),
|
||||
self._call_context_vars.get("middle_initial", ""),
|
||||
self._call_context_vars.get("last_name", ""),
|
||||
]
|
||||
)
|
||||
self._gathered_context["agent_name"] = "Alex"
|
||||
self._gathered_context["customer_phone_number"] = self._call_context_vars.get(
|
||||
"phone", ""
|
||||
)
|
||||
self._gathered_context["timezone"] = self._call_context_vars.get("province", "")
|
||||
self._gathered_context["vendor_id"] = self._call_context_vars.get(
|
||||
"vendor_lead_code", ""
|
||||
)
|
||||
|
||||
decision_maker = self._gathered_context.get("primary_cardholder", False)
|
||||
employment_status = self._gathered_context.get("employment_status", "N/A")
|
||||
call_transfer_context = {
|
||||
"first_name": self._call_context_vars.get("first_name", ""),
|
||||
"full_name": self._gathered_context.get("full_name", ""),
|
||||
"phone": self._call_context_vars.get("phone", ""),
|
||||
"lead_id": self._call_context_vars.get("lead_id"),
|
||||
"disposition": mapped_disposition,
|
||||
"agent_name": self._gathered_context.get("agent_name", "Alex"),
|
||||
"decision_maker": str(decision_maker),
|
||||
"employment": employment_status.title() if employment_status else "N/A",
|
||||
"debts": self._gathered_context.get("total_debt", "N/A"),
|
||||
"number_of_credit_cards": self._gathered_context.get(
|
||||
"number_of_credit_cards", "N/A"
|
||||
),
|
||||
"time": self._gathered_context.get("time"),
|
||||
}
|
||||
|
||||
logger.debug(
|
||||
f"gathered_context: {self._gathered_context} call_transfer_context: {call_transfer_context}"
|
||||
)
|
||||
|
||||
# Initiate immediate transfer for Stasis connections when user is qualified
|
||||
if (
|
||||
reason == EndTaskReason.USER_QUALIFIED.value
|
||||
and self._stasis_connection is not None
|
||||
and not abort_immediately
|
||||
):
|
||||
try:
|
||||
logger.info(
|
||||
f"Initiating immediate Stasis transfer for channel {self._stasis_connection.channel_id}"
|
||||
)
|
||||
await self._stasis_connection.transfer(call_transfer_context)
|
||||
logger.info("Immediate transfer initiated successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initiate immediate transfer: {e}")
|
||||
# Continue with normal flow even if immediate transfer fails
|
||||
|
||||
if reason == EndTaskReason.CALL_DURATION_EXCEEDED.value:
|
||||
await self.task.queue_frame(
|
||||
TTSSpeakFrame(
|
||||
"Sorry! It seems like our time has exceeded. Someone from our team will reach out to you soon. Thank you!"
|
||||
)
|
||||
)
|
||||
|
||||
metadata = {
|
||||
# Keep original reason in metadata, which would be used to decide
|
||||
# whether to disconnect or to transfer the call in the transport
|
||||
"reason": reason,
|
||||
"call_transfer_context": call_transfer_context,
|
||||
}
|
||||
|
||||
# Add any additional metadata
|
||||
if additional_metadata:
|
||||
metadata.update(additional_metadata)
|
||||
|
||||
frame_to_push.metadata = metadata
|
||||
|
||||
# Store the original reason for later retrieval in event handler
|
||||
self._call_disposition = mapped_disposition
|
||||
|
||||
logger.debug(
|
||||
f"Finishing run with reason: {reason}, disposition: {mapped_disposition} queueing frame {frame_to_push}"
|
||||
)
|
||||
await self.task.queue_frame(frame_to_push)
|
||||
|
||||
async def _compose_system_message_functions_for_node(
|
||||
self, node: "Node"
|
||||
) -> tuple[list[dict], list[dict]]:
|
||||
"""Generate the system messages and function schemas for the given node.
|
||||
|
||||
This performs the same formatting logic used when entering a node but
|
||||
does **not** register the functions with the LLM; callers are
|
||||
responsible for that.
|
||||
"""
|
||||
|
||||
global_prompt = ""
|
||||
if self.workflow.global_node_id and node.add_global_prompt:
|
||||
global_node = self.workflow.nodes[self.workflow.global_node_id]
|
||||
global_prompt = self._format_prompt(global_node.prompt)
|
||||
|
||||
functions: list[dict] = []
|
||||
|
||||
# Add built-in function schemas (calculator and timezone tools)
|
||||
functions.extend(self.builtin_function_schemas)
|
||||
|
||||
# Transition functions (schema only; registration handled elsewhere)
|
||||
for outgoing_edge in node.out_edges:
|
||||
function_schema = self._get_function_schema(
|
||||
outgoing_edge.get_function_name(), outgoing_edge.condition
|
||||
)
|
||||
functions.append(function_schema)
|
||||
|
||||
formatted_node_prompt = self._format_prompt(node.prompt)
|
||||
|
||||
system_message = {
|
||||
"role": "system",
|
||||
"content": "\n\n".join(
|
||||
p for p in (global_prompt, formatted_node_prompt) if p
|
||||
),
|
||||
}
|
||||
|
||||
return system_message, functions
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Pending transition handling
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def flush_pending_transitions(self, *, source: str = "context_push"):
|
||||
"""Execute and clear any pending transitions.
|
||||
|
||||
Args:
|
||||
source: Indicates the trigger that caused this flush:
|
||||
- "context_push": the assistant context aggregator completed a push.
|
||||
"""
|
||||
|
||||
if source != "context_push":
|
||||
raise ValueError("Invalid flush source – expected 'context_push'")
|
||||
|
||||
len_pending_functions = 0
|
||||
|
||||
if self._pending_generated_transition_after_context_push is not None:
|
||||
len_pending_functions += 1
|
||||
if self._pending_control_transition_after_context_push is not None:
|
||||
len_pending_functions += 1
|
||||
|
||||
# Nothing to do
|
||||
if len_pending_functions == 0:
|
||||
return
|
||||
|
||||
logger.debug(
|
||||
f"Flushing {len_pending_functions} pending transition(s) after {source.replace('_', ' ')}"
|
||||
)
|
||||
|
||||
# Generated transition
|
||||
if self._pending_generated_transition_after_context_push is not None:
|
||||
pending_cb = self._pending_generated_transition_after_context_push
|
||||
self._pending_generated_transition_after_context_push = None
|
||||
try:
|
||||
await pending_cb()
|
||||
except Exception as exc: # pragma: no cover
|
||||
logger.error(f"Error executing deferred transition: {exc}")
|
||||
|
||||
# Control transition (context push)
|
||||
if self._pending_control_transition_after_context_push is not None:
|
||||
logger.debug("Executing control transition after context push")
|
||||
static_cb = self._pending_control_transition_after_context_push
|
||||
self._pending_control_transition_after_context_push = None
|
||||
try:
|
||||
await static_cb()
|
||||
except Exception as exc: # pragma: no cover
|
||||
logger.error(f"Error executing deferred static node transition: {exc}")
|
||||
|
||||
def create_should_mute_callback(self) -> Callable[[STTMuteFilter], Awaitable[bool]]:
|
||||
"""
|
||||
This callback is called by STTMuteFilter to determine if the STT should be muted.
|
||||
"""
|
||||
return engine_callbacks.create_should_mute_callback(self)
|
||||
|
||||
def create_user_idle_callback(self):
|
||||
"""
|
||||
This callback is called when the user is idle for a certain duration.
|
||||
We use this to either play the static text or end the call
|
||||
"""
|
||||
return engine_callbacks.create_user_idle_callback(self)
|
||||
|
||||
def create_max_duration_callback(self):
|
||||
"""
|
||||
This callback is called when the call duration exceeds the max duration.
|
||||
We use this to send the EndTaskFrame.
|
||||
"""
|
||||
return engine_callbacks.create_max_duration_callback(self)
|
||||
|
||||
def create_llm_generated_text_callback(self):
|
||||
"""
|
||||
This callback is called when some text is generated by the LLM.
|
||||
We use this to defer the result_callback of the node transition functions if
|
||||
there is set_node called along with some text generated. This way, we will
|
||||
have the context sent in the next generation from new node.
|
||||
"""
|
||||
return engine_callbacks.create_llm_generated_text_callback(self)
|
||||
|
||||
def create_generation_started_callback(self):
|
||||
"""
|
||||
This callback is called when a new generation starts.
|
||||
This is used to reset the flags that control the flow of the engine.
|
||||
"""
|
||||
return engine_callbacks.create_generation_started_callback(self)
|
||||
|
||||
def create_user_stopped_speaking_callback(self):
|
||||
"""
|
||||
This callback is called when the user stops speaking.
|
||||
We use this to handle transitions when wait_for_user_response is enabled.
|
||||
"""
|
||||
return engine_callbacks.create_user_stopped_speaking_callback(self)
|
||||
|
||||
def create_user_started_speaking_callback(self):
|
||||
"""
|
||||
This callback is called when the user starts speaking.
|
||||
We use this to handle wait_for_user_greeting functionality.
|
||||
"""
|
||||
return engine_callbacks.create_user_started_speaking_callback(self)
|
||||
|
||||
def create_aggregation_correction_callback(self) -> Callable[[str], str]:
|
||||
"""Create a callback that corrects corrupted aggregation using reference text."""
|
||||
return engine_callbacks.create_aggregation_correction_callback(self)
|
||||
|
||||
def get_call_disposition(self) -> Optional[str]:
|
||||
"""Get the disconnect reason set by the engine."""
|
||||
return self._call_disposition
|
||||
|
||||
def get_gathered_context(self) -> dict:
|
||||
"""Get the gathered context including extracted variables."""
|
||||
return self._gathered_context.copy()
|
||||
|
||||
def set_context(self, context: OpenAILLMContext) -> None:
|
||||
"""Set the OpenAI LLM context.
|
||||
|
||||
This allows setting the context after the engine has been created,
|
||||
which is useful when the context needs to be created after the engine.
|
||||
"""
|
||||
self.context = context
|
||||
|
||||
def set_task(self, task: PipelineTask) -> None:
|
||||
"""Set the pipeline task.
|
||||
|
||||
This allows setting the task after the engine has been created,
|
||||
which is useful when the task needs to be created after the engine.
|
||||
"""
|
||||
self.task = task
|
||||
|
||||
def set_audio_buffer(self, audio_buffer: "AudioBuffer") -> None:
|
||||
"""Set the audio buffer.
|
||||
|
||||
This allows setting the audio buffer after the engine has been created,
|
||||
which is useful when the audio buffer needs to be created after the engine.
|
||||
"""
|
||||
self._audio_buffer = audio_buffer
|
||||
|
||||
def set_stasis_connection(
|
||||
self, connection: Optional["StasisRTPConnection"]
|
||||
) -> None:
|
||||
"""Set the Stasis RTP connection for immediate transfers.
|
||||
|
||||
This allows the engine to initiate transfers immediately when XFER
|
||||
disposition is detected, without waiting for pipeline shutdown.
|
||||
|
||||
Args:
|
||||
connection: The StasisRTPConnection instance, or None for non-Stasis transports
|
||||
"""
|
||||
self._stasis_connection = connection
|
||||
if connection:
|
||||
logger.debug(
|
||||
f"Stasis connection set for immediate transfers: {connection.channel_id}"
|
||||
)
|
||||
|
||||
async def handle_llm_text_frame(self, text: str):
|
||||
"""Accumulate LLM text frames to build reference text."""
|
||||
self._current_llm_reference_text += text
|
||||
|
||||
async def cleanup(self):
|
||||
"""Clean up engine resources on disconnect."""
|
||||
# Cancel any pending timeout tasks
|
||||
if (
|
||||
self._user_response_timeout_task
|
||||
and not self._user_response_timeout_task.done()
|
||||
):
|
||||
self._user_response_timeout_task.cancel()
|
||||
|
||||
# Stop voicemail detection if active
|
||||
if self._voicemail_detector and hasattr(
|
||||
self._voicemail_detector, "stop_detection"
|
||||
):
|
||||
await self._voicemail_detector.stop_detection()
|
||||
305
api/services/workflow/pipecat_engine_callbacks.py
Normal file
305
api/services/workflow/pipecat_engine_callbacks.py
Normal file
|
|
@ -0,0 +1,305 @@
|
|||
from __future__ import annotations
|
||||
|
||||
"""Callback factory helpers for :pyclass:`~api.services.workflow.pipecat_engine.PipecatEngine`.
|
||||
|
||||
Each helper takes a :class:`PipecatEngine` instance and returns an async
|
||||
callback function suitable for passing to the various pipeline processors.
|
||||
Separating these helpers into their own module keeps
|
||||
``pipecat_engine.py`` focused on high-level engine orchestration logic while
|
||||
encapsulating the callback implementations here for easier maintenance and
|
||||
unit-testing.
|
||||
"""
|
||||
|
||||
import re
|
||||
from typing import TYPE_CHECKING, Awaitable, Callable
|
||||
|
||||
from loguru import logger
|
||||
from pipecat.frames.frames import (
|
||||
LLMFullResponseEndFrame,
|
||||
LLMFullResponseStartFrame,
|
||||
TTSSpeakFrame,
|
||||
)
|
||||
from pipecat.processors.filters.stt_mute_filter import STTMuteFilter
|
||||
from pipecat.utils.enums import EndTaskReason
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pipecat.processors.user_idle_processor import UserIdleProcessor
|
||||
|
||||
from api.services.workflow.pipecat_engine import PipecatEngine
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# STT mute handling
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def create_should_mute_callback(
|
||||
engine: "PipecatEngine",
|
||||
) -> Callable[[STTMuteFilter], Awaitable[bool]]:
|
||||
"""Return a callback indicating whether STT should be muted.
|
||||
|
||||
STT is muted when *interruptions are **not*** allowed on the current node.
|
||||
"""
|
||||
|
||||
async def callback(_: STTMuteFilter) -> bool: # noqa: D401
|
||||
if engine._current_node is None:
|
||||
# Default to not muting if we have no active node yet.
|
||||
return False
|
||||
|
||||
logger.debug(
|
||||
f"STT mute callback: allow_interrupt={engine._current_node.allow_interrupt}"
|
||||
)
|
||||
return not engine._current_node.allow_interrupt
|
||||
|
||||
return callback
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# User-idle handling
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def create_user_idle_callback(engine: "PipecatEngine"):
|
||||
"""Return a callback that handles user-idle timeouts."""
|
||||
|
||||
async def handle_user_idle(
|
||||
user_idle: "UserIdleProcessor", retry_count: int
|
||||
) -> bool:
|
||||
logger.debug(f"Handling user_idle, attempt: {retry_count}")
|
||||
|
||||
# Check if we're on a StartNode - if yes, directly disconnect
|
||||
if engine._current_node and engine._current_node.is_start:
|
||||
logger.debug("User idle on StartNode - disconnecting immediately")
|
||||
await engine.send_end_task_frame(
|
||||
EndTaskReason.USER_IDLE_MAX_DURATION_EXCEEDED.value
|
||||
)
|
||||
return False
|
||||
|
||||
if retry_count == 1:
|
||||
# Simulate an LLM generation, so that we can have the LLM context
|
||||
# updated with the new message
|
||||
await engine.task.queue_frames(
|
||||
[
|
||||
LLMFullResponseStartFrame(),
|
||||
TTSSpeakFrame("Just checking in to see if you're still there."),
|
||||
LLMFullResponseEndFrame(),
|
||||
]
|
||||
)
|
||||
return True
|
||||
|
||||
# Second attempt: terminate the call due to inactivity.
|
||||
await user_idle.push_frame(
|
||||
TTSSpeakFrame("It seems like you're busy right now. Have a nice day!")
|
||||
)
|
||||
await engine.send_end_task_frame(
|
||||
EndTaskReason.USER_IDLE_MAX_DURATION_EXCEEDED.value
|
||||
)
|
||||
return False
|
||||
|
||||
return handle_user_idle
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Max-duration handling
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def create_max_duration_callback(engine: "PipecatEngine"):
|
||||
"""Return a callback that ends the task when the max call duration is exceeded."""
|
||||
|
||||
async def handle_max_duration():
|
||||
logger.debug("Max call duration exceeded. Terminating call")
|
||||
await engine.send_end_task_frame(EndTaskReason.CALL_DURATION_EXCEEDED.value)
|
||||
|
||||
return handle_max_duration
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# LLM-generated-text handling
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def create_llm_generated_text_callback(engine: "PipecatEngine"):
|
||||
"""Return a callback invoked when the LLM emits text (not only tool calls)."""
|
||||
|
||||
async def handle_llm_generated_text(): # noqa: D401
|
||||
logger.debug(
|
||||
"Generation has text content in current response - deferring context push from set_node"
|
||||
)
|
||||
engine._defer_context_push = True
|
||||
|
||||
return handle_llm_generated_text
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Generation-started handling
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def create_generation_started_callback(engine: "PipecatEngine"):
|
||||
"""Return a callback that resets flags at the start of each LLM generation."""
|
||||
|
||||
async def handle_generation_started(): # noqa: D401
|
||||
logger.debug("LLM generation started - resetting defer flags and tool counters")
|
||||
engine._defer_context_push = False
|
||||
engine._pending_function_calls = 0
|
||||
engine._pending_generated_transition_after_context_push = None
|
||||
# Clear reference text from previous generation
|
||||
engine._current_llm_reference_text = ""
|
||||
|
||||
return handle_generation_started
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# User-stopped-speaking handling
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def create_user_stopped_speaking_callback(engine: "PipecatEngine"):
|
||||
"""Return a callback that handles when the user stops speaking.
|
||||
|
||||
According to simplified flow:
|
||||
- For start nodes with wait_for_user_response=True:
|
||||
- Cancel timeout task if still active
|
||||
- Transition to next node with _queue_context_frame=False
|
||||
"""
|
||||
|
||||
async def handle_user_stopped_speaking():
|
||||
# Only handle if current node is a start node with wait_for_user_response
|
||||
if (
|
||||
engine._current_node
|
||||
and engine._current_node.is_start
|
||||
and engine._current_node.wait_for_user_response
|
||||
and engine._current_node.out_edges
|
||||
):
|
||||
# Cancel timeout task if it's still active
|
||||
if (
|
||||
engine._user_response_timeout_task
|
||||
and not engine._user_response_timeout_task.done()
|
||||
):
|
||||
logger.debug("Cancelling user response timeout - user responded")
|
||||
engine._user_response_timeout_task.cancel()
|
||||
engine._user_response_timeout_task = None
|
||||
|
||||
# Transition to next node
|
||||
next_node_id = engine._current_node.out_edges[0].target
|
||||
logger.debug(
|
||||
f"User stopped speaking after wait_for_user_response - transitioning to: {next_node_id}"
|
||||
)
|
||||
|
||||
# Set flag to not queue context frame since
|
||||
# it will be pushed by user context aggregator
|
||||
# we are just setting the context with next node's
|
||||
# functions and prompts
|
||||
engine._queue_context_frame = False
|
||||
|
||||
# Transition to next node
|
||||
await engine.set_node(next_node_id)
|
||||
|
||||
return handle_user_stopped_speaking
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# User-started-speaking handling
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def create_user_started_speaking_callback(engine: "PipecatEngine"):
|
||||
"""Return a callback that handles when the user starts speaking.
|
||||
|
||||
According to simplified flow:
|
||||
- For start nodes with wait_for_user_response=True:
|
||||
- Cancel the timeout timer if it exists (but don't set to None)
|
||||
"""
|
||||
|
||||
async def handle_user_started_speaking():
|
||||
# Only handle if current node is a start node with wait_for_user_response
|
||||
if (
|
||||
engine._current_node
|
||||
and engine._current_node.is_start
|
||||
and engine._current_node.wait_for_user_response
|
||||
and engine._user_response_timeout_task
|
||||
and not engine._user_response_timeout_task.done()
|
||||
):
|
||||
logger.debug(
|
||||
"User started speaking during wait_for_user_response - cancelling timeout timer"
|
||||
)
|
||||
engine._user_response_timeout_task.cancel()
|
||||
# Don't set to None here - let user_stopped_speaking handle the transition
|
||||
|
||||
return handle_user_started_speaking
|
||||
|
||||
|
||||
def create_aggregation_correction_callback(engine: "PipecatEngine"):
|
||||
"""Create a callback that uses engine's reference text to correct corrupted aggregation."""
|
||||
|
||||
def correct_corrupted_aggregation(ref: str, corrupted: str) -> str:
|
||||
"""Correct corrupted text by aligning it with reference text.
|
||||
|
||||
This is a pure function that doesn't depend on engine instance.
|
||||
"""
|
||||
# 1) Safety check: if ref (minus spaces) is shorter than corrupted, bail out
|
||||
# also if corrupted is less than 10 characters, lets also return that since most likely
|
||||
# Elevenlabs returned the right alignment
|
||||
alnum_corr = "".join(ch for ch in corrupted if ch.isalnum())
|
||||
alnum_ref = "".join(ch for ch in ref if ch.isalnum())
|
||||
|
||||
if corrupted in ref or len(alnum_ref) < len(alnum_corr) or len(alnum_corr) < 10:
|
||||
return corrupted
|
||||
|
||||
# 2) Find where in `ref` we should start aligning.
|
||||
# We take the first N (N=10) characters of `corrupted`
|
||||
# and look for all their occurrences in `ref`.
|
||||
# We pick the *last* one
|
||||
prefix = corrupted[:10]
|
||||
|
||||
# find all start‐indices of that prefix in ref
|
||||
starts = [m.start() for m in re.finditer(re.escape(prefix), ref)]
|
||||
start_idx = starts[-1] if starts else 0
|
||||
|
||||
# 3) Now run the same two‑pointer scan from start_idx
|
||||
i, j = start_idx, 0
|
||||
out_chars = []
|
||||
while i < len(ref) and j < len(corrupted):
|
||||
r_ch, c_ch = ref[i], corrupted[j]
|
||||
if r_ch == c_ch:
|
||||
out_chars.append(r_ch)
|
||||
i += 1
|
||||
j += 1
|
||||
|
||||
elif c_ch == " ":
|
||||
# extra space in corrupted → skip it
|
||||
j += 1
|
||||
|
||||
elif r_ch == " " or r_ch in ".,;:!?":
|
||||
# missing structural char in corrupted → emit from ref
|
||||
out_chars.append(r_ch)
|
||||
i += 1
|
||||
|
||||
else:
|
||||
# letter mismatch → best‑effort copy from ref
|
||||
out_chars.append(r_ch)
|
||||
i += 1
|
||||
j += 1
|
||||
|
||||
# 4) A final check - the final created output should be exactly
|
||||
# as corrupted sentence sans whitespace.
|
||||
alnum_out = "".join([ch for ch in out_chars if ch.isalnum()])
|
||||
if alnum_out != alnum_corr:
|
||||
return corrupted
|
||||
|
||||
# 5) Join and return exactly what we built
|
||||
return "".join(out_chars)
|
||||
|
||||
def correct_aggregation(corrupted: str) -> str:
|
||||
reference = engine._current_llm_reference_text
|
||||
|
||||
if not reference:
|
||||
logger.warning("No reference text available for aggregation correction")
|
||||
return corrupted
|
||||
|
||||
# Apply the correction algorithm
|
||||
corrected = correct_corrupted_aggregation(reference, corrupted)
|
||||
return corrected
|
||||
|
||||
return correct_aggregation
|
||||
90
api/services/workflow/pipecat_engine_utils.py
Normal file
90
api/services/workflow/pipecat_engine_utils.py
Normal file
|
|
@ -0,0 +1,90 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from google.genai.types import (
|
||||
Content,
|
||||
Part,
|
||||
)
|
||||
from pipecat.adapters.schemas.function_schema import FunctionSchema
|
||||
from pipecat.adapters.schemas.tools_schema import ToolsSchema
|
||||
from pipecat.services.google.llm import GoogleLLMContext
|
||||
from pipecat.services.openai.llm import OpenAILLMContext
|
||||
|
||||
from api.utils.template_renderer import render_template
|
||||
|
||||
__all__ = [
|
||||
"get_function_schema",
|
||||
"update_llm_context",
|
||||
"render_template",
|
||||
]
|
||||
|
||||
|
||||
def get_function_schema(
|
||||
function_name: str,
|
||||
description: str,
|
||||
*,
|
||||
properties: Dict[str, Any] | None = None,
|
||||
required: List[str] | None = None,
|
||||
) -> FunctionSchema:
|
||||
"""Create a FunctionSchema definition that can later be transformed into
|
||||
the provider-specific format (OpenAI, Gemini, etc.).
|
||||
|
||||
The helper keeps the public signature backward-compatible – callers that
|
||||
only pass ``function_name`` and ``description`` continue to work and will
|
||||
define a parameter-less function.
|
||||
"""
|
||||
|
||||
return FunctionSchema(
|
||||
name=function_name,
|
||||
description=description,
|
||||
properties=properties or {},
|
||||
required=required or [],
|
||||
)
|
||||
|
||||
|
||||
def update_llm_context(
|
||||
context: OpenAILLMContext,
|
||||
system_message: Dict[str, Any],
|
||||
functions: List[FunctionSchema],
|
||||
) -> None:
|
||||
"""Update *context* with an up-to-date system message and tool list.
|
||||
|
||||
This helper removes any previous system messages before inserting the new
|
||||
*system_message* at the top of the conversation history and then instructs
|
||||
the LLM which *functions* (a.k.a. tools) are currently available.
|
||||
"""
|
||||
|
||||
# Wrap the provided function schemas in a ToolsSchema so that the adapter
|
||||
# associated with the current LLM service can convert them to the correct
|
||||
# provider-specific representation when required.
|
||||
tools_schema = ToolsSchema(standard_tools=functions)
|
||||
|
||||
if isinstance(context, GoogleLLMContext):
|
||||
context.system_message = system_message["content"]
|
||||
|
||||
if functions:
|
||||
# Lets only call set_tools if we have functions, else Gemini will
|
||||
# throw an exception
|
||||
context.set_tools(tools_schema)
|
||||
|
||||
if context.messages[-1].role != "user":
|
||||
# Google expects the last message should end with user message
|
||||
context.add_message(Content(role="user", parts=[Part(text="...")]))
|
||||
return
|
||||
|
||||
# In case of OpenAILLMContext, replace the system message with incoming system message
|
||||
previous_interactions = context.messages
|
||||
|
||||
# Filter out old system messages but keep user/assistant/function content.
|
||||
messages: List[Dict[str, Any]] = [system_message]
|
||||
messages.extend(
|
||||
interaction
|
||||
for interaction in previous_interactions
|
||||
if interaction["role"] != "system"
|
||||
)
|
||||
|
||||
context.set_messages(messages)
|
||||
|
||||
if functions:
|
||||
context.set_tools(tools_schema)
|
||||
192
api/services/workflow/pipecat_engine_variable_extractor.py
Normal file
192
api/services/workflow/pipecat_engine_variable_extractor.py
Normal file
|
|
@ -0,0 +1,192 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any, List
|
||||
|
||||
from loguru import logger
|
||||
from openai import AsyncOpenAI
|
||||
from opentelemetry import trace
|
||||
from pipecat.services.openai.llm import OpenAILLMContext
|
||||
from pipecat.utils.tracing.service_attributes import add_llm_span_attributes
|
||||
|
||||
from api.services.pipecat.tracing_config import is_tracing_enabled
|
||||
from api.services.workflow.dto import ExtractionVariableDTO
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from api.services.workflow.pipecat_engine import PipecatEngine
|
||||
|
||||
|
||||
class VariableExtractionManager:
|
||||
"""Helper that registers and executes the \"extract_variables\" tool.
|
||||
|
||||
The manager is responsible for two things:
|
||||
1. Registering a callable with the LLM service so that the tool can be
|
||||
invoked from within the model.
|
||||
2. Executing the extraction in a background task while maintaining
|
||||
correct bookkeeping and optional OpenTelemetry tracing.
|
||||
"""
|
||||
|
||||
def __init__(self, engine: "PipecatEngine") -> None: # noqa: F821
|
||||
# We keep a reference to the engine so we can reuse its context
|
||||
# and update internal counters / extracted variable state.
|
||||
self._engine = engine
|
||||
self._context = engine.context
|
||||
self._model = "gpt-4o"
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Internal helpers
|
||||
# ------------------------------------------------------------------
|
||||
async def _perform_extraction(
|
||||
self,
|
||||
extraction_variables: List[ExtractionVariableDTO],
|
||||
parent_ctx: Any,
|
||||
extraction_prompt: str = "",
|
||||
) -> dict:
|
||||
"""Run the actual extraction chat completion and post-process the result."""
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Build the prompt that instructs the model to extract the variables.
|
||||
# ------------------------------------------------------------------
|
||||
vars_description = "\n".join(
|
||||
f"- {v.name} ({v.type}): {v.prompt}" for v in extraction_variables
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Build a normalized representation of the existing conversation so the
|
||||
# extractor works with both OpenAI-style (dict) messages and Google
|
||||
# Gemini `Content` objects.
|
||||
# ------------------------------------------------------------------
|
||||
def _get_role_and_content(msg: Any) -> tuple[str | None, str | None]:
|
||||
"""Return a pair of (role, content) for the given message.
|
||||
|
||||
The logic supports both OpenAI-style dict messages and Google
|
||||
`Content` objects that expose ``role`` and ``parts`` attributes.
|
||||
Only plain textual content is extracted – image parts, tool call
|
||||
placeholders, etc. are ignored for the purpose of variable
|
||||
extraction.
|
||||
"""
|
||||
|
||||
# --------------------------------------------------------------
|
||||
# OpenAI format → simple dict with ``role`` and ``content`` keys
|
||||
# --------------------------------------------------------------
|
||||
if isinstance(msg, dict):
|
||||
role = msg.get("role")
|
||||
content_field = msg.get("content")
|
||||
|
||||
# Content can be a str, list of segments, or None.
|
||||
if isinstance(content_field, str):
|
||||
content = content_field
|
||||
elif isinstance(content_field, list):
|
||||
# Collapse all text parts into a single string.
|
||||
texts = [
|
||||
segment.get("text", "")
|
||||
for segment in content_field
|
||||
if isinstance(segment, dict) and segment.get("type") == "text"
|
||||
]
|
||||
content = " ".join(texts) if texts else None
|
||||
else:
|
||||
content = None
|
||||
|
||||
return role, content
|
||||
|
||||
# --------------------------------------------------------------
|
||||
# Google Gemini format → ``Content`` object with ``parts`` list
|
||||
# --------------------------------------------------------------
|
||||
role_attr = getattr(msg, "role", None)
|
||||
parts_attr = getattr(msg, "parts", None)
|
||||
|
||||
if role_attr is None or parts_attr is None:
|
||||
return None, None # Unrecognised message format
|
||||
|
||||
role = (
|
||||
"assistant" if role_attr == "model" else role_attr
|
||||
) # Normalise role name
|
||||
|
||||
# Collect textual parts only (ignore images, function calls, etc.)
|
||||
texts: list[str] = []
|
||||
for part in parts_attr:
|
||||
text_val = getattr(part, "text", None)
|
||||
if text_val:
|
||||
texts.append(text_val)
|
||||
|
||||
content = " ".join(texts) if texts else None
|
||||
return role, content
|
||||
|
||||
conversation_lines: list[str] = []
|
||||
for msg in self._context.messages:
|
||||
role, content = _get_role_and_content(msg)
|
||||
if role in ("assistant", "user") and content:
|
||||
conversation_lines.append(f"{role}: {content}")
|
||||
|
||||
conversation_history = "\n".join(conversation_lines)
|
||||
|
||||
system_prompt = (
|
||||
"You are an assistant tasked with extracting structured data from the conversation. "
|
||||
"Return ONLY a valid JSON object with the requested variables as top-level keys. Do not wrap the JSON in markdown." # noqa: E501
|
||||
)
|
||||
# Use provided extraction_prompt as system prompt, or default
|
||||
system_prompt = (
|
||||
system_prompt + "\n\n" + extraction_prompt
|
||||
if extraction_prompt
|
||||
else system_prompt
|
||||
)
|
||||
|
||||
user_prompt = (
|
||||
"\n\nVariables to extract:\n"
|
||||
f"{vars_description}"
|
||||
"\n\nConversation history:\n"
|
||||
f"{conversation_history}"
|
||||
)
|
||||
|
||||
extraction_context = OpenAILLMContext()
|
||||
extraction_messages = [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt},
|
||||
]
|
||||
extraction_context.set_messages(extraction_messages)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Use independent OpenAI client for LLM call
|
||||
# ------------------------------------------------------------------
|
||||
client = AsyncOpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
|
||||
|
||||
# Direct API call - no pipeline involvement
|
||||
response = await client.chat.completions.create(
|
||||
model=self._model,
|
||||
messages=extraction_messages,
|
||||
temperature=0.0,
|
||||
response_format={"type": "json_object"},
|
||||
)
|
||||
|
||||
llm_response = response.choices[0].message.content
|
||||
|
||||
if is_tracing_enabled():
|
||||
tracer = trace.get_tracer("pipecat")
|
||||
with tracer.start_as_current_span(
|
||||
"variable_extraction", context=parent_ctx
|
||||
) as span:
|
||||
add_llm_span_attributes(
|
||||
span,
|
||||
service_name="OpenAILLMService",
|
||||
model=self._model,
|
||||
operation_name="variable_extraction",
|
||||
messages=json.dumps(extraction_messages),
|
||||
output=llm_response,
|
||||
stream=False,
|
||||
parameters={"temperature": 0.0, "response_format": "json_object"},
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Parse the assistant output – fall back to raw text if it is not valid JSON.
|
||||
# ------------------------------------------------------------------
|
||||
try:
|
||||
extracted = json.loads(llm_response)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(
|
||||
"Extractor returned invalid JSON; storing raw content instead."
|
||||
)
|
||||
extracted = {"raw": llm_response}
|
||||
|
||||
logger.debug(f"Extracted variables: {extracted}")
|
||||
return extracted
|
||||
448
api/services/workflow/pipecat_engine_voicemail_detector.py
Normal file
448
api/services/workflow/pipecat_engine_voicemail_detector.py
Normal file
|
|
@ -0,0 +1,448 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
import wave
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from langfuse import get_client
|
||||
from loguru import logger
|
||||
from openai import AsyncOpenAI
|
||||
from opentelemetry import context as otel_context
|
||||
from pipecat.utils.enums import EndTaskReason
|
||||
from pipecat.utils.tracing.context_registry import get_current_turn_context
|
||||
|
||||
from api.db import db_client
|
||||
from api.services.pipecat.tracing_config import is_tracing_enabled
|
||||
from api.tasks.arq import enqueue_job
|
||||
from api.tasks.function_names import FunctionNames
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from api.services.workflow.pipecat_engine import PipecatEngine
|
||||
|
||||
|
||||
DEFAULT_VOICEMAIL_PROMPT = """
|
||||
You are analyzing the beginning of a phone call to determine if it's a voicemail greeting.
|
||||
|
||||
Common voicemail indicators:
|
||||
- "You've reached the voicemail of..."
|
||||
- "Please leave a message after the beep"
|
||||
- "I'm not available right now"
|
||||
- "Press 1 to leave a message"
|
||||
- Robotic or pre-recorded voice quality mentioned
|
||||
- Background music or hold music references
|
||||
|
||||
Transcript: {transcript}
|
||||
|
||||
Respond with a JSON object:
|
||||
{
|
||||
"is_voicemail": true/false,
|
||||
"confidence": 0.0-1.0,
|
||||
"reasoning": "Brief explanation"
|
||||
}
|
||||
"""
|
||||
|
||||
|
||||
class VoicemailDetector:
|
||||
"""
|
||||
Autonomous voicemail detection system that operates independently of the main pipeline.
|
||||
"""
|
||||
|
||||
def __init__(self, detection_duration: float = 15.0, workflow_run_id: int = None):
|
||||
self.detection_duration = detection_duration
|
||||
self.audio_buffer = bytearray()
|
||||
self.is_detecting = False
|
||||
self.workflow_run_id = workflow_run_id
|
||||
self._langfuse_client = get_client()
|
||||
|
||||
# We will set the sample rate when we receive the audio packet
|
||||
self._sample_rate = None
|
||||
|
||||
# Task management
|
||||
self._detection_task: Optional[asyncio.Task] = None
|
||||
self._is_cancelled = False
|
||||
self._engine: Optional[PipecatEngine] = None
|
||||
|
||||
# Event for audio collection completion
|
||||
self._audio_collected_event = asyncio.Event()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Utility helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _current_duration_seconds(self) -> float:
|
||||
"""Return the duration (in seconds) of the audio currently in the buffer."""
|
||||
if self._sample_rate:
|
||||
return len(self.audio_buffer) / (self._sample_rate * 2)
|
||||
return 0.0
|
||||
|
||||
async def handle_audio_data(
|
||||
self, processor, pcm: bytes, sample_rate: int, num_channels: int
|
||||
):
|
||||
"""Handle incoming audio data without affecting pipeline."""
|
||||
if not self.is_detecting or self._is_cancelled:
|
||||
return
|
||||
|
||||
# Store the actual sample rate from the first audio packet
|
||||
if self._sample_rate is None:
|
||||
self._sample_rate = sample_rate
|
||||
logger.debug(f"Voicemail detector using sample rate: {sample_rate}")
|
||||
|
||||
# Add to buffer without resampling
|
||||
self.audio_buffer.extend(pcm)
|
||||
|
||||
# Check if we've collected enough audio
|
||||
current_duration = self._current_duration_seconds()
|
||||
if current_duration >= self.detection_duration:
|
||||
self._audio_collected_event.set()
|
||||
|
||||
async def start_detection(self, engine: PipecatEngine):
|
||||
"""Start voicemail detection process."""
|
||||
logger.info("Starting voicemail detection")
|
||||
self.is_detecting = True
|
||||
self._is_cancelled = False
|
||||
self._engine = engine
|
||||
self._audio_collected_event.clear()
|
||||
|
||||
# Start detection in background
|
||||
self._detection_task = asyncio.create_task(self._run_detection_with_timeout())
|
||||
|
||||
async def stop_detection(self):
|
||||
"""Stop detection immediately (called on disconnect)."""
|
||||
logger.info("Stopping voicemail detection due to disconnect")
|
||||
self._is_cancelled = True
|
||||
self.is_detecting = False
|
||||
|
||||
# Set the event to unblock any waiting tasks
|
||||
self._audio_collected_event.set()
|
||||
|
||||
# Cancel ongoing detection task
|
||||
if self._detection_task and not self._detection_task.done():
|
||||
self._detection_task.cancel()
|
||||
|
||||
# Clear audio buffer
|
||||
self.audio_buffer.clear()
|
||||
|
||||
# Wait for tasks to complete cancellation
|
||||
if self._detection_task:
|
||||
try:
|
||||
await self._detection_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
async def _run_detection_with_timeout(self):
|
||||
"""Run detection with proper timeout and cancellation handling."""
|
||||
try:
|
||||
# Wait for audio collection or cancellation directly
|
||||
await self._wait_for_audio_collection()
|
||||
|
||||
# Check if cancelled during collection
|
||||
if self._is_cancelled:
|
||||
logger.info("Detection cancelled during audio collection")
|
||||
return
|
||||
|
||||
# Process detection
|
||||
await self._process_detection()
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.info("Voicemail detection task cancelled")
|
||||
except Exception as e:
|
||||
logger.error(f"Error in voicemail detection: {e}")
|
||||
finally:
|
||||
self.is_detecting = False
|
||||
|
||||
async def _wait_for_audio_collection(self):
|
||||
"""Wait for audio buffer to fill or timeout."""
|
||||
try:
|
||||
# Wait for either audio collection completion or timeout
|
||||
await asyncio.wait_for(
|
||||
self._audio_collected_event.wait(),
|
||||
timeout=self.detection_duration + 2.0,
|
||||
)
|
||||
|
||||
if not self._is_cancelled:
|
||||
current_duration = self._current_duration_seconds()
|
||||
logger.info(
|
||||
f"Collected {current_duration:.1f}s of audio for voicemail detection (sample rate: {self._sample_rate}Hz)"
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
if not self._is_cancelled:
|
||||
current_duration = self._current_duration_seconds()
|
||||
logger.warning("Audio collection timeout exceeded")
|
||||
logger.info(
|
||||
f"Proceeding with {current_duration:.1f}s of audio (sample rate: {self._sample_rate}Hz)"
|
||||
)
|
||||
|
||||
async def _process_detection(self):
|
||||
"""Process the collected audio to detect voicemail."""
|
||||
if not self.audio_buffer or not self._engine:
|
||||
logger.warning("No audio buffer or engine available for detection")
|
||||
return
|
||||
|
||||
try:
|
||||
# Convert PCM to WAV once for both transcription and storage
|
||||
wav_data = self._create_wav_from_pcm(bytes(self.audio_buffer))
|
||||
|
||||
# Transcribe audio
|
||||
logger.info("Transcribing audio for voicemail detection")
|
||||
transcript = await self._transcribe_audio(wav_data)
|
||||
|
||||
if not transcript:
|
||||
logger.warning("No transcript obtained from audio")
|
||||
|
||||
# Still upload the raw recording so data pipeline has it
|
||||
if self.workflow_run_id:
|
||||
await self._save_voicemail_audio(wav_data, 0.0, False)
|
||||
|
||||
return
|
||||
|
||||
logger.info(
|
||||
f"Voicemail detection transcript obtained: {transcript[:100]}..."
|
||||
)
|
||||
|
||||
# Analyze transcript
|
||||
result = await self._analyze_transcript(transcript)
|
||||
|
||||
# Extract common fields
|
||||
confidence = result.get("confidence", 0.0)
|
||||
reasoning = result.get("reasoning", "No reasoning provided")
|
||||
|
||||
# Save voicemail audio to S3 once for data pipeline (include duration in filename)
|
||||
s3_path = None
|
||||
if self.workflow_run_id:
|
||||
s3_path = await self._save_voicemail_audio(
|
||||
wav_data, confidence, result.get("is_voicemail")
|
||||
)
|
||||
|
||||
# Take action based on result
|
||||
if result.get("is_voicemail", False):
|
||||
logger.info(
|
||||
f"Voicemail detected with confidence {confidence}: {reasoning}"
|
||||
)
|
||||
|
||||
# Update workflow run with voicemail tags
|
||||
if self.workflow_run_id:
|
||||
# Fetch the workflow run from database
|
||||
workflow_run = await db_client.get_workflow_run_by_id(
|
||||
self.workflow_run_id
|
||||
)
|
||||
if workflow_run:
|
||||
call_tags = workflow_run.gathered_context.get("call_tags", [])
|
||||
call_tags.extend(["voicemail_detected", "not_connected"])
|
||||
|
||||
await db_client.update_workflow_run(
|
||||
run_id=workflow_run.id,
|
||||
gathered_context={
|
||||
"call_tags": call_tags,
|
||||
"voicemail_transcript": transcript,
|
||||
"voicemail_confidence": confidence,
|
||||
},
|
||||
)
|
||||
|
||||
# Send end task frame with metadata (including optional S3 path)
|
||||
await self._engine.send_end_task_frame(
|
||||
reason=EndTaskReason.VOICEMAIL_DETECTED.value,
|
||||
additional_metadata={
|
||||
"voicemail_transcript": transcript,
|
||||
"voicemail_confidence": confidence,
|
||||
"voicemail_reasoning": reasoning,
|
||||
"voicemail_detection_duration": self.detection_duration,
|
||||
"voicemail_audio_s3_path": s3_path,
|
||||
},
|
||||
abort_immediately=True,
|
||||
)
|
||||
else:
|
||||
logger.info("No voicemail detected, continuing normal conversation")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing voicemail detection: {e}")
|
||||
|
||||
async def _transcribe_audio(self, wav_data: bytes) -> str:
|
||||
"""Transcribe audio using OpenAI API directly.
|
||||
|
||||
Args:
|
||||
wav_data: WAV formatted audio data
|
||||
"""
|
||||
client = AsyncOpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
|
||||
|
||||
# Direct API call - no pipeline involvement
|
||||
response = await client.audio.transcriptions.create(
|
||||
file=("audio.wav", wav_data, "audio/wav"),
|
||||
model="whisper-1", # Using whisper-1 as it's more stable for transcription
|
||||
language="en",
|
||||
temperature=0.0,
|
||||
)
|
||||
|
||||
return response.text.strip()
|
||||
|
||||
def _create_wav_from_pcm(self, pcm_data: bytes) -> bytes:
|
||||
"""Convert raw PCM data to WAV format."""
|
||||
wav_buffer = io.BytesIO()
|
||||
with wave.open(wav_buffer, "wb") as wav_file:
|
||||
wav_file.setnchannels(1) # Mono
|
||||
wav_file.setsampwidth(2) # 16-bit
|
||||
wav_file.setframerate(self._sample_rate)
|
||||
wav_file.writeframes(pcm_data)
|
||||
|
||||
wav_buffer.seek(0)
|
||||
return wav_buffer.read()
|
||||
|
||||
async def _analyze_transcript(self, transcript: str) -> dict:
|
||||
"""Analyze transcript using independent OpenAI client."""
|
||||
# Capture the current turn context for proper span nesting
|
||||
parent_context = get_current_turn_context()
|
||||
|
||||
client = AsyncOpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
|
||||
|
||||
langfuse_prompt = None
|
||||
try:
|
||||
langfuse_prompt = self._langfuse_client.get_prompt(
|
||||
"production/voicemail_detection"
|
||||
)
|
||||
prompt = langfuse_prompt.compile(transcript=transcript)
|
||||
except Exception as e:
|
||||
logger.warning(f"Error getting prompt from Langfuse: {e}")
|
||||
prompt = DEFAULT_VOICEMAIL_PROMPT.replace("{transcript}", transcript)
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": prompt,
|
||||
}
|
||||
]
|
||||
|
||||
# When we have a parent OpenTelemetry context, we need to activate it
|
||||
# so that Langfuse's OTEL tracer will automatically pick it up
|
||||
if parent_context and is_tracing_enabled():
|
||||
# Activate the parent context for this scope
|
||||
token = otel_context.attach(parent_context)
|
||||
try:
|
||||
# Start Langfuse generation - it will automatically use the active OTEL context
|
||||
langfuse_generation = None
|
||||
try:
|
||||
langfuse_generation = self._langfuse_client.start_generation(
|
||||
name="voicemail_detection",
|
||||
model="gpt-4o",
|
||||
input=messages,
|
||||
metadata={
|
||||
"temperature": 0.0,
|
||||
"detection_duration": self.detection_duration,
|
||||
"transcript_length": len(transcript),
|
||||
},
|
||||
prompt=langfuse_prompt,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Error starting Langfuse generation: {e}")
|
||||
|
||||
# Direct API call
|
||||
response = await client.chat.completions.create(
|
||||
model="gpt-4o",
|
||||
messages=messages,
|
||||
temperature=0.0,
|
||||
response_format={"type": "json_object"},
|
||||
)
|
||||
|
||||
llm_response = response.choices[0].message.content
|
||||
|
||||
# Update and end Langfuse generation
|
||||
if langfuse_generation:
|
||||
try:
|
||||
langfuse_generation.update(
|
||||
output=llm_response,
|
||||
usage_details={
|
||||
"prompt_tokens": response.usage.prompt_tokens
|
||||
if response.usage
|
||||
else 0,
|
||||
"completion_tokens": response.usage.completion_tokens
|
||||
if response.usage
|
||||
else 0,
|
||||
"total_tokens": response.usage.total_tokens
|
||||
if response.usage
|
||||
else 0,
|
||||
},
|
||||
)
|
||||
langfuse_generation.end()
|
||||
except Exception as e:
|
||||
logger.warning(f"Error updating Langfuse generation: {e}")
|
||||
finally:
|
||||
# Detach the context
|
||||
otel_context.detach(token)
|
||||
else:
|
||||
# No parent context or tracing disabled - just make the API call
|
||||
response = await client.chat.completions.create(
|
||||
model="gpt-4o",
|
||||
messages=messages,
|
||||
temperature=0.0,
|
||||
response_format={"type": "json_object"},
|
||||
)
|
||||
llm_response = response.choices[0].message.content
|
||||
|
||||
# Parse response
|
||||
try:
|
||||
return json.loads(llm_response)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning("Invalid JSON response from voicemail detection")
|
||||
return {
|
||||
"is_voicemail": False,
|
||||
"confidence": 0.0,
|
||||
"reasoning": "Invalid response",
|
||||
}
|
||||
|
||||
async def _save_voicemail_audio(
|
||||
self, wav_data: bytes, confidence: float, is_voicemail: bool
|
||||
) -> Optional[str]:
|
||||
"""Save voicemail audio to temp file and enqueue task to upload to S3.
|
||||
|
||||
Args:
|
||||
wav_data: WAV formatted audio data
|
||||
confidence: Detection confidence score
|
||||
is_voicemail: Whether it was detected as voicemail
|
||||
|
||||
Returns:
|
||||
The expected S3 object key (bucket path). The actual upload happens asynchronously.
|
||||
"""
|
||||
try:
|
||||
# Create filename with prediction, confidence and duration
|
||||
duration_seconds = self._current_duration_seconds()
|
||||
prediction = "voicemail" if is_voicemail else "not_voicemail"
|
||||
confidence_int = int(confidence * 100)
|
||||
duration_int = int(duration_seconds)
|
||||
s3_key = f"voicemail_detections/{self.workflow_run_id}_{prediction}_{confidence_int}_{duration_int}.wav"
|
||||
|
||||
# Write WAV data to temp file - DO NOT delete it here, the async task will handle cleanup
|
||||
with tempfile.NamedTemporaryFile(
|
||||
suffix=".wav",
|
||||
delete=False, # Important: don't delete immediately
|
||||
prefix=f"voicemail_{self.workflow_run_id}_",
|
||||
) as tmp_file:
|
||||
tmp_file.write(wav_data)
|
||||
tmp_file.flush()
|
||||
temp_file_path = tmp_file.name
|
||||
|
||||
logger.info(f"Saved voicemail audio to temp file: {temp_file_path}")
|
||||
|
||||
# Enqueue async task to upload to S3
|
||||
await enqueue_job(
|
||||
FunctionNames.UPLOAD_VOICEMAIL_AUDIO_TO_S3,
|
||||
self.workflow_run_id,
|
||||
temp_file_path,
|
||||
s3_key,
|
||||
)
|
||||
|
||||
logger.info(f"Enqueued voicemail audio upload task for: {s3_key}")
|
||||
return s3_key
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save voicemail audio: {e}")
|
||||
# Clean up temp file if task enqueue failed
|
||||
if "temp_file_path" in locals() and os.path.exists(temp_file_path):
|
||||
try:
|
||||
os.remove(temp_file_path)
|
||||
except Exception as cleanup_error:
|
||||
logger.warning(
|
||||
f"Failed to cleanup temp file after error: {cleanup_error}"
|
||||
)
|
||||
return None
|
||||
0
api/services/workflow/test/__init__.py
Normal file
0
api/services/workflow/test/__init__.py
Normal file
164
api/services/workflow/test/definitions/rf-1.json
Normal file
164
api/services/workflow/test/definitions/rf-1.json
Normal file
|
|
@ -0,0 +1,164 @@
|
|||
{
|
||||
"nodes": [
|
||||
{
|
||||
"id": "915",
|
||||
"type": "agentNode",
|
||||
"position": {
|
||||
"x": 633,
|
||||
"y": 324
|
||||
},
|
||||
"data": {
|
||||
"prompt": "You are a voice agent whose mode of speaking is voice. Ask the user whether they want to talk to a sales guy or a customer service agent",
|
||||
"name": "Agent"
|
||||
},
|
||||
"measured": {
|
||||
"width": 300,
|
||||
"height": 100
|
||||
},
|
||||
"selected": false,
|
||||
"dragging": false
|
||||
},
|
||||
{
|
||||
"id": "7598",
|
||||
"type": "agentNode",
|
||||
"position": {
|
||||
"x": 460.1247806640531,
|
||||
"y": 610.3714977079578
|
||||
},
|
||||
"data": {
|
||||
"prompt": "You are a customer service agent whose mode of communication with the user is voice. Tell them that someone from our team will reach out to them soon",
|
||||
"name": "Agent"
|
||||
},
|
||||
"measured": {
|
||||
"width": 300,
|
||||
"height": 100
|
||||
},
|
||||
"selected": false,
|
||||
"dragging": false
|
||||
},
|
||||
{
|
||||
"id": "6919",
|
||||
"type": "agentNode",
|
||||
"position": {
|
||||
"x": 914.666735413607,
|
||||
"y": 642.9800281289787
|
||||
},
|
||||
"data": {
|
||||
"prompt": "You are a sales representative whose mode of communication with the user is voice. Tell the user that someone from our team will reach out to you soon",
|
||||
"name": "Agent"
|
||||
},
|
||||
"measured": {
|
||||
"width": 300,
|
||||
"height": 100
|
||||
},
|
||||
"selected": false,
|
||||
"dragging": false
|
||||
},
|
||||
{
|
||||
"id": "6581",
|
||||
"type": "startCall",
|
||||
"position": {
|
||||
"x": 648,
|
||||
"y": 35
|
||||
},
|
||||
"data": {
|
||||
"prompt": "Hello, I am Abhishek from Dograh. ",
|
||||
"is_static": true,
|
||||
"name": "Start Call",
|
||||
"is_start": true
|
||||
},
|
||||
"measured": {
|
||||
"width": 300,
|
||||
"height": 100
|
||||
},
|
||||
"selected": false,
|
||||
"dragging": false
|
||||
},
|
||||
{
|
||||
"id": "1802",
|
||||
"type": "endCall",
|
||||
"position": {
|
||||
"x": 666.7733431033548,
|
||||
"y": 987.4345801025363
|
||||
},
|
||||
"data": {
|
||||
"prompt": "Thank you for calling Dograh. Have a great day!",
|
||||
"is_static": true,
|
||||
"name": "End Call"
|
||||
},
|
||||
"measured": {
|
||||
"width": 300,
|
||||
"height": 100
|
||||
},
|
||||
"selected": false,
|
||||
"dragging": false
|
||||
}
|
||||
],
|
||||
"edges": [
|
||||
{
|
||||
"animated": true,
|
||||
"type": "custom",
|
||||
"source": "915",
|
||||
"target": "7598",
|
||||
"id": "xy-edge__915-7598",
|
||||
"selected": false,
|
||||
"data": {
|
||||
"condition": "The customer wants to talk to a customer service agent",
|
||||
"label": "customer service agent"
|
||||
}
|
||||
},
|
||||
{
|
||||
"animated": true,
|
||||
"type": "custom",
|
||||
"source": "915",
|
||||
"target": "6919",
|
||||
"id": "xy-edge__915-6919",
|
||||
"selected": false,
|
||||
"data": {
|
||||
"condition": "customer wants to talk to a sales representative",
|
||||
"label": "sales representative"
|
||||
}
|
||||
},
|
||||
{
|
||||
"animated": true,
|
||||
"type": "custom",
|
||||
"source": "6581",
|
||||
"target": "915",
|
||||
"id": "xy-edge__6581-915",
|
||||
"selected": false,
|
||||
"data": {
|
||||
"condition": "Always take this route",
|
||||
"label": "Always take this route"
|
||||
}
|
||||
},
|
||||
{
|
||||
"animated": true,
|
||||
"type": "custom",
|
||||
"source": "7598",
|
||||
"target": "1802",
|
||||
"id": "xy-edge__7598-1802",
|
||||
"selected": false,
|
||||
"data": {
|
||||
"condition": "end call",
|
||||
"label": "end call"
|
||||
}
|
||||
},
|
||||
{
|
||||
"animated": true,
|
||||
"type": "custom",
|
||||
"source": "6919",
|
||||
"target": "1802",
|
||||
"id": "xy-edge__6919-1802",
|
||||
"selected": false,
|
||||
"data": {
|
||||
"condition": "end call",
|
||||
"label": "end call"
|
||||
}
|
||||
}
|
||||
],
|
||||
"viewport": {
|
||||
"x": 0,
|
||||
"y": 0,
|
||||
"zoom": 1
|
||||
}
|
||||
}
|
||||
192
api/services/workflow/test/test_aggregation_fix.py
Normal file
192
api/services/workflow/test/test_aggregation_fix.py
Normal file
|
|
@ -0,0 +1,192 @@
|
|||
from unittest.mock import Mock
|
||||
|
||||
from api.services.workflow.pipecat_engine_callbacks import (
|
||||
create_aggregation_correction_callback,
|
||||
)
|
||||
|
||||
|
||||
def test_aggregation_fixer():
|
||||
"""Validate the aggregation correction algorithm using a helper that
|
||||
creates a fresh callback for every (reference, corrupted) pair.
|
||||
|
||||
The production callback now needs a PipecatEngine instance with the
|
||||
`_current_llm_reference_text` set. For test-friendliness we mock a bare
|
||||
object providing just that attribute for each assertion so the original
|
||||
two-argument test cases remain unchanged.
|
||||
"""
|
||||
|
||||
def fixer(reference: str, corrupted: str) -> str: # noqa: D401
|
||||
mock_engine = Mock()
|
||||
mock_engine._current_llm_reference_text = reference
|
||||
return create_aggregation_correction_callback(mock_engine)(corrupted)
|
||||
|
||||
##### Trailing extra Chars #####
|
||||
|
||||
assert (
|
||||
fixer(
|
||||
"Good Morning Mr NARGES , My name is Alex and I am calling you from Consumer Services. The reason of my call today, as I can see in our records that you are making your monthly credit card payments on time, but you STILL carry a balance of over 7 thousand dollars, right?",
|
||||
"My name is Alex and I am calling you from Cons umer Services. The reason of my call today, as I can see in our records that you are making your monthly credit card payments on time, but you STILL carry a balance of over 7 thousand dollars, right?",
|
||||
)
|
||||
== "My name is Alex and I am calling you from Consumer Services. The reason of my call today, as I can see in our records that you are making your monthly credit card payments on time, but you STILL carry a balance of over 7 thousand dollars, right?"
|
||||
), "leading_whole_sentence"
|
||||
|
||||
# Whole sentences
|
||||
assert (
|
||||
fixer(
|
||||
"Good Morning Mr NARGES , My name is Alex and I am calling you from Consumer Services. The reason of my call today, as I can see in our records that you are making your monthly credit card payments on time, but you STILL carry a balance of over 7 thousand dollars, right?",
|
||||
"Good Morning Mr NAR GES , My name is Alex and I am calling you from Cons umer Services. The reason of my call today, as I can see in our records that you are making your monthly credit card payments on time, but you STILL carry a balance of over 7 thousand dollars, right?",
|
||||
)
|
||||
== "Good Morning Mr NARGES , My name is Alex and I am calling you from Consumer Services. The reason of my call today, as I can see in our records that you are making your monthly credit card payments on time, but you STILL carry a balance of over 7 thousand dollars, right?"
|
||||
), "whole_sentences"
|
||||
|
||||
# With a period in the end
|
||||
assert (
|
||||
fixer(
|
||||
"Good Morning Mr NARGES , My name is Alex and I am calling you from Consumer Services. The reason of my call today, as I can see in our records that you are making your monthly credit card payments on time, but you STILL carry a balance of over 7 thousand dollars, right?",
|
||||
"Good Morning Mr NAR GES , My name is Alex and I am calling you from Cons umer Services.",
|
||||
)
|
||||
== "Good Morning Mr NARGES , My name is Alex and I am calling you from Consumer Services."
|
||||
), "period_end"
|
||||
|
||||
# without a period in the end
|
||||
assert (
|
||||
fixer(
|
||||
"Good Morning Mr NARGES , My name is Alex and I am calling you from Consumer Services. The reason of my call today, as I can see in our records that you are making your monthly credit card payments on time, but you STILL carry a balance of over 7 thousand dollars, right?",
|
||||
"Good Morning Mr NAR GES , My name is Alex and I am calling you from Cons umer Services",
|
||||
)
|
||||
== "Good Morning Mr NARGES , My name is Alex and I am calling you from Consumer Services"
|
||||
), "without_period_end"
|
||||
|
||||
# Extra space in the end
|
||||
assert (
|
||||
fixer(
|
||||
"Good Morning Mr NARGES , My name is Alex and I am calling you from Consumer Services. The reason of my call today, as I can see in our records that you are making your monthly credit card payments on time, but you STILL carry a balance of over 7 thousand dollars, right?",
|
||||
"Good Morning Mr NAR GES , My name is Alex and I am calling you from Cons umer Services ",
|
||||
)
|
||||
== "Good Morning Mr NARGES , My name is Alex and I am calling you from Consumer Services"
|
||||
), "extra_space"
|
||||
|
||||
# Multiple spaces in corruption
|
||||
assert (
|
||||
fixer(
|
||||
"Good Morning Mr NARGES , My name is Alex and I am calling you from Consumer Services. The reason of my call today, as I can see in our records that you are making your monthly credit card payments on time, but you STILL carry a balance of over 7 thousand dollars, right?",
|
||||
"Good Morning Mr NAR GES , My name is Alex and I am calling you from Cons umer Servi ces ",
|
||||
)
|
||||
== "Good Morning Mr NARGES , My name is Alex and I am calling you from Consumer Services"
|
||||
), "multiple_space"
|
||||
|
||||
# Multiple spaces in corruption ending in a whitespace
|
||||
assert (
|
||||
fixer(
|
||||
"Good Morning Mr NARGES , My name is Alex and I am calling you from Consumer Services. The reason of my call today, as I can see in our records that you are making your monthly credit card payments on time, but you STILL carry a balance of over 7 thousand dollars, right?",
|
||||
"Good Morning Mr NAR GES , My name is Alex and I am calling you from Cons umer Servi ces. ",
|
||||
)
|
||||
== "Good Morning Mr NARGES , My name is Alex and I am calling you from Consumer Services. "
|
||||
), "multiple_space_end_ws"
|
||||
|
||||
##### Leading extra Chars #####
|
||||
|
||||
# Whole sentences
|
||||
assert (
|
||||
fixer(
|
||||
"Good Morning Mr NARGES , My name is Alex and I am calling you from Consumer Services. The reason of my call today, as I can see in our records that you are making your monthly credit card payments on time, but you STILL carry a balance of over 7 thousand dollars, right?",
|
||||
"My name is Alex and I am calling you from Cons umer Services. The reason of my call today, as I can see in our records that you are making your monthly credit card payments on time, but you STILL carry a balance of over 7 thousand dollars, right?",
|
||||
)
|
||||
== "My name is Alex and I am calling you from Consumer Services. The reason of my call today, as I can see in our records that you are making your monthly credit card payments on time, but you STILL carry a balance of over 7 thousand dollars, right?"
|
||||
), "leading_whole_sentence"
|
||||
|
||||
# With a period in the end
|
||||
assert (
|
||||
fixer(
|
||||
"Good Morning Mr NARGES , My name is Alex and I am calling you from Consumer Services. The reason of my call today, as I can see in our records that you are making your monthly credit card payments on time, but you STILL carry a balance of over 7 thousand dollars, right?",
|
||||
"My name is Alex and I am calling you from Cons umer Services.",
|
||||
)
|
||||
== "My name is Alex and I am calling you from Consumer Services."
|
||||
), "leading_period_end"
|
||||
|
||||
# without a period in the end
|
||||
assert (
|
||||
fixer(
|
||||
"Good Morning Mr NARGES , My name is Alex and I am calling you from Consumer Services. The reason of my call today, as I can see in our records that you are making your monthly credit card payments on time, but you STILL carry a balance of over 7 thousand dollars, right?",
|
||||
"My name is Alex and I am calling you from Cons umer Services",
|
||||
)
|
||||
== "My name is Alex and I am calling you from Consumer Services"
|
||||
), "leading_without_period_end"
|
||||
|
||||
# Extra space in the end
|
||||
assert (
|
||||
fixer(
|
||||
"Good Morning Mr NARGES , My name is Alex and I am calling you from Consumer Services. The reason of my call today, as I can see in our records that you are making your monthly credit card payments on time, but you STILL carry a balance of over 7 thousand dollars, right?",
|
||||
"My name is Alex and I am calling you from Cons umer Services ",
|
||||
)
|
||||
== "My name is Alex and I am calling you from Consumer Services"
|
||||
), "leading_extra_space"
|
||||
|
||||
# Multiple spaces in corruption
|
||||
assert (
|
||||
fixer(
|
||||
"Good Morning Mr NARGES , My name is Alex and I am calling you from Consumer Services. The reason of my call today, as I can see in our records that you are making your monthly credit card payments on time, but you STILL carry a balance of over 7 thousand dollars, right?",
|
||||
"My name is Alex and I am calling you from Cons umer Servi ces ",
|
||||
)
|
||||
== "My name is Alex and I am calling you from Consumer Services"
|
||||
), "leading_multiple_space"
|
||||
|
||||
# Multiple spaces in corruption ending in a whitespace
|
||||
assert (
|
||||
fixer(
|
||||
"Good Morning Mr NARGES , My name is Alex and I am calling you from Consumer Services. The reason of my call today, as I can see in our records that you are making your monthly credit card payments on time, but you STILL carry a balance of over 7 thousand dollars, right?",
|
||||
"My name is Alex and I am calling you from Cons umer Servi ces. ",
|
||||
)
|
||||
== "My name is Alex and I am calling you from Consumer Services. "
|
||||
), "leading_multiple_space_end_ws"
|
||||
|
||||
# Whitespace
|
||||
assert fixer("", "") == ""
|
||||
|
||||
# Missing reference
|
||||
assert (
|
||||
fixer("", "My name is Alex and I am calling you from Cons umer Servi ces.")
|
||||
== "My name is Alex and I am calling you from Cons umer Servi ces."
|
||||
), "missing_reference"
|
||||
|
||||
# Smaller reference
|
||||
assert (
|
||||
fixer(
|
||||
"My name is Alex",
|
||||
"My name is Alex and I am calling you from Cons umer Servi ces.",
|
||||
)
|
||||
== "My name is Alex and I am calling you from Cons umer Servi ces."
|
||||
), "smaller_reference"
|
||||
|
||||
# Unrelated reference
|
||||
assert (
|
||||
fixer(
|
||||
"Hello Hello",
|
||||
"My name is Alex and I am calling you from Cons umer Servi ces.",
|
||||
)
|
||||
== "My name is Alex and I am calling you from Cons umer Servi ces."
|
||||
), "unrelated_reference"
|
||||
|
||||
|
||||
def test_create_aggregation_correction_callback():
|
||||
"""Test the new aggregation correction callback creator."""
|
||||
# Mock engine with reference text
|
||||
mock_engine = Mock()
|
||||
mock_engine._current_llm_reference_text = "Good Morning Mr NARGES, My name is Alex and I am calling you from Consumer Services."
|
||||
|
||||
# Create callback
|
||||
callback = create_aggregation_correction_callback(mock_engine)
|
||||
|
||||
# Test correction
|
||||
corrected = callback(
|
||||
"Good Morning Mr NAR GES, My name is Alex and I am calling you from Cons umer Services."
|
||||
)
|
||||
assert (
|
||||
corrected
|
||||
== "Good Morning Mr NARGES, My name is Alex and I am calling you from Consumer Services."
|
||||
)
|
||||
|
||||
# Test with no reference text
|
||||
mock_engine._current_llm_reference_text = ""
|
||||
corrected = callback("Some corrupted text")
|
||||
assert corrected == "Some corrupted text" # Should return as-is when no reference
|
||||
128
api/services/workflow/test/test_aggregation_integration.py
Normal file
128
api/services/workflow/test/test_aggregation_integration.py
Normal file
|
|
@ -0,0 +1,128 @@
|
|||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
from pipecat.processors.aggregators.llm_response import LLMAssistantAggregatorParams
|
||||
from pipecat.services.openai.llm import OpenAILLMContext
|
||||
|
||||
from api.services.workflow.pipecat_engine import PipecatEngine
|
||||
from api.services.workflow.pipecat_engine_callbacks import (
|
||||
create_generation_started_callback,
|
||||
)
|
||||
|
||||
|
||||
class TestAggregationIntegration:
|
||||
"""Integration tests for the TTS aggregation correction flow."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_engine_reference_text_tracking(self):
|
||||
"""Test that the engine properly tracks LLM reference text."""
|
||||
# Create mock dependencies
|
||||
mock_task = Mock()
|
||||
mock_llm = Mock()
|
||||
mock_context = Mock(spec=OpenAILLMContext)
|
||||
mock_tts = Mock()
|
||||
mock_workflow = Mock()
|
||||
mock_workflow.start_node_id = "start"
|
||||
mock_workflow.nodes = {
|
||||
"start": Mock(is_start=True, is_static=True, is_end=False, out_edges=[])
|
||||
}
|
||||
|
||||
# Create engine
|
||||
engine = PipecatEngine(
|
||||
task=mock_task,
|
||||
llm=mock_llm,
|
||||
context=mock_context,
|
||||
tts=mock_tts,
|
||||
workflow=mock_workflow,
|
||||
call_context_vars={},
|
||||
workflow_run_id=1,
|
||||
)
|
||||
|
||||
# Test initial state
|
||||
assert engine._current_llm_reference_text == ""
|
||||
|
||||
# Test accumulating LLM text
|
||||
await engine.handle_llm_text_frame("Hello ")
|
||||
assert engine._current_llm_reference_text == "Hello "
|
||||
|
||||
await engine.handle_llm_text_frame("world!")
|
||||
assert engine._current_llm_reference_text == "Hello world!"
|
||||
|
||||
# Test generation started callback clears reference text
|
||||
callback = create_generation_started_callback(engine)
|
||||
await callback()
|
||||
assert engine._current_llm_reference_text == ""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aggregation_correction_callback_creation(self):
|
||||
"""Test creating the aggregation correction callback."""
|
||||
# Create mock engine
|
||||
mock_task = Mock()
|
||||
mock_llm = Mock()
|
||||
mock_context = Mock(spec=OpenAILLMContext)
|
||||
mock_workflow = Mock()
|
||||
|
||||
engine = PipecatEngine(
|
||||
task=mock_task,
|
||||
llm=mock_llm,
|
||||
context=mock_context,
|
||||
workflow=mock_workflow,
|
||||
call_context_vars={},
|
||||
workflow_run_id=1,
|
||||
)
|
||||
|
||||
# Set reference text
|
||||
engine._current_llm_reference_text = "Hello, world! How are you?"
|
||||
|
||||
# Create correction callback
|
||||
callback = engine.create_aggregation_correction_callback()
|
||||
|
||||
# Test correction - note that trailing punctuation might be stripped if not in corrupted text
|
||||
corrected = callback("Hello world How are you")
|
||||
assert corrected == "Hello, world! How are you"
|
||||
|
||||
def test_llm_assistant_aggregator_params_with_callback(self):
|
||||
"""Test that LLMAssistantAggregatorParams accepts correction callback."""
|
||||
|
||||
def mock_callback(text: str) -> str:
|
||||
return text.upper()
|
||||
|
||||
params = LLMAssistantAggregatorParams(
|
||||
expect_stripped_words=True, correct_aggregation_callback=mock_callback
|
||||
)
|
||||
|
||||
assert params.expect_stripped_words is True
|
||||
assert params.correct_aggregation_callback is not None
|
||||
assert params.correct_aggregation_callback("hello") == "HELLO"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pipeline_callbacks_processor_llm_text_frame(self):
|
||||
"""Test that PipelineEngineCallbacksProcessor handles LLMTextFrame."""
|
||||
from pipecat.frames.frames import LLMTextFrame
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
|
||||
from api.services.pipecat.pipeline_engine_callbacks_processor import (
|
||||
PipelineEngineCallbacksProcessor,
|
||||
)
|
||||
|
||||
# Track callback invocations
|
||||
callback_invoked = False
|
||||
callback_text = None
|
||||
|
||||
async def mock_llm_text_callback(text: str):
|
||||
nonlocal callback_invoked, callback_text
|
||||
callback_invoked = True
|
||||
callback_text = text
|
||||
|
||||
# Create processor with callback
|
||||
processor = PipelineEngineCallbacksProcessor(
|
||||
llm_text_frame_callback=mock_llm_text_callback
|
||||
)
|
||||
|
||||
# Process LLMTextFrame
|
||||
frame = LLMTextFrame(text="Hello world")
|
||||
await processor.process_frame(frame, FrameDirection.DOWNSTREAM)
|
||||
|
||||
# Verify callback was invoked
|
||||
assert callback_invoked is True
|
||||
assert callback_text == "Hello world"
|
||||
31
api/services/workflow/test/test_cost_calculator.py
Normal file
31
api/services/workflow/test/test_cost_calculator.py
Normal file
|
|
@ -0,0 +1,31 @@
|
|||
from api.services.pricing.cost_calculator import cost_calculator
|
||||
|
||||
|
||||
def test_cost_calculator():
|
||||
"""Test function to verify cost calculation works"""
|
||||
sample_usage = {
|
||||
"llm": {
|
||||
"OpenAILLMService#0|||gpt-4.1-mini": {
|
||||
"prompt_tokens": 45380,
|
||||
"completion_tokens": 496,
|
||||
"total_tokens": 45876,
|
||||
"cache_read_input_tokens": 0,
|
||||
"cache_creation_input_tokens": 0,
|
||||
}
|
||||
},
|
||||
"tts": {"ElevenLabsTTSService#0|||eleven_flash_v2_5": 2399},
|
||||
"stt": {"DeepgramSTTService#0|||nova-3-general": 177.21536946296692},
|
||||
"call_duration_seconds": 179,
|
||||
}
|
||||
|
||||
result = cost_calculator.calculate_total_cost(sample_usage)
|
||||
assert result["llm_cost"] == 45380 * 0.40 / 1_000_000 + 496 * 1.60 / 1_000_000
|
||||
assert result["tts_cost"] == 2399 * 0.0256 / 1_000
|
||||
assert result["stt_cost"] == 177.21536946296692 / 60 * 0.0077
|
||||
assert (
|
||||
abs(
|
||||
result["total"]
|
||||
- (result["llm_cost"] + result["tts_cost"] + result["stt_cost"])
|
||||
)
|
||||
< 1e-10
|
||||
)
|
||||
11
api/services/workflow/test/test_dto.py
Normal file
11
api/services/workflow/test/test_dto.py
Normal file
|
|
@ -0,0 +1,11 @@
|
|||
import pytest
|
||||
|
||||
from api.services.workflow.dto import ReactFlowDTO
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dto():
|
||||
# assert no exceptions are raised
|
||||
with open("services/workflow/test/definitions/rf-1.json", "r") as f:
|
||||
dto = ReactFlowDTO.model_validate_json(f.read())
|
||||
assert dto is not None
|
||||
159
api/services/workflow/test/test_interruption_correction.py
Normal file
159
api/services/workflow/test/test_interruption_correction.py
Normal file
|
|
@ -0,0 +1,159 @@
|
|||
from unittest.mock import AsyncMock, Mock
|
||||
|
||||
import pytest
|
||||
from pipecat.frames.frames import StartInterruptionFrame
|
||||
from pipecat.processors.aggregators.llm_response import LLMAssistantAggregatorParams
|
||||
from pipecat.services.openai.llm import (
|
||||
OpenAIAssistantContextAggregator,
|
||||
OpenAILLMContext,
|
||||
)
|
||||
|
||||
|
||||
class TestInterruptionCorrection:
|
||||
"""Test that TTS aggregation correction works during interruptions."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openai_interruption_with_correction(self):
|
||||
"""Test OpenAI assistant context aggregator applies correction during interruption."""
|
||||
# Create mock context
|
||||
mock_context = Mock(spec=OpenAILLMContext)
|
||||
mock_context.get_messages.return_value = []
|
||||
mock_context.add_message = Mock()
|
||||
|
||||
# Create correction callback
|
||||
def correction_callback(text: str) -> str:
|
||||
# Simulate fixing corrupted text
|
||||
if text == "Hello world how are you":
|
||||
return "Hello world, how are you"
|
||||
return text
|
||||
|
||||
# Create aggregator with correction callback
|
||||
params = LLMAssistantAggregatorParams(
|
||||
expect_stripped_words=True, correct_aggregation_callback=correction_callback
|
||||
)
|
||||
|
||||
aggregator = OpenAIAssistantContextAggregator(
|
||||
context=mock_context, params=params
|
||||
)
|
||||
|
||||
# Set up aggregation state
|
||||
aggregator._aggregation = "Hello world how are you"
|
||||
aggregator._current_llm_response_id = "test-id"
|
||||
aggregator._response_function_messages = {}
|
||||
aggregator._function_calls_in_progress = {}
|
||||
aggregator._started = 1
|
||||
|
||||
# Mock push_context_frame and reset methods
|
||||
aggregator.push_context_frame = AsyncMock()
|
||||
aggregator.reset = AsyncMock()
|
||||
|
||||
# Process interruption
|
||||
interruption_frame = StartInterruptionFrame()
|
||||
await aggregator._handle_interruptions(interruption_frame)
|
||||
|
||||
# Verify the corrected text was added to context
|
||||
mock_context.add_message.assert_called_once()
|
||||
added_message = mock_context.add_message.call_args[0][0]
|
||||
assert added_message["role"] == "assistant"
|
||||
assert (
|
||||
added_message["content"]
|
||||
== "Hello world, how are you <<interrupted_by_user>>"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_google_interruption_with_correction(self):
|
||||
"""Test Google assistant context aggregator applies correction during interruption."""
|
||||
from pipecat.services.google.llm import (
|
||||
Content,
|
||||
GoogleAssistantContextAggregator,
|
||||
)
|
||||
|
||||
# Create mock context
|
||||
mock_context = Mock(spec=OpenAILLMContext)
|
||||
mock_context.get_messages.return_value = []
|
||||
mock_context.add_message = Mock()
|
||||
|
||||
# Create correction callback
|
||||
def correction_callback(text: str) -> str:
|
||||
# Simulate fixing corrupted text
|
||||
if text == "I am here to help":
|
||||
return "I am here to help"
|
||||
return text
|
||||
|
||||
# Create aggregator with correction callback
|
||||
params = LLMAssistantAggregatorParams(
|
||||
expect_stripped_words=True, correct_aggregation_callback=correction_callback
|
||||
)
|
||||
|
||||
aggregator = GoogleAssistantContextAggregator(
|
||||
context=mock_context, params=params
|
||||
)
|
||||
|
||||
# Set up aggregation state
|
||||
aggregator._aggregation = "I am here to help"
|
||||
aggregator._current_llm_response_id = "test-id"
|
||||
aggregator._response_function_messages = {}
|
||||
aggregator._function_calls_in_progress = {}
|
||||
aggregator._started = 1
|
||||
|
||||
# Mock push_context_frame and reset methods
|
||||
aggregator.push_context_frame = AsyncMock()
|
||||
aggregator.reset = AsyncMock()
|
||||
|
||||
# Process interruption
|
||||
interruption_frame = StartInterruptionFrame()
|
||||
await aggregator._handle_interruptions(interruption_frame)
|
||||
|
||||
# Verify the corrected text was added to context
|
||||
mock_context.add_message.assert_called_once()
|
||||
added_content = mock_context.add_message.call_args[0][0]
|
||||
|
||||
# Google uses Content objects
|
||||
assert isinstance(added_content, Content)
|
||||
assert added_content.role == "model"
|
||||
assert len(added_content.parts) == 1
|
||||
assert (
|
||||
added_content.parts[0].text == "I am here to help <<interrupted_by_user>>"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_interruption_correction_error_handling(self):
|
||||
"""Test that interruption handling continues even if correction callback fails."""
|
||||
# Create mock context
|
||||
mock_context = Mock(spec=OpenAILLMContext)
|
||||
mock_context.get_messages.return_value = []
|
||||
mock_context.add_message = Mock()
|
||||
|
||||
# Create correction callback that raises error
|
||||
def failing_callback(text: str) -> str:
|
||||
raise ValueError("Correction failed")
|
||||
|
||||
# Create aggregator with failing callback
|
||||
params = LLMAssistantAggregatorParams(
|
||||
expect_stripped_words=True, correct_aggregation_callback=failing_callback
|
||||
)
|
||||
|
||||
aggregator = OpenAIAssistantContextAggregator(
|
||||
context=mock_context, params=params
|
||||
)
|
||||
|
||||
# Set up aggregation state
|
||||
aggregator._aggregation = "Some text"
|
||||
aggregator._current_llm_response_id = "test-id"
|
||||
aggregator._response_function_messages = {}
|
||||
aggregator._function_calls_in_progress = {}
|
||||
aggregator._started = 1
|
||||
|
||||
# Mock push_context_frame and reset methods
|
||||
aggregator.push_context_frame = AsyncMock()
|
||||
aggregator.reset = AsyncMock()
|
||||
|
||||
# Process interruption - should not raise
|
||||
interruption_frame = StartInterruptionFrame()
|
||||
await aggregator._handle_interruptions(interruption_frame)
|
||||
|
||||
# Verify the original text was still added (fallback behavior)
|
||||
mock_context.add_message.assert_called_once()
|
||||
added_message = mock_context.add_message.call_args[0][0]
|
||||
assert added_message["role"] == "assistant"
|
||||
assert added_message["content"] == "Some text <<interrupted_by_user>>"
|
||||
0
api/services/workflow/tools/__init__.py
Normal file
0
api/services/workflow/tools/__init__.py
Normal file
51
api/services/workflow/tools/calculator.py
Normal file
51
api/services/workflow/tools/calculator.py
Normal file
|
|
@ -0,0 +1,51 @@
|
|||
import ast
|
||||
from typing import Any, Dict
|
||||
|
||||
|
||||
def safe_calculator(expr: str) -> float:
|
||||
"""
|
||||
Parse arithmetic expressions using ast and support + - * / ** and parentheses.
|
||||
"""
|
||||
allowed_nodes = {
|
||||
ast.Expression,
|
||||
ast.BinOp,
|
||||
ast.UnaryOp,
|
||||
ast.Add,
|
||||
ast.Sub,
|
||||
ast.Mult,
|
||||
ast.Div,
|
||||
ast.Pow,
|
||||
ast.USub,
|
||||
ast.UAdd,
|
||||
ast.Constant,
|
||||
ast.Load,
|
||||
ast.Mod,
|
||||
}
|
||||
|
||||
node = ast.parse(expr, mode="eval")
|
||||
if not all(isinstance(n, tuple(allowed_nodes)) for n in ast.walk(node)):
|
||||
raise ValueError("Unsupported expression")
|
||||
return eval(compile(node, "<safe_calculator>", mode="eval"))
|
||||
|
||||
|
||||
def get_calculator_tools() -> list[Dict[str, Any]]:
|
||||
"""Get calculator tool definitions for LLM function calling."""
|
||||
return [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "safe_calculator",
|
||||
"description": "Perform simple arithmetic calculations",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"expression": {
|
||||
"type": "string",
|
||||
"description": "Arithmetic expression to evaluate (supports +, -, *, /, **, %, and parentheses). Example: 2000 + 5000",
|
||||
}
|
||||
},
|
||||
"required": ["expression"],
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
196
api/services/workflow/tools/timezone.py
Normal file
196
api/services/workflow/tools/timezone.py
Normal file
|
|
@ -0,0 +1,196 @@
|
|||
"""Time tools for LLM function calling - timezone and time conversion utilities."""
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, Dict, Optional
|
||||
from zoneinfo import ZoneInfo
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class TimeResult(BaseModel):
|
||||
"""Result model for time queries."""
|
||||
|
||||
timezone: str
|
||||
datetime: str
|
||||
is_dst: bool
|
||||
|
||||
|
||||
class TimeConversionResult(BaseModel):
|
||||
"""Result model for time conversions."""
|
||||
|
||||
source: TimeResult
|
||||
target: TimeResult
|
||||
time_difference: str
|
||||
|
||||
|
||||
def get_local_timezone(local_tz_override: Optional[str] = None) -> str:
|
||||
"""
|
||||
Get the local timezone name using system timezone.
|
||||
Falls back to UTC if cannot determine.
|
||||
"""
|
||||
if local_tz_override:
|
||||
return local_tz_override
|
||||
|
||||
try:
|
||||
# Try to get timezone from datetime
|
||||
local_tz = datetime.now().astimezone().tzinfo
|
||||
if hasattr(local_tz, "key"):
|
||||
return local_tz.key
|
||||
|
||||
# Try to parse from string representation
|
||||
tz_str = str(local_tz)
|
||||
if tz_str and not tz_str.startswith("UTC"):
|
||||
return tz_str
|
||||
|
||||
# Default to UTC
|
||||
return "UTC"
|
||||
except:
|
||||
return "UTC"
|
||||
|
||||
|
||||
def get_current_time(timezone: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Get current time in specified timezone.
|
||||
|
||||
Args:
|
||||
timezone: IANA timezone name (e.g., 'America/New_York', 'Europe/London')
|
||||
|
||||
Returns:
|
||||
Dict containing timezone, datetime, and DST status
|
||||
"""
|
||||
try:
|
||||
tz = ZoneInfo(timezone)
|
||||
current_time = datetime.now(tz)
|
||||
|
||||
result = TimeResult(
|
||||
timezone=timezone,
|
||||
datetime=current_time.isoformat(timespec="seconds"),
|
||||
is_dst=bool(current_time.dst()),
|
||||
)
|
||||
return result.model_dump()
|
||||
except Exception as e:
|
||||
raise ValueError(f"Invalid timezone '{timezone}': {str(e)}")
|
||||
|
||||
|
||||
def convert_time(
|
||||
source_timezone: str, time: str, target_timezone: str
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Convert time between timezones.
|
||||
|
||||
Args:
|
||||
source_timezone: Source IANA timezone name
|
||||
time: Time to convert in 24-hour format (HH:MM)
|
||||
target_timezone: Target IANA timezone name
|
||||
|
||||
Returns:
|
||||
Dict containing source time, target time, and time difference
|
||||
"""
|
||||
try:
|
||||
source_tz = ZoneInfo(source_timezone)
|
||||
target_tz = ZoneInfo(target_timezone)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Invalid timezone: {str(e)}")
|
||||
|
||||
# Parse time
|
||||
try:
|
||||
parsed_time = datetime.strptime(time, "%H:%M").time()
|
||||
except ValueError:
|
||||
raise ValueError("Invalid time format. Expected HH:MM in 24-hour format")
|
||||
|
||||
# Create datetime objects
|
||||
now = datetime.now(source_tz)
|
||||
source_time = datetime(
|
||||
now.year,
|
||||
now.month,
|
||||
now.day,
|
||||
parsed_time.hour,
|
||||
parsed_time.minute,
|
||||
tzinfo=source_tz,
|
||||
)
|
||||
|
||||
# Convert to target timezone
|
||||
target_time = source_time.astimezone(target_tz)
|
||||
|
||||
# Calculate time difference
|
||||
source_offset = source_time.utcoffset() or timedelta()
|
||||
target_offset = target_time.utcoffset() or timedelta()
|
||||
hours_difference = (target_offset - source_offset).total_seconds() / 3600
|
||||
|
||||
# Format time difference
|
||||
if hours_difference.is_integer():
|
||||
time_diff_str = f"{int(hours_difference):+d}h"
|
||||
else:
|
||||
# For fractional hours like Nepal's UTC+5:45
|
||||
hours = int(hours_difference)
|
||||
minutes = int(abs(hours_difference - hours) * 60)
|
||||
if hours_difference >= 0:
|
||||
time_diff_str = f"+{hours}h{minutes:02d}m"
|
||||
else:
|
||||
time_diff_str = f"{hours}h{minutes:02d}m"
|
||||
|
||||
result = TimeConversionResult(
|
||||
source=TimeResult(
|
||||
timezone=source_timezone,
|
||||
datetime=source_time.isoformat(timespec="seconds"),
|
||||
is_dst=bool(source_time.dst()),
|
||||
),
|
||||
target=TimeResult(
|
||||
timezone=target_timezone,
|
||||
datetime=target_time.isoformat(timespec="seconds"),
|
||||
is_dst=bool(target_time.dst()),
|
||||
),
|
||||
time_difference=time_diff_str,
|
||||
)
|
||||
return result.model_dump()
|
||||
|
||||
|
||||
# Tool definitions for LLM function calling
|
||||
def get_time_tools(local_tz_override: Optional[str] = None) -> list[Dict[str, Any]]:
|
||||
"""Get tool definitions with dynamic local timezone."""
|
||||
local_tz = local_tz_override or get_local_timezone()
|
||||
|
||||
return [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_current_time",
|
||||
"description": "Get current time in a specific timezone",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"timezone": {
|
||||
"type": "string",
|
||||
"description": f"IANA timezone name (e.g., 'America/New_York', 'Europe/London'). Use '{local_tz}' as local timezone if no timezone provided by the user.",
|
||||
}
|
||||
},
|
||||
"required": ["timezone"],
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "convert_time",
|
||||
"description": "Convert time between timezones",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"source_timezone": {
|
||||
"type": "string",
|
||||
"description": f"Source IANA timezone name (e.g., 'America/New_York', 'Europe/London'). Use '{local_tz}' as local timezone if no source timezone provided by the user.",
|
||||
},
|
||||
"time": {
|
||||
"type": "string",
|
||||
"description": "Time to convert in 24-hour format (HH:MM)",
|
||||
},
|
||||
"target_timezone": {
|
||||
"type": "string",
|
||||
"description": f"Target IANA timezone name (e.g., 'Asia/Tokyo', 'America/San_Francisco'). Use '{local_tz}' as local timezone if no target timezone provided by the user.",
|
||||
},
|
||||
},
|
||||
"required": ["source_timezone", "time", "target_timezone"],
|
||||
},
|
||||
},
|
||||
},
|
||||
]
|
||||
261
api/services/workflow/workflow.py
Normal file
261
api/services/workflow/workflow.py
Normal file
|
|
@ -0,0 +1,261 @@
|
|||
import re
|
||||
from collections import Counter
|
||||
from typing import Dict, List
|
||||
|
||||
from api.services.workflow.dto import EdgeDataDTO, NodeDataDTO, NodeType, ReactFlowDTO
|
||||
from api.services.workflow.errors import ItemKind, WorkflowError
|
||||
|
||||
|
||||
class Edge:
|
||||
def __init__(self, source: str, target: str, data: EdgeDataDTO):
|
||||
self.source = source
|
||||
self.target = target
|
||||
|
||||
self.label = data.label
|
||||
self.condition = data.condition
|
||||
|
||||
self.data = data
|
||||
|
||||
def get_function_name(self):
|
||||
return re.sub(r"[^a-z0-9]", "_", self.label.lower())
|
||||
|
||||
def __eq__(self, other):
|
||||
if not isinstance(other, Edge):
|
||||
return False
|
||||
return self.source == other.source and self.target == other.target
|
||||
|
||||
def __hash__(self):
|
||||
return hash((self.source, self.target))
|
||||
|
||||
|
||||
class Node:
|
||||
def __init__(self, id: str, node_type: NodeType, data: NodeDataDTO):
|
||||
self.id, self.node_type, self.data = id, node_type, data
|
||||
self.out: Dict[str, "Node"] = {} # forward nodes
|
||||
self.out_edges: List[Edge] = [] # forward edges with properties
|
||||
|
||||
self.name = data.name
|
||||
self.prompt = data.prompt
|
||||
self.is_static = data.is_static
|
||||
self.is_start = data.is_start
|
||||
self.is_end = data.is_end
|
||||
self.allow_interrupt = data.allow_interrupt
|
||||
self.extraction_enabled = data.extraction_enabled
|
||||
self.extraction_prompt = data.extraction_prompt
|
||||
self.extraction_variables = data.extraction_variables
|
||||
self.add_global_prompt = data.add_global_prompt
|
||||
self.wait_for_user_response = data.wait_for_user_response
|
||||
self.wait_for_user_response_timeout = data.wait_for_user_response_timeout
|
||||
self.detect_voicemail = data.detect_voicemail
|
||||
self.delayed_start = data.delayed_start
|
||||
self.delayed_start_duration = data.delayed_start_duration
|
||||
|
||||
self.data = data
|
||||
|
||||
|
||||
class WorkflowGraph:
|
||||
"""
|
||||
*All* business invariants (acyclic, cardinality, etc.) are verified here.
|
||||
The constructor accepts a validated ReactFlowDTO.
|
||||
"""
|
||||
|
||||
def __init__(self, dto: ReactFlowDTO):
|
||||
# build adjacency list
|
||||
self.nodes: Dict[str, Node] = {
|
||||
n.id: Node(n.id, n.type, n.data) for n in dto.nodes
|
||||
}
|
||||
|
||||
# Store all edges
|
||||
self.edges: List[Edge] = []
|
||||
|
||||
for e in dto.edges:
|
||||
source_node = self.nodes[e.source]
|
||||
target_node = self.nodes[e.target]
|
||||
|
||||
# Create the edge with properties from dto
|
||||
edge = Edge(source=e.source, target=e.target, data=e.data)
|
||||
|
||||
# Add to the edge list
|
||||
self.edges.append(edge)
|
||||
|
||||
# Add to the source node's outgoing edges
|
||||
source_node.out_edges.append(edge)
|
||||
|
||||
# Set up the node references for backward compatibility
|
||||
source_node.out[target_node.id] = target_node
|
||||
|
||||
self._validate_graph()
|
||||
|
||||
# Get a reference to the start node
|
||||
self.start_node_id = [n.id for n in dto.nodes if n.data.is_start][0]
|
||||
|
||||
# Get a reference to the global node
|
||||
try:
|
||||
self.global_node_id = [
|
||||
n.id for n in dto.nodes if n.type == NodeType.globalNode
|
||||
][0]
|
||||
except IndexError:
|
||||
self.global_node_id = None
|
||||
|
||||
# -----------------------------------------------------------
|
||||
# validators
|
||||
# -----------------------------------------------------------
|
||||
def _validate_graph(self) -> None:
|
||||
errors: list[WorkflowError] = []
|
||||
|
||||
# TODO: Figure out what kind of cyclic contraints can be applied, since there can be a cycle in the graph
|
||||
# try:
|
||||
# self._assert_acyclic()
|
||||
# except ValueError as e:
|
||||
# errors.append(
|
||||
# WorkflowError(
|
||||
# kind=ItemKind.workflow, id=None, field=None, message=str(e)
|
||||
# )
|
||||
# )
|
||||
|
||||
errors.extend(self._assert_start_node())
|
||||
errors.extend(self._assert_end_node())
|
||||
errors.extend(self._assert_connection_counts())
|
||||
errors.extend(self._assert_global_node())
|
||||
errors.extend(self._assert_node_configs())
|
||||
if errors:
|
||||
raise ValueError(errors)
|
||||
|
||||
def _assert_acyclic(self):
|
||||
color: Dict[str, str] = {} # white / gray / black
|
||||
|
||||
def dfs(n: Node):
|
||||
if color.get(n.id) == "gray": # back-edge
|
||||
raise ValueError("workflow contains a cycle")
|
||||
if color.get(n.id) != "black":
|
||||
color[n.id] = "gray"
|
||||
for m in n.out.values():
|
||||
dfs(m)
|
||||
color[n.id] = "black"
|
||||
|
||||
for n in self.nodes.values():
|
||||
dfs(n)
|
||||
|
||||
def _assert_start_node(self):
|
||||
errors: list[WorkflowError] = []
|
||||
start_node = [n for n in self.nodes.values() if n.data.is_start]
|
||||
if not start_node:
|
||||
errors.append(
|
||||
WorkflowError(
|
||||
kind=ItemKind.workflow,
|
||||
id=None,
|
||||
field=None,
|
||||
message="Workflow must have exactly one start node",
|
||||
)
|
||||
)
|
||||
elif len(start_node) > 1:
|
||||
errors.append(
|
||||
WorkflowError(
|
||||
kind=ItemKind.workflow,
|
||||
id=None,
|
||||
field=None,
|
||||
message="Workflow must have exactly one start node",
|
||||
)
|
||||
)
|
||||
return errors
|
||||
|
||||
def _assert_global_node(self):
|
||||
errors: list[WorkflowError] = []
|
||||
global_node = [
|
||||
n for n in self.nodes.values() if n.node_type == NodeType.globalNode
|
||||
]
|
||||
if not len(global_node) <= 1:
|
||||
errors.append(
|
||||
WorkflowError(
|
||||
kind=ItemKind.workflow,
|
||||
id=None,
|
||||
field=None,
|
||||
message="Workflow must have at most one global node",
|
||||
)
|
||||
)
|
||||
return errors
|
||||
|
||||
def _assert_end_node(self):
|
||||
errors: list[WorkflowError] = []
|
||||
end_node = [n for n in self.nodes.values() if n.data.is_end]
|
||||
if not end_node:
|
||||
errors.append(
|
||||
WorkflowError(
|
||||
kind=ItemKind.workflow,
|
||||
id=None,
|
||||
field=None,
|
||||
message="Workflow must have exactly one end node",
|
||||
)
|
||||
)
|
||||
|
||||
elif len(end_node) > 1:
|
||||
errors.append(
|
||||
WorkflowError(
|
||||
kind=ItemKind.workflow,
|
||||
id=None,
|
||||
field=None,
|
||||
message="Workflow must have exactly one end node",
|
||||
)
|
||||
)
|
||||
return errors
|
||||
|
||||
def _assert_connection_counts(self):
|
||||
errors: list[WorkflowError] = []
|
||||
|
||||
out_deg = Counter()
|
||||
in_deg = Counter()
|
||||
for n in self.nodes.values(): # init counters
|
||||
out_deg[n.id] = in_deg[n.id] = 0
|
||||
for src, n in self.nodes.items(): # compute degrees
|
||||
for m in n.out.values():
|
||||
out_deg[src] += 1
|
||||
in_deg[m.id] += 1
|
||||
|
||||
for n in self.nodes.values():
|
||||
in_d, out_d = in_deg[n.id], out_deg[n.id]
|
||||
|
||||
match n.node_type:
|
||||
case NodeType.startNode:
|
||||
if in_d != 0 or out_d < 1:
|
||||
errors.append(
|
||||
WorkflowError(
|
||||
kind=ItemKind.node,
|
||||
id=n.id,
|
||||
field=None,
|
||||
message=f"StartNode must have at least 1 outgoing edge",
|
||||
)
|
||||
)
|
||||
case NodeType.endNode:
|
||||
if in_d < 1 or out_d != 0:
|
||||
errors.append(
|
||||
WorkflowError(
|
||||
kind=ItemKind.node,
|
||||
id=n.id,
|
||||
field=None,
|
||||
message=f"EndNode must have at least 1 incoming edge",
|
||||
)
|
||||
)
|
||||
case NodeType.agentNode:
|
||||
if in_d < 1 or out_d < 1:
|
||||
errors.append(
|
||||
WorkflowError(
|
||||
kind=ItemKind.node,
|
||||
id=n.id,
|
||||
field=None,
|
||||
message=f"Worker must have at least 1 incoming and 1 outgoing edge",
|
||||
)
|
||||
)
|
||||
|
||||
return errors
|
||||
|
||||
def _assert_node_configs(self):
|
||||
"""Validate node-specific configuration constraints."""
|
||||
errors: list[WorkflowError] = []
|
||||
|
||||
for node in self.nodes.values():
|
||||
# Validate StartNode constraints
|
||||
if node.node_type == NodeType.startNode:
|
||||
# No specific validations for start node at this time
|
||||
pass
|
||||
|
||||
return errors
|
||||
Loading…
Add table
Add a link
Reference in a new issue