mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-19 08:28:10 +02:00
feat: simplify pipecat engine execution (#54)
This commit is contained in:
parent
99a768f291
commit
6ce25a589c
20 changed files with 52 additions and 1405 deletions
|
|
@ -14,14 +14,14 @@ from pipecat.frames.frames import (
|
|||
CancelFrame,
|
||||
EndFrame,
|
||||
FunctionCallResultProperties,
|
||||
LLMContextFrame,
|
||||
LLMFullResponseEndFrame,
|
||||
LLMFullResponseStartFrame,
|
||||
TTSSpeakFrame,
|
||||
)
|
||||
from pipecat.pipeline.task import PipelineTask
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContextFrame
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.services.llm_service import FunctionCallParams
|
||||
from pipecat.services.openai.llm import OpenAILLMContext
|
||||
from pipecat.transports.base_transport import BaseTransport
|
||||
from pipecat.utils.enums import EndTaskReason
|
||||
|
||||
|
|
@ -63,7 +63,7 @@ class PipecatEngine:
|
|||
*,
|
||||
task: Optional[PipelineTask] = None,
|
||||
llm: Optional["LLMService"] = None,
|
||||
context: Optional[OpenAILLMContext] = None,
|
||||
context: Optional[LLMContext] = None,
|
||||
tts: Optional[Any] = None,
|
||||
transport: Optional[BaseTransport] = None,
|
||||
workflow: WorkflowGraph,
|
||||
|
|
@ -82,7 +82,6 @@ class PipecatEngine:
|
|||
self._workflow_run_id = workflow_run_id
|
||||
self._initialized = False
|
||||
self._client_disconnected = False
|
||||
self._pending_function_calls = 0
|
||||
self._current_node: Optional[Node] = None
|
||||
self._gathered_context: dict = {}
|
||||
self._user_response_timeout_task: Optional[asyncio.Task] = None
|
||||
|
|
@ -102,29 +101,9 @@ class PipecatEngine:
|
|||
self._voicemail_detector = None
|
||||
self._voicemail_detection_task: Optional[asyncio.Task] = None
|
||||
|
||||
# This transition is generated by the llm as part of tool call. This can
|
||||
# also be accompanied with some content which can be played using TTS. If the
|
||||
# bot is interrupted, we would cancel this transition (we do cancel this currently when
|
||||
# the next generation starts in handle_generation_started callback handler.)
|
||||
self._pending_generated_transition_after_context_push: Optional[
|
||||
Callable[[], Awaitable[None]]
|
||||
] = None
|
||||
|
||||
# This is the transtion which is typically programmatic transition, and not goes as
|
||||
# tool call to LLM. This is not interrupted by the user and is done on context push
|
||||
self._pending_control_transition_after_context_push: Optional[
|
||||
Callable[[], Awaitable[None]]
|
||||
] = None
|
||||
|
||||
# Flag to determine if the current llm generation has a text completion
|
||||
self._defer_context_push: bool = False
|
||||
|
||||
# Lazy loaded built-in function schemas
|
||||
self._builtin_function_schemas: Optional[list[dict]] = None
|
||||
|
||||
# Flag to control whether to queue context frame
|
||||
self._queue_context_frame: bool = True
|
||||
|
||||
# Track current LLM reference text for TTS aggregation correction
|
||||
self._current_llm_reference_text: str = ""
|
||||
|
||||
|
|
@ -211,23 +190,15 @@ class PipecatEngine:
|
|||
|
||||
async def _create_transition_func(self, name: str, transition_to_node: str):
|
||||
async def transition_func(function_call_params: FunctionCallParams) -> None:
|
||||
"""Inner function that handles the actual tool invocation."""
|
||||
"""Inner function that handles the node change tool calls"""
|
||||
try:
|
||||
# Track pending function call
|
||||
self._pending_function_calls += 1
|
||||
logger.debug(
|
||||
f"Function call pending: {function_call_params.function_name} (total: {self._pending_function_calls})"
|
||||
)
|
||||
|
||||
# For edge functions, prevent LLM completion until transition (run_llm=False)
|
||||
# For node functions, allow immediate completion (run_llm=True)
|
||||
async def on_context_updated() -> None:
|
||||
"""
|
||||
Framework will run this function after the function call result has been updated in the context.
|
||||
pipecat framework will run this function after the function call result has been updated in the context.
|
||||
This way, when we do set_node from within this function, and go for LLM completion with updated
|
||||
system prompts, the context is updated with function call result.
|
||||
"""
|
||||
self._pending_function_calls -= 1
|
||||
# Perform variable extraction before transitioning to new node
|
||||
await self._perform_variable_extraction_if_needed(
|
||||
self._current_node
|
||||
|
|
@ -241,41 +212,14 @@ class PipecatEngine:
|
|||
on_context_updated=on_context_updated,
|
||||
)
|
||||
|
||||
async def _invoke_result_callback():
|
||||
"""
|
||||
Functions are executed immediately when they come from LLM as part of text completion.
|
||||
But, if the LLM completion also has some text, we would want to not call the function if the user interrupts the speech.
|
||||
We would also not want the function to be added to context, so that the LLM can call the function again. Hence, we
|
||||
defer the function invocation until we receive on_context_updated callback, i.e the bot has finished speaking
|
||||
the text that was generated.
|
||||
"""
|
||||
await function_call_params.result_callback(
|
||||
result, properties=properties
|
||||
)
|
||||
|
||||
if self._defer_context_push:
|
||||
"""
|
||||
We set the flag to _defer_context_push when we receive text in the current generation from LLM.
|
||||
This is set in the handle_llm_generated_text callback handler.
|
||||
"""
|
||||
logger.debug(
|
||||
"Deferring transition function result until context push"
|
||||
)
|
||||
# Only one deferred transition should exist at any time.
|
||||
# Overwrite if one is somehow already set (unexpected).
|
||||
self._pending_generated_transition_after_context_push = (
|
||||
_invoke_result_callback
|
||||
)
|
||||
else:
|
||||
"""
|
||||
If there was no text in the current generation, and we only had function call,
|
||||
lets invoke the result callback, so that framework can call on_context_updated and
|
||||
we can do switch node.
|
||||
"""
|
||||
await _invoke_result_callback()
|
||||
# Call results callback from the pipecat framework
|
||||
# so that a new llm generation can be triggred if
|
||||
# required
|
||||
await function_call_params.result_callback(
|
||||
result, properties=properties
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in transition function {name}: {str(e)}")
|
||||
self._pending_function_calls = 0
|
||||
error_result = {"status": "error", "error": str(e)}
|
||||
await function_call_params.result_callback(error_result)
|
||||
|
||||
|
|
@ -362,27 +306,6 @@ class PipecatEngine:
|
|||
]
|
||||
)
|
||||
|
||||
async def _setup_static_start_node_transition(self, node: Node) -> None:
|
||||
"""Set up the deferred transition for static start nodes."""
|
||||
if not node.out_edges:
|
||||
return
|
||||
|
||||
next_node_id = node.out_edges[0].target
|
||||
|
||||
if not node.wait_for_user_response:
|
||||
# Normal static start node - transition immediately after context push
|
||||
async def _deferred_static_transition():
|
||||
try:
|
||||
await self.set_node(next_node_id)
|
||||
except Exception as exc:
|
||||
logger.error(
|
||||
f"Error executing deferred static node transition to {next_node_id}: {exc}"
|
||||
)
|
||||
|
||||
self._pending_control_transition_after_context_push = (
|
||||
_deferred_static_transition
|
||||
)
|
||||
|
||||
async def _perform_variable_extraction_if_needed(
|
||||
self, previous_node: Optional[Node]
|
||||
) -> None:
|
||||
|
|
@ -441,17 +364,7 @@ class PipecatEngine:
|
|||
functions,
|
||||
) = await self._compose_system_message_functions_for_node(node)
|
||||
await self._update_llm_context(system_message, functions)
|
||||
|
||||
# Queue context frame if needed
|
||||
if self._queue_context_frame:
|
||||
await self.task.queue_frame(OpenAILLMContextFrame(self.context))
|
||||
else:
|
||||
logger.debug(
|
||||
f"Not queueing context frame for node: {node.name} as _queue_context_frame is False"
|
||||
)
|
||||
|
||||
# Reset _queue_context_frame as default behavior
|
||||
self._queue_context_frame = True
|
||||
await self.task.queue_frame(LLMContextFrame(self.context))
|
||||
|
||||
async def set_node(self, node_id: str):
|
||||
"""
|
||||
|
|
@ -525,12 +438,7 @@ class PipecatEngine:
|
|||
await asyncio.sleep(delay_duration)
|
||||
|
||||
if node.is_static:
|
||||
# Queue TTS for static start node
|
||||
formatted_prompt = self._format_prompt(node.prompt)
|
||||
await self._queue_tts_response(formatted_prompt)
|
||||
|
||||
# Set up deferred transition for static start nodes
|
||||
await self._setup_static_start_node_transition(node)
|
||||
raise ValueError("Static nodes are not supported!")
|
||||
else:
|
||||
# Start generation for non-static start node
|
||||
await self._setup_llm_context_and_start_generation(node)
|
||||
|
|
@ -538,66 +446,24 @@ class PipecatEngine:
|
|||
async def _handle_end_node(self, node: Node) -> None:
|
||||
"""Handle end node execution."""
|
||||
if node.is_static:
|
||||
# Queue TTS for static end node
|
||||
formatted_prompt = self._format_prompt(node.prompt)
|
||||
await self._queue_tts_response(formatted_prompt)
|
||||
raise ValueError("Static nodes are not supported!")
|
||||
else:
|
||||
# Start generation for non-static end node
|
||||
await self._setup_llm_context_and_start_generation(node)
|
||||
|
||||
# If this end node has extraction enabled, perform extraction immediately
|
||||
if node.extraction_enabled and node.extraction_variables:
|
||||
await self._perform_variable_extraction_if_needed(node)
|
||||
|
||||
# TODO: Extract disposition code from extracted variables
|
||||
# Defer send_end_task_frame using _pending_control_transition_after_context_push
|
||||
|
||||
# Decide the end-task reason dynamically depending on call_disposition.
|
||||
async def _deferred_end_task():
|
||||
# call_disposition is the disposition which is generated from
|
||||
# llm call based on the conversation so far.
|
||||
# TODO: Make this more generic based on configuration or llm prompting
|
||||
disposition = self._gathered_context.get("call_disposition")
|
||||
if disposition == "XFER":
|
||||
reason = EndTaskReason.USER_QUALIFIED.value
|
||||
else:
|
||||
reason = EndTaskReason.USER_DISQUALIFIED.value
|
||||
await self.send_end_task_frame(reason)
|
||||
|
||||
self._pending_control_transition_after_context_push = _deferred_end_task
|
||||
await self.send_end_task_frame(EndTaskReason.USER_QUALIFIED.value)
|
||||
|
||||
async def _handle_agent_node(self, node: Node) -> None:
|
||||
"""Handle agent node execution."""
|
||||
if node.is_static:
|
||||
# Queue TTS for static agent node
|
||||
formatted_prompt = self._format_prompt(node.prompt)
|
||||
await self._queue_tts_response(formatted_prompt)
|
||||
|
||||
# Set up deferred transition for static agent nodes
|
||||
await self._setup_agent_node_transition(node)
|
||||
raise ValueError("Static nodes are not supported!")
|
||||
else:
|
||||
# Set context and functions for non-static agent node
|
||||
await self._setup_llm_context_and_start_generation(node)
|
||||
|
||||
async def _setup_agent_node_transition(self, node: Node) -> None:
|
||||
"""Set up the deferred transition for static agent nodes."""
|
||||
if not node.out_edges:
|
||||
return
|
||||
|
||||
next_node_id = node.out_edges[0].target
|
||||
|
||||
async def _deferred_static_transition():
|
||||
try:
|
||||
await self.set_node(next_node_id)
|
||||
except Exception as exc:
|
||||
logger.error(
|
||||
f"Error executing deferred static node transition to {next_node_id}: {exc}"
|
||||
)
|
||||
|
||||
self._pending_control_transition_after_context_push = (
|
||||
_deferred_static_transition
|
||||
)
|
||||
|
||||
async def send_end_task_frame(
|
||||
self,
|
||||
reason: str,
|
||||
|
|
@ -640,7 +506,7 @@ class PipecatEngine:
|
|||
# Store the mapped disconnect reason
|
||||
self._gathered_context["call_disposition"] = mapped_disposition
|
||||
|
||||
# TODO: Generalise this, currently tailored to Kapil's use case
|
||||
# TODO: Generalise this
|
||||
self._gathered_context["address"] = ", ".join(
|
||||
[
|
||||
self._call_context_vars.get("address1", ""),
|
||||
|
|
@ -759,55 +625,6 @@ class PipecatEngine:
|
|||
|
||||
return system_message, functions
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Pending transition handling
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def flush_pending_transitions(self, *, source: str = "context_push"):
|
||||
"""Execute and clear any pending transitions.
|
||||
|
||||
Args:
|
||||
source: Indicates the trigger that caused this flush:
|
||||
- "context_push": the assistant context aggregator completed a push.
|
||||
"""
|
||||
|
||||
if source != "context_push":
|
||||
raise ValueError("Invalid flush source – expected 'context_push'")
|
||||
|
||||
len_pending_functions = 0
|
||||
|
||||
if self._pending_generated_transition_after_context_push is not None:
|
||||
len_pending_functions += 1
|
||||
if self._pending_control_transition_after_context_push is not None:
|
||||
len_pending_functions += 1
|
||||
|
||||
# Nothing to do
|
||||
if len_pending_functions == 0:
|
||||
return
|
||||
|
||||
logger.debug(
|
||||
f"Flushing {len_pending_functions} pending transition(s) after {source.replace('_', ' ')}"
|
||||
)
|
||||
|
||||
# Generated transition
|
||||
if self._pending_generated_transition_after_context_push is not None:
|
||||
pending_cb = self._pending_generated_transition_after_context_push
|
||||
self._pending_generated_transition_after_context_push = None
|
||||
try:
|
||||
await pending_cb()
|
||||
except Exception as exc: # pragma: no cover
|
||||
logger.error(f"Error executing deferred transition: {exc}")
|
||||
|
||||
# Control transition (context push)
|
||||
if self._pending_control_transition_after_context_push is not None:
|
||||
logger.debug("Executing control transition after context push")
|
||||
static_cb = self._pending_control_transition_after_context_push
|
||||
self._pending_control_transition_after_context_push = None
|
||||
try:
|
||||
await static_cb()
|
||||
except Exception as exc: # pragma: no cover
|
||||
logger.error(f"Error executing deferred static node transition: {exc}")
|
||||
|
||||
def create_should_mute_callback(self) -> Callable[[STTMuteFilter], Awaitable[bool]]:
|
||||
"""
|
||||
This callback is called by STTMuteFilter to determine if the STT should be muted.
|
||||
|
|
@ -828,15 +645,6 @@ class PipecatEngine:
|
|||
"""
|
||||
return engine_callbacks.create_max_duration_callback(self)
|
||||
|
||||
def create_llm_generated_text_callback(self):
|
||||
"""
|
||||
This callback is called when some text is generated by the LLM.
|
||||
We use this to defer the result_callback of the node transition functions if
|
||||
there is set_node called along with some text generated. This way, we will
|
||||
have the context sent in the next generation from new node.
|
||||
"""
|
||||
return engine_callbacks.create_llm_generated_text_callback(self)
|
||||
|
||||
def create_generation_started_callback(self):
|
||||
"""
|
||||
This callback is called when a new generation starts.
|
||||
|
|
@ -844,26 +652,12 @@ class PipecatEngine:
|
|||
"""
|
||||
return engine_callbacks.create_generation_started_callback(self)
|
||||
|
||||
def create_user_stopped_speaking_callback(self):
|
||||
"""
|
||||
This callback is called when the user stops speaking.
|
||||
We use this to handle transitions when wait_for_user_response is enabled.
|
||||
"""
|
||||
return engine_callbacks.create_user_stopped_speaking_callback(self)
|
||||
|
||||
def create_user_started_speaking_callback(self):
|
||||
"""
|
||||
This callback is called when the user starts speaking.
|
||||
We use this to handle wait_for_user_greeting functionality.
|
||||
"""
|
||||
return engine_callbacks.create_user_started_speaking_callback(self)
|
||||
|
||||
def create_aggregation_correction_callback(self) -> Callable[[str], str]:
|
||||
"""Create a callback that corrects corrupted aggregation using reference text."""
|
||||
return engine_callbacks.create_aggregation_correction_callback(self)
|
||||
|
||||
def set_context(self, context: OpenAILLMContext) -> None:
|
||||
"""Set the OpenAI LLM context.
|
||||
def set_context(self, context: LLMContext) -> None:
|
||||
"""Set the LLM context.
|
||||
|
||||
This allows setting the context after the engine has been created,
|
||||
which is useful when the context needs to be created after the engine.
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue