chore: remove unused smart_turn service

Fixes #323, #324, #325.
This commit is contained in:
Abhishek Kumar 2026-05-27 09:49:36 +00:00
parent 573dd68d76
commit 6d78537297
3 changed files with 0 additions and 795 deletions

View file

@ -1,3 +0,0 @@
from .websocket_smart_turn import WebSocketSmartTurnAnalyzer
__all__ = ["WebSocketSmartTurnAnalyzer"]

View file

@ -1,478 +0,0 @@
import asyncio
import io
import json
import logging
import os
import sys
import time
from contextlib import asynccontextmanager
from datetime import datetime
from pathlib import Path
import numpy as np
from fastapi import (
BackgroundTasks,
FastAPI,
HTTPException,
Request,
WebSocket,
WebSocketDisconnect,
WebSocketException,
status,
)
from fastapi.websockets import WebSocketState
from pipecat.audio.turn.smart_turn.local_smart_turn_v2 import LocalSmartTurnAnalyzerV2
from scipy.io import wavfile
LOG_LEVEL = (
logging.DEBUG
if os.environ.get("LOG_LEVEL", "DEBUG").lower() == "debug"
else logging.INFO
)
logger = logging.getLogger("smart_turn")
logger.setLevel(LOG_LEVEL)
handler = logging.StreamHandler(sys.stdout)
handler.setFormatter(
logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
)
logger.addHandler(handler)
# ----------------------------------------------------------------------------
# Configuration
# ----------------------------------------------------------------------------
MODEL_PATH = os.getenv("LOCAL_SMART_TURN_MODEL_PATH", "pipecat-ai/smart-turn-v2")
# ----------------------------------------------------------------------------
# Analyzer Pool
# ----------------------------------------------------------------------------
class _AnalyzerWrapper:
"""Wraps a LocalSmartTurnAnalyzer with a lock so only one request can use it at a time."""
def __init__(self, analyzer: LocalSmartTurnAnalyzerV2):
self.analyzer = analyzer
self.lock = asyncio.Lock()
_analyzer_wrapper: _AnalyzerWrapper | None = None # Will be initialised in the lifespan
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Manage the application lifespan - startup and shutdown logic."""
# Startup logic
global _analyzer_wrapper
if _analyzer_wrapper is None:
logger.debug("Initializing LocalSmartTurnAnalyzer")
analyzer = LocalSmartTurnAnalyzerV2(smart_turn_model_path=MODEL_PATH)
_analyzer_wrapper = _AnalyzerWrapper(analyzer)
logger.debug("LocalSmartTurnAnalyzer initialized")
yield # Application runs here
# Shutdown logic (if needed in the future)
# Any cleanup code would go here
app = FastAPI(
title="Smart Turn API",
description="A FastAPI application exposing LocalSmartTurnAnalyzer via HTTP",
lifespan=lifespan,
)
# ----------------------------------------------------------------------------
# API Endpoints
# ----------------------------------------------------------------------------
async def save_wav_file(
audio_array: np.ndarray,
prediction: int,
probability: float,
service_id: str | None = None,
sample_rate: int = 16000,
) -> None:
"""Save audio data as a WAV file in the background.
Runs the blocking ``wavfile.write`` call in a thread so that the event loop
is not blocked. This function is now ``async`` so it can be scheduled with
``asyncio.create_task`` from the WebSocket endpoint, while still being
compatible with ``BackgroundTasks`` (which will ``await`` coroutine
functions).
Args:
audio_array: The audio data as a numpy array
prediction: The prediction result (0 or 1)
probability: The probability of the prediction
service_id: Optional service identifier
sample_rate: The sample rate of the audio (default: 16000 Hz)
"""
def _blocking_save() -> None:
try:
# Generate filename with current timestamp and prediction
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")[:-3] # Include ms
# Include service_id in filename if available
service_prefix = f"{service_id}_" if service_id else ""
root_dir = (
Path(__file__).resolve().parents[3]
) # dograh/api/services/smart_turn/app.py
filename = (
root_dir
/ f"smart_turn_pipeline/{service_prefix}{timestamp}_{prediction}_{probability}.wav"
)
# Convert float32 [-1, 1] back to int16 PCM for WAV file
audio_int16 = np.clip(audio_array * 32767, -32768, 32767).astype(np.int16)
# Use provided sample rate
wavfile.write(filename, sample_rate, audio_int16)
length_seconds = len(audio_array) / sample_rate
log_message = f"Saved audio to {filename} (length: {length_seconds:.2f}s, prediction: {prediction}"
if service_id:
log_message += f", service_id: {service_id}"
log_message += ")"
logger.info(log_message)
except Exception as exc: # pragma: no cover best-effort logging only
log_message = f"Failed to save WAV file: {exc}"
if service_id:
log_message += f" (service_id: {service_id})"
logger.error(log_message)
# Offload the blocking I/O to a thread to avoid blocking the event loop
await asyncio.to_thread(_blocking_save)
@app.post("/raw", status_code=status.HTTP_200_OK)
async def handle_raw(request: Request, background_tasks: BackgroundTasks):
"""
Accept a NumPy-serialized float32 array (written via ``np.save``) in the body and
return a JSON prediction compatible with ``HttpSmartTurnAnalyzer``.
"""
# ------------------------------------------------------------------
# Secret key validation
# ------------------------------------------------------------------
expected_secret = os.getenv("SMART_TURN_HTTP_SERVICE_KEY")
if expected_secret: # If a secret is configured, enforce validation
provided_secret = request.headers.get("X-API-Key")
if provided_secret != expected_secret:
logger.warning(
"Unauthorized access attempt to /raw endpoint with invalid or missing secret key"
)
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Unauthorized",
)
# ------------------------------------------------------------------
# Start total-time measurement as early as possible
# ------------------------------------------------------------------
request_start_time = time.perf_counter()
# ------------------------------------------------------------------
# Log that we received a request (before doing any heavy work)
# ------------------------------------------------------------------
logger.debug("Received /raw request")
body = await request.body()
if not body:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail="Empty request body"
)
# Extract service context and sample rate from headers
service_id = request.headers.get("X-Service-Context")
sample_rate_str = request.headers.get("X-Sample-Rate")
sample_rate = int(sample_rate_str) if sample_rate_str else 16000
# Deserialize NumPy array
try:
audio_array = np.load(io.BytesIO(body))
except Exception as exc:
error_msg = f"Invalid NumPy payload: {exc}"
if service_id:
error_msg += f" (service_id: {service_id})"
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=error_msg,
)
wrapper = _analyzer_wrapper
if wrapper is None:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="Analyzer not initialized",
)
# Run inference guarded by the wrapper lock so the model isn't used concurrently
log_msg = "Going to acquire lock for model inference"
if service_id:
log_msg += f" (service_id: {service_id})"
logger.debug(log_msg)
async with wrapper.lock:
log_msg = "Acquired lock for model inference"
if service_id:
log_msg += f" (service_id: {service_id})"
logger.debug(log_msg)
# Measure inference-only latency
inference_start_time = time.perf_counter()
result = await wrapper.analyzer._predict_endpoint(audio_array)
inference_time = time.perf_counter() - inference_start_time
# Calculate total processing time (from request receipt to response preparation)
total_time = time.perf_counter() - request_start_time
log_msg = (
f"Inference done result: {result['prediction']} "
f"probability: {result['probability']} time taken: {inference_time:.2f}s total: {total_time:.2f}s"
)
if service_id:
log_msg += f" (service_id: {service_id})"
logger.debug(log_msg)
# Ensure metrics section exists so client code can parse it consistently
metrics = result.get("metrics", {})
# Overwrite / set the timing metrics explicitly
metrics["inference_time"] = inference_time
metrics["total_time"] = total_time
result["metrics"] = metrics
logger.debug(f"Result for service_id: {service_id} is: {result}")
# Add service_id to result for potential client use
if service_id:
result["service_id"] = service_id
# Persist audio in background so it doesn't block the response.
background_tasks.add_task(
save_wav_file,
audio_array,
result.get("prediction", 0),
result.get("probability", 0),
service_id,
sample_rate,
)
return result
@app.get("/")
async def root():
"""Health-check endpoint."""
return {"message": "Smart Turn API is running"}
# ----------------------------------------------------------------------------
# WebSocket endpoint
# ----------------------------------------------------------------------------
@app.websocket("/ws")
async def websocket_endpoint(ws: WebSocket):
"""Handle streaming Smart Turn requests over WebSocket.
Each incoming binary message must be a NumPy-serialized float32 array (as
produced by ``np.save``). A JSON-formatted prediction (identical to the
``/raw`` HTTP endpoint) is sent back as a text message.
"""
# Extract optional secret key from headers (during handshake)
expected_secret = os.getenv("SMART_TURN_HTTP_SERVICE_KEY")
if expected_secret:
provided_secret = ws.headers.get("X-API-Key")
if provided_secret != expected_secret:
await ws.close(code=4401, reason="Unauthorized")
return
# Accept the websocket connection and log it
await ws.accept()
service_id = ws.headers.get("X-Service-Context")
sample_rate_str = ws.headers.get("X-Sample-Rate")
sample_rate = int(sample_rate_str) if sample_rate_str else 16000
logger.debug(
f"WebSocket connection accepted from service_id: {service_id}, sample_rate: {sample_rate}"
)
# ------------------------------------------------------------------
# Tunables consider moving to env vars for ops control
# ------------------------------------------------------------------
connection_timeout = 120.0 # Seconds of inactivity before timing out
MAX_BINARY_SIZE = int(
os.getenv("SMART_TURN_MAX_PAYLOAD", 10 * 1024 * 1024) # 10MB max message size
)
# Track background tasks so we can cancel them on disconnect
background_tasks = set() # Track background tasks for cleanup
try:
logger.debug("Entering WebSocket message loop")
while True:
data = None # Initialize data for each iteration
try:
logger.debug("Waiting for WebSocket message…")
# Create receive task to handle timeout properly
receive_task = asyncio.create_task(ws.receive())
try:
msg = await asyncio.wait_for(
receive_task, timeout=connection_timeout
)
except asyncio.TimeoutError:
# Cancel the receive task to prevent it from running in background
receive_task.cancel()
try:
await receive_task
except asyncio.CancelledError:
pass
logger.warning(
f"WebSocket connection timeout for service_id: {service_id}"
)
try:
await ws.close(code=1001, reason="Connection timeout")
except Exception as e:
logger.debug(f"Error closing WebSocket after timeout: {e}")
break
except WebSocketDisconnect as e:
logger.debug(f"WebSocket client disconnected: {e}")
break
# Validate message structure
if not isinstance(msg, dict):
logger.error(f"Unexpected message type: {type(msg)}")
break
# Handle disconnect message explicitly
if msg.get("type") == "websocket.disconnect":
logger.debug("Client sent disconnect frame")
break
data = None
# Binary frame
if "bytes" in msg and msg["bytes"] is not None:
data = msg["bytes"]
logger.debug(
"Received WebSocket audio payload (%d bytes)", len(data)
)
except WebSocketDisconnect as e:
logger.debug(f"WebSocket client disconnected: {e}")
break
except Exception as e:
logger.error(f"Error in WebSocket loop: {e}")
break
if data is None:
continue
request_start_time = time.perf_counter()
# --------------------------------------------------------------
# Basic validation & secure deserialisation
# --------------------------------------------------------------
if len(data) > MAX_BINARY_SIZE:
logger.warning("Received payload exceeding maximum allowed size")
await ws.send_text('{"error": "Payload too large"}')
continue
# Deserialize NumPy array (pickle disabled for security)
try:
audio_array = np.load(io.BytesIO(data), allow_pickle=False)
except Exception as exc:
error_msg = f"Invalid NumPy payload: {exc}"
if service_id:
error_msg += f" (service_id: {service_id})"
# Send error response with proper error handling
if ws.application_state == WebSocketState.CONNECTED:
try:
await ws.send_text(f'{{"error": "{error_msg}"}}')
except Exception as e:
logger.error(f"Failed to send error message: {e}")
continue
wrapper = _analyzer_wrapper
if wrapper is None:
logger.error("Analyzer not initialized; closing connection")
if ws.application_state == WebSocketState.CONNECTED:
await ws.close(code=1011, reason="Analyzer not ready")
break
async with wrapper.lock:
inference_start_time = time.perf_counter()
result = await wrapper.analyzer._predict_endpoint(audio_array)
inference_time = time.perf_counter() - inference_start_time
# Timing metrics
total_time = time.perf_counter() - request_start_time
metrics = result.get("metrics", {})
metrics["inference_time"] = inference_time
metrics["total_time"] = total_time
result["metrics"] = metrics
logger.debug(f"Result for service_id: {service_id} is: {result}")
if service_id:
result["service_id"] = service_id
# Send result with proper error handling
try:
if ws.application_state == WebSocketState.CONNECTED:
await ws.send_text(json.dumps(result))
else:
logger.warning(
f"Cannot send result - WebSocket not connected for service_id: {service_id}"
)
break
except WebSocketDisconnect:
logger.debug(
f"Client disconnected while sending result for service_id: {service_id}"
)
break
except Exception as e:
logger.error(f"Failed to send result: {e}")
break
# Save audio in the background so that it doesn't block streaming
task = asyncio.create_task(
save_wav_file(
audio_array,
result.get("prediction", 0),
result.get("probability", 0),
service_id,
sample_rate,
)
)
# Track task and remove when done
background_tasks.add(task)
task.add_done_callback(background_tasks.discard)
except WebSocketException as exc:
logger.error(f"WebSocket error: {exc}")
finally:
# Cancel any remaining background tasks
for task in background_tasks:
if not task.done():
task.cancel()
# Wait for all background tasks to complete or be cancelled
if background_tasks:
await asyncio.gather(*background_tasks, return_exceptions=True)
# Attempt a graceful close if it's not already closed
if ws.application_state == WebSocketState.CONNECTED:
try:
await ws.close()
except Exception as exc:
# Socket is probably already closed; log and ignore
logger.debug(f"WebSocket already closed: {exc}")

View file

@ -1,314 +0,0 @@
"""Smart-Turn analyzer that talks to a FastAPI WebSocket endpoint.
This analyzer keeps a persistent WebSocket connection alive so that the TCP/TLS
handshake and HTTP upgrade happen only once per call session. Each speech
segment is sent as a single binary message containing the NumPy-serialized
float32 array, and a JSON reply is expected in return.
Rewritten to use the websockets library for simplified connection management.
"""
from __future__ import annotations
import asyncio
import io
import json
import random
import time
from typing import Any, Dict, Optional
import numpy as np
import websockets
from loguru import logger
from pipecat.audio.turn.smart_turn.base_smart_turn import (
BaseSmartTurn,
SmartTurnTimeoutException,
)
class WebSocketSmartTurnAnalyzer(BaseSmartTurn):
"""End-of-turn analyzer that sends audio via a persistent WebSocket."""
def __init__(
self,
*,
url: str,
headers: Optional[Dict[str, str]] = None,
service_context: Optional[Any] = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
self._url = url.rstrip("/")
self._headers = headers or {}
self._service_context = service_context
# WebSocket connection
self._ws: Optional[websockets.WebSocketClientProtocol] = None
self._ws_lock = asyncio.Lock()
# Connection management
self._connection_task: Optional[asyncio.Task] = None
self._reconnect_delay = 1.0
self._max_reconnect_delay = 30.0
self._closing = False
self._connection_closed_event = asyncio.Event()
# Connection health monitoring
self._last_successful_request = 0.0
self._connection_attempts = 0
# Start connection manager in background
try:
loop = asyncio.get_event_loop()
if loop.is_running():
self._connection_task = loop.create_task(self._connection_manager())
except RuntimeError:
logger.debug(
"No running loop at object creation time. Connection will be opened lazily on first use."
)
def _serialize_array(self, audio_array: np.ndarray) -> bytes:
"""Serialize numpy array to bytes."""
buffer = io.BytesIO()
np.save(buffer, audio_array)
return buffer.getvalue()
async def _connection_manager(self) -> None:
"""Manages WebSocket connection lifecycle with automatic reconnection."""
while not self._closing:
try:
# Establish connection
await self._establish_connection()
# Reset reconnect delay on successful connection
self._reconnect_delay = 1.0
self._connection_attempts = 0
# Wait for connection close event
self._connection_closed_event.clear()
await self._connection_closed_event.wait()
logger.debug("WebSocket connection closed")
except Exception as e:
logger.error(f"Connection manager error: {e}")
finally:
# Clean up connection
if self._ws:
try:
await self._ws.close()
except:
pass
self._ws = None
if not self._closing:
# Exponential backoff for reconnection
self._connection_attempts += 1
delay = min(
self._reconnect_delay
* (2 ** min(self._connection_attempts - 1, 5)),
self._max_reconnect_delay,
)
# Add jitter to avoid thundering herd
delay += random.uniform(0, 0.5)
logger.info(
f"Reconnecting in {delay:.1f} seconds (attempt {self._connection_attempts})"
)
await asyncio.sleep(delay)
async def _establish_connection(self) -> None:
"""Establish a new WebSocket connection with retry logic."""
logger.debug("Establishing new WebSocket connection to Smart-Turn service...")
# Prepare headers
additional_headers = dict(self._headers)
if self._service_context is not None:
additional_headers["X-Service-Context"] = str(self._service_context)
# _init_sample_rate is being set in the constructor, which we should
# use in case self._sample_rate is not set yet. The actual _sample_rate
# is being set in the set_sample_rate() method
# but in case of WebSocketSmartTurnAnalyzer, we establish the websocket connection
# during __init__() and won't see the set_sample_rate until later. So, lets
# user the _init_sample_rate instead
_sample_rate = self._sample_rate or self._init_sample_rate
if _sample_rate > 0:
additional_headers["X-Sample-Rate"] = str(_sample_rate)
max_attempts = 3
for attempt in range(max_attempts):
try:
# Add jitter to prevent thundering herd
if attempt > 0:
jitter = 0.1 * attempt
await asyncio.sleep(jitter)
# Connect with websockets library
self._ws = await websockets.connect(
self._url,
additional_headers=additional_headers,
ping_interval=5.0, # let websockets send pings every 5s
ping_timeout=3.0, # fail fast if no pong in 3s
close_timeout=10,
max_size=10 * 1024 * 1024, # 10MB max message size
)
logger.info("WebSocket connection established successfully")
return
except asyncio.CancelledError:
raise
except Exception as exc:
logger.warning(
f"Failed to establish WebSocket (attempt {attempt + 1}/{max_attempts}): {exc}"
)
if attempt == max_attempts - 1:
raise
await asyncio.sleep(0.5 * (attempt + 1))
async def _ensure_ws(self) -> websockets.WebSocketClientProtocol:
"""Return a connected WebSocket, waiting for connection if necessary."""
async with self._ws_lock:
# If connection manager isn't running, start it
if not self._connection_task or self._connection_task.done():
self._connection_task = asyncio.create_task(self._connection_manager())
# Wait for connection with timeout
start_time = time.time()
max_wait_time = 10.0
while not self._closing:
if self._ws:
return self._ws
elapsed = time.time() - start_time
if elapsed > max_wait_time:
raise Exception(
f"Timeout waiting for WebSocket connection after {max_wait_time}s"
)
await asyncio.sleep(0.1)
if self._closing:
raise Exception("Analyzer is closing")
raise Exception("Failed to establish WebSocket connection")
async def _predict_endpoint(self, audio_array: np.ndarray) -> Dict[str, Any]:
"""Send audio and await JSON response via WebSocket."""
data_bytes = self._serialize_array(audio_array)
try:
# Ensure we have a connection
ws = await self._ensure_ws()
# Send data
try:
await ws.send(data_bytes)
except Exception as e:
logger.error(f"Failed to send data: {e}")
self._connection_closed_event.set()
return {
"prediction": 0,
"probability": 0.0,
"metrics": {"inference_time": 0.0, "total_time": 0.0},
}
# Wait for response
start_time = time.time()
while True:
remaining_timeout = self._params.stop_secs - (time.time() - start_time)
if remaining_timeout <= 0:
raise SmartTurnTimeoutException(
f"Request exceeded {self._params.stop_secs} seconds."
)
try:
# Receive message with timeout
message = await asyncio.wait_for(
ws.recv(), timeout=min(remaining_timeout, 0.5)
)
# Handle text messages (JSON responses)
if isinstance(message, str):
try:
result = json.loads(message)
# Skip ping/pong messages
if result.get("type") in ["ping", "pong"]:
continue
# Validate prediction response
if "prediction" not in result:
if "type" in result:
continue
else:
logger.error(
"Invalid response format from Smart-Turn service"
)
return {
"prediction": 0,
"probability": 0.0,
"metrics": {
"inference_time": 0.0,
"total_time": 0.0,
},
}
self._last_successful_request = time.time()
return result
except json.JSONDecodeError as exc:
logger.error(
f"Smart turn service returned invalid JSON: {exc}"
)
raise
else:
logger.error(f"Unexpected message type: {type(message)}")
except asyncio.TimeoutError:
continue
except websockets.exceptions.ConnectionClosed:
logger.warning("WebSocket connection closed during prediction")
self._connection_closed_event.set()
return {
"prediction": 0,
"probability": 0.0,
"metrics": {"inference_time": 0.0, "total_time": 0.0},
}
except SmartTurnTimeoutException:
raise
except Exception as exc:
logger.error(f"Smart turn prediction failed over WebSocket: {exc}")
self._connection_closed_event.set()
return {
"prediction": 0,
"probability": 0.0,
"metrics": {"inference_time": 0.0, "total_time": 0.0},
}
async def close(self):
"""Asynchronously close the WebSocket."""
self._closing = True
self._connection_closed_event.set()
async with self._ws_lock:
# Cancel tasks
if self._connection_task and not self._connection_task.done():
self._connection_task.cancel()
try:
await self._connection_task
except asyncio.CancelledError:
pass
# Close WebSocket
if self._ws:
try:
await self._ws.close()
except:
pass
finally:
self._ws = None