dograh/api/tests/test_leave_counter.py
Abhishek Kumar 4f2a629340 Initial Commit 🚀 🚀
2025-09-09 14:37:32 +05:30

248 lines
8.5 KiB
Python

"""
Test cases for _leave_counter mechanism in transport clients.
This test suite verifies that the _leave_counter prevents premature disconnection
when both input and output transports are using the same client.
"""
import asyncio
from unittest.mock import AsyncMock, Mock
import pytest
from pipecat.frames.frames import EndFrame, StartFrame
from pipecat.transports.network.fastapi_websocket import (
FastAPIWebsocketCallbacks,
FastAPIWebsocketClient,
FastAPIWebsocketParams,
FastAPIWebsocketTransport,
)
from pipecat.transports.network.small_webrtc import SmallWebRTCClient
from api.services.telephony.stasis_rtp_client import StasisRTPClient
class TestLeaveCounterFastAPIWebsocket:
"""Test the _leave_counter mechanism in FastAPIWebsocketClient."""
@pytest.mark.asyncio
async def test_leave_counter_prevents_early_disconnect(self):
"""Test that disconnect only happens when both transports have disconnected."""
# Create mock websocket
mock_websocket = Mock()
mock_websocket.close = AsyncMock()
# Set client_state directly to WebSocketState.CONNECTED value
from starlette.websockets import WebSocketState
mock_websocket.client_state = WebSocketState.CONNECTED
# Create callbacks
callbacks = FastAPIWebsocketCallbacks(
on_client_connected=AsyncMock(),
on_client_disconnected=AsyncMock(),
on_session_timeout=AsyncMock(),
)
# Create client
client = FastAPIWebsocketClient(
mock_websocket, is_binary=False, callbacks=callbacks
)
# Create StartFrame
start_frame = StartFrame()
# Simulate both input and output transports calling setup
await client.setup(start_frame) # Input transport
assert client._leave_counter == 1
await client.setup(start_frame) # Output transport
assert client._leave_counter == 2
# First disconnect - should not actually disconnect
await client.disconnect()
assert client._leave_counter == 1
mock_websocket.close.assert_not_called()
callbacks.on_client_disconnected.assert_not_called()
# Second disconnect - should actually disconnect
await client.disconnect()
assert client._leave_counter == 0
mock_websocket.close.assert_called_once()
callbacks.on_client_disconnected.assert_called_once()
class TestLeaveCounterStasisRTP:
"""Test the _leave_counter mechanism in StasisRTPClient."""
@pytest.mark.asyncio
async def test_leave_counter_prevents_early_disconnect(self):
"""Test that disconnect only happens when both transports have disconnected."""
# Create mock connection
mock_connection = Mock()
mock_connection.is_connected.return_value = True
mock_connection.disconnect = AsyncMock()
mock_connection.notify_sockets_closed = AsyncMock()
# Mock event_handler as a callable that acts as a decorator
def mock_event_handler(event_name):
def decorator(func):
return func
return decorator
mock_connection.event_handler = mock_event_handler
# Create callbacks
from api.services.telephony.stasis_rtp_transport import StasisRTPCallbacks
callbacks = StasisRTPCallbacks(
on_client_connected=AsyncMock(),
on_client_disconnected=AsyncMock(),
on_client_closed=AsyncMock(),
)
# Create client
client = StasisRTPClient(mock_connection, callbacks)
# Create StartFrame
start_frame = StartFrame()
# Simulate both input and output transports calling setup
await client.setup(start_frame) # Input transport
assert client._leave_counter == 1
await client.setup(start_frame) # Output transport
assert client._leave_counter == 2
# First disconnect - should not actually disconnect
await client.disconnect()
assert client._leave_counter == 1
mock_connection.disconnect.assert_not_called()
# Second disconnect - should actually disconnect
await client.disconnect()
assert client._leave_counter == 0
mock_connection.disconnect.assert_called_once()
class TestLeaveCounterSmallWebRTC:
"""Test the _leave_counter mechanism in SmallWebRTCClient."""
@pytest.mark.asyncio
async def test_leave_counter_prevents_early_disconnect(self):
"""Test that disconnect only happens when both transports have disconnected."""
# Create mock connection
mock_connection = Mock()
mock_connection.is_connected.return_value = True
mock_connection.disconnect = AsyncMock()
mock_connection.notify_sockets_closed = AsyncMock()
# Mock event_handler as a callable that acts as a decorator
def mock_event_handler(event_name):
def decorator(func):
return func
return decorator
mock_connection.event_handler = mock_event_handler
# Create callbacks
from pipecat.transports.network.small_webrtc import SmallWebRTCCallbacks
callbacks = SmallWebRTCCallbacks(
on_app_message=AsyncMock(),
on_client_connected=AsyncMock(),
on_client_disconnected=AsyncMock(),
)
# Create client
client = SmallWebRTCClient(mock_connection, callbacks)
# Create StartFrame with required attributes
start_frame = StartFrame()
# Create mock transport params
from pipecat.transports.base_transport import TransportParams
params = TransportParams(
audio_in_channels=1, audio_in_sample_rate=16000, audio_out_sample_rate=16000
)
# Simulate both input and output transports calling setup
await client.setup(params, start_frame) # Input transport
assert client._leave_counter == 1
await client.setup(params, start_frame) # Output transport
assert client._leave_counter == 2
# First disconnect - should not actually disconnect
await client.disconnect()
assert client._leave_counter == 1
mock_connection.disconnect.assert_not_called()
# Second disconnect - should actually disconnect
await client.disconnect()
assert client._leave_counter == 0
mock_connection.disconnect.assert_called_once()
@pytest.mark.skip(reason="Complex integration test - requires additional mocking")
@pytest.mark.asyncio
async def test_transport_lifecycle_with_leave_counter():
"""Test complete transport lifecycle with proper leave counter handling."""
# Create mock websocket
mock_websocket = Mock()
mock_websocket.close = AsyncMock()
# Set client_state directly to WebSocketState.CONNECTED value
from starlette.websockets import WebSocketState
mock_websocket.client_state = WebSocketState.CONNECTED
mock_websocket.iter_bytes = Mock(return_value=iter([]))
mock_websocket.send_bytes = AsyncMock()
# Create transport
params = FastAPIWebsocketParams(audio_in_enabled=True, audio_out_enabled=True)
transport = FastAPIWebsocketTransport(mock_websocket, params)
# Get input and output transports
input_transport = transport.input()
output_transport = transport.output()
# Setup the transport with required components
from pipecat.clocks.system_clock import SystemClock
from pipecat.processors.frame_processor import FrameProcessorSetup
from pipecat.utils.asyncio.task_manager import TaskManager, TaskManagerParams
clock = SystemClock()
task_manager = TaskManager()
# Setup task manager with event loop
loop = asyncio.get_event_loop()
task_manager_params = TaskManagerParams(loop=loop)
task_manager.setup(task_manager_params)
setup = FrameProcessorSetup(clock=clock, task_manager=task_manager)
# Setup both input and output transports
await input_transport.setup(setup)
await output_transport.setup(setup)
# Start both transports
start_frame = StartFrame()
await input_transport.start(start_frame)
await output_transport.start(start_frame)
# Verify leave counter is 2
assert transport._client._leave_counter == 2
# Stop input transport
end_frame = EndFrame()
await input_transport.stop(end_frame)
# Verify websocket not closed yet
mock_websocket.close.assert_not_called()
# Stop output transport
await output_transport.stop(end_frame)
# Now websocket should be closed
mock_websocket.close.assert_called_once()