on_client_disconnect only when user disconnects.

This commit is contained in:
a6kme 2025-09-22 09:31:13 +00:00
parent 70c58e1a40
commit 33ff3b9779
11 changed files with 162 additions and 214 deletions

View file

@ -26,6 +26,8 @@ from api.services.telephony.stasis_event_protocol import (
RedisChannels,
RedisKeys,
SocketClosedCommand,
StasisEndEvent,
StasisStartEvent,
TransferCommand,
parse_command,
)
@ -123,13 +125,12 @@ class ARIManager:
# Publish StasisEnd event to worker immediately
if connection and caller_channel_id:
worker_id = self._get_worker_for_channel(caller_channel_id)
event = {
"type": "stasis_end",
"channel_id": caller_channel_id,
"reason": EndTaskReason.USER_HANGUP.value,
}
event = StasisEndEvent(
channel_id=caller_channel_id,
reason=EndTaskReason.USER_HANGUP.value,
)
await self.redis.publish(
RedisChannels.worker_events(worker_id), json.dumps(event)
RedisChannels.worker_events(worker_id), event.to_json()
)
logger.info(f"channelID: {channel_id} Published StasisEnd event")
@ -163,16 +164,17 @@ class ARIManager:
self._active_channels.add(channel_id)
# Create event with all connection details
event = {
"type": "stasis_start",
"channel_id": channel_id,
"caller_channel_id": channel_id,
"em_channel_id": em_channel_id,
"bridge_id": bridge_id,
"local_addr": list(connection.local_addr),
"remote_addr": list(connection.remote_addr),
"call_context_vars": call_context_vars,
}
event = StasisStartEvent(
channel_id=channel_id,
caller_channel_id=channel_id,
em_channel_id=em_channel_id,
bridge_id=bridge_id,
local_addr=list(connection.local_addr),
remote_addr=list(connection.remote_addr)
if connection.remote_addr
else None,
call_context_vars=call_context_vars,
)
# Select worker using round-robin
worker_id = await self._select_worker()
@ -186,7 +188,7 @@ class ARIManager:
channel = RedisChannels.worker_events(worker_id)
# Publish event to specific worker
await self.redis.publish(channel, json.dumps(event))
await self.redis.publish(channel, event.to_json())
logger.info(
f"channelID: {channel_id} Published stasis_start event to worker {worker_id}"
)
@ -416,10 +418,8 @@ class ARIManager:
):
"""Execute commands from workers."""
if isinstance(command, DisconnectCommand):
logger.info(
f"channelID: {channel_id} Worker requested disconnect: {command.reason}"
)
await connection.disconnect(command.reason)
logger.info(f"channelID: {channel_id} Worker requested disconnect")
await connection.disconnect()
elif isinstance(command, TransferCommand):
logger.info(f"channelID: {channel_id} Worker requested transfer")

View file

@ -189,7 +189,7 @@ class ARIManagerConnection(BaseObject):
f"channelID: {self.caller_channel_id} Failed to sync data to ARI_DATA_SYNCING_URI: {e}"
)
async def disconnect(self, reason: str):
async def disconnect(self):
"""Instruct Asterisk to hang-up the call and perform cleanup."""
if self._closed:
return
@ -206,7 +206,7 @@ class ARIManagerConnection(BaseObject):
caller_channel = await self._get_channel(self.caller_channel_id)
if caller_channel:
logger.debug(
f"channelID: {self.caller_channel_id} Hanging up caller channel due to reason: {reason}"
f"channelID: {self.caller_channel_id} Hanging up caller channel"
)
await caller_channel.hangup()
except Exception:

View file

