mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-07-01 08:59:46 +02:00
fix: prevent pipeline freezes when sending endframe (#77)
* fix: dont cancel task if call is already ending * Update pipecat
This commit is contained in:
parent
0a8ce3f644
commit
909c258b6a
7 changed files with 44 additions and 20 deletions
|
|
@ -58,16 +58,21 @@ def register_transport_event_handlers(
|
|||
|
||||
@transport.event_handler("on_client_disconnected")
|
||||
async def on_client_disconnected(transport, participant):
|
||||
logger.debug("In on_client_disconnected callback handler")
|
||||
await engine.handle_client_disconnected()
|
||||
call_disposed = engine.is_call_disposed()
|
||||
|
||||
logger.debug(
|
||||
f"In on_client_disconnected callback handler. Call disposed: {call_disposed}"
|
||||
)
|
||||
engine.handle_client_disconnected()
|
||||
|
||||
# Stop recordings
|
||||
await audio_buffer.stop_recording()
|
||||
if audio_synchronizer:
|
||||
await audio_synchronizer.stop_recording()
|
||||
|
||||
# Cancel the task since the client is disconnected
|
||||
await task.cancel()
|
||||
# Only cancel the task if the call is not already disposed by the engine
|
||||
if not call_disposed:
|
||||
await task.cancel()
|
||||
|
||||
# Return the buffers so they can be passed to other handlers
|
||||
return in_memory_audio_buffer, in_memory_transcript_buffer
|
||||
|
|
|
|||
|
|
@ -52,6 +52,7 @@ class PipelineEngineCallbacksProcessor(FrameProcessor):
|
|||
# Include TTSSpeakFrame here since for static nodes, we send TTSSpeakFrame
|
||||
# which can act as reference while fixing the aggregated trascript
|
||||
await self._llm_text_frame_callback(frame.text)
|
||||
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
async def _start(self, _: StartFrame):
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
import asyncio
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import HTTPException, WebSocket
|
||||
|
|
@ -39,7 +40,7 @@ from api.services.telephony.stasis_rtp_connection import StasisRTPConnection
|
|||
from api.services.workflow.dto import ReactFlowDTO
|
||||
from api.services.workflow.pipecat_engine import PipecatEngine
|
||||
from api.services.workflow.workflow import WorkflowGraph
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.base_task import PipelineTaskParams
|
||||
from pipecat.processors.aggregators.llm_response import LLMAssistantAggregatorParams
|
||||
from pipecat.processors.aggregators.llm_response_universal import (
|
||||
LLMContextAggregatorPair,
|
||||
|
|
@ -513,9 +514,12 @@ async def _run_pipeline(
|
|||
|
||||
try:
|
||||
# Run the pipeline
|
||||
runner = PipelineRunner()
|
||||
await runner.run(task)
|
||||
logger.info(f"Pipeline runner completed for run {workflow_run_id}")
|
||||
loop = asyncio.get_running_loop()
|
||||
params = PipelineTaskParams(loop=loop)
|
||||
await task.run(params)
|
||||
logger.info(f"Task completed for run {workflow_run_id}")
|
||||
except asyncio.CancelledError:
|
||||
logger.warning("Received CancelledError in _run_pipeline")
|
||||
finally:
|
||||
ContextProviderRegistry.remove_providers(str(workflow_run_id))
|
||||
logger.debug(f"Cleaned up context providers for workflow run {workflow_run_id}")
|
||||
|
|
|
|||
|
|
@ -70,15 +70,15 @@ def create_tts_service(user_config, audio_config: "AudioConfig"):
|
|||
xml_function_tag_filter = XMLFunctionTagFilter()
|
||||
if user_config.tts.provider == ServiceProviders.DEEPGRAM.value:
|
||||
return DeepgramTTSService(
|
||||
api_key=user_config.tts.api_key,
|
||||
api_key=user_config.tts.api_key,
|
||||
voice=user_config.tts.voice.value,
|
||||
text_filters=[xml_function_tag_filter]
|
||||
text_filters=[xml_function_tag_filter],
|
||||
)
|
||||
elif user_config.tts.provider == ServiceProviders.OPENAI.value:
|
||||
return OpenAITTSService(
|
||||
api_key=user_config.tts.api_key,
|
||||
api_key=user_config.tts.api_key,
|
||||
model=user_config.tts.model.value,
|
||||
text_filters=[xml_function_tag_filter]
|
||||
text_filters=[xml_function_tag_filter],
|
||||
)
|
||||
elif user_config.tts.provider == ServiceProviders.ELEVENLABS.value:
|
||||
voice_id = user_config.tts.voice.split(" - ")[1]
|
||||
|
|
@ -90,7 +90,7 @@ def create_tts_service(user_config, audio_config: "AudioConfig"):
|
|||
params=ElevenLabsTTSService.InputParams(
|
||||
stability=0.8, speed=user_config.tts.speed, similarity_boost=0.75
|
||||
),
|
||||
text_filters=[xml_function_tag_filter]
|
||||
text_filters=[xml_function_tag_filter],
|
||||
)
|
||||
elif user_config.tts.provider == ServiceProviders.DOGRAH.value:
|
||||
# Convert HTTP URL to WebSocket URL for TTS
|
||||
|
|
@ -101,7 +101,7 @@ def create_tts_service(user_config, audio_config: "AudioConfig"):
|
|||
api_key=user_config.tts.api_key,
|
||||
model=user_config.tts.model.value,
|
||||
voice=user_config.tts.voice.value,
|
||||
text_filters=[xml_function_tag_filter]
|
||||
text_filters=[xml_function_tag_filter],
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue