mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-07 07:55:16 +02:00
Merge pull request #7 from dograh-hq/enh/standardise-transport
on_client_disconnect only when user disconnects.
This commit is contained in:
commit
4b4a7ba19a
11 changed files with 162 additions and 214 deletions
|
|
@ -1,4 +1,4 @@
|
|||
pipecat-ai[cartesia,deepgram,openai,elevenlabs,groq,google,azure,soundfile,silero,webrtc] @ git+https://github.com/dograh-hq/pipecat.git@9dbd5eb
|
||||
pipecat-ai[cartesia,deepgram,openai,elevenlabs,groq,google,azure,soundfile,silero,webrtc] @ git+https://github.com/dograh-hq/pipecat.git@c327208
|
||||
langfuse==3.4.0
|
||||
fastapi==0.116.2
|
||||
asyncpg==0.30.0
|
||||
|
|
|
|||
|
|
@ -1,9 +1,8 @@
|
|||
from typing import Optional
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from api.db import db_client
|
||||
from api.services.campaign.call_dispatcher import campaign_call_dispatcher
|
||||
from api.services.pipecat.audio_config import AudioConfig
|
||||
from api.services.pipecat.audio_transcript_buffers import (
|
||||
InMemoryAudioBuffer,
|
||||
InMemoryTranscriptBuffer,
|
||||
|
|
@ -16,20 +15,20 @@ from api.services.workflow.disposition_mapper import (
|
|||
from api.services.workflow.pipecat_engine import PipecatEngine
|
||||
from api.tasks.arq import enqueue_job
|
||||
from api.tasks.function_names import FunctionNames
|
||||
from pipecat.frames.frames import Frame
|
||||
from pipecat.pipeline.task import PipelineTask
|
||||
from pipecat.transports.base_transport import BaseTransport
|
||||
from pipecat.utils.enums import EndTaskReason
|
||||
from pipecat.processors.audio.audio_buffer_processor import AudioBuffer
|
||||
from pipecat.processors.audio.audio_synchronizer import AudioSynchronizer
|
||||
|
||||
|
||||
def register_transport_event_handlers(
|
||||
task: PipelineTask,
|
||||
transport,
|
||||
workflow_run_id,
|
||||
audio_buffer,
|
||||
task: PipelineTask,
|
||||
engine: PipecatEngine,
|
||||
usage_metrics_aggregator: PipelineMetricsAggregator,
|
||||
audio_synchronizer=None,
|
||||
audio_config=None,
|
||||
audio_buffer: AudioBuffer,
|
||||
audio_synchronizer: AudioSynchronizer,
|
||||
audio_config=AudioConfig,
|
||||
):
|
||||
"""Register event handlers for transport events"""
|
||||
|
||||
|
|
@ -58,52 +57,55 @@ def register_transport_event_handlers(
|
|||
await engine.initialize()
|
||||
|
||||
@transport.event_handler("on_client_disconnected")
|
||||
async def on_client_disconnected(
|
||||
transport: BaseTransport,
|
||||
participant,
|
||||
transport_disconnect_reason: Optional[str] = None,
|
||||
async def on_client_disconnected(transport, participant):
|
||||
logger.debug("In on_client_disconnected callback handler")
|
||||
await engine.handle_client_disconnected()
|
||||
|
||||
# Stop recordings
|
||||
await audio_buffer.stop_recording()
|
||||
if audio_synchronizer:
|
||||
await audio_synchronizer.stop_recording()
|
||||
|
||||
# Cancel the task since the client is disconnected
|
||||
await task.cancel()
|
||||
|
||||
# Return the buffers so they can be passed to other handlers
|
||||
return in_memory_audio_buffer, in_memory_transcript_buffer
|
||||
|
||||
|
||||
def register_task_event_handler(
|
||||
workflow_run_id: int,
|
||||
engine: PipecatEngine,
|
||||
task: PipelineTask,
|
||||
transport,
|
||||
audio_buffer: AudioBuffer,
|
||||
audio_synchronizer: AudioSynchronizer,
|
||||
in_memory_audio_buffer: InMemoryAudioBuffer,
|
||||
in_memory_transcript_buffer: InMemoryTranscriptBuffer,
|
||||
pipeline_metrics_aggregator: PipelineMetricsAggregator,
|
||||
):
|
||||
@task.event_handler("on_pipeline_finished")
|
||||
async def on_pipeline_finished(
|
||||
task: PipelineTask,
|
||||
frame: Frame,
|
||||
):
|
||||
logger.debug(
|
||||
f"In on_client_disconnected callback handler, disconnect_reason: {transport_disconnect_reason}"
|
||||
)
|
||||
logger.debug(f"In on_pipeline_finished callback handler")
|
||||
|
||||
workflow_run = await db_client.get_workflow_run_by_id(workflow_run_id)
|
||||
|
||||
# First priority: Check if engine has a disconnect reason (local disconnect)
|
||||
engine_call_disposition = engine.get_call_disposition()
|
||||
gathered_context = engine.get_gathered_context()
|
||||
# Stop recordings
|
||||
await audio_buffer.stop_recording()
|
||||
if audio_synchronizer:
|
||||
await audio_synchronizer.stop_recording()
|
||||
|
||||
call_disposition = await engine.get_call_disposition()
|
||||
logger.debug(f"call disposition in on_pipeline_finished: {call_disposition}")
|
||||
|
||||
gathered_context = await engine.get_gathered_context()
|
||||
|
||||
# also consider existing gathered context in workflow_run
|
||||
gathered_context = {**gathered_context, **workflow_run.gathered_context}
|
||||
|
||||
if engine_call_disposition:
|
||||
# Engine has set a disconnect reason - this takes priority
|
||||
call_disposition = engine_call_disposition
|
||||
logger.debug(f"Engine disposition detected, code: {call_disposition}")
|
||||
elif transport_disconnect_reason:
|
||||
# TODO: Make this more generic using some DSL or equivalent. This is currently
|
||||
# configured to work for Kapil's bot
|
||||
call_duration = usage_metrics_aggregator.get_call_duration()
|
||||
if transport_disconnect_reason == EndTaskReason.USER_HANGUP.value:
|
||||
if call_duration < 10:
|
||||
call_disposition = "HU"
|
||||
else:
|
||||
call_disposition = "NIBP"
|
||||
else:
|
||||
# Transport provided a disconnect reason (remote hangup)
|
||||
call_disposition = transport_disconnect_reason
|
||||
logger.debug(
|
||||
f"Remote disconnect detected, reason: {call_disposition} duration: {call_duration}"
|
||||
)
|
||||
else:
|
||||
# No reason provided - assume user hangup
|
||||
call_disposition = EndTaskReason.UNKNOWN.value
|
||||
logger.debug("No disposition found from either engine or transport")
|
||||
|
||||
# Cancel task only when no engine disconnect reason (remote disconnect)
|
||||
if not engine_call_disposition:
|
||||
await task.cancel()
|
||||
|
||||
organization_id = await get_organization_id_from_workflow_run(workflow_run_id)
|
||||
mapped_call_disposition = await apply_disposition_mapping(
|
||||
call_disposition, organization_id
|
||||
|
|
@ -111,6 +113,7 @@ def register_transport_event_handlers(
|
|||
|
||||
gathered_context.update({"mapped_call_disposition": mapped_call_disposition})
|
||||
|
||||
# Set user_speech call tag
|
||||
if in_memory_transcript_buffer:
|
||||
call_tags = gathered_context.get("call_tags", [])
|
||||
|
||||
|
|
@ -132,10 +135,6 @@ def register_transport_event_handlers(
|
|||
# Clean up engine resources (including voicemail detector)
|
||||
await engine.cleanup()
|
||||
|
||||
await audio_buffer.stop_recording()
|
||||
if audio_synchronizer:
|
||||
await audio_synchronizer.stop_recording()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Close Smart-Turn WebSocket if the transport's analyzer supports it
|
||||
# ------------------------------------------------------------------
|
||||
|
|
@ -163,7 +162,7 @@ def register_transport_event_handlers(
|
|||
except Exception as exc:
|
||||
logger.warning(f"Failed to close Smart-Turn analyzer gracefully: {exc}")
|
||||
|
||||
usage_info = usage_metrics_aggregator.get_all_usage_metrics_serialized()
|
||||
usage_info = pipeline_metrics_aggregator.get_all_usage_metrics_serialized()
|
||||
|
||||
logger.debug(f"Usage metrics: {usage_info}")
|
||||
|
||||
|
|
@ -209,9 +208,6 @@ def register_transport_event_handlers(
|
|||
FunctionNames.RUN_INTEGRATIONS_POST_WORKFLOW_RUN, workflow_run_id
|
||||
)
|
||||
|
||||
# Return the buffers so they can be passed to other handlers
|
||||
return in_memory_audio_buffer, in_memory_transcript_buffer
|
||||
|
||||
|
||||
def register_audio_data_handler(
|
||||
audio_synchronizer, workflow_run_id, in_memory_buffer: InMemoryAudioBuffer
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ from api.services.pipecat.engine_pre_aggregator_processor import (
|
|||
)
|
||||
from api.services.pipecat.event_handlers import (
|
||||
register_audio_data_handler,
|
||||
register_task_event_handler,
|
||||
register_transcript_handler,
|
||||
register_transport_event_handlers,
|
||||
)
|
||||
|
|
@ -361,16 +362,28 @@ async def _run_pipeline(
|
|||
# Register event handlers
|
||||
in_memory_audio_buffer, in_memory_transcript_buffer = (
|
||||
register_transport_event_handlers(
|
||||
task,
|
||||
transport,
|
||||
workflow_run_id,
|
||||
audio_buffer,
|
||||
task,
|
||||
engine=engine,
|
||||
usage_metrics_aggregator=pipeline_metrics_aggregator,
|
||||
audio_buffer=audio_buffer,
|
||||
audio_synchronizer=audio_synchronizer,
|
||||
audio_config=audio_config,
|
||||
)
|
||||
)
|
||||
|
||||
register_task_event_handler(
|
||||
workflow_run_id,
|
||||
engine,
|
||||
task,
|
||||
transport,
|
||||
audio_buffer,
|
||||
audio_synchronizer,
|
||||
in_memory_audio_buffer,
|
||||
in_memory_transcript_buffer,
|
||||
pipeline_metrics_aggregator,
|
||||
)
|
||||
|
||||
register_audio_data_handler(
|
||||
audio_synchronizer, workflow_run_id, in_memory_audio_buffer
|
||||
)
|
||||
|
|
|
|||
|
|
@ -26,6 +26,8 @@ from api.services.telephony.stasis_event_protocol import (
|
|||
RedisChannels,
|
||||
RedisKeys,
|
||||
SocketClosedCommand,
|
||||
StasisEndEvent,
|
||||
StasisStartEvent,
|
||||
TransferCommand,
|
||||
parse_command,
|
||||
)
|
||||
|
|
@ -123,13 +125,12 @@ class ARIManager:
|
|||
# Publish StasisEnd event to worker immediately
|
||||
if connection and caller_channel_id:
|
||||
worker_id = self._get_worker_for_channel(caller_channel_id)
|
||||
event = {
|
||||
"type": "stasis_end",
|
||||
"channel_id": caller_channel_id,
|
||||
"reason": EndTaskReason.USER_HANGUP.value,
|
||||
}
|
||||
event = StasisEndEvent(
|
||||
channel_id=caller_channel_id,
|
||||
reason=EndTaskReason.USER_HANGUP.value,
|
||||
)
|
||||
await self.redis.publish(
|
||||
RedisChannels.worker_events(worker_id), json.dumps(event)
|
||||
RedisChannels.worker_events(worker_id), event.to_json()
|
||||
)
|
||||
logger.info(f"channelID: {channel_id} Published StasisEnd event")
|
||||
|
||||
|
|
@ -163,16 +164,17 @@ class ARIManager:
|
|||
self._active_channels.add(channel_id)
|
||||
|
||||
# Create event with all connection details
|
||||
event = {
|
||||
"type": "stasis_start",
|
||||
"channel_id": channel_id,
|
||||
"caller_channel_id": channel_id,
|
||||
"em_channel_id": em_channel_id,
|
||||
"bridge_id": bridge_id,
|
||||
"local_addr": list(connection.local_addr),
|
||||
"remote_addr": list(connection.remote_addr),
|
||||
"call_context_vars": call_context_vars,
|
||||
}
|
||||
event = StasisStartEvent(
|
||||
channel_id=channel_id,
|
||||
caller_channel_id=channel_id,
|
||||
em_channel_id=em_channel_id,
|
||||
bridge_id=bridge_id,
|
||||
local_addr=list(connection.local_addr),
|
||||
remote_addr=list(connection.remote_addr)
|
||||
if connection.remote_addr
|
||||
else None,
|
||||
call_context_vars=call_context_vars,
|
||||
)
|
||||
|
||||
# Select worker using round-robin
|
||||
worker_id = await self._select_worker()
|
||||
|
|
@ -186,7 +188,7 @@ class ARIManager:
|
|||
channel = RedisChannels.worker_events(worker_id)
|
||||
|
||||
# Publish event to specific worker
|
||||
await self.redis.publish(channel, json.dumps(event))
|
||||
await self.redis.publish(channel, event.to_json())
|
||||
logger.info(
|
||||
f"channelID: {channel_id} Published stasis_start event to worker {worker_id}"
|
||||
)
|
||||
|
|
@ -416,10 +418,8 @@ class ARIManager:
|
|||
):
|
||||
"""Execute commands from workers."""
|
||||
if isinstance(command, DisconnectCommand):
|
||||
logger.info(
|
||||
f"channelID: {channel_id} Worker requested disconnect: {command.reason}"
|
||||
)
|
||||
await connection.disconnect(command.reason)
|
||||
logger.info(f"channelID: {channel_id} Worker requested disconnect")
|
||||
await connection.disconnect()
|
||||
|
||||
elif isinstance(command, TransferCommand):
|
||||
logger.info(f"channelID: {channel_id} Worker requested transfer")
|
||||
|
|
|
|||
|
|
@ -189,7 +189,7 @@ class ARIManagerConnection(BaseObject):
|
|||
f"channelID: {self.caller_channel_id} Failed to sync data to ARI_DATA_SYNCING_URI: {e}"
|
||||
)
|
||||
|
||||
async def disconnect(self, reason: str):
|
||||
async def disconnect(self):
|
||||
"""Instruct Asterisk to hang-up the call and perform cleanup."""
|
||||
if self._closed:
|
||||
return
|
||||
|
|
@ -206,7 +206,7 @@ class ARIManagerConnection(BaseObject):
|
|||
caller_channel = await self._get_channel(self.caller_channel_id)
|
||||
if caller_channel:
|
||||
logger.debug(
|
||||
f"channelID: {self.caller_channel_id} Hanging up caller channel due to reason: {reason}"
|
||||
f"channelID: {self.caller_channel_id} Hanging up caller channel"
|
||||
)
|
||||
await caller_channel.hangup()
|
||||
except Exception:
|
||||
|
|
|
|||
|
|
@ -16,7 +16,6 @@ import struct
|
|||
from typing import TYPE_CHECKING, AsyncIterator, Optional
|
||||
|
||||
from loguru import logger
|
||||
from pipecat.utils.enums import EndTaskReason
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from api.services.telephony.stasis_rtp_connection import StasisRTPConnection
|
||||
|
|
@ -113,9 +112,6 @@ class StasisRTPClient:
|
|||
self._closing = False
|
||||
self._recv_sock_ready = asyncio.Event() # Signal when recv socket is ready
|
||||
self._leave_counter = 0 # Track input/output transport usage
|
||||
self._fallback_disconnect_timer: Optional[asyncio.Task] = (
|
||||
None # Safety timer for disconnect
|
||||
)
|
||||
|
||||
# ── wire event handlers to the connection ────────────────
|
||||
@self._connection.event_handler("connected")
|
||||
|
|
@ -126,27 +122,10 @@ class StasisRTPClient:
|
|||
)
|
||||
|
||||
@self._connection.event_handler("disconnected")
|
||||
async def _on_disconnected(_: Any, reason: str):
|
||||
# Cancel the safety timer if it exists. We start the safety timer when
|
||||
# sending disconnect or transfer from the engine, i.e when the disconnect()
|
||||
# method of the StasisRTPClient is called during wind down of the pipeline.
|
||||
# We start the timer so that if we don't get the remote hangup in a given
|
||||
# duration, we will call client disconnected handler.
|
||||
if (
|
||||
self._fallback_disconnect_timer
|
||||
and not self._fallback_disconnect_timer.done()
|
||||
):
|
||||
self._fallback_disconnect_timer.cancel()
|
||||
self._fallback_disconnect_timer = None
|
||||
|
||||
if not self._closing:
|
||||
# Mark the client as closing, so that when the pipeline is
|
||||
# cancelled or getting closed, we don't try start the fallback
|
||||
# disconnect timer and return safely from disconnect
|
||||
self._closing = True
|
||||
|
||||
async def _on_disconnected(_: Any):
|
||||
logger.debug("In _on_disconnected of StasisRTPClient")
|
||||
await self._callbacks.on_client_disconnected(
|
||||
self._connection.caller_channel_id, reason
|
||||
self._connection.caller_channel_id
|
||||
)
|
||||
|
||||
# ─── public helpers ──────────────────────────────────────────
|
||||
|
|
@ -161,20 +140,16 @@ class StasisRTPClient:
|
|||
return
|
||||
await self._connection.connect()
|
||||
|
||||
async def disconnect(
|
||||
self,
|
||||
reason: str = EndTaskReason.UNKNOWN.value,
|
||||
call_transfer_context: dict = {}, # Keep parameter for backward compatibility
|
||||
):
|
||||
async def disconnect(self):
|
||||
"""Disconnect from the RTP socket."""
|
||||
# Decrement leave counter when disconnect is called
|
||||
logger.debug(f"StasisRTPClient.disconnect leave_counter: {self._leave_counter}")
|
||||
self._leave_counter -= 1
|
||||
if self._leave_counter > 0:
|
||||
# Early return - InputTransport called first, OutputTransport will call later
|
||||
# Only proceed when counter reaches 0 (OutputTransport's call)
|
||||
return
|
||||
|
||||
# Only proceed when counter reaches 0 (OutputTransport's call)
|
||||
# Close sockets
|
||||
logger.debug("Going to close sockets")
|
||||
await self._close_sockets()
|
||||
|
|
@ -186,33 +161,12 @@ class StasisRTPClient:
|
|||
return
|
||||
self._closing = True
|
||||
|
||||
# Create a safety timer that will call on_client_disconnected if we don't
|
||||
# get StasisEnd from the dialer within 5 seconds. StasisEnd is needed to
|
||||
# trigger on_client_disconnected handler in the event_handlers
|
||||
async def _fallback_disconnect_timeout():
|
||||
await asyncio.sleep(5.0)
|
||||
logger.warning(
|
||||
"Disconnect event not received within 5 seconds, calling on_client_disconnected as fallback"
|
||||
)
|
||||
await self._callbacks.on_client_disconnected(
|
||||
self._connection.caller_channel_id
|
||||
)
|
||||
|
||||
self._fallback_disconnect_timer = asyncio.create_task(
|
||||
_fallback_disconnect_timeout()
|
||||
)
|
||||
|
||||
# Only call disconnect if not a transfer (transfer already handled in PipecatEngine)
|
||||
# NOTE: Transfer now happens immediately in PipecatEngine.send_end_task_frame()
|
||||
if reason != EndTaskReason.USER_QUALIFIED.value:
|
||||
try:
|
||||
await self._connection.disconnect(reason)
|
||||
except Exception as exc:
|
||||
logger.error(f"Failed to disconnect RTP connection: {exc}")
|
||||
else:
|
||||
logger.debug(
|
||||
"Skipping disconnect call for USER_QUALIFIED - transfer already initiated by engine"
|
||||
)
|
||||
# If we have initiated transfer before, we would ignore _connection.disconnect()
|
||||
# in the connection. (since is_closing would be set by transfer)
|
||||
try:
|
||||
await self._connection.disconnect()
|
||||
except Exception as exc:
|
||||
logger.error(f"Failed to disconnect RTP connection: {exc}")
|
||||
|
||||
# ─── socket management ──────────────────────────────────────
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
"""Stasis RTP connection for worker processes.
|
||||
"""Stasis RTP connection for worker processes - is used by stasis rtp transport.
|
||||
|
||||
This connection works without direct ARI access and communicates with
|
||||
the ARI Manager via Redis for all control operations.
|
||||
|
|
@ -9,7 +9,6 @@ from typing import Optional, Tuple
|
|||
import redis.asyncio as aioredis
|
||||
from loguru import logger
|
||||
from pipecat.utils.base_object import BaseObject
|
||||
from pipecat.utils.enums import EndTaskReason
|
||||
|
||||
from api.services.telephony.stasis_event_protocol import (
|
||||
DisconnectCommand,
|
||||
|
|
@ -77,6 +76,10 @@ class StasisRTPConnection(BaseObject):
|
|||
# StasisEnd from the transport
|
||||
self._closed_by_stasis_end = False
|
||||
|
||||
# self._closing should be True if we have received disconnect
|
||||
# or transfer request
|
||||
self._closing = False
|
||||
|
||||
self._connect_invoked = False
|
||||
|
||||
# Register event handlers
|
||||
|
|
@ -102,18 +105,20 @@ class StasisRTPConnection(BaseObject):
|
|||
"StasisRTPConnection is not connected - did not call connected handler"
|
||||
)
|
||||
|
||||
async def disconnect(self, reason: str):
|
||||
async def disconnect(self):
|
||||
"""Request disconnection via Redis command to ARI Manager. Usually called
|
||||
when there is a disconnect triggered by workflow"""
|
||||
# If we have already received user hangup via StasisEnd, lets
|
||||
# return
|
||||
if self._closed_by_stasis_end:
|
||||
if self._closed_by_stasis_end or self._closing:
|
||||
return
|
||||
|
||||
logger.info(f"channelID: {self.channel_id} Requesting disconnect: {reason}")
|
||||
self._closing = True
|
||||
|
||||
logger.info(f"channelID: {self.channel_id} Requesting disconnect")
|
||||
|
||||
# Send disconnect command to ARI Manager
|
||||
command = DisconnectCommand(channel_id=self.channel_id, reason=reason)
|
||||
command = DisconnectCommand(channel_id=self.channel_id)
|
||||
channel = RedisChannels.channel_commands(self.channel_id)
|
||||
await self.redis.publish(channel, command.to_json())
|
||||
|
||||
|
|
@ -121,9 +126,11 @@ class StasisRTPConnection(BaseObject):
|
|||
"""Request call transfer via Redis command to ARI Manager."""
|
||||
# If we have already received user hangup via StasisEnd, lets
|
||||
# return
|
||||
if self._closed_by_stasis_end:
|
||||
if self._closed_by_stasis_end or self._closing:
|
||||
return
|
||||
|
||||
self._closing = True
|
||||
|
||||
logger.info(f"channelID: {self.channel_id} Requesting transfer")
|
||||
|
||||
# Send transfer command to ARI Manager
|
||||
|
|
@ -149,11 +156,15 @@ class StasisRTPConnection(BaseObject):
|
|||
|
||||
Returns True once connect() has been called and connection is not closed.
|
||||
"""
|
||||
return self._connect_invoked and not self._closed_by_stasis_end
|
||||
return (
|
||||
self._connect_invoked
|
||||
and not self._closed_by_stasis_end
|
||||
and not self._closing
|
||||
)
|
||||
|
||||
async def handle_remote_disconnect(self, reason: str = EndTaskReason.UNKNOWN.value):
|
||||
async def handle_remote_disconnect(self):
|
||||
"""Handle disconnection initiated by ARI Manager. Is called when the user hangs up."""
|
||||
if self._closed_by_stasis_end:
|
||||
if self._closed_by_stasis_end or self._closing:
|
||||
return
|
||||
|
||||
self._closed_by_stasis_end = True
|
||||
|
|
@ -163,15 +174,13 @@ class StasisRTPConnection(BaseObject):
|
|||
# register the event handler of client when the transports are initiated during pipeline
|
||||
# initialisation. Any caller must check and wait for _connect_invoked before
|
||||
# calling the method
|
||||
await self._call_event_handler("disconnected", reason)
|
||||
await self._call_event_handler("disconnected")
|
||||
else:
|
||||
logger.warning(
|
||||
f"ChannelID: {self.channel_id} Got remote disconnect before connection was invoked"
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"channelID: {self.channel_id} StasisRTPConnection disconnected: {reason}"
|
||||
)
|
||||
logger.info(f"channelID: {self.channel_id} StasisRTPConnection disconnected")
|
||||
|
||||
def __repr__(self):
|
||||
"""String representation of connection."""
|
||||
|
|
|
|||
|
|
@ -23,7 +23,6 @@ from pipecat.transports.base_output import (
|
|||
TransportClientNotConnectedException,
|
||||
)
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
from pipecat.utils.enums import EndTaskReason
|
||||
from pydantic import BaseModel
|
||||
|
||||
from api.services.telephony.stasis_rtp_client import StasisRTPClient
|
||||
|
|
@ -40,9 +39,7 @@ class StasisRTPCallbacks(BaseModel):
|
|||
"""Callbacks for Stasis RTP transport events."""
|
||||
|
||||
on_client_connected: Callable[[str], Awaitable[None]]
|
||||
on_client_disconnected: Callable[
|
||||
[str, Optional[str]], Awaitable[None]
|
||||
] # Added optional disconnect reason
|
||||
on_client_disconnected: Callable[[str], Awaitable[None]]
|
||||
on_client_closed: Callable[[str], Awaitable[None]]
|
||||
|
||||
|
||||
|
|
@ -116,22 +113,14 @@ class StasisRTPInputTransport(BaseInputTransport):
|
|||
"""Stop the input transport."""
|
||||
await super().stop(frame)
|
||||
await self._stop_tasks()
|
||||
# Call disconnect on the client when EndFrame is encountered
|
||||
await self._client.disconnect(
|
||||
frame.metadata.get("reason", EndTaskReason.UNKNOWN.value),
|
||||
frame.metadata.get("call_transfer_context", {}),
|
||||
)
|
||||
await self._client.disconnect()
|
||||
logger.debug("Successfully disconnected from StasisRTPClient")
|
||||
|
||||
async def cancel(self, frame: CancelFrame):
|
||||
"""Cancel the input transport."""
|
||||
await super().cancel(frame)
|
||||
await self._stop_tasks()
|
||||
# Call disconnect on the client when CancelFrame is encountered
|
||||
await self._client.disconnect(
|
||||
frame.metadata.get("reason", EndTaskReason.SYSTEM_CANCELLED.value),
|
||||
frame.metadata.get("call_transfer_context", {}),
|
||||
)
|
||||
await self._client.disconnect()
|
||||
|
||||
async def _receive_audio(self):
|
||||
try:
|
||||
|
|
@ -198,22 +187,12 @@ class StasisRTPOutputTransport(BaseOutputTransport):
|
|||
async def stop(self, frame: EndFrame):
|
||||
"""Stop the output transport."""
|
||||
await super().stop(frame)
|
||||
|
||||
# Call disconnect on the client when EndFrame is encountered
|
||||
# The client will check its _leave_counter and decide whether to close sockets
|
||||
await self._client.disconnect(
|
||||
frame.metadata.get("reason", EndTaskReason.UNKNOWN.value),
|
||||
frame.metadata.get("call_transfer_context", {}),
|
||||
)
|
||||
await self._client.disconnect()
|
||||
|
||||
async def cancel(self, frame: CancelFrame):
|
||||
"""Cancel the output transport."""
|
||||
await super().cancel(frame)
|
||||
# Call disconnect on the client when CancelFrame is encountered
|
||||
await self._client.disconnect(
|
||||
frame.metadata.get("reason", EndTaskReason.SYSTEM_CANCELLED.value),
|
||||
frame.metadata.get("call_transfer_context", {}),
|
||||
)
|
||||
await self._client.disconnect()
|
||||
|
||||
async def send_message(
|
||||
self, frame: TransportMessageFrame | TransportMessageUrgentFrame
|
||||
|
|
@ -317,8 +296,8 @@ class StasisRTPTransport(BaseTransport):
|
|||
async def _on_client_connected(self, chan_id: str):
|
||||
await self._call_event_handler("on_client_connected", chan_id)
|
||||
|
||||
async def _on_client_disconnected(self, chan_id: str, reason: Optional[str] = None):
|
||||
await self._call_event_handler("on_client_disconnected", chan_id, reason)
|
||||
async def _on_client_disconnected(self, chan_id: str):
|
||||
await self._call_event_handler("on_client_disconnected", chan_id)
|
||||
|
||||
async def _on_client_closed(self, chan_id: str):
|
||||
await self._call_event_handler("on_client_closed", chan_id)
|
||||
|
|
|
|||
|
|
@ -299,7 +299,7 @@ class WorkerEventSubscriber:
|
|||
if channel_id in self._active_tasks:
|
||||
del self._active_tasks[channel_id]
|
||||
|
||||
async def _process_cleanup(self, channel_id: str, reason: str):
|
||||
async def _process_cleanup(self, channel_id: str):
|
||||
"""Process call cleanup in the background."""
|
||||
try:
|
||||
if channel_id in self._active_connections:
|
||||
|
|
@ -317,7 +317,7 @@ class WorkerEventSubscriber:
|
|||
if connection.workflow_run_id:
|
||||
set_current_run_id(connection.workflow_run_id)
|
||||
|
||||
await connection.handle_remote_disconnect(reason)
|
||||
await connection.handle_remote_disconnect()
|
||||
del self._active_connections[channel_id]
|
||||
except Exception as e:
|
||||
logger.exception(f"Error during cleanup for {channel_id}: {e}")
|
||||
|
|
@ -330,7 +330,7 @@ class WorkerEventSubscriber:
|
|||
"""Handle call termination."""
|
||||
channel_id = event.channel_id
|
||||
logger.info(
|
||||
f"channelID: {channel_id} Worker {self.worker_id} handling StasisEnd, Reason: {event.reason}"
|
||||
f"channelID: {channel_id} Worker {self.worker_id} handling StasisEnd"
|
||||
)
|
||||
|
||||
# Create a background task to handle the cleanup
|
||||
|
|
@ -344,7 +344,7 @@ class WorkerEventSubscriber:
|
|||
# connection to be invoked from the pipeline before
|
||||
# caling remote disconnect
|
||||
task = asyncio.create_task(
|
||||
self._process_cleanup(channel_id, event.reason),
|
||||
self._process_cleanup(channel_id),
|
||||
name=f"cleanup_handler_{channel_id}",
|
||||
)
|
||||
self._cleanup_tasks[channel_id] = task
|
||||
|
|
|
|||
|
|
@ -83,6 +83,7 @@ class PipecatEngine:
|
|||
self._audio_buffer = audio_buffer
|
||||
self._workflow_run_id = workflow_run_id
|
||||
self._initialized = False
|
||||
self._client_disconnected = False
|
||||
self._pending_function_calls = 0
|
||||
self._current_node: Optional[Node] = None
|
||||
self._gathered_context: dict = {}
|
||||
|
|
@ -602,7 +603,6 @@ class PipecatEngine:
|
|||
async def send_end_task_frame(
|
||||
self,
|
||||
reason: str,
|
||||
additional_metadata: dict = None,
|
||||
abort_immediately: bool = False,
|
||||
):
|
||||
"""
|
||||
|
|
@ -621,6 +621,11 @@ class PipecatEngine:
|
|||
self._workflow_run_id
|
||||
)
|
||||
|
||||
# If client is disconnected before we get a chance to disconnect from
|
||||
# the bot, lets consider that as final disposition
|
||||
if self._client_disconnected:
|
||||
call_disposition = EndTaskReason.USER_HANGUP.value
|
||||
|
||||
if call_disposition:
|
||||
# If call_disposition exists, map it
|
||||
mapped_disposition = await apply_disposition_mapping(
|
||||
|
|
@ -710,19 +715,6 @@ class PipecatEngine:
|
|||
)
|
||||
)
|
||||
|
||||
metadata = {
|
||||
# Keep original reason in metadata, which would be used to decide
|
||||
# whether to disconnect or to transfer the call in the transport
|
||||
"reason": reason,
|
||||
"call_transfer_context": call_transfer_context,
|
||||
}
|
||||
|
||||
# Add any additional metadata
|
||||
if additional_metadata:
|
||||
metadata.update(additional_metadata)
|
||||
|
||||
frame_to_push.metadata = metadata
|
||||
|
||||
# Store the original reason for later retrieval in event handler
|
||||
self._call_disposition = mapped_disposition
|
||||
|
||||
|
|
@ -872,14 +864,6 @@ class PipecatEngine:
|
|||
"""Create a callback that corrects corrupted aggregation using reference text."""
|
||||
return engine_callbacks.create_aggregation_correction_callback(self)
|
||||
|
||||
def get_call_disposition(self) -> Optional[str]:
|
||||
"""Get the disconnect reason set by the engine."""
|
||||
return self._call_disposition
|
||||
|
||||
def get_gathered_context(self) -> dict:
|
||||
"""Get the gathered context including extracted variables."""
|
||||
return self._gathered_context.copy()
|
||||
|
||||
def set_context(self, context: OpenAILLMContext) -> None:
|
||||
"""Set the OpenAI LLM context.
|
||||
|
||||
|
|
@ -925,6 +909,26 @@ class PipecatEngine:
|
|||
"""Accumulate LLM text frames to build reference text."""
|
||||
self._current_llm_reference_text += text
|
||||
|
||||
async def handle_client_disconnected(self):
|
||||
"""Handle client disconnected event."""
|
||||
self._client_disconnected = True
|
||||
|
||||
async def get_call_disposition(self) -> Optional[str]:
|
||||
"""Get the disconnect reason set by the engine."""
|
||||
if self._call_disposition:
|
||||
# We would have a _call_disposition variable set if we have initiated
|
||||
# a disconnect from the bot, i.e we have called send_end_task_frame.
|
||||
return self._call_disposition
|
||||
|
||||
if self._client_disconnected:
|
||||
return EndTaskReason.USER_HANGUP.value
|
||||
else:
|
||||
return EndTaskReason.UNKNOWN.value
|
||||
|
||||
async def get_gathered_context(self) -> dict:
|
||||
"""Get the gathered context including extracted variables."""
|
||||
return self._gathered_context.copy()
|
||||
|
||||
async def cleanup(self):
|
||||
"""Clean up engine resources on disconnect."""
|
||||
# Cancel any pending timeout tasks
|
||||
|
|
|
|||
|
|
@ -245,13 +245,6 @@ class VoicemailDetector:
|
|||
# Send end task frame with metadata (including optional S3 path)
|
||||
await self._engine.send_end_task_frame(
|
||||
reason=EndTaskReason.VOICEMAIL_DETECTED.value,
|
||||
additional_metadata={
|
||||
"voicemail_transcript": transcript,
|
||||
"voicemail_confidence": confidence,
|
||||
"voicemail_reasoning": reasoning,
|
||||
"voicemail_detection_duration": self.detection_duration,
|
||||
"voicemail_audio_s3_path": s3_path,
|
||||
},
|
||||
abort_immediately=True,
|
||||
)
|
||||
else:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue