mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-07 07:55:16 +02:00
fix: prevent pipeline freezes when sending endframe (#77)
* fix: dont cancel task if call is already ending * Update pipecat
This commit is contained in:
parent
0a8ce3f644
commit
909c258b6a
7 changed files with 44 additions and 20 deletions
|
|
@ -136,9 +136,7 @@ def apply_workflow_run_filters(
|
|||
)
|
||||
# Use -> operator with literal text key to get call_tags as JSONB
|
||||
call_tags = gathered_context_jsonb.op("->")("call_tags")
|
||||
filter_conditions.append(
|
||||
call_tags.op("@>")(func.cast(tags, JSONB))
|
||||
)
|
||||
filter_conditions.append(call_tags.op("@>")(func.cast(tags, JSONB)))
|
||||
|
||||
elif filter_type == "text" and field == "initial_context.phone":
|
||||
# Filter by phone number (contains search)
|
||||
|
|
|
|||
|
|
@ -58,16 +58,21 @@ def register_transport_event_handlers(
|
|||
|
||||
@transport.event_handler("on_client_disconnected")
|
||||
async def on_client_disconnected(transport, participant):
|
||||
logger.debug("In on_client_disconnected callback handler")
|
||||
await engine.handle_client_disconnected()
|
||||
call_disposed = engine.is_call_disposed()
|
||||
|
||||
logger.debug(
|
||||
f"In on_client_disconnected callback handler. Call disposed: {call_disposed}"
|
||||
)
|
||||
engine.handle_client_disconnected()
|
||||
|
||||
# Stop recordings
|
||||
await audio_buffer.stop_recording()
|
||||
if audio_synchronizer:
|
||||
await audio_synchronizer.stop_recording()
|
||||
|
||||
# Cancel the task since the client is disconnected
|
||||
await task.cancel()
|
||||
# Only cancel the task if the call is not already disposed by the engine
|
||||
if not call_disposed:
|
||||
await task.cancel()
|
||||
|
||||
# Return the buffers so they can be passed to other handlers
|
||||
return in_memory_audio_buffer, in_memory_transcript_buffer
|
||||
|
|
|
|||
|
|
@ -52,6 +52,7 @@ class PipelineEngineCallbacksProcessor(FrameProcessor):
|
|||
# Include TTSSpeakFrame here since for static nodes, we send TTSSpeakFrame
|
||||
# which can act as reference while fixing the aggregated trascript
|
||||
await self._llm_text_frame_callback(frame.text)
|
||||
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
async def _start(self, _: StartFrame):
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
import asyncio
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import HTTPException, WebSocket
|
||||
|
|
@ -39,7 +40,7 @@ from api.services.telephony.stasis_rtp_connection import StasisRTPConnection
|
|||
from api.services.workflow.dto import ReactFlowDTO
|
||||
from api.services.workflow.pipecat_engine import PipecatEngine
|
||||
from api.services.workflow.workflow import WorkflowGraph
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.base_task import PipelineTaskParams
|
||||
from pipecat.processors.aggregators.llm_response import LLMAssistantAggregatorParams
|
||||
from pipecat.processors.aggregators.llm_response_universal import (
|
||||
LLMContextAggregatorPair,
|
||||
|
|
@ -513,9 +514,12 @@ async def _run_pipeline(
|
|||
|
||||
try:
|
||||
# Run the pipeline
|
||||
runner = PipelineRunner()
|
||||
await runner.run(task)
|
||||
logger.info(f"Pipeline runner completed for run {workflow_run_id}")
|
||||
loop = asyncio.get_running_loop()
|
||||
params = PipelineTaskParams(loop=loop)
|
||||
await task.run(params)
|
||||
logger.info(f"Task completed for run {workflow_run_id}")
|
||||
except asyncio.CancelledError:
|
||||
logger.warning("Received CancelledError in _run_pipeline")
|
||||
finally:
|
||||
ContextProviderRegistry.remove_providers(str(workflow_run_id))
|
||||
logger.debug(f"Cleaned up context providers for workflow run {workflow_run_id}")
|
||||
|
|
|
|||
|
|
@ -70,15 +70,15 @@ def create_tts_service(user_config, audio_config: "AudioConfig"):
|
|||
xml_function_tag_filter = XMLFunctionTagFilter()
|
||||
if user_config.tts.provider == ServiceProviders.DEEPGRAM.value:
|
||||
return DeepgramTTSService(
|
||||
api_key=user_config.tts.api_key,
|
||||
api_key=user_config.tts.api_key,
|
||||
voice=user_config.tts.voice.value,
|
||||
text_filters=[xml_function_tag_filter]
|
||||
text_filters=[xml_function_tag_filter],
|
||||
)
|
||||
elif user_config.tts.provider == ServiceProviders.OPENAI.value:
|
||||
return OpenAITTSService(
|
||||
api_key=user_config.tts.api_key,
|
||||
api_key=user_config.tts.api_key,
|
||||
model=user_config.tts.model.value,
|
||||
text_filters=[xml_function_tag_filter]
|
||||
text_filters=[xml_function_tag_filter],
|
||||
)
|
||||
elif user_config.tts.provider == ServiceProviders.ELEVENLABS.value:
|
||||
voice_id = user_config.tts.voice.split(" - ")[1]
|
||||
|
|
@ -90,7 +90,7 @@ def create_tts_service(user_config, audio_config: "AudioConfig"):
|
|||
params=ElevenLabsTTSService.InputParams(
|
||||
stability=0.8, speed=user_config.tts.speed, similarity_boost=0.75
|
||||
),
|
||||
text_filters=[xml_function_tag_filter]
|
||||
text_filters=[xml_function_tag_filter],
|
||||
)
|
||||
elif user_config.tts.provider == ServiceProviders.DOGRAH.value:
|
||||
# Convert HTTP URL to WebSocket URL for TTS
|
||||
|
|
@ -101,7 +101,7 @@ def create_tts_service(user_config, audio_config: "AudioConfig"):
|
|||
api_key=user_config.tts.api_key,
|
||||
model=user_config.tts.model.value,
|
||||
voice=user_config.tts.voice.value,
|
||||
text_filters=[xml_function_tag_filter]
|
||||
text_filters=[xml_function_tag_filter],
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
|
|
|
|||
|
|
@ -81,6 +81,7 @@ class PipecatEngine:
|
|||
self._workflow_run_id = workflow_run_id
|
||||
self._initialized = False
|
||||
self._client_disconnected = False
|
||||
self._call_disposed = False
|
||||
self._current_node: Optional[Node] = None
|
||||
self._gathered_context: dict = {}
|
||||
self._user_response_timeout_task: Optional[asyncio.Task] = None
|
||||
|
|
@ -182,7 +183,9 @@ class PipecatEngine:
|
|||
async def transition_func(function_call_params: FunctionCallParams) -> None:
|
||||
"""Inner function that handles the node change tool calls"""
|
||||
logger.info(f"LLM Function Call EXECUTED: {name}")
|
||||
logger.info(f"Function: {name} -> transitioning to node: {transition_to_node}")
|
||||
logger.info(
|
||||
f"Function: {name} -> transitioning to node: {transition_to_node}"
|
||||
)
|
||||
logger.info(f"Arguments: {function_call_params.arguments}")
|
||||
try:
|
||||
|
||||
|
|
@ -472,6 +475,15 @@ class PipecatEngine:
|
|||
Centralized method to send EndTaskFrame with metadata including
|
||||
call_transfer_context and call_context_vars
|
||||
"""
|
||||
if self._call_disposed or self._client_disconnected:
|
||||
# Call is already disposed and client disconnected
|
||||
logger.debug(
|
||||
f"Not sending EndFrame since call is already disposed: Call Disposed: {self._call_disposed} Client Disconnected: {self._client_disconnected}"
|
||||
)
|
||||
return
|
||||
|
||||
self._call_disposed = True
|
||||
|
||||
frame_to_push = CancelFrame() if abort_immediately else EndFrame()
|
||||
|
||||
# Customer disposition code using their mapping
|
||||
|
|
@ -700,10 +712,14 @@ class PipecatEngine:
|
|||
"""Accumulate LLM text frames to build reference text."""
|
||||
self._current_llm_reference_text += text
|
||||
|
||||
async def handle_client_disconnected(self):
|
||||
def handle_client_disconnected(self):
|
||||
"""Handle client disconnected event."""
|
||||
self._client_disconnected = True
|
||||
|
||||
def is_call_disposed(self):
|
||||
"""Check whether a call has been disposed by the engine"""
|
||||
return self._call_disposed
|
||||
|
||||
async def get_call_disposition(self) -> Optional[str]:
|
||||
"""Get the disconnect reason set by the engine."""
|
||||
if self._call_disposition:
|
||||
|
|
|
|||
2
pipecat
2
pipecat
|
|
@ -1 +1 @@
|
|||
Subproject commit 6e6cf412ad1d3251c3f4b0c06a85ae9a66b5719e
|
||||
Subproject commit 5f7c03c6a0d10fa5804fabd4cf148ac2805624fe
|
||||
Loading…
Add table
Add a link
Reference in a new issue