From 6d78537297a7a476e71dc74ab834c7c08a26ca54 Mon Sep 17 00:00:00 2001 From: Abhishek Kumar Date: Wed, 27 May 2026 09:49:36 +0000 Subject: [PATCH] chore: remove unused smart_turn service Fixes #323, #324, #325. --- api/services/smart_turn/__init__.py | 3 - api/services/smart_turn/app.py | 478 ------------------ .../smart_turn/websocket_smart_turn.py | 314 ------------ 3 files changed, 795 deletions(-) delete mode 100644 api/services/smart_turn/__init__.py delete mode 100644 api/services/smart_turn/app.py delete mode 100644 api/services/smart_turn/websocket_smart_turn.py diff --git a/api/services/smart_turn/__init__.py b/api/services/smart_turn/__init__.py deleted file mode 100644 index 75d582a..0000000 --- a/api/services/smart_turn/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .websocket_smart_turn import WebSocketSmartTurnAnalyzer - -__all__ = ["WebSocketSmartTurnAnalyzer"] diff --git a/api/services/smart_turn/app.py b/api/services/smart_turn/app.py deleted file mode 100644 index 6bbd0ab..0000000 --- a/api/services/smart_turn/app.py +++ /dev/null @@ -1,478 +0,0 @@ -import asyncio -import io -import json -import logging -import os -import sys -import time -from contextlib import asynccontextmanager -from datetime import datetime -from pathlib import Path - -import numpy as np -from fastapi import ( - BackgroundTasks, - FastAPI, - HTTPException, - Request, - WebSocket, - WebSocketDisconnect, - WebSocketException, - status, -) -from fastapi.websockets import WebSocketState -from pipecat.audio.turn.smart_turn.local_smart_turn_v2 import LocalSmartTurnAnalyzerV2 -from scipy.io import wavfile - -LOG_LEVEL = ( - logging.DEBUG - if os.environ.get("LOG_LEVEL", "DEBUG").lower() == "debug" - else logging.INFO -) - -logger = logging.getLogger("smart_turn") -logger.setLevel(LOG_LEVEL) -handler = logging.StreamHandler(sys.stdout) -handler.setFormatter( - logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") -) -logger.addHandler(handler) - - -# ---------------------------------------------------------------------------- -# Configuration -# ---------------------------------------------------------------------------- -MODEL_PATH = os.getenv("LOCAL_SMART_TURN_MODEL_PATH", "pipecat-ai/smart-turn-v2") - -# ---------------------------------------------------------------------------- -# Analyzer Pool -# ---------------------------------------------------------------------------- - - -class _AnalyzerWrapper: - """Wraps a LocalSmartTurnAnalyzer with a lock so only one request can use it at a time.""" - - def __init__(self, analyzer: LocalSmartTurnAnalyzerV2): - self.analyzer = analyzer - self.lock = asyncio.Lock() - - -_analyzer_wrapper: _AnalyzerWrapper | None = None # Will be initialised in the lifespan - - -@asynccontextmanager -async def lifespan(app: FastAPI): - """Manage the application lifespan - startup and shutdown logic.""" - # Startup logic - global _analyzer_wrapper - - if _analyzer_wrapper is None: - logger.debug("Initializing LocalSmartTurnAnalyzer") - analyzer = LocalSmartTurnAnalyzerV2(smart_turn_model_path=MODEL_PATH) - _analyzer_wrapper = _AnalyzerWrapper(analyzer) - logger.debug("LocalSmartTurnAnalyzer initialized") - - yield # Application runs here - - # Shutdown logic (if needed in the future) - # Any cleanup code would go here - - -app = FastAPI( - title="Smart Turn API", - description="A FastAPI application exposing LocalSmartTurnAnalyzer via HTTP", - lifespan=lifespan, -) - -# ---------------------------------------------------------------------------- -# API Endpoints -# ---------------------------------------------------------------------------- - - -async def save_wav_file( - audio_array: np.ndarray, - prediction: int, - probability: float, - service_id: str | None = None, - sample_rate: int = 16000, -) -> None: - """Save audio data as a WAV file in the background. - - Runs the blocking ``wavfile.write`` call in a thread so that the event loop - is not blocked. This function is now ``async`` so it can be scheduled with - ``asyncio.create_task`` from the WebSocket endpoint, while still being - compatible with ``BackgroundTasks`` (which will ``await`` coroutine - functions). - - Args: - audio_array: The audio data as a numpy array - prediction: The prediction result (0 or 1) - probability: The probability of the prediction - service_id: Optional service identifier - sample_rate: The sample rate of the audio (default: 16000 Hz) - """ - - def _blocking_save() -> None: - try: - # Generate filename with current timestamp and prediction - timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")[:-3] # Include ms - - # Include service_id in filename if available - service_prefix = f"{service_id}_" if service_id else "" - - root_dir = ( - Path(__file__).resolve().parents[3] - ) # dograh/api/services/smart_turn/app.py - filename = ( - root_dir - / f"smart_turn_pipeline/{service_prefix}{timestamp}_{prediction}_{probability}.wav" - ) - - # Convert float32 [-1, 1] back to int16 PCM for WAV file - audio_int16 = np.clip(audio_array * 32767, -32768, 32767).astype(np.int16) - - # Use provided sample rate - wavfile.write(filename, sample_rate, audio_int16) - - length_seconds = len(audio_array) / sample_rate - log_message = f"Saved audio to {filename} (length: {length_seconds:.2f}s, prediction: {prediction}" - if service_id: - log_message += f", service_id: {service_id}" - log_message += ")" - - logger.info(log_message) - - except Exception as exc: # pragma: no cover – best-effort logging only - log_message = f"Failed to save WAV file: {exc}" - if service_id: - log_message += f" (service_id: {service_id})" - logger.error(log_message) - - # Offload the blocking I/O to a thread to avoid blocking the event loop - await asyncio.to_thread(_blocking_save) - - -@app.post("/raw", status_code=status.HTTP_200_OK) -async def handle_raw(request: Request, background_tasks: BackgroundTasks): - """ - Accept a NumPy-serialized float32 array (written via ``np.save``) in the body and - return a JSON prediction compatible with ``HttpSmartTurnAnalyzer``. - """ - - # ------------------------------------------------------------------ - # Secret key validation - # ------------------------------------------------------------------ - expected_secret = os.getenv("SMART_TURN_HTTP_SERVICE_KEY") - if expected_secret: # If a secret is configured, enforce validation - provided_secret = request.headers.get("X-API-Key") - if provided_secret != expected_secret: - logger.warning( - "Unauthorized access attempt to /raw endpoint with invalid or missing secret key" - ) - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Unauthorized", - ) - - # ------------------------------------------------------------------ - # Start total-time measurement as early as possible - # ------------------------------------------------------------------ - request_start_time = time.perf_counter() - - # ------------------------------------------------------------------ - # Log that we received a request (before doing any heavy work) - # ------------------------------------------------------------------ - logger.debug("Received /raw request") - - body = await request.body() - if not body: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, detail="Empty request body" - ) - - # Extract service context and sample rate from headers - service_id = request.headers.get("X-Service-Context") - sample_rate_str = request.headers.get("X-Sample-Rate") - sample_rate = int(sample_rate_str) if sample_rate_str else 16000 - - # Deserialize NumPy array - try: - audio_array = np.load(io.BytesIO(body)) - except Exception as exc: - error_msg = f"Invalid NumPy payload: {exc}" - if service_id: - error_msg += f" (service_id: {service_id})" - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=error_msg, - ) - - wrapper = _analyzer_wrapper - if wrapper is None: - raise HTTPException( - status_code=status.HTTP_503_SERVICE_UNAVAILABLE, - detail="Analyzer not initialized", - ) - - # Run inference guarded by the wrapper lock so the model isn't used concurrently - log_msg = "Going to acquire lock for model inference" - if service_id: - log_msg += f" (service_id: {service_id})" - logger.debug(log_msg) - - async with wrapper.lock: - log_msg = "Acquired lock for model inference" - if service_id: - log_msg += f" (service_id: {service_id})" - logger.debug(log_msg) - - # Measure inference-only latency - inference_start_time = time.perf_counter() - result = await wrapper.analyzer._predict_endpoint(audio_array) - inference_time = time.perf_counter() - inference_start_time - - # Calculate total processing time (from request receipt to response preparation) - total_time = time.perf_counter() - request_start_time - - log_msg = ( - f"Inference done result: {result['prediction']} " - f"probability: {result['probability']} time taken: {inference_time:.2f}s total: {total_time:.2f}s" - ) - if service_id: - log_msg += f" (service_id: {service_id})" - logger.debug(log_msg) - - # Ensure metrics section exists so client code can parse it consistently - metrics = result.get("metrics", {}) - # Overwrite / set the timing metrics explicitly - metrics["inference_time"] = inference_time - metrics["total_time"] = total_time - result["metrics"] = metrics - - logger.debug(f"Result for service_id: {service_id} is: {result}") - - # Add service_id to result for potential client use - if service_id: - result["service_id"] = service_id - - # Persist audio in background so it doesn't block the response. - background_tasks.add_task( - save_wav_file, - audio_array, - result.get("prediction", 0), - result.get("probability", 0), - service_id, - sample_rate, - ) - return result - - -@app.get("/") -async def root(): - """Health-check endpoint.""" - return {"message": "Smart Turn API is running"} - - -# ---------------------------------------------------------------------------- -# WebSocket endpoint -# ---------------------------------------------------------------------------- - - -@app.websocket("/ws") -async def websocket_endpoint(ws: WebSocket): - """Handle streaming Smart Turn requests over WebSocket. - - Each incoming binary message must be a NumPy-serialized float32 array (as - produced by ``np.save``). A JSON-formatted prediction (identical to the - ``/raw`` HTTP endpoint) is sent back as a text message. - """ - - # Extract optional secret key from headers (during handshake) - expected_secret = os.getenv("SMART_TURN_HTTP_SERVICE_KEY") - if expected_secret: - provided_secret = ws.headers.get("X-API-Key") - if provided_secret != expected_secret: - await ws.close(code=4401, reason="Unauthorized") - return - - # Accept the websocket connection and log it - await ws.accept() - - service_id = ws.headers.get("X-Service-Context") - sample_rate_str = ws.headers.get("X-Sample-Rate") - sample_rate = int(sample_rate_str) if sample_rate_str else 16000 - logger.debug( - f"WebSocket connection accepted from service_id: {service_id}, sample_rate: {sample_rate}" - ) - - # ------------------------------------------------------------------ - # Tunables – consider moving to env vars for ops control - # ------------------------------------------------------------------ - connection_timeout = 120.0 # Seconds of inactivity before timing out - MAX_BINARY_SIZE = int( - os.getenv("SMART_TURN_MAX_PAYLOAD", 10 * 1024 * 1024) # 10MB max message size - ) - - # Track background tasks so we can cancel them on disconnect - background_tasks = set() # Track background tasks for cleanup - - try: - logger.debug("Entering WebSocket message loop") - while True: - data = None # Initialize data for each iteration - try: - logger.debug("Waiting for WebSocket message…") - - # Create receive task to handle timeout properly - receive_task = asyncio.create_task(ws.receive()) - try: - msg = await asyncio.wait_for( - receive_task, timeout=connection_timeout - ) - except asyncio.TimeoutError: - # Cancel the receive task to prevent it from running in background - receive_task.cancel() - try: - await receive_task - except asyncio.CancelledError: - pass - - logger.warning( - f"WebSocket connection timeout for service_id: {service_id}" - ) - try: - await ws.close(code=1001, reason="Connection timeout") - except Exception as e: - logger.debug(f"Error closing WebSocket after timeout: {e}") - break - except WebSocketDisconnect as e: - logger.debug(f"WebSocket client disconnected: {e}") - break - - # Validate message structure - if not isinstance(msg, dict): - logger.error(f"Unexpected message type: {type(msg)}") - break - - # Handle disconnect message explicitly - if msg.get("type") == "websocket.disconnect": - logger.debug("Client sent disconnect frame") - break - - data = None - # Binary frame - if "bytes" in msg and msg["bytes"] is not None: - data = msg["bytes"] - logger.debug( - "Received WebSocket audio payload (%d bytes)", len(data) - ) - - except WebSocketDisconnect as e: - logger.debug(f"WebSocket client disconnected: {e}") - break - except Exception as e: - logger.error(f"Error in WebSocket loop: {e}") - break - - if data is None: - continue - - request_start_time = time.perf_counter() - - # -------------------------------------------------------------- - # Basic validation & secure deserialisation - # -------------------------------------------------------------- - if len(data) > MAX_BINARY_SIZE: - logger.warning("Received payload exceeding maximum allowed size") - await ws.send_text('{"error": "Payload too large"}') - continue - - # Deserialize NumPy array (pickle disabled for security) - try: - audio_array = np.load(io.BytesIO(data), allow_pickle=False) - except Exception as exc: - error_msg = f"Invalid NumPy payload: {exc}" - if service_id: - error_msg += f" (service_id: {service_id})" - # Send error response with proper error handling - if ws.application_state == WebSocketState.CONNECTED: - try: - await ws.send_text(f'{{"error": "{error_msg}"}}') - except Exception as e: - logger.error(f"Failed to send error message: {e}") - continue - - wrapper = _analyzer_wrapper - if wrapper is None: - logger.error("Analyzer not initialized; closing connection") - if ws.application_state == WebSocketState.CONNECTED: - await ws.close(code=1011, reason="Analyzer not ready") - break - - async with wrapper.lock: - inference_start_time = time.perf_counter() - result = await wrapper.analyzer._predict_endpoint(audio_array) - inference_time = time.perf_counter() - inference_start_time - - # Timing metrics - total_time = time.perf_counter() - request_start_time - metrics = result.get("metrics", {}) - metrics["inference_time"] = inference_time - metrics["total_time"] = total_time - result["metrics"] = metrics - - logger.debug(f"Result for service_id: {service_id} is: {result}") - - if service_id: - result["service_id"] = service_id - - # Send result with proper error handling - try: - if ws.application_state == WebSocketState.CONNECTED: - await ws.send_text(json.dumps(result)) - else: - logger.warning( - f"Cannot send result - WebSocket not connected for service_id: {service_id}" - ) - break - except WebSocketDisconnect: - logger.debug( - f"Client disconnected while sending result for service_id: {service_id}" - ) - break - except Exception as e: - logger.error(f"Failed to send result: {e}") - break - - # Save audio in the background so that it doesn't block streaming - task = asyncio.create_task( - save_wav_file( - audio_array, - result.get("prediction", 0), - result.get("probability", 0), - service_id, - sample_rate, - ) - ) - # Track task and remove when done - background_tasks.add(task) - task.add_done_callback(background_tasks.discard) - - except WebSocketException as exc: - logger.error(f"WebSocket error: {exc}") - finally: - # Cancel any remaining background tasks - for task in background_tasks: - if not task.done(): - task.cancel() - # Wait for all background tasks to complete or be cancelled - if background_tasks: - await asyncio.gather(*background_tasks, return_exceptions=True) - - # Attempt a graceful close if it's not already closed - if ws.application_state == WebSocketState.CONNECTED: - try: - await ws.close() - except Exception as exc: - # Socket is probably already closed; log and ignore - logger.debug(f"WebSocket already closed: {exc}") diff --git a/api/services/smart_turn/websocket_smart_turn.py b/api/services/smart_turn/websocket_smart_turn.py deleted file mode 100644 index 82a7e6f..0000000 --- a/api/services/smart_turn/websocket_smart_turn.py +++ /dev/null @@ -1,314 +0,0 @@ -"""Smart-Turn analyzer that talks to a FastAPI WebSocket endpoint. - -This analyzer keeps a persistent WebSocket connection alive so that the TCP/TLS -handshake and HTTP upgrade happen only once per call session. Each speech -segment is sent as a single binary message containing the NumPy-serialized -float32 array, and a JSON reply is expected in return. - -Rewritten to use the websockets library for simplified connection management. -""" - -from __future__ import annotations - -import asyncio -import io -import json -import random -import time -from typing import Any, Dict, Optional - -import numpy as np -import websockets -from loguru import logger -from pipecat.audio.turn.smart_turn.base_smart_turn import ( - BaseSmartTurn, - SmartTurnTimeoutException, -) - - -class WebSocketSmartTurnAnalyzer(BaseSmartTurn): - """End-of-turn analyzer that sends audio via a persistent WebSocket.""" - - def __init__( - self, - *, - url: str, - headers: Optional[Dict[str, str]] = None, - service_context: Optional[Any] = None, - **kwargs, - ) -> None: - super().__init__(**kwargs) - self._url = url.rstrip("/") - self._headers = headers or {} - self._service_context = service_context - - # WebSocket connection - self._ws: Optional[websockets.WebSocketClientProtocol] = None - self._ws_lock = asyncio.Lock() - - # Connection management - self._connection_task: Optional[asyncio.Task] = None - self._reconnect_delay = 1.0 - self._max_reconnect_delay = 30.0 - self._closing = False - self._connection_closed_event = asyncio.Event() - - # Connection health monitoring - self._last_successful_request = 0.0 - self._connection_attempts = 0 - - # Start connection manager in background - try: - loop = asyncio.get_event_loop() - if loop.is_running(): - self._connection_task = loop.create_task(self._connection_manager()) - except RuntimeError: - logger.debug( - "No running loop at object creation time. Connection will be opened lazily on first use." - ) - - def _serialize_array(self, audio_array: np.ndarray) -> bytes: - """Serialize numpy array to bytes.""" - buffer = io.BytesIO() - np.save(buffer, audio_array) - return buffer.getvalue() - - async def _connection_manager(self) -> None: - """Manages WebSocket connection lifecycle with automatic reconnection.""" - while not self._closing: - try: - # Establish connection - await self._establish_connection() - - # Reset reconnect delay on successful connection - self._reconnect_delay = 1.0 - self._connection_attempts = 0 - - # Wait for connection close event - self._connection_closed_event.clear() - await self._connection_closed_event.wait() - - logger.debug("WebSocket connection closed") - - except Exception as e: - logger.error(f"Connection manager error: {e}") - - finally: - # Clean up connection - if self._ws: - try: - await self._ws.close() - except: - pass - self._ws = None - - if not self._closing: - # Exponential backoff for reconnection - self._connection_attempts += 1 - delay = min( - self._reconnect_delay - * (2 ** min(self._connection_attempts - 1, 5)), - self._max_reconnect_delay, - ) - # Add jitter to avoid thundering herd - delay += random.uniform(0, 0.5) - logger.info( - f"Reconnecting in {delay:.1f} seconds (attempt {self._connection_attempts})" - ) - await asyncio.sleep(delay) - - async def _establish_connection(self) -> None: - """Establish a new WebSocket connection with retry logic.""" - logger.debug("Establishing new WebSocket connection to Smart-Turn service...") - - # Prepare headers - additional_headers = dict(self._headers) - if self._service_context is not None: - additional_headers["X-Service-Context"] = str(self._service_context) - - # _init_sample_rate is being set in the constructor, which we should - # use in case self._sample_rate is not set yet. The actual _sample_rate - # is being set in the set_sample_rate() method - # but in case of WebSocketSmartTurnAnalyzer, we establish the websocket connection - # during __init__() and won't see the set_sample_rate until later. So, lets - # user the _init_sample_rate instead - _sample_rate = self._sample_rate or self._init_sample_rate - - if _sample_rate > 0: - additional_headers["X-Sample-Rate"] = str(_sample_rate) - - max_attempts = 3 - for attempt in range(max_attempts): - try: - # Add jitter to prevent thundering herd - if attempt > 0: - jitter = 0.1 * attempt - await asyncio.sleep(jitter) - - # Connect with websockets library - self._ws = await websockets.connect( - self._url, - additional_headers=additional_headers, - ping_interval=5.0, # let websockets send pings every 5s - ping_timeout=3.0, # fail fast if no pong in 3s - close_timeout=10, - max_size=10 * 1024 * 1024, # 10MB max message size - ) - - logger.info("WebSocket connection established successfully") - return - - except asyncio.CancelledError: - raise - except Exception as exc: - logger.warning( - f"Failed to establish WebSocket (attempt {attempt + 1}/{max_attempts}): {exc}" - ) - if attempt == max_attempts - 1: - raise - await asyncio.sleep(0.5 * (attempt + 1)) - - async def _ensure_ws(self) -> websockets.WebSocketClientProtocol: - """Return a connected WebSocket, waiting for connection if necessary.""" - async with self._ws_lock: - # If connection manager isn't running, start it - if not self._connection_task or self._connection_task.done(): - self._connection_task = asyncio.create_task(self._connection_manager()) - - # Wait for connection with timeout - start_time = time.time() - max_wait_time = 10.0 - - while not self._closing: - if self._ws: - return self._ws - - elapsed = time.time() - start_time - if elapsed > max_wait_time: - raise Exception( - f"Timeout waiting for WebSocket connection after {max_wait_time}s" - ) - - await asyncio.sleep(0.1) - - if self._closing: - raise Exception("Analyzer is closing") - - raise Exception("Failed to establish WebSocket connection") - - async def _predict_endpoint(self, audio_array: np.ndarray) -> Dict[str, Any]: - """Send audio and await JSON response via WebSocket.""" - data_bytes = self._serialize_array(audio_array) - - try: - # Ensure we have a connection - ws = await self._ensure_ws() - - # Send data - try: - await ws.send(data_bytes) - except Exception as e: - logger.error(f"Failed to send data: {e}") - self._connection_closed_event.set() - return { - "prediction": 0, - "probability": 0.0, - "metrics": {"inference_time": 0.0, "total_time": 0.0}, - } - - # Wait for response - start_time = time.time() - while True: - remaining_timeout = self._params.stop_secs - (time.time() - start_time) - if remaining_timeout <= 0: - raise SmartTurnTimeoutException( - f"Request exceeded {self._params.stop_secs} seconds." - ) - - try: - # Receive message with timeout - message = await asyncio.wait_for( - ws.recv(), timeout=min(remaining_timeout, 0.5) - ) - - # Handle text messages (JSON responses) - if isinstance(message, str): - try: - result = json.loads(message) - - # Skip ping/pong messages - if result.get("type") in ["ping", "pong"]: - continue - - # Validate prediction response - if "prediction" not in result: - if "type" in result: - continue - else: - logger.error( - "Invalid response format from Smart-Turn service" - ) - return { - "prediction": 0, - "probability": 0.0, - "metrics": { - "inference_time": 0.0, - "total_time": 0.0, - }, - } - - self._last_successful_request = time.time() - return result - - except json.JSONDecodeError as exc: - logger.error( - f"Smart turn service returned invalid JSON: {exc}" - ) - raise - else: - logger.error(f"Unexpected message type: {type(message)}") - - except asyncio.TimeoutError: - continue - except websockets.exceptions.ConnectionClosed: - logger.warning("WebSocket connection closed during prediction") - self._connection_closed_event.set() - return { - "prediction": 0, - "probability": 0.0, - "metrics": {"inference_time": 0.0, "total_time": 0.0}, - } - - except SmartTurnTimeoutException: - raise - except Exception as exc: - logger.error(f"Smart turn prediction failed over WebSocket: {exc}") - self._connection_closed_event.set() - return { - "prediction": 0, - "probability": 0.0, - "metrics": {"inference_time": 0.0, "total_time": 0.0}, - } - - async def close(self): - """Asynchronously close the WebSocket.""" - self._closing = True - self._connection_closed_event.set() - - async with self._ws_lock: - # Cancel tasks - if self._connection_task and not self._connection_task.done(): - self._connection_task.cancel() - try: - await self._connection_task - except asyncio.CancelledError: - pass - - # Close WebSocket - if self._ws: - try: - await self._ws.close() - except: - pass - finally: - self._ws = None