feat: simplify pipecat engine execution (#54)

This commit is contained in:
Abhishek 2025-11-15 17:38:27 +05:30 committed by GitHub
parent 99a768f291
commit 6ce25a589c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
20 changed files with 52 additions and 1405 deletions

View file

@ -1,4 +1,4 @@
langfuse==3.4.0
langfuse==3.9.3
fastapi==0.116.2
asyncpg==0.30.0
alembic==1.16.5

View file

@ -24,6 +24,9 @@ from api.services.workflow.dto import ReactFlowDTO
from api.services.workflow.pipecat_engine import PipecatEngine
from api.services.workflow.workflow import WorkflowGraph
from pipecat.pipeline.pipeline import Pipeline
from pipecat.processors.aggregators.llm_response_universal import (
LLMContextAggregatorPair,
)
from pipecat.processors.filters.stt_mute_filter import (
STTMuteConfig,
STTMuteFilter,
@ -83,7 +86,8 @@ class LoopTalkPipelineBuilder:
audio_buffer, audio_synchronizer, transcript, context = (
create_pipeline_components(audio_config)
)
context_aggregator = llm.create_context_aggregator(context)
context_aggregator = LLMContextAggregatorPair(context)
# Get workflow graph
workflow_graph = WorkflowGraph(
@ -113,7 +117,6 @@ class LoopTalkPipelineBuilder:
pipeline_engine_callback_processor = PipelineEngineCallbacksProcessor(
max_call_duration_seconds=300,
max_duration_end_task_callback=engine.create_max_duration_callback(),
llm_generated_text_callback=engine.create_llm_generated_text_callback(),
generation_started_callback=engine.create_generation_started_callback(),
)

View file

@ -272,14 +272,6 @@ class LoopTalkTestOrchestrator:
await task.cancel()
# Connect the context aggregator events to engine
@assistant_context_aggregator.event_handler("on_push_aggregation")
async def on_assistant_aggregator_push_context(_aggregator):
logger.debug(
"Assistant aggregator push context flushing pending transitions"
)
await engine.flush_pending_transitions()
# Register custom audio and transcript handlers for LoopTalk
await self._register_looptalk_handlers(
audio_synchronizer, transcript, test_session_id, role

View file

@ -1,69 +0,0 @@
"""Engine Pre-Aggregator Processor
This processor sits before the user context aggregator in the pipeline and handles
engine-specific callbacks for frames that need to be processed before aggregation.
This ensures the engine can update context before the aggregator generates LLM frames.
"""
from typing import Awaitable, Callable, Optional
from loguru import logger
from api.services.pipecat.exceptions import VoicemailDetectedException
from pipecat.frames.frames import (
Frame,
UserStartedSpeakingFrame,
UserStoppedSpeakingFrame,
)
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
class EnginePreAggregatorProcessor(FrameProcessor):
"""
Processor that handles engine callbacks before user context aggregation.
This processor is positioned before the user context aggregator to ensure
the engine can update LLM context before aggregation occurs.
"""
def __init__(
self,
user_started_speaking_callback: Optional[Callable[[], Awaitable[None]]] = None,
user_stopped_speaking_callback: Optional[Callable[[], Awaitable[None]]] = None,
**kwargs,
):
super().__init__(**kwargs)
self._user_started_speaking_callback = user_started_speaking_callback
self._user_stopped_speaking_callback = user_stopped_speaking_callback
async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)
# Handle frames that need engine processing before aggregation
if isinstance(frame, UserStartedSpeakingFrame):
await self._handle_user_started_speaking()
elif isinstance(frame, UserStoppedSpeakingFrame):
try:
await self._handle_user_stopped_speaking()
except VoicemailDetectedException:
# We have detected voicemail, lets not
# forward the UserStoppedSpeakingFrame, so that
# we don't issue an llm call from user context
# aggregator
logger.debug("Voicemail detected, not pushing UserStoppedSpeakingFrame")
return
# Always push the frame downstream
await self.push_frame(frame, direction)
async def _handle_user_started_speaking(self):
"""Handle UserStartedSpeakingFrame before aggregation."""
if self._user_started_speaking_callback:
# logger.debug("Engine pre-aggregator: User started speaking")
await self._user_started_speaking_callback()
async def _handle_user_stopped_speaking(self):
"""Handle UserStoppedSpeakingFrame before aggregation."""
if self._user_stopped_speaking_callback:
# logger.debug("Engine pre-aggregator: User stopped speaking")
await self._user_stopped_speaking_callback()

View file

@ -9,7 +9,7 @@ from api.constants import (
from api.services.pipecat.audio_config import AudioConfig
from pipecat.pipeline.pipeline import Pipeline
from pipecat.pipeline.task import PipelineParams, PipelineTask
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
from pipecat.processors.aggregators.llm_context import LLMContext
from pipecat.processors.audio.audio_buffer_processor import AudioBuffer
from pipecat.processors.audio.audio_synchronizer import AudioSynchronizer
from pipecat.processors.transcript_processor import TranscriptProcessor
@ -39,7 +39,7 @@ def create_pipeline_components(audio_config: AudioConfig, engine: "PipecatEngine
assistant_correct_aggregation_callback=engine.create_aggregation_correction_callback()
)
context = OpenAILLMContext()
context = LLMContext()
return audio_buffer, audio_synchronizer, transcript, context
@ -58,7 +58,6 @@ def build_pipeline(
stt_mute_filter,
pipeline_metrics_aggregator,
user_idle_disconnect,
engine_pre_aggregator_processor=None,
):
"""Build the main pipeline with all components"""
# Register processors with synchronizer for merged audio
@ -69,16 +68,12 @@ def build_pipeline(
processors = [
transport.input(), # Transport user input
audio_buffer.input(), # Record input audio (only processes InputAudioRawFrame)
stt_mute_filter,
stt, # STT can now have audio_passthrough=False
stt_mute_filter, # STTMuteFilters don't let VAD related events pass through if muted
user_idle_disconnect,
transcript.user(),
]
# Insert engine pre-aggregator processor if provided (before user aggregator)
if engine_pre_aggregator_processor:
processors.append(engine_pre_aggregator_processor)
processors.extend(
[
user_context_aggregator,

View file

@ -7,7 +7,6 @@ from pipecat.frames.frames import (
Frame,
HeartbeatFrame,
LLMFullResponseStartFrame,
LLMGeneratedTextFrame,
LLMTextFrame,
StartFrame,
TTSSpeakFrame,
@ -26,7 +25,6 @@ class PipelineEngineCallbacksProcessor(FrameProcessor):
self,
max_call_duration_seconds: int = 300,
max_duration_end_task_callback: Optional[Callable[[], Awaitable[None]]] = None,
llm_generated_text_callback: Optional[Callable[[], Awaitable[None]]] = None,
generation_started_callback: Optional[Callable[[], Awaitable[None]]] = None,
llm_text_frame_callback: Optional[Callable[[str], Awaitable[None]]] = None,
):
@ -34,7 +32,6 @@ class PipelineEngineCallbacksProcessor(FrameProcessor):
self._start_time = None
self._max_call_duration_seconds = max_call_duration_seconds
self._max_duration_end_task_callback = max_duration_end_task_callback
self._llm_generated_text_callback = llm_generated_text_callback
self._generation_started_callback = generation_started_callback
self._llm_text_frame_callback = llm_text_frame_callback
self._end_task_frame_pushed = False
@ -46,8 +43,6 @@ class PipelineEngineCallbacksProcessor(FrameProcessor):
await self._start(frame)
elif isinstance(frame, HeartbeatFrame):
await self._check_call_duration()
elif isinstance(frame, LLMGeneratedTextFrame):
await self._generated_text_frame(frame)
elif isinstance(frame, LLMFullResponseStartFrame):
await self._generation_started()
elif (
@ -74,11 +69,6 @@ class PipelineEngineCallbacksProcessor(FrameProcessor):
"Max call duration exceeded. Skipping EndTaskFrame since already sent"
)
async def _generated_text_frame(self, _: LLMGeneratedTextFrame):
"""Handle LLMGeneratedTextFrame."""
if self._llm_generated_text_callback is not None:
await self._llm_generated_text_callback()
async def _generation_started(self):
if self._generation_started_callback:
await self._generation_started_callback()

View file

@ -7,9 +7,6 @@ from api.db import db_client
from api.db.models import WorkflowModel
from api.enums import WorkflowRunMode
from api.services.pipecat.audio_config import AudioConfig, create_audio_config
from api.services.pipecat.engine_pre_aggregator_processor import (
EnginePreAggregatorProcessor,
)
from api.services.pipecat.event_handlers import (
register_audio_data_handler,
register_task_event_handler,
@ -43,6 +40,9 @@ from api.services.workflow.pipecat_engine import PipecatEngine
from api.services.workflow.workflow import WorkflowGraph
from pipecat.pipeline.runner import PipelineRunner
from pipecat.processors.aggregators.llm_response import LLMAssistantAggregatorParams
from pipecat.processors.aggregators.llm_response_universal import (
LLMContextAggregatorPair,
)
from pipecat.processors.filters.stt_mute_filter import (
STTMuteConfig,
STTMuteFilter,
@ -357,21 +357,14 @@ async def _run_pipeline(
expect_stripped_words=True,
correct_aggregation_callback=engine.create_aggregation_correction_callback(),
)
context_aggregator = llm.create_context_aggregator(
context_aggregator = LLMContextAggregatorPair(
context, assistant_params=assistant_params
)
# Create engine pre-aggregator processor for speaking events
engine_pre_aggregator_processor = EnginePreAggregatorProcessor(
user_started_speaking_callback=engine.create_user_started_speaking_callback(),
user_stopped_speaking_callback=engine.create_user_stopped_speaking_callback(),
)
# Create usage metrics aggregator with engine's callback
pipeline_engine_callback_processor = PipelineEngineCallbacksProcessor(
max_call_duration_seconds=max_call_duration_seconds,
max_duration_end_task_callback=engine.create_max_duration_callback(),
llm_generated_text_callback=engine.create_llm_generated_text_callback(),
generation_started_callback=engine.create_generation_started_callback(),
llm_text_frame_callback=engine.handle_llm_text_frame,
# Note: speaking event callbacks are now handled by pre-aggregator processor
@ -398,11 +391,6 @@ async def _run_pipeline(
user_context_aggregator = context_aggregator.user()
assistant_context_aggregator = context_aggregator.assistant()
@assistant_context_aggregator.event_handler("on_push_aggregation")
async def on_assistant_aggregator_push_context(_aggregator):
logger.debug("Assistant aggregator push context flushing pending transitions")
await engine.flush_pending_transitions(source="context_push")
# Build the pipeline with the STT mute filter and context controller
pipeline = build_pipeline(
transport,
@ -418,7 +406,6 @@ async def _run_pipeline(
stt_mute_filter,
pipeline_metrics_aggregator,
user_idle_disconnect,
engine_pre_aggregator_processor=engine_pre_aggregator_processor,
)
# Create pipeline task with audio configuration

View file

@ -14,14 +14,14 @@ from pipecat.frames.frames import (
CancelFrame,
EndFrame,
FunctionCallResultProperties,
LLMContextFrame,
LLMFullResponseEndFrame,
LLMFullResponseStartFrame,
TTSSpeakFrame,
)
from pipecat.pipeline.task import PipelineTask
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContextFrame
from pipecat.processors.aggregators.llm_context import LLMContext
from pipecat.services.llm_service import FunctionCallParams
from pipecat.services.openai.llm import OpenAILLMContext
from pipecat.transports.base_transport import BaseTransport
from pipecat.utils.enums import EndTaskReason
@ -63,7 +63,7 @@ class PipecatEngine:
*,
task: Optional[PipelineTask] = None,
llm: Optional["LLMService"] = None,
context: Optional[OpenAILLMContext] = None,
context: Optional[LLMContext] = None,
tts: Optional[Any] = None,
transport: Optional[BaseTransport] = None,
workflow: WorkflowGraph,
@ -82,7 +82,6 @@ class PipecatEngine:
self._workflow_run_id = workflow_run_id
self._initialized = False
self._client_disconnected = False
self._pending_function_calls = 0
self._current_node: Optional[Node] = None
self._gathered_context: dict = {}
self._user_response_timeout_task: Optional[asyncio.Task] = None
@ -102,29 +101,9 @@ class PipecatEngine:
self._voicemail_detector = None
self._voicemail_detection_task: Optional[asyncio.Task] = None
# This transition is generated by the llm as part of tool call. This can
# also be accompanied with some content which can be played using TTS. If the
# bot is interrupted, we would cancel this transition (we do cancel this currently when
# the next generation starts in handle_generation_started callback handler.)
self._pending_generated_transition_after_context_push: Optional[
Callable[[], Awaitable[None]]
] = None
# This is the transtion which is typically programmatic transition, and not goes as
# tool call to LLM. This is not interrupted by the user and is done on context push
self._pending_control_transition_after_context_push: Optional[
Callable[[], Awaitable[None]]
] = None
# Flag to determine if the current llm generation has a text completion
self._defer_context_push: bool = False
# Lazy loaded built-in function schemas
self._builtin_function_schemas: Optional[list[dict]] = None
# Flag to control whether to queue context frame
self._queue_context_frame: bool = True
# Track current LLM reference text for TTS aggregation correction
self._current_llm_reference_text: str = ""
@ -211,23 +190,15 @@ class PipecatEngine:
async def _create_transition_func(self, name: str, transition_to_node: str):
async def transition_func(function_call_params: FunctionCallParams) -> None:
"""Inner function that handles the actual tool invocation."""
"""Inner function that handles the node change tool calls"""
try:
# Track pending function call
self._pending_function_calls += 1
logger.debug(
f"Function call pending: {function_call_params.function_name} (total: {self._pending_function_calls})"
)
# For edge functions, prevent LLM completion until transition (run_llm=False)
# For node functions, allow immediate completion (run_llm=True)
async def on_context_updated() -> None:
"""
Framework will run this function after the function call result has been updated in the context.
pipecat framework will run this function after the function call result has been updated in the context.
This way, when we do set_node from within this function, and go for LLM completion with updated
system prompts, the context is updated with function call result.
"""
self._pending_function_calls -= 1
# Perform variable extraction before transitioning to new node
await self._perform_variable_extraction_if_needed(
self._current_node
@ -241,41 +212,14 @@ class PipecatEngine:
on_context_updated=on_context_updated,
)
async def _invoke_result_callback():
"""
Functions are executed immediately when they come from LLM as part of text completion.
But, if the LLM completion also has some text, we would want to not call the function if the user interrupts the speech.
We would also not want the function to be added to context, so that the LLM can call the function again. Hence, we
defer the function invocation until we receive on_context_updated callback, i.e the bot has finished speaking
the text that was generated.
"""
await function_call_params.result_callback(
result, properties=properties
)
if self._defer_context_push:
"""
We set the flag to _defer_context_push when we receive text in the current generation from LLM.
This is set in the handle_llm_generated_text callback handler.
"""
logger.debug(
"Deferring transition function result until context push"
)
# Only one deferred transition should exist at any time.
# Overwrite if one is somehow already set (unexpected).
self._pending_generated_transition_after_context_push = (
_invoke_result_callback
)
else:
"""
If there was no text in the current generation, and we only had function call,
lets invoke the result callback, so that framework can call on_context_updated and
we can do switch node.
"""
await _invoke_result_callback()
# Call results callback from the pipecat framework
# so that a new llm generation can be triggred if
# required
await function_call_params.result_callback(
result, properties=properties
)
except Exception as e:
logger.error(f"Error in transition function {name}: {str(e)}")
self._pending_function_calls = 0
error_result = {"status": "error", "error": str(e)}
await function_call_params.result_callback(error_result)
@ -362,27 +306,6 @@ class PipecatEngine:
]
)
async def _setup_static_start_node_transition(self, node: Node) -> None:
"""Set up the deferred transition for static start nodes."""
if not node.out_edges:
return
next_node_id = node.out_edges[0].target
if not node.wait_for_user_response:
# Normal static start node - transition immediately after context push
async def _deferred_static_transition():
try:
await self.set_node(next_node_id)
except Exception as exc:
logger.error(
f"Error executing deferred static node transition to {next_node_id}: {exc}"
)
self._pending_control_transition_after_context_push = (
_deferred_static_transition
)
async def _perform_variable_extraction_if_needed(
self, previous_node: Optional[Node]
) -> None:
@ -441,17 +364,7 @@ class PipecatEngine:
functions,
) = await self._compose_system_message_functions_for_node(node)
await self._update_llm_context(system_message, functions)
# Queue context frame if needed
if self._queue_context_frame:
await self.task.queue_frame(OpenAILLMContextFrame(self.context))
else:
logger.debug(
f"Not queueing context frame for node: {node.name} as _queue_context_frame is False"
)
# Reset _queue_context_frame as default behavior
self._queue_context_frame = True
await self.task.queue_frame(LLMContextFrame(self.context))
async def set_node(self, node_id: str):
"""
@ -525,12 +438,7 @@ class PipecatEngine:
await asyncio.sleep(delay_duration)
if node.is_static:
# Queue TTS for static start node
formatted_prompt = self._format_prompt(node.prompt)
await self._queue_tts_response(formatted_prompt)
# Set up deferred transition for static start nodes
await self._setup_static_start_node_transition(node)
raise ValueError("Static nodes are not supported!")
else:
# Start generation for non-static start node
await self._setup_llm_context_and_start_generation(node)
@ -538,66 +446,24 @@ class PipecatEngine:
async def _handle_end_node(self, node: Node) -> None:
"""Handle end node execution."""
if node.is_static:
# Queue TTS for static end node
formatted_prompt = self._format_prompt(node.prompt)
await self._queue_tts_response(formatted_prompt)
raise ValueError("Static nodes are not supported!")
else:
# Start generation for non-static end node
await self._setup_llm_context_and_start_generation(node)
# If this end node has extraction enabled, perform extraction immediately
if node.extraction_enabled and node.extraction_variables:
await self._perform_variable_extraction_if_needed(node)
# TODO: Extract disposition code from extracted variables
# Defer send_end_task_frame using _pending_control_transition_after_context_push
# Decide the end-task reason dynamically depending on call_disposition.
async def _deferred_end_task():
# call_disposition is the disposition which is generated from
# llm call based on the conversation so far.
# TODO: Make this more generic based on configuration or llm prompting
disposition = self._gathered_context.get("call_disposition")
if disposition == "XFER":
reason = EndTaskReason.USER_QUALIFIED.value
else:
reason = EndTaskReason.USER_DISQUALIFIED.value
await self.send_end_task_frame(reason)
self._pending_control_transition_after_context_push = _deferred_end_task
await self.send_end_task_frame(EndTaskReason.USER_QUALIFIED.value)
async def _handle_agent_node(self, node: Node) -> None:
"""Handle agent node execution."""
if node.is_static:
# Queue TTS for static agent node
formatted_prompt = self._format_prompt(node.prompt)
await self._queue_tts_response(formatted_prompt)
# Set up deferred transition for static agent nodes
await self._setup_agent_node_transition(node)
raise ValueError("Static nodes are not supported!")
else:
# Set context and functions for non-static agent node
await self._setup_llm_context_and_start_generation(node)
async def _setup_agent_node_transition(self, node: Node) -> None:
"""Set up the deferred transition for static agent nodes."""
if not node.out_edges:
return
next_node_id = node.out_edges[0].target
async def _deferred_static_transition():
try:
await self.set_node(next_node_id)
except Exception as exc:
logger.error(
f"Error executing deferred static node transition to {next_node_id}: {exc}"
)
self._pending_control_transition_after_context_push = (
_deferred_static_transition
)
async def send_end_task_frame(
self,
reason: str,
@ -640,7 +506,7 @@ class PipecatEngine:
# Store the mapped disconnect reason
self._gathered_context["call_disposition"] = mapped_disposition
# TODO: Generalise this, currently tailored to Kapil's use case
# TODO: Generalise this
self._gathered_context["address"] = ", ".join(
[
self._call_context_vars.get("address1", ""),
@ -759,55 +625,6 @@ class PipecatEngine:
return system_message, functions
# ------------------------------------------------------------------
# Pending transition handling
# ------------------------------------------------------------------
async def flush_pending_transitions(self, *, source: str = "context_push"):
"""Execute and clear any pending transitions.
Args:
source: Indicates the trigger that caused this flush:
- "context_push": the assistant context aggregator completed a push.
"""
if source != "context_push":
raise ValueError("Invalid flush source expected 'context_push'")
len_pending_functions = 0
if self._pending_generated_transition_after_context_push is not None:
len_pending_functions += 1
if self._pending_control_transition_after_context_push is not None:
len_pending_functions += 1
# Nothing to do
if len_pending_functions == 0:
return
logger.debug(
f"Flushing {len_pending_functions} pending transition(s) after {source.replace('_', ' ')}"
)
# Generated transition
if self._pending_generated_transition_after_context_push is not None:
pending_cb = self._pending_generated_transition_after_context_push
self._pending_generated_transition_after_context_push = None
try:
await pending_cb()
except Exception as exc: # pragma: no cover
logger.error(f"Error executing deferred transition: {exc}")
# Control transition (context push)
if self._pending_control_transition_after_context_push is not None:
logger.debug("Executing control transition after context push")
static_cb = self._pending_control_transition_after_context_push
self._pending_control_transition_after_context_push = None
try:
await static_cb()
except Exception as exc: # pragma: no cover
logger.error(f"Error executing deferred static node transition: {exc}")
def create_should_mute_callback(self) -> Callable[[STTMuteFilter], Awaitable[bool]]:
"""
This callback is called by STTMuteFilter to determine if the STT should be muted.
@ -828,15 +645,6 @@ class PipecatEngine:
"""
return engine_callbacks.create_max_duration_callback(self)
def create_llm_generated_text_callback(self):
"""
This callback is called when some text is generated by the LLM.
We use this to defer the result_callback of the node transition functions if
there is set_node called along with some text generated. This way, we will
have the context sent in the next generation from new node.
"""
return engine_callbacks.create_llm_generated_text_callback(self)
def create_generation_started_callback(self):
"""
This callback is called when a new generation starts.
@ -844,26 +652,12 @@ class PipecatEngine:
"""
return engine_callbacks.create_generation_started_callback(self)
def create_user_stopped_speaking_callback(self):
"""
This callback is called when the user stops speaking.
We use this to handle transitions when wait_for_user_response is enabled.
"""
return engine_callbacks.create_user_stopped_speaking_callback(self)
def create_user_started_speaking_callback(self):
"""
This callback is called when the user starts speaking.
We use this to handle wait_for_user_greeting functionality.
"""
return engine_callbacks.create_user_started_speaking_callback(self)
def create_aggregation_correction_callback(self) -> Callable[[str], str]:
"""Create a callback that corrects corrupted aggregation using reference text."""
return engine_callbacks.create_aggregation_correction_callback(self)
def set_context(self, context: OpenAILLMContext) -> None:
"""Set the OpenAI LLM context.
def set_context(self, context: LLMContext) -> None:
"""Set the LLM context.
This allows setting the context after the engine has been created,
which is useful when the context needs to be created after the engine.

View file

@ -14,6 +14,7 @@ import re
from typing import TYPE_CHECKING, Awaitable, Callable
from loguru import logger
from pipecat.frames.frames import (
LLMFullResponseEndFrame,
LLMFullResponseStartFrame,
@ -23,9 +24,8 @@ from pipecat.processors.filters.stt_mute_filter import STTMuteFilter
from pipecat.utils.enums import EndTaskReason
if TYPE_CHECKING:
from pipecat.processors.user_idle_processor import UserIdleProcessor
from api.services.workflow.pipecat_engine import PipecatEngine
from pipecat.processors.user_idle_processor import UserIdleProcessor
# ---------------------------------------------------------------------------
@ -114,23 +114,6 @@ def create_max_duration_callback(engine: "PipecatEngine"):
return handle_max_duration
# ---------------------------------------------------------------------------
# LLM-generated-text handling
# ---------------------------------------------------------------------------
def create_llm_generated_text_callback(engine: "PipecatEngine"):
"""Return a callback invoked when the LLM emits text (not only tool calls)."""
async def handle_llm_generated_text(): # noqa: D401
logger.debug(
"Generation has text content in current response - deferring context push from set_node"
)
engine._defer_context_push = True
return handle_llm_generated_text
# ---------------------------------------------------------------------------
# Generation-started handling
# ---------------------------------------------------------------------------
@ -140,96 +123,13 @@ def create_generation_started_callback(engine: "PipecatEngine"):
"""Return a callback that resets flags at the start of each LLM generation."""
async def handle_generation_started(): # noqa: D401
logger.debug("LLM generation started - resetting defer flags and tool counters")
engine._defer_context_push = False
engine._pending_function_calls = 0
engine._pending_generated_transition_after_context_push = None
logger.debug("LLM generation started in callback processor")
# Clear reference text from previous generation
engine._current_llm_reference_text = ""
return handle_generation_started
# ---------------------------------------------------------------------------
# User-stopped-speaking handling
# ---------------------------------------------------------------------------
def create_user_stopped_speaking_callback(engine: "PipecatEngine"):
"""Return a callback that handles when the user stops speaking.
According to simplified flow:
- For start nodes with wait_for_user_response=True:
- Cancel timeout task if still active
- Transition to next node with _queue_context_frame=False
"""
async def handle_user_stopped_speaking():
# Only handle if current node is a start node with wait_for_user_response
if (
engine._current_node
and engine._current_node.is_start
and engine._current_node.wait_for_user_response
and engine._current_node.out_edges
):
# Cancel timeout task if it's still active
if (
engine._user_response_timeout_task
and not engine._user_response_timeout_task.done()
):
logger.debug("Cancelling user response timeout - user responded")
engine._user_response_timeout_task.cancel()
engine._user_response_timeout_task = None
# Transition to next node
next_node_id = engine._current_node.out_edges[0].target
logger.debug(
f"User stopped speaking after wait_for_user_response - transitioning to: {next_node_id}"
)
# Set flag to not queue context frame since
# it will be pushed by user context aggregator
# we are just setting the context with next node's
# functions and prompts
engine._queue_context_frame = False
# Transition to next node
await engine.set_node(next_node_id)
return handle_user_stopped_speaking
# ---------------------------------------------------------------------------
# User-started-speaking handling
# ---------------------------------------------------------------------------
def create_user_started_speaking_callback(engine: "PipecatEngine"):
"""Return a callback that handles when the user starts speaking.
According to simplified flow:
- For start nodes with wait_for_user_response=True:
- Cancel the timeout timer if it exists (but don't set to None)
"""
async def handle_user_started_speaking():
# Only handle if current node is a start node with wait_for_user_response
if (
engine._current_node
and engine._current_node.is_start
and engine._current_node.wait_for_user_response
and engine._user_response_timeout_task
and not engine._user_response_timeout_task.done()
):
logger.debug(
"User started speaking during wait_for_user_response - cancelling timeout timer"
)
engine._user_response_timeout_task.cancel()
# Don't set to None here - let user_stopped_speaking handle the transition
return handle_user_started_speaking
def create_aggregation_correction_callback(engine: "PipecatEngine"):
"""Create a callback that uses engine's reference text to correct corrupted aggregation."""

View file

@ -2,16 +2,10 @@ from __future__ import annotations
from typing import Any, Dict, List
from google.genai.types import (
Content,
Part,
)
from api.utils.template_renderer import render_template
from pipecat.adapters.schemas.function_schema import FunctionSchema
from pipecat.adapters.schemas.tools_schema import ToolsSchema
from pipecat.services.google.llm import GoogleLLMContext
from pipecat.services.openai.llm import OpenAILLMContext
from api.utils.template_renderer import render_template
from pipecat.processors.aggregators.llm_context import LLMContext
__all__ = [
"get_function_schema",
@ -44,7 +38,7 @@ def get_function_schema(
def update_llm_context(
context: OpenAILLMContext,
context: LLMContext,
system_message: Dict[str, Any],
functions: List[FunctionSchema],
) -> None:
@ -59,21 +53,6 @@ def update_llm_context(
# associated with the current LLM service can convert them to the correct
# provider-specific representation when required.
tools_schema = ToolsSchema(standard_tools=functions)
if isinstance(context, GoogleLLMContext):
context.system_message = system_message["content"]
if functions:
# Lets only call set_tools if we have functions, else Gemini will
# throw an exception
context.set_tools(tools_schema)
if context.messages[-1].role != "user":
# Google expects the last message should end with user message
context.add_message(Content(role="user", parts=[Part(text="...")]))
return
# In case of OpenAILLMContext, replace the system message with incoming system message
previous_interactions = context.messages
# Filter out old system messages but keep user/assistant/function content.

View file

@ -7,11 +7,11 @@ from typing import TYPE_CHECKING, Any, List
from loguru import logger
from openai import AsyncOpenAI
from opentelemetry import trace
from pipecat.services.openai.llm import OpenAILLMContext
from pipecat.utils.tracing.service_attributes import add_llm_span_attributes
from api.services.pipecat.tracing_config import is_tracing_enabled
from api.services.workflow.dto import ExtractionVariableDTO
from pipecat.processors.aggregators.llm_context import LLMContext
from pipecat.utils.tracing.service_attributes import add_llm_span_attributes
if TYPE_CHECKING:
from api.services.workflow.pipecat_engine import PipecatEngine
@ -139,7 +139,7 @@ class VariableExtractionManager:
f"{conversation_history}"
)
extraction_context = OpenAILLMContext()
extraction_context = LLMContext()
extraction_messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
@ -171,7 +171,7 @@ class VariableExtractionManager:
service_name="OpenAILLMService",
model=self._model,
operation_name="variable_extraction",
messages=json.dumps(extraction_messages),
messages=extraction_messages,
output=llm_response,
stream=False,
parameters={"temperature": 0.0, "response_format": "json_object"},

View file

@ -44,8 +44,6 @@ class Node:
self.extraction_prompt = data.extraction_prompt
self.extraction_variables = data.extraction_variables
self.add_global_prompt = data.add_global_prompt
self.wait_for_user_response = data.wait_for_user_response
self.wait_for_user_response_timeout = data.wait_for_user_response_timeout
self.detect_voicemail = data.detect_voicemail
self.delayed_start = data.delayed_start
self.delayed_start_duration = data.delayed_start_duration

View file

@ -3,12 +3,12 @@ import os
import aiohttp
import httpx
from loguru import logger
from pipecat.utils.context import set_current_run_id
from api.db import db_client
from api.db.models import IntegrationModel
from api.enums import OrganizationConfigurationKey, WorkflowRunMode
from api.utils.template_renderer import render_template
from pipecat.utils.context import set_current_run_id
async def run_integrations_post_workflow_run(ctx, workflow_run_id: int):
@ -162,7 +162,7 @@ async def _process_slack_integration(
"""
logger.info(f"Processing Slack integration {integration.id}")
# TODO: Generalise this, currently tailored to Kapil's use case
# TODO: Generalise this
if gathered_context.get("mapped_call_disposition") != "XFER":
logger.debug(
f"Not sending message on slack since not XFER: {gathered_context.get('mapped_call_disposition')}"

View file

@ -1,179 +0,0 @@
### - This test has some weird loop which keeps on increasing the context size
# import asyncio
# import json
# import unittest
# from types import SimpleNamespace
# from unittest import mock
# from loguru import logger
# from pipecat.frames.frames import (
# FunctionCallInProgressFrame,
# FunctionCallResultFrame,
# FunctionCallsStartedFrame,
# LLMFullResponseEndFrame,
# LLMFullResponseStartFrame,
# LLMGeneratedTextFrame,
# LLMTextFrame,
# )
# from pipecat.pipeline.pipeline import Pipeline
# from pipecat.processors.aggregators.openai_llm_context import (
# OpenAILLMContext,
# OpenAILLMContextFrame,
# )
# from pipecat.services.llm_service import (
# FunctionCallParams,
# FunctionCallResultProperties,
# )
# from pipecat.services.openai.llm import OpenAILLMService
# from pipecat.tests.utils import run_test
# class _MockAsyncStream:
# """A minimal async-stream wrapper that mimics ``openai.AsyncStream``."""
# def __init__(self, chunks):
# self._chunks = chunks
# def __aiter__(self):
# self._idx = 0
# return self
# async def __anext__(self):
# if self._idx >= len(self._chunks):
# raise StopAsyncIteration
# item = self._chunks[self._idx]
# self._idx += 1
# await asyncio.sleep(0) # Yield control
# return item
# # ------------------------------------------------------------------
# # Factories for mock chunks
# # ------------------------------------------------------------------
# def _make_tool_call(tool_name: str, args_json: str, *, idx: int = 0):
# function = SimpleNamespace(name=tool_name, arguments=args_json)
# return SimpleNamespace(index=idx, id=f"call-{idx}", function=function)
# def _make_chunk(*, content: str | None = None, tool_calls=None, usage=None):
# delta = SimpleNamespace()
# # When we are asked to simulate multiple tool calls in parallel, OpenAI
# # sends *separate* chunks for every tool-call index. To mimic that behaviour
# # in tests we split a list of tool calls (>1) into individual chunks one
# # for each tool call while keeping the original single-chunk behaviour
# # when zero or one tool calls are supplied. This enables us to write
# # concise tests such as ``_make_chunk(tool_calls=[call_1, call_2])`` that
# # accurately reflect the streaming protocol.
# # No special handling needed if there is textual content or 0/1 tool calls.
# if content is not None or tool_calls is None or len(tool_calls) <= 1:
# if content is not None:
# delta.content = content
# # Always set tool_calls so downstream code can safely access it
# delta.tool_calls = tool_calls if tool_calls is not None else None
# return SimpleNamespace(choices=[SimpleNamespace(delta=delta)], usage=usage)
# # --- Multiple tool calls (len(tool_calls) > 1) ---
# # Create a list of chunks, each containing a single tool call. This is the
# # format produced by the OpenAI client when several tools are invoked in a
# # single assistant response.
# chunks = []
# for tc in tool_calls:
# delta_tc = SimpleNamespace(tool_calls=[tc])
# chunks.append(SimpleNamespace(choices=[SimpleNamespace(delta=delta_tc)], usage=usage))
# return chunks
# class TestBaseOpenAILLMService(unittest.IsolatedAsyncioTestCase):
# async def test_process_context_with_patch(self):
# streamed_text = "Hello from OpenAI!"
# tool_name = "echo"
# tool_name_2 = "echo_2"
# tool_args = {"text": "hello"}
# tool_args_2 = {"text": "hello_2"}
# # Build mocked stream (tool call first, then text)
# chunks = [
# _make_chunk(content=streamed_text),
# _make_chunk(tool_calls=[_make_tool_call(tool_name, json.dumps(tool_args))]),
# _make_chunk(tool_calls=[_make_tool_call(tool_name_2, json.dumps(tool_args_2), idx=1)]),
# ]
# # Instantiate real OpenAILLMService (no need for actual API key)
# llm = OpenAILLMService(model="gpt-4o-mini", api_key="test")
# # Patch get_chat_completions to return our mocked async stream
# async def fake_get_chat_completions(self, context, messages): # noqa: D401
# return _MockAsyncStream(chunks)
# with mock.patch.object(llm.__class__, "get_chat_completions", fake_get_chat_completions):
# # Register echo tool
# executed = False
# async def echo_handler(params: FunctionCallParams):
# nonlocal executed
# executed = True
# # sleep for 1 second
# logger.info("echo_handler: sleeping for 5 second")
# await asyncio.sleep(5)
# await params.result_callback(
# {"ok": True},
# properties=FunctionCallResultProperties(run_llm=True),
# )
# async def echo_2_handler(params: FunctionCallParams):
# nonlocal executed
# executed = True
# # sleep for 1 second
# logger.info("echo_2_handler: sleeping for 5 second")
# await asyncio.sleep(5)
# await params.result_callback(
# {"ok": True},
# properties=FunctionCallResultProperties(run_llm=True),
# )
# llm.register_function(tool_name, echo_handler)
# llm.register_function(tool_name_2, echo_2_handler)
# # Prepare context and send
# context = OpenAILLMContext()
# context.add_message({"role": "user", "content": "Hi"})
# frames_to_send = [OpenAILLMContextFrame(context)]
# expected_down_frames = [
# LLMFullResponseStartFrame,
# FunctionCallsStartedFrame,
# FunctionCallInProgressFrame,
# FunctionCallResultFrame,
# LLMGeneratedTextFrame,
# LLMTextFrame,
# LLMFullResponseEndFrame,
# ]
# context_aggregator = llm.create_context_aggregator(context)
# pipeline = Pipeline([llm, context_aggregator.assistant()])
# down_frames, _ = await run_test(
# pipeline,
# frames_to_send=frames_to_send,
# expected_down_frames=expected_down_frames,
# send_end_frame=False,
# )
# # Assertions
# self.assertTrue(executed)
# for fr in down_frames:
# if isinstance(fr, FunctionCallResultFrame):
# self.assertTrue(fr.run_llm)
# if isinstance(fr, LLMTextFrame):
# self.assertEqual(fr.text, streamed_text)
# if __name__ == "__main__":
# unittest.main()

View file

@ -1,143 +0,0 @@
#!/usr/bin/env python3
"""
Test script to verify that LLMGeneratedTextFrame signaling works correctly
with the new local variable approach.
"""
def test_local_variable_logic():
"""Test the core logic using the same pattern as the implementation"""
print("=== Testing Local Variable Logic ===")
# Simulate the logic from _process_context
text_generation_signaled = False
frames_sent = []
# Simulate chunks with text content
chunks_with_content = ["Hello", " world", "!"]
for content in chunks_with_content:
# This is the exact logic from our implementation
if content: # equivalent to chunk.choices[0].delta.content
if not text_generation_signaled:
frames_sent.append("LLMGeneratedTextFrame")
text_generation_signaled = True
frames_sent.append(f"LLMTextFrame({content})")
print(f"Frames sent: {frames_sent}")
# Verify behavior
generated_signals = [f for f in frames_sent if f == "LLMGeneratedTextFrame"]
text_frames = [f for f in frames_sent if f.startswith("LLMTextFrame")]
assert len(generated_signals) == 1, (
f"Expected 1 signal, got {len(generated_signals)}"
)
assert len(text_frames) == 3, f"Expected 3 text frames, got {len(text_frames)}"
assert frames_sent[0] == "LLMGeneratedTextFrame", "Signal should be first"
print("✅ Local variable logic works correctly")
return True
def test_no_text_logic():
"""Test that no signal is sent when there's no text"""
print("\n=== Testing No Text Logic ===")
text_generation_signaled = False
frames_sent = []
# Simulate chunks with no text content (function calls only)
chunks_with_content = [None, None, None] # No text content
for content in chunks_with_content:
if content: # This will be False for all chunks
if not text_generation_signaled:
frames_sent.append("LLMGeneratedTextFrame")
text_generation_signaled = True
frames_sent.append(f"LLMTextFrame({content})")
print(f"Frames sent: {frames_sent}")
assert len(frames_sent) == 0, f"Expected no frames, got {frames_sent}"
print("✅ No signal sent when no text content")
return True
def test_mixed_content_logic():
"""Test behavior with mixed function calls and text"""
print("\n=== Testing Mixed Content Logic ===")
text_generation_signaled = False
frames_sent = []
# Simulate chunks: function call, text, function call, text
chunks = [
{"type": "function", "content": None},
{"type": "text", "content": "Hello"},
{"type": "function", "content": None},
{"type": "text", "content": " world"},
]
for chunk in chunks:
if chunk["type"] == "function":
frames_sent.append("FunctionCallFrame")
elif chunk["content"]: # text content
if not text_generation_signaled:
frames_sent.append("LLMGeneratedTextFrame")
text_generation_signaled = True
frames_sent.append(f"LLMTextFrame({chunk['content']})")
print(f"Frames sent: {frames_sent}")
generated_signals = [f for f in frames_sent if f == "LLMGeneratedTextFrame"]
assert len(generated_signals) == 1, (
f"Expected 1 signal, got {len(generated_signals)}"
)
# Signal should come before first text frame but after any function frames
signal_index = frames_sent.index("LLMGeneratedTextFrame")
first_text_index = next(
i for i, f in enumerate(frames_sent) if f.startswith("LLMTextFrame")
)
assert signal_index == first_text_index - 1, (
"Signal should come right before first text"
)
print("✅ Mixed content logic works correctly")
return True
def main():
try:
test1_result = test_local_variable_logic()
test2_result = test_no_text_logic()
test3_result = test_mixed_content_logic()
print(f"\n=== Test Results ===")
print(f"Local variable test: {'✅ PASS' if test1_result else '❌ FAIL'}")
print(f"No text test: {'✅ PASS' if test2_result else '❌ FAIL'}")
print(f"Mixed content test: {'✅ PASS' if test3_result else '❌ FAIL'}")
if test1_result and test2_result and test3_result:
print("\n🎉 All LLMGeneratedTextFrame signaling logic tests passed!")
print(
"✅ Implementation correctly signals text generation once, as early as possible"
)
else:
print("\n❌ Some tests failed.")
except Exception as e:
print(f"❌ Test failed with error: {e}")
import traceback
traceback.print_exc()
if __name__ == "__main__":
main()

View file

@ -1,536 +0,0 @@
import asyncio
from unittest.mock import AsyncMock, Mock, patch
import pytest
from pipecat.frames.frames import (
EndFrame,
LLMFullResponseEndFrame,
LLMFullResponseStartFrame,
TTSSpeakFrame,
)
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContextFrame
from pipecat.services.openai.llm import OpenAILLMContext
from api.services.workflow.dto import EdgeDataDTO, NodeDataDTO
from api.services.workflow.pipecat_engine import PipecatEngine
from api.services.workflow.workflow import Edge, Node, WorkflowGraph
class TestPipecatEngineSetNode:
"""Test cases for PipecatEngine.set_node method refactoring."""
@pytest.fixture
def mock_workflow(self):
"""Create a mock workflow with various node types."""
workflow = Mock(spec=WorkflowGraph)
workflow.nodes = {}
workflow.start_node_id = "start_node"
workflow.global_node_id = None
return workflow
@pytest.fixture
def mock_dependencies(self, mock_workflow):
"""Create mock dependencies for PipecatEngine initialization."""
task = AsyncMock()
task.queue_frames = AsyncMock()
task.queue_frame = AsyncMock()
llm = AsyncMock()
llm.register_function = Mock()
llm.push_frame = AsyncMock()
context = Mock(spec=OpenAILLMContext)
context.set_node_name = Mock()
return {
"task": task,
"llm": llm,
"context": context,
"tts": Mock(),
"transport": Mock(),
"workflow": mock_workflow,
"call_context_vars": {"test_var": "test_value"},
}
@pytest.fixture
def engine(self, mock_dependencies):
"""Create a PipecatEngine instance."""
# Add audio_buffer and workflow_run_id to dependencies
mock_dependencies["audio_buffer"] = None
mock_dependencies["workflow_run_id"] = 123
engine = PipecatEngine(**mock_dependencies)
# Mock the builtin function registration
engine._register_builtin_functions = AsyncMock()
return engine
def create_node(self, node_id, **kwargs):
"""Helper to create a node with default values."""
defaults = {
"name": f"Node {node_id}",
"prompt": f"Prompt for {node_id}",
"is_static": False,
"is_start": False,
"is_end": False,
"allow_interrupt": True,
"extraction_enabled": False,
"extraction_prompt": "",
"extraction_variables": [],
"add_global_prompt": True,
"wait_for_user_response": False,
"detect_voicemail": False,
}
defaults.update(kwargs)
data = Mock(spec=NodeDataDTO)
for key, value in defaults.items():
setattr(data, key, value)
node = Mock(spec=Node)
node.id = node_id
node.data = data
node.out_edges = []
# Copy attributes from data to node
for key, value in defaults.items():
setattr(node, key, value)
return node
def create_edge(
self, source, target, label="Continue", condition="Always continue"
):
"""Helper to create an edge."""
data = Mock(spec=EdgeDataDTO)
data.label = label
data.condition = condition
edge = Mock(spec=Edge)
edge.source = source
edge.target = target
edge.data = data
edge.get_function_name = Mock(return_value=label.lower().replace(" ", "_"))
return edge
# ===== START NODE TESTS =====
@pytest.mark.asyncio
async def test_start_node_static_immediate_execution(self, engine, mock_workflow):
"""Test: Basic static start node executes immediately."""
# Setup
start_node = self.create_node(
"start_node",
is_start=True,
is_static=True,
prompt="Welcome to our service!",
)
next_node = self.create_node("next_node", is_static=False)
edge = self.create_edge("start_node", "next_node")
start_node.out_edges = [edge]
mock_workflow.nodes = {"start_node": start_node, "next_node": next_node}
# Execute
await engine.set_node("start_node")
# Verify
# Should queue TTS immediately
engine.task.queue_frames.assert_called_once()
frames = engine.task.queue_frames.call_args[0][0]
assert len(frames) == 3
assert isinstance(frames[0], LLMFullResponseStartFrame)
assert isinstance(frames[1], TTSSpeakFrame)
assert frames[1].text == "Welcome to our service!"
assert isinstance(frames[2], LLMFullResponseEndFrame)
# Static start nodes now set pending transition after context push
assert engine._pending_control_transition_after_context_push is not None
# Should not have set detect_voicemail for static start without it
assert not engine._detect_voicemail
@pytest.mark.asyncio
async def test_start_node_with_detect_voicemail_no_audio_buffer(
self, engine, mock_workflow
):
"""Test: Start node with voicemail detection but no audio buffer logs warning."""
# Setup
start_node = self.create_node(
"start_node",
is_start=True,
is_static=True,
detect_voicemail=True,
prompt="Hello, this is a business call.",
)
mock_workflow.nodes = {"start_node": start_node}
# Engine has no audio buffer (None)
assert engine._audio_buffer is None
# Execute
await engine.set_node("start_node")
# Verify
# Should NOT set voicemail detection flag since no audio buffer
assert engine._detect_voicemail is False
assert engine._voicemail_detector is None
# Should queue TTS immediately
engine.task.queue_frames.assert_called_once()
frames = engine.task.queue_frames.call_args[0][0]
assert isinstance(frames[1], TTSSpeakFrame)
assert frames[1].text == "Hello, this is a business call."
@pytest.mark.asyncio
async def test_start_node_non_static_with_detect_voicemail(
self, engine, mock_workflow
):
"""Test: Non-static start node with voicemail detection without audio buffer."""
# Setup
start_node = self.create_node(
"start_node",
is_start=True,
is_static=False, # Non-static
detect_voicemail=True,
prompt="You are an AI assistant. Start the conversation.",
)
mock_workflow.nodes = {"start_node": start_node}
# Mock the context update method
engine._update_llm_context = AsyncMock()
engine._compose_system_message_functions_for_node = AsyncMock(
return_value=({"role": "system", "content": "Test prompt"}, [])
)
# Execute
await engine.set_node("start_node")
# Verify
# Should NOT set voicemail detection flags (no audio buffer)
assert engine._detect_voicemail is False
assert engine._voicemail_detector is None
# Should update LLM context for non-static node
engine._update_llm_context.assert_called_once()
# Should queue context frame
engine.task.queue_frame.assert_called_once()
frame = engine.task.queue_frame.call_args[0][0]
assert isinstance(frame, OpenAILLMContextFrame)
@pytest.mark.asyncio
async def test_start_node_static_with_wait_for_user_response(
self, engine, mock_workflow
):
"""Test: Static start node with wait_for_user_response."""
# Setup
start_node = self.create_node(
"start_node",
is_start=True,
is_static=True,
wait_for_user_response=True,
prompt="Please tell me your name.",
)
next_node = self.create_node("next_node")
edge = self.create_edge("start_node", "next_node")
start_node.out_edges = [edge]
mock_workflow.nodes = {"start_node": start_node, "next_node": next_node}
# Execute
await engine.set_node("start_node")
# Verify
# Should queue TTS immediately
engine.task.queue_frames.assert_called_once()
# Should have a pending control transition that will start the timer
assert engine._pending_control_transition_after_context_push is not None
# Timer task should not exist yet
assert (
not hasattr(engine, "_user_response_timeout_task")
or engine._user_response_timeout_task is None
)
# Simulate context push to start the timer
await engine.flush_pending_transitions(source="context_push")
# Now the timeout task should be created
assert engine._user_response_timeout_task is not None
assert not engine._user_response_timeout_task.done()
# Clean up the task
engine._user_response_timeout_task.cancel()
@pytest.mark.asyncio
async def test_start_node_non_static(self, engine, mock_workflow):
"""Test: Non-static start node sends context to LLM."""
# Setup
start_node = self.create_node(
"start_node",
is_start=True,
is_static=False,
prompt="You are a helpful assistant. Greet the user.",
)
mock_workflow.nodes = {"start_node": start_node}
# Mock the context update method
engine._update_llm_context = AsyncMock()
engine._compose_system_message_functions_for_node = AsyncMock(
return_value=({"role": "system", "content": "Test prompt"}, [])
)
# Execute
await engine.set_node("start_node")
# Verify
# Should set context name
engine.context.set_node_name.assert_called_once_with("Node start_node")
# Should update LLM context
engine._update_llm_context.assert_called_once()
# Should queue context frame
engine.task.queue_frame.assert_called_once()
frame = engine.task.queue_frame.call_args[0][0]
assert isinstance(frame, OpenAILLMContextFrame)
# ===== AGENT NODE TESTS =====
@pytest.mark.asyncio
async def test_agent_node_static(self, engine, mock_workflow):
"""Test: Static agent node plays TTS and transitions."""
# Setup
agent_node = self.create_node(
"agent_node", is_static=True, prompt="Processing your request..."
)
next_node = self.create_node("next_node")
edge = self.create_edge("agent_node", "next_node")
agent_node.out_edges = [edge]
mock_workflow.nodes = {"agent_node": agent_node, "next_node": next_node}
# Execute
await engine.set_node("agent_node")
# Verify
# Should queue TTS
engine.task.queue_frames.assert_called_once()
frames = engine.task.queue_frames.call_args[0][0]
assert isinstance(frames[1], TTSSpeakFrame)
assert frames[1].text == "Processing your request..."
# Should have pending transition
assert engine._pending_control_transition_after_context_push is not None
@pytest.mark.asyncio
async def test_agent_node_non_static(self, engine, mock_workflow):
"""Test: Non-static agent node sends context to LLM."""
# Setup
agent_node = self.create_node(
"agent_node",
is_static=False,
prompt="Analyze the user's request and respond appropriately.",
)
decision_node = self.create_node("decision_node")
edge = self.create_edge("agent_node", "decision_node", "analyze_complete")
agent_node.out_edges = [edge]
mock_workflow.nodes = {"agent_node": agent_node, "decision_node": decision_node}
# Mock methods
engine._update_llm_context = AsyncMock()
engine._compose_system_message_functions_for_node = AsyncMock(
return_value=(
{"role": "system", "content": "Test"},
[{"name": "test_func"}],
)
)
# Execute
await engine.set_node("agent_node")
# Verify
# Should register transition function
engine.llm.register_function.assert_called_once()
call_args = engine.llm.register_function.call_args
assert call_args[0][0] == "analyze_complete"
assert callable(call_args[0][1]) # Check it's a function
assert call_args[1]["cancel_on_interruption"] is True
# Should update context and send frame
engine._update_llm_context.assert_called_once()
engine.task.queue_frame.assert_called_once()
@pytest.mark.asyncio
async def test_agent_node_with_interruption_control(self, engine, mock_workflow):
"""Test: Agent node respects allow_interrupt flag."""
# Setup
no_interrupt_node = self.create_node(
"no_interrupt",
is_static=True,
allow_interrupt=False,
prompt="Please wait while I process...",
)
mock_workflow.nodes = {"no_interrupt": no_interrupt_node}
# Execute
await engine.set_node("no_interrupt")
# Verify current node is set (for STT mute callback)
assert engine._current_node == no_interrupt_node
assert engine._current_node.allow_interrupt is False
# ===== END NODE TESTS =====
@pytest.mark.asyncio
async def test_end_node_static(self, engine, mock_workflow):
"""Test: Static end node plays final message and schedules end task."""
# Setup
end_node = self.create_node(
"end_node",
is_static=True,
is_end=True,
prompt="Thank you for calling. Goodbye!",
)
mock_workflow.nodes = {"end_node": end_node}
# Execute
await engine.set_node("end_node")
# Verify
# Should queue TTS
engine.task.queue_frames.assert_called_once()
frames = engine.task.queue_frames.call_args[0][0]
assert frames[1].text == "Thank you for calling. Goodbye!"
# Should have pending end task
assert engine._pending_control_transition_after_context_push is not None
# Execute the pending transition
await engine._pending_control_transition_after_context_push()
# Should have sent EndFrame via task.queue_frame
# The second call should be the EndFrame (first was TTS frames)
assert engine.task.queue_frame.call_count >= 1
end_frame = engine.task.queue_frame.call_args[0][0]
assert isinstance(end_frame, EndFrame)
@pytest.mark.asyncio
async def test_end_node_with_extraction(self, engine, mock_workflow):
"""Test: End node with variable extraction."""
# Setup
end_node = self.create_node(
"end_node",
is_end=True,
is_static=False,
extraction_enabled=True,
extraction_variables=["user_name", "satisfaction_level"],
extraction_prompt="Extract user name and satisfaction",
)
mock_workflow.nodes = {"end_node": end_node}
# Mock the extraction manager
engine._variable_extraction_manager = Mock()
engine._perform_variable_extraction_if_needed = AsyncMock()
# Mock context update and composition methods
engine._update_llm_context = AsyncMock()
engine._compose_system_message_functions_for_node = AsyncMock(
return_value=({"role": "system", "content": "Test"}, [])
)
# Execute
await engine.set_node("end_node")
# Verify
# Should trigger extraction
engine._perform_variable_extraction_if_needed.assert_called_once_with(end_node)
# Should have pending end task
assert engine._pending_control_transition_after_context_push is not None
# ===== CALLBACK INTEGRATION TESTS =====
@pytest.mark.asyncio
async def test_user_stopped_speaking_during_response_wait(
self, engine, mock_workflow
):
"""Test: User stops speaking triggers transition during wait_for_response."""
# Setup
start_node = self.create_node(
"start_node", is_start=True, is_static=True, wait_for_user_response=True
)
next_node = self.create_node("next_node")
edge = self.create_edge("start_node", "next_node")
start_node.out_edges = [edge]
mock_workflow.nodes = {"start_node": start_node, "next_node": next_node}
# Set current node to start node
engine._current_node = start_node
engine._user_response_timeout_task = asyncio.create_task(asyncio.sleep(3))
# Create callback and execute
callback = engine.create_user_stopped_speaking_callback()
# Mock set_node to avoid recursion
with patch.object(engine, "set_node", new=AsyncMock()) as mock_set_node:
await callback()
# Verify
mock_set_node.assert_called_once_with("next_node")
assert engine._queue_context_frame is False # Should be set to False
@pytest.mark.asyncio
async def test_context_push_callback_executes_pending_transitions(self, engine):
"""Test: flush_pending_transitions executes deferred transitions."""
# Setup pending transitions
mock_generated_transition = AsyncMock()
mock_control_transition = AsyncMock()
engine._pending_generated_transition_after_context_push = (
mock_generated_transition
)
engine._pending_control_transition_after_context_push = mock_control_transition
# Execute
await engine.flush_pending_transitions(source="context_push")
# Verify both transitions were executed
mock_generated_transition.assert_called_once()
mock_control_transition.assert_called_once()
# Verify they were cleared
assert engine._pending_generated_transition_after_context_push is None
assert engine._pending_control_transition_after_context_push is None
# ===== COMPLEX SCENARIO TESTS =====
# Add helper for testing with real async behavior
def ANY(cls=None):
"""Helper for matching any argument in mock calls."""
class AnyMatcher:
def __init__(self, cls):
self.cls = cls
def __eq__(self, other):
if self.cls:
return isinstance(other, self.cls)
return True
return AnyMatcher(cls)

@ -1 +1 @@
Subproject commit fa68d2ce261544398013307d2c6a69e0556b4449
Subproject commit 53653657d851e8052f9cc5b73b6f675a44c86fe7

View file

@ -18,8 +18,6 @@ interface EndCallEditFormProps {
nodeData: FlowNodeData;
prompt: string;
setPrompt: (value: string) => void;
isStatic: boolean;
setIsStatic: (value: boolean) => void;
name: string;
setName: (value: string) => void;
extractionEnabled: boolean;
@ -45,7 +43,6 @@ export const EndCall = memo(({ data, selected, id }: EndCallNodeProps) => {
// Form state
const [prompt, setPrompt] = useState(data.prompt);
const [isStatic, setIsStatic] = useState(data.is_static ?? true);
const [name, setName] = useState(data.name);
// Variable Extraction state
@ -58,7 +55,6 @@ export const EndCall = memo(({ data, selected, id }: EndCallNodeProps) => {
handleSaveNodeData({
...data,
prompt,
is_static: isStatic,
name,
allow_interrupt: false, // Always set to false for end nodes
extraction_enabled: extractionEnabled,
@ -77,7 +73,6 @@ export const EndCall = memo(({ data, selected, id }: EndCallNodeProps) => {
const handleOpenChange = (newOpen: boolean) => {
if (newOpen) {
setPrompt(data.prompt);
setIsStatic(data.is_static ?? true);
setName(data.name);
setExtractionEnabled(data.extraction_enabled ?? false);
setExtractionPrompt(data.extraction_prompt ?? "");
@ -91,7 +86,6 @@ export const EndCall = memo(({ data, selected, id }: EndCallNodeProps) => {
useEffect(() => {
if (open) {
setPrompt(data.prompt);
setIsStatic(data.is_static ?? true);
setName(data.name);
setExtractionEnabled(data.extraction_enabled ?? false);
setExtractionPrompt(data.extraction_prompt ?? "");
@ -137,8 +131,6 @@ export const EndCall = memo(({ data, selected, id }: EndCallNodeProps) => {
nodeData={data}
prompt={prompt}
setPrompt={setPrompt}
isStatic={isStatic}
setIsStatic={setIsStatic}
name={name}
setName={setName}
extractionEnabled={extractionEnabled}
@ -159,8 +151,6 @@ export const EndCall = memo(({ data, selected, id }: EndCallNodeProps) => {
const EndCallEditForm = ({
prompt,
setPrompt,
isStatic,
setIsStatic,
name,
setName,
extractionEnabled,
@ -206,14 +196,10 @@ const EndCallEditForm = ({
</Label>
<Input value={name} onChange={(e) => setName(e.target.value)} />
<Label>{isStatic ? "Text" : "Prompt"}</Label>
<Label>Prompt</Label>
<Label className="text-xs text-gray-500">
What would you like the agent to say when the call ends? Its a good idea to have a static goodbye message.
Enter the prompt for the agent. This will be used to generate the agent&apos;s response. Prompt engineering&apos;s best practices apply.
</Label>
<div className="flex items-center space-x-2">
<Switch id="static-text" checked={isStatic} onCheckedChange={setIsStatic} />
<Label htmlFor="static-text">Static Text</Label>
</div>
<Textarea
value={prompt}
onChange={(e) => setPrompt(e.target.value)}
@ -221,7 +207,7 @@ const EndCallEditForm = ({
style={{
overflowY: 'auto'
}}
placeholder={isStatic ? "Thank you for calling Dograh. Have a great day!" : "Enter a dynamic prompt"}
placeholder="Enter a dynamic prompt"
/>
<div className="flex items-center space-x-2">
<Switch id="add-global-prompt" checked={addGlobalPrompt} onCheckedChange={setAddGlobalPrompt} />

View file

@ -19,16 +19,12 @@ interface StartCallEditFormProps {
nodeData: FlowNodeData;
prompt: string;
setPrompt: (value: string) => void;
isStatic: boolean;
setIsStatic: (value: boolean) => void;
name: string;
setName: (value: string) => void;
allowInterrupt: boolean;
setAllowInterrupt: (value: boolean) => void;
addGlobalPrompt: boolean;
setAddGlobalPrompt: (value: boolean) => void;
waitForUserResponse: boolean;
setWaitForUserResponse: (value: boolean) => void;
detectVoicemail: boolean;
setDetectVoicemail: (value: boolean) => void;
delayedStart: boolean;
@ -50,11 +46,9 @@ export const StartCall = memo(({ data, selected, id }: StartCallNodeProps) => {
// Form state
const [prompt, setPrompt] = useState(data.prompt ?? "");
const [isStatic, setIsStatic] = useState(data.is_static ?? true);
const [name, setName] = useState(data.name);
const [allowInterrupt, setAllowInterrupt] = useState(data.allow_interrupt ?? true);
const [addGlobalPrompt, setAddGlobalPrompt] = useState(data.add_global_prompt ?? true);
const [waitForUserResponse, setWaitForUserResponse] = useState(data.wait_for_user_response ?? false);
const [detectVoicemail, setDetectVoicemail] = useState(data.detect_voicemail ?? true);
const [delayedStart, setDelayedStart] = useState(data.delayed_start ?? false);
const [delayedStartDuration, setDelayedStartDuration] = useState(data.delayed_start_duration ?? 2);
@ -63,11 +57,9 @@ export const StartCall = memo(({ data, selected, id }: StartCallNodeProps) => {
handleSaveNodeData({
...data,
prompt,
is_static: isStatic,
name,
allow_interrupt: allowInterrupt,
add_global_prompt: addGlobalPrompt,
wait_for_user_response: waitForUserResponse,
detect_voicemail: detectVoicemail,
delayed_start: delayedStart,
delayed_start_duration: delayedStart ? delayedStartDuration : undefined
@ -83,11 +75,9 @@ export const StartCall = memo(({ data, selected, id }: StartCallNodeProps) => {
const handleOpenChange = (newOpen: boolean) => {
if (newOpen) {
setPrompt(data.prompt ?? "");
setIsStatic(data.is_static ?? true);
setName(data.name);
setAllowInterrupt(data.allow_interrupt ?? true);
setAddGlobalPrompt(data.add_global_prompt ?? true);
setWaitForUserResponse(data.wait_for_user_response ?? false);
setDetectVoicemail(data.detect_voicemail ?? true);
setDelayedStart(data.delayed_start ?? false);
setDelayedStartDuration(data.delayed_start_duration ?? 3);
@ -99,11 +89,9 @@ export const StartCall = memo(({ data, selected, id }: StartCallNodeProps) => {
useEffect(() => {
if (open) {
setPrompt(data.prompt ?? "");
setIsStatic(data.is_static ?? true);
setName(data.name);
setAllowInterrupt(data.allow_interrupt ?? true);
setAddGlobalPrompt(data.add_global_prompt ?? true);
setWaitForUserResponse(data.wait_for_user_response ?? false);
setDetectVoicemail(data.detect_voicemail ?? true);
setDelayedStart(data.delayed_start ?? false);
setDelayedStartDuration(data.delayed_start_duration ?? 3);
@ -147,16 +135,12 @@ export const StartCall = memo(({ data, selected, id }: StartCallNodeProps) => {
nodeData={data}
prompt={prompt}
setPrompt={setPrompt}
isStatic={isStatic}
setIsStatic={setIsStatic}
name={name}
setName={setName}
allowInterrupt={allowInterrupt}
setAllowInterrupt={setAllowInterrupt}
addGlobalPrompt={addGlobalPrompt}
setAddGlobalPrompt={setAddGlobalPrompt}
waitForUserResponse={waitForUserResponse}
setWaitForUserResponse={setWaitForUserResponse}
detectVoicemail={detectVoicemail}
setDetectVoicemail={setDetectVoicemail}
delayedStart={delayedStart}
@ -173,16 +157,12 @@ export const StartCall = memo(({ data, selected, id }: StartCallNodeProps) => {
const StartCallEditForm = ({
prompt,
setPrompt,
isStatic,
setIsStatic,
name,
setName,
allowInterrupt,
setAllowInterrupt,
addGlobalPrompt,
setAddGlobalPrompt,
waitForUserResponse,
setWaitForUserResponse,
detectVoicemail,
setDetectVoicemail,
delayedStart,
@ -201,14 +181,10 @@ const StartCallEditForm = ({
onChange={(e) => setName(e.target.value)}
/>
<Label>{isStatic ? "Text" : "Prompt"}</Label>
<Label>Prompt</Label>
<Label className="text-xs text-gray-500">
What would you like the agent to say when the call starts? Its a good idea to have a static greeting that can be used to identify the call.
Enter the prompt for the agent. This will be used to generate the agent&apos;s response. Prompt engineering&apos;s best practices apply.
</Label>
<div className="flex items-center space-x-2">
<Switch id="static-text" checked={isStatic} onCheckedChange={setIsStatic} />
<Label htmlFor="static-text">Static Text</Label>
</div>
<Textarea
value={prompt}
onChange={(e) => setPrompt(e.target.value)}
@ -216,7 +192,7 @@ const StartCallEditForm = ({
style={{
overflowY: 'auto'
}}
placeholder={isStatic ? "Hello, welcome to Dograh. How can I help you today?" : "Enter a dynamic prompt"}
placeholder="Enter a prompt"
/>
<div className="flex items-center space-x-2">
<Switch id="allow-interrupt" checked={allowInterrupt} onCheckedChange={setAllowInterrupt} />
@ -230,34 +206,10 @@ const StartCallEditForm = ({
id="add-global-prompt"
checked={addGlobalPrompt}
onCheckedChange={setAddGlobalPrompt}
disabled={isStatic}
/>
<Label htmlFor="add-global-prompt" className={isStatic ? "opacity-50" : ""}>
<Label htmlFor="add-global-prompt">
Add Global Prompt
</Label>
<Label className={`text-xs text-gray-500 ${isStatic ? "opacity-50" : ""}`}>
{isStatic
? "Not applicable for static text"
: "Whether you want to add global prompt with this node's prompt."}
</Label>
</div>
<div className="flex flex-col space-y-2">
<div className="flex items-center space-x-2">
<Switch
id="wait-for-user-response"
checked={waitForUserResponse}
onCheckedChange={setWaitForUserResponse}
disabled={!isStatic}
/>
<Label htmlFor="wait-for-user-response" className={!isStatic ? "opacity-50" : ""}>
Wait for user&apos;s response
</Label>
<Label className={`text-xs text-gray-500 ${!isStatic ? "opacity-50" : ""}`}>
{!isStatic
? "Only applicable for static text"
: "Wait for user to respond before disconnecting the call."}
</Label>
</div>
</div>
{!isOSSMode() && (
<div className="flex items-center space-x-2">

View file

@ -20,8 +20,6 @@ export type FlowNodeData = {
extraction_prompt?: string;
extraction_variables?: ExtractionVariable[];
add_global_prompt?: boolean;
wait_for_user_response?: boolean;
wait_for_user_response_timeout?: number;
wait_for_user_greeting?: boolean;
detect_voicemail?: boolean;
delayed_start?: boolean;