@ -16,7 +16,6 @@ import struct
from typing import TYPE_CHECKING, AsyncIterator, Optional
from loguru import logger
from pipecat.utils.enums import EndTaskReason
if TYPE_CHECKING:
from api.services.telephony.stasis_rtp_connection import StasisRTPConnection
@ -113,9 +112,6 @@ class StasisRTPClient:
self._closing = False
self._recv_sock_ready = asyncio.Event() # Signal when recv socket is ready
self._leave_counter = 0 # Track input/output transport usage
self._fallback_disconnect_timer: Optional[asyncio.Task] = (
None # Safety timer for disconnect
)
# ── wire event handlers to the connection ────────────────
@self._connection.event_handler("connected")
@ -126,27 +122,10 @@ class StasisRTPClient:
)
@self._connection.event_handler("disconnected")
async def _on_disconnected(_: Any, reason: str):
# Cancel the safety timer if it exists. We start the safety timer when
# sending disconnect or transfer from the engine, i.e when the disconnect()
# method of the StasisRTPClient is called during wind down of the pipeline.
# We start the timer so that if we don't get the remote hangup in a given
# duration, we will call client disconnected handler.
if (
self._fallback_disconnect_timer
and not self._fallback_disconnect_timer.done()
):
self._fallback_disconnect_timer.cancel()
self._fallback_disconnect_timer = None
if not self._closing:
# Mark the client as closing, so that when the pipeline is
# cancelled or getting closed, we don't try start the fallback
# disconnect timer and return safely from disconnect
self._closing = True
async def _on_disconnected(_: Any):
logger.debug("In _on_disconnected of StasisRTPClient")
await self._callbacks.on_client_disconnected(
self._connection.caller_channel_id, reason
self._connection.caller_channel_id
)
# ─── public helpers ──────────────────────────────────────────
@ -161,20 +140,16 @@ class StasisRTPClient:
return
await self._connection.connect()
async def disconnect(
self,
reason: str = EndTaskReason.UNKNOWN.value,
call_transfer_context: dict = {}, # Keep parameter for backward compatibility
):
async def disconnect(self):
"""Disconnect from the RTP socket."""
# Decrement leave counter when disconnect is called
logger.debug(f"StasisRTPClient.disconnect leave_counter: {self._leave_counter}")
self._leave_counter -= 1
if self._leave_counter > 0:
# Early return - InputTransport called first, OutputTransport will call later
# Only proceed when counter reaches 0 (OutputTransport's call)
return
# Only proceed when counter reaches 0 (OutputTransport's call)
# Close sockets
logger.debug("Going to close sockets")
await self._close_sockets()
@ -186,33 +161,12 @@ class StasisRTPClient:
return
self._closing = True
# Create a safety timer that will call on_client_disconnected if we don't
# get StasisEnd from the dialer within 5 seconds. StasisEnd is needed to
# trigger on_client_disconnected handler in the event_handlers
async def _fallback_disconnect_timeout():
await asyncio.sleep(5.0)
logger.warning(
"Disconnect event not received within 5 seconds, calling on_client_disconnected as fallback"
)
await self._callbacks.on_client_disconnected(
self._connection.caller_channel_id
)
self._fallback_disconnect_timer = asyncio.create_task(
_fallback_disconnect_timeout()
)
# Only call disconnect if not a transfer (transfer already handled in PipecatEngine)
# NOTE: Transfer now happens immediately in PipecatEngine.send_end_task_frame()
if reason != EndTaskReason.USER_QUALIFIED.value:
try:
await self._connection.disconnect(reason)
except Exception as exc:
logger.error(f"Failed to disconnect RTP connection: {exc}")
else:
logger.debug(
"Skipping disconnect call for USER_QUALIFIED - transfer already initiated by engine"
)
# If we have initiated transfer before, we would ignore _connection.disconnect()
# in the connection. (since is_closing would be set by transfer)
try:
await self._connection.disconnect()
except Exception as exc:
logger.error(f"Failed to disconnect RTP connection: {exc}")
# ─── socket management ──────────────────────────────────────

View file

@ -1,4 +1,4 @@
"""Stasis RTP connection for worker processes.
"""Stasis RTP connection for worker processes - is used by stasis rtp transport.
This connection works without direct ARI access and communicates with
the ARI Manager via Redis for all control operations.
@ -9,7 +9,6 @@ from typing import Optional, Tuple
import redis.asyncio as aioredis
from loguru import logger
from pipecat.utils.base_object import BaseObject
from pipecat.utils.enums import EndTaskReason
from api.services.telephony.stasis_event_protocol import (
DisconnectCommand,
@ -77,6 +76,10 @@ class StasisRTPConnection(BaseObject):
# StasisEnd from the transport
self._closed_by_stasis_end = False
# self._closing should be True if we have received disconnect
# or transfer request
self._closing = False
self._connect_invoked = False
# Register event handlers
@ -102,18 +105,20 @@ class StasisRTPConnection(BaseObject):
"StasisRTPConnection is not connected - did not call connected handler"
)
async def disconnect(self, reason: str):
async def disconnect(self):
"""Request disconnection via Redis command to ARI Manager. Usually called
when there is a disconnect triggered by workflow"""
# If we have already received user hangup via StasisEnd, lets
# return
if self._closed_by_stasis_end:
if self._closed_by_stasis_end or self._closing:
return
logger.info(f"channelID: {self.channel_id} Requesting disconnect: {reason}")
self._closing = True
logger.info(f"channelID: {self.channel_id} Requesting disconnect")
# Send disconnect command to ARI Manager
command = DisconnectCommand(channel_id=self.channel_id, reason=reason)
command = DisconnectCommand(channel_id=self.channel_id)
channel = RedisChannels.channel_commands(self.channel_id)
await self.redis.publish(channel, command.to_json())
@ -121,9 +126,11 @@ class StasisRTPConnection(BaseObject):
"""Request call transfer via Redis command to ARI Manager."""
# If we have already received user hangup via StasisEnd, lets
# return
if self._closed_by_stasis_end:
if self._closed_by_stasis_end or self._closing:
return
self._closing = True
logger.info(f"channelID: {self.channel_id} Requesting transfer")
# Send transfer command to ARI Manager
@ -149,11 +156,15 @@ class StasisRTPConnection(BaseObject):
Returns True once connect() has been called and connection is not closed.
"""
return self._connect_invoked and not self._closed_by_stasis_end
return (
self._connect_invoked
and not self._closed_by_stasis_end
and not self._closing
)
async def handle_remote_disconnect(self, reason: str = EndTaskReason.UNKNOWN.value):
async def handle_remote_disconnect(self):
"""Handle disconnection initiated by ARI Manager. Is called when the user hangs up."""
if self._closed_by_stasis_end:
if self._closed_by_stasis_end or self._closing:
return
self._closed_by_stasis_end = True
@ -163,15 +174,13 @@ class StasisRTPConnection(BaseObject):
# register the event handler of client when the transports are initiated during pipeline
# initialisation. Any caller must check and wait for _connect_invoked before
# calling the method
await self._call_event_handler("disconnected", reason)
await self._call_event_handler("disconnected")
else:
logger.warning(
f"ChannelID: {self.channel_id} Got remote disconnect before connection was invoked"
)
logger.info(
f"channelID: {self.channel_id} StasisRTPConnection disconnected: {reason}"
)
logger.info(f"channelID: {self.channel_id} StasisRTPConnection disconnected")
def __repr__(self):
"""String representation of connection."""

View file

@ -23,7 +23,6 @@ from pipecat.transports.base_output import (
TransportClientNotConnectedException,
)
from pipecat.transports.base_transport import BaseTransport, TransportParams
from pipecat.utils.enums import EndTaskReason
from pydantic import BaseModel
from api.services.telephony.stasis_rtp_client import StasisRTPClient
@ -40,9 +39,7 @@ class StasisRTPCallbacks(BaseModel):
"""Callbacks for Stasis RTP transport events."""
on_client_connected: Callable[[str], Awaitable[None]]
on_client_disconnected: Callable[
[str, Optional[str]], Awaitable[None]
] # Added optional disconnect reason
on_client_disconnected: Callable[[str], Awaitable[None]]
on_client_closed: Callable[[str], Awaitable[None]]
@ -116,22 +113,14 @@ class StasisRTPInputTransport(BaseInputTransport):
"""Stop the input transport."""
await super().stop(frame)
await self._stop_tasks()
# Call disconnect on the client when EndFrame is encountered
await self._client.disconnect(
frame.metadata.get("reason", EndTaskReason.UNKNOWN.value),
frame.metadata.get("call_transfer_context", {}),
)
await self._client.disconnect()
logger.debug("Successfully disconnected from StasisRTPClient")
async def cancel(self, frame: CancelFrame):
"""Cancel the input transport."""
await super().cancel(frame)
await self._stop_tasks()
# Call disconnect on the client when CancelFrame is encountered
await self._client.disconnect(
frame.metadata.get("reason", EndTaskReason.SYSTEM_CANCELLED.value),
frame.metadata.get("call_transfer_context", {}),
)
await self._client.disconnect()
async def _receive_audio(self):
try:
@ -198,22 +187,12 @@ class StasisRTPOutputTransport(BaseOutputTransport):
async def stop(self, frame: EndFrame):
"""Stop the output transport."""
await super().stop(frame)
# Call disconnect on the client when EndFrame is encountered
# The client will check its _leave_counter and decide whether to close sockets
await self._client.disconnect(
frame.metadata.get("reason", EndTaskReason.UNKNOWN.value),
frame.metadata.get("call_transfer_context", {}),
)
await self._client.disconnect()
async def cancel(self, frame: CancelFrame):
"""Cancel the output transport."""
await super().cancel(frame)
# Call disconnect on the client when CancelFrame is encountered
await self._client.disconnect(
frame.metadata.get("reason", EndTaskReason.SYSTEM_CANCELLED.value),
frame.metadata.get("call_transfer_context", {}),
)
await self._client.disconnect()
async def send_message(
self, frame: TransportMessageFrame | TransportMessageUrgentFrame
@ -317,8 +296,8 @@ class StasisRTPTransport(BaseTransport):
async def _on_client_connected(self, chan_id: str):
await self._call_event_handler("on_client_connected", chan_id)
async def _on_client_disconnected(self, chan_id: str, reason: Optional[str] = None):
await self._call_event_handler("on_client_disconnected", chan_id, reason)
async def _on_client_disconnected(self, chan_id: str):
await self._call_event_handler("on_client_disconnected", chan_id)
async def _on_client_closed(self, chan_id: str):
await self._call_event_handler("on_client_closed", chan_id)

View file

@ -299,7 +299,7 @@ class WorkerEventSubscriber:
if channel_id in self._active_tasks:
del self._active_tasks[channel_id]
async def _process_cleanup(self, channel_id: str, reason: str):
async def _process_cleanup(self, channel_id: str):
"""Process call cleanup in the background."""
try:
if channel_id in self._active_connections:
@ -317,7 +317,7 @@ class WorkerEventSubscriber:
if connection.workflow_run_id:
set_current_run_id(connection.workflow_run_id)
await connection.handle_remote_disconnect(reason)
await connection.handle_remote_disconnect()
del self._active_connections[channel_id]
except Exception as e:
logger.exception(f"Error during cleanup for {channel_id}: {e}")
@ -330,7 +330,7 @@ class WorkerEventSubscriber:
"""Handle call termination."""
channel_id = event.channel_id
logger.info(
f"channelID: {channel_id} Worker {self.worker_id} handling StasisEnd, Reason: {event.reason}"
f"channelID: {channel_id} Worker {self.worker_id} handling StasisEnd"
)
# Create a background task to handle the cleanup
@ -344,7 +344,7 @@ class WorkerEventSubscriber:
# connection to be invoked from the pipeline before
# caling remote disconnect
task = asyncio.create_task(
self._process_cleanup(channel_id, event.reason),
self._process_cleanup(channel_id),
name=f"cleanup_handler_{channel_id}",
)
self._cleanup_tasks[channel_id] = task