mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-07 07:55:16 +02:00
* feat: add headless widget for deployment * feat: call callbacks at the right time * feat: add onCallConnected & onCallDisconnected callback * feat: add a button with text for floating widget * feat: add headless widget for deployment * feat: call callbacks at the right time * feat: add onCallConnected & onCallDisconnected callback * feat: add a button with text for floating widget * docs: web widget * fix: format issue in pre-pr drift check * fix: fix CD to rely on pipecat dev dependey * chore: update message --------- Co-authored-by: Abhishek Kumar <abhishek@a6k.me>
478 lines
18 KiB
Python
478 lines
18 KiB
Python
import asyncio
|
||
import io
|
||
import json
|
||
import logging
|
||
import os
|
||
import sys
|
||
import time
|
||
from contextlib import asynccontextmanager
|
||
from datetime import datetime
|
||
from pathlib import Path
|
||
|
||
import numpy as np
|
||
from fastapi import (
|
||
BackgroundTasks,
|
||
FastAPI,
|
||
HTTPException,
|
||
Request,
|
||
WebSocket,
|
||
WebSocketDisconnect,
|
||
WebSocketException,
|
||
status,
|
||
)
|
||
from fastapi.websockets import WebSocketState
|
||
from pipecat.audio.turn.smart_turn.local_smart_turn_v2 import LocalSmartTurnAnalyzerV2
|
||
from scipy.io import wavfile
|
||
|
||
LOG_LEVEL = (
|
||
logging.DEBUG
|
||
if os.environ.get("LOG_LEVEL", "DEBUG").lower() == "debug"
|
||
else logging.INFO
|
||
)
|
||
|
||
logger = logging.getLogger("smart_turn")
|
||
logger.setLevel(LOG_LEVEL)
|
||
handler = logging.StreamHandler(sys.stdout)
|
||
handler.setFormatter(
|
||
logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
||
)
|
||
logger.addHandler(handler)
|
||
|
||
|
||
# ----------------------------------------------------------------------------
|
||
# Configuration
|
||
# ----------------------------------------------------------------------------
|
||
MODEL_PATH = os.getenv("LOCAL_SMART_TURN_MODEL_PATH", "pipecat-ai/smart-turn-v2")
|
||
|
||
# ----------------------------------------------------------------------------
|
||
# Analyzer Pool
|
||
# ----------------------------------------------------------------------------
|
||
|
||
|
||
class _AnalyzerWrapper:
|
||
"""Wraps a LocalSmartTurnAnalyzer with a lock so only one request can use it at a time."""
|
||
|
||
def __init__(self, analyzer: LocalSmartTurnAnalyzerV2):
|
||
self.analyzer = analyzer
|
||
self.lock = asyncio.Lock()
|
||
|
||
|
||
_analyzer_wrapper: _AnalyzerWrapper | None = None # Will be initialised in the lifespan
|
||
|
||
|
||
@asynccontextmanager
|
||
async def lifespan(app: FastAPI):
|
||
"""Manage the application lifespan - startup and shutdown logic."""
|
||
# Startup logic
|
||
global _analyzer_wrapper
|
||
|
||
if _analyzer_wrapper is None:
|
||
logger.debug("Initializing LocalSmartTurnAnalyzer")
|
||
analyzer = LocalSmartTurnAnalyzerV2(smart_turn_model_path=MODEL_PATH)
|
||
_analyzer_wrapper = _AnalyzerWrapper(analyzer)
|
||
logger.debug("LocalSmartTurnAnalyzer initialized")
|
||
|
||
yield # Application runs here
|
||
|
||
# Shutdown logic (if needed in the future)
|
||
# Any cleanup code would go here
|
||
|
||
|
||
app = FastAPI(
|
||
title="Smart Turn API",
|
||
description="A FastAPI application exposing LocalSmartTurnAnalyzer via HTTP",
|
||
lifespan=lifespan,
|
||
)
|
||
|
||
# ----------------------------------------------------------------------------
|
||
# API Endpoints
|
||
# ----------------------------------------------------------------------------
|
||
|
||
|
||
async def save_wav_file(
|
||
audio_array: np.ndarray,
|
||
prediction: int,
|
||
probability: float,
|
||
service_id: str | None = None,
|
||
sample_rate: int = 16000,
|
||
) -> None:
|
||
"""Save audio data as a WAV file in the background.
|
||
|
||
Runs the blocking ``wavfile.write`` call in a thread so that the event loop
|
||
is not blocked. This function is now ``async`` so it can be scheduled with
|
||
``asyncio.create_task`` from the WebSocket endpoint, while still being
|
||
compatible with ``BackgroundTasks`` (which will ``await`` coroutine
|
||
functions).
|
||
|
||
Args:
|
||
audio_array: The audio data as a numpy array
|
||
prediction: The prediction result (0 or 1)
|
||
probability: The probability of the prediction
|
||
service_id: Optional service identifier
|
||
sample_rate: The sample rate of the audio (default: 16000 Hz)
|
||
"""
|
||
|
||
def _blocking_save() -> None:
|
||
try:
|
||
# Generate filename with current timestamp and prediction
|
||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")[:-3] # Include ms
|
||
|
||
# Include service_id in filename if available
|
||
service_prefix = f"{service_id}_" if service_id else ""
|
||
|
||
root_dir = (
|
||
Path(__file__).resolve().parents[3]
|
||
) # dograh/api/services/smart_turn/app.py
|
||
filename = (
|
||
root_dir
|
||
/ f"smart_turn_pipeline/{service_prefix}{timestamp}_{prediction}_{probability}.wav"
|
||
)
|
||
|
||
# Convert float32 [-1, 1] back to int16 PCM for WAV file
|
||
audio_int16 = np.clip(audio_array * 32767, -32768, 32767).astype(np.int16)
|
||
|
||
# Use provided sample rate
|
||
wavfile.write(filename, sample_rate, audio_int16)
|
||
|
||
length_seconds = len(audio_array) / sample_rate
|
||
log_message = f"Saved audio to {filename} (length: {length_seconds:.2f}s, prediction: {prediction}"
|
||
if service_id:
|
||
log_message += f", service_id: {service_id}"
|
||
log_message += ")"
|
||
|
||
logger.info(log_message)
|
||
|
||
except Exception as exc: # pragma: no cover – best-effort logging only
|
||
log_message = f"Failed to save WAV file: {exc}"
|
||
if service_id:
|
||
log_message += f" (service_id: {service_id})"
|
||
logger.error(log_message)
|
||
|
||
# Offload the blocking I/O to a thread to avoid blocking the event loop
|
||
await asyncio.to_thread(_blocking_save)
|
||
|
||
|
||
@app.post("/raw", status_code=status.HTTP_200_OK)
|
||
async def handle_raw(request: Request, background_tasks: BackgroundTasks):
|
||
"""
|
||
Accept a NumPy-serialized float32 array (written via ``np.save``) in the body and
|
||
return a JSON prediction compatible with ``HttpSmartTurnAnalyzer``.
|
||
"""
|
||
|
||
# ------------------------------------------------------------------
|
||
# Secret key validation
|
||
# ------------------------------------------------------------------
|
||
expected_secret = os.getenv("SMART_TURN_HTTP_SERVICE_KEY")
|
||
if expected_secret: # If a secret is configured, enforce validation
|
||
provided_secret = request.headers.get("X-API-Key")
|
||
if provided_secret != expected_secret:
|
||
logger.warning(
|
||
"Unauthorized access attempt to /raw endpoint with invalid or missing secret key"
|
||
)
|
||
raise HTTPException(
|
||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||
detail="Unauthorized",
|
||
)
|
||
|
||
# ------------------------------------------------------------------
|
||
# Start total-time measurement as early as possible
|
||
# ------------------------------------------------------------------
|
||
request_start_time = time.perf_counter()
|
||
|
||
# ------------------------------------------------------------------
|
||
# Log that we received a request (before doing any heavy work)
|
||
# ------------------------------------------------------------------
|
||
logger.debug("Received /raw request")
|
||
|
||
body = await request.body()
|
||
if not body:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_400_BAD_REQUEST, detail="Empty request body"
|
||
)
|
||
|
||
# Extract service context and sample rate from headers
|
||
service_id = request.headers.get("X-Service-Context")
|
||
sample_rate_str = request.headers.get("X-Sample-Rate")
|
||
sample_rate = int(sample_rate_str) if sample_rate_str else 16000
|
||
|
||
# Deserialize NumPy array
|
||
try:
|
||
audio_array = np.load(io.BytesIO(body))
|
||
except Exception as exc:
|
||
error_msg = f"Invalid NumPy payload: {exc}"
|
||
if service_id:
|
||
error_msg += f" (service_id: {service_id})"
|
||
raise HTTPException(
|
||
status_code=status.HTTP_400_BAD_REQUEST,
|
||
detail=error_msg,
|
||
)
|
||
|
||
wrapper = _analyzer_wrapper
|
||
if wrapper is None:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||
detail="Analyzer not initialized",
|
||
)
|
||
|
||
# Run inference guarded by the wrapper lock so the model isn't used concurrently
|
||
log_msg = "Going to acquire lock for model inference"
|
||
if service_id:
|
||
log_msg += f" (service_id: {service_id})"
|
||
logger.debug(log_msg)
|
||
|
||
async with wrapper.lock:
|
||
log_msg = "Acquired lock for model inference"
|
||
if service_id:
|
||
log_msg += f" (service_id: {service_id})"
|
||
logger.debug(log_msg)
|
||
|
||
# Measure inference-only latency
|
||
inference_start_time = time.perf_counter()
|
||
result = await wrapper.analyzer._predict_endpoint(audio_array)
|
||
inference_time = time.perf_counter() - inference_start_time
|
||
|
||
# Calculate total processing time (from request receipt to response preparation)
|
||
total_time = time.perf_counter() - request_start_time
|
||
|
||
log_msg = (
|
||
f"Inference done result: {result['prediction']} "
|
||
f"probability: {result['probability']} time taken: {inference_time:.2f}s total: {total_time:.2f}s"
|
||
)
|
||
if service_id:
|
||
log_msg += f" (service_id: {service_id})"
|
||
logger.debug(log_msg)
|
||
|
||
# Ensure metrics section exists so client code can parse it consistently
|
||
metrics = result.get("metrics", {})
|
||
# Overwrite / set the timing metrics explicitly
|
||
metrics["inference_time"] = inference_time
|
||
metrics["total_time"] = total_time
|
||
result["metrics"] = metrics
|
||
|
||
logger.debug(f"Result for service_id: {service_id} is: {result}")
|
||
|
||
# Add service_id to result for potential client use
|
||
if service_id:
|
||
result["service_id"] = service_id
|
||
|
||
# Persist audio in background so it doesn't block the response.
|
||
background_tasks.add_task(
|
||
save_wav_file,
|
||
audio_array,
|
||
result.get("prediction", 0),
|
||
result.get("probability", 0),
|
||
service_id,
|
||
sample_rate,
|
||
)
|
||
return result
|
||
|
||
|
||
@app.get("/")
|
||
async def root():
|
||
"""Health-check endpoint."""
|
||
return {"message": "Smart Turn API is running"}
|
||
|
||
|
||
# ----------------------------------------------------------------------------
|
||
# WebSocket endpoint
|
||
# ----------------------------------------------------------------------------
|
||
|
||
|
||
@app.websocket("/ws")
|
||
async def websocket_endpoint(ws: WebSocket):
|
||
"""Handle streaming Smart Turn requests over WebSocket.
|
||
|
||
Each incoming binary message must be a NumPy-serialized float32 array (as
|
||
produced by ``np.save``). A JSON-formatted prediction (identical to the
|
||
``/raw`` HTTP endpoint) is sent back as a text message.
|
||
"""
|
||
|
||
# Extract optional secret key from headers (during handshake)
|
||
expected_secret = os.getenv("SMART_TURN_HTTP_SERVICE_KEY")
|
||
if expected_secret:
|
||
provided_secret = ws.headers.get("X-API-Key")
|
||
if provided_secret != expected_secret:
|
||
await ws.close(code=4401, reason="Unauthorized")
|
||
return
|
||
|
||
# Accept the websocket connection and log it
|
||
await ws.accept()
|
||
|
||
service_id = ws.headers.get("X-Service-Context")
|
||
sample_rate_str = ws.headers.get("X-Sample-Rate")
|
||
sample_rate = int(sample_rate_str) if sample_rate_str else 16000
|
||
logger.debug(
|
||
f"WebSocket connection accepted from service_id: {service_id}, sample_rate: {sample_rate}"
|
||
)
|
||
|
||
# ------------------------------------------------------------------
|
||
# Tunables – consider moving to env vars for ops control
|
||
# ------------------------------------------------------------------
|
||
connection_timeout = 120.0 # Seconds of inactivity before timing out
|
||
MAX_BINARY_SIZE = int(
|
||
os.getenv("SMART_TURN_MAX_PAYLOAD", 10 * 1024 * 1024) # 10MB max message size
|
||
)
|
||
|
||
# Track background tasks so we can cancel them on disconnect
|
||
background_tasks = set() # Track background tasks for cleanup
|
||
|
||
try:
|
||
logger.debug("Entering WebSocket message loop")
|
||
while True:
|
||
data = None # Initialize data for each iteration
|
||
try:
|
||
logger.debug("Waiting for WebSocket message…")
|
||
|
||
# Create receive task to handle timeout properly
|
||
receive_task = asyncio.create_task(ws.receive())
|
||
try:
|
||
msg = await asyncio.wait_for(
|
||
receive_task, timeout=connection_timeout
|
||
)
|
||
except asyncio.TimeoutError:
|
||
# Cancel the receive task to prevent it from running in background
|
||
receive_task.cancel()
|
||
try:
|
||
await receive_task
|
||
except asyncio.CancelledError:
|
||
pass
|
||
|
||
logger.warning(
|
||
f"WebSocket connection timeout for service_id: {service_id}"
|
||
)
|
||
try:
|
||
await ws.close(code=1001, reason="Connection timeout")
|
||
except Exception as e:
|
||
logger.debug(f"Error closing WebSocket after timeout: {e}")
|
||
break
|
||
except WebSocketDisconnect as e:
|
||
logger.debug(f"WebSocket client disconnected: {e}")
|
||
break
|
||
|
||
# Validate message structure
|
||
if not isinstance(msg, dict):
|
||
logger.error(f"Unexpected message type: {type(msg)}")
|
||
break
|
||
|
||
# Handle disconnect message explicitly
|
||
if msg.get("type") == "websocket.disconnect":
|
||
logger.debug("Client sent disconnect frame")
|
||
break
|
||
|
||
data = None
|
||
# Binary frame
|
||
if "bytes" in msg and msg["bytes"] is not None:
|
||
data = msg["bytes"]
|
||
logger.debug(
|
||
"Received WebSocket audio payload (%d bytes)", len(data)
|
||
)
|
||
|
||
except WebSocketDisconnect as e:
|
||
logger.debug(f"WebSocket client disconnected: {e}")
|
||
break
|
||
except Exception as e:
|
||
logger.error(f"Error in WebSocket loop: {e}")
|
||
break
|
||
|
||
if data is None:
|
||
continue
|
||
|
||
request_start_time = time.perf_counter()
|
||
|
||
# --------------------------------------------------------------
|
||
# Basic validation & secure deserialisation
|
||
# --------------------------------------------------------------
|
||
if len(data) > MAX_BINARY_SIZE:
|
||
logger.warning("Received payload exceeding maximum allowed size")
|
||
await ws.send_text('{"error": "Payload too large"}')
|
||
continue
|
||
|
||
# Deserialize NumPy array (pickle disabled for security)
|
||
try:
|
||
audio_array = np.load(io.BytesIO(data), allow_pickle=False)
|
||
except Exception as exc:
|
||
error_msg = f"Invalid NumPy payload: {exc}"
|
||
if service_id:
|
||
error_msg += f" (service_id: {service_id})"
|
||
# Send error response with proper error handling
|
||
if ws.application_state == WebSocketState.CONNECTED:
|
||
try:
|
||
await ws.send_text(f'{{"error": "{error_msg}"}}')
|
||
except Exception as e:
|
||
logger.error(f"Failed to send error message: {e}")
|
||
continue
|
||
|
||
wrapper = _analyzer_wrapper
|
||
if wrapper is None:
|
||
logger.error("Analyzer not initialized; closing connection")
|
||
if ws.application_state == WebSocketState.CONNECTED:
|
||
await ws.close(code=1011, reason="Analyzer not ready")
|
||
break
|
||
|
||
async with wrapper.lock:
|
||
inference_start_time = time.perf_counter()
|
||
result = await wrapper.analyzer._predict_endpoint(audio_array)
|
||
inference_time = time.perf_counter() - inference_start_time
|
||
|
||
# Timing metrics
|
||
total_time = time.perf_counter() - request_start_time
|
||
metrics = result.get("metrics", {})
|
||
metrics["inference_time"] = inference_time
|
||
metrics["total_time"] = total_time
|
||
result["metrics"] = metrics
|
||
|
||
logger.debug(f"Result for service_id: {service_id} is: {result}")
|
||
|
||
if service_id:
|
||
result["service_id"] = service_id
|
||
|
||
# Send result with proper error handling
|
||
try:
|
||
if ws.application_state == WebSocketState.CONNECTED:
|
||
await ws.send_text(json.dumps(result))
|
||
else:
|
||
logger.warning(
|
||
f"Cannot send result - WebSocket not connected for service_id: {service_id}"
|
||
)
|
||
break
|
||
except WebSocketDisconnect:
|
||
logger.debug(
|
||
f"Client disconnected while sending result for service_id: {service_id}"
|
||
)
|
||
break
|
||
except Exception as e:
|
||
logger.error(f"Failed to send result: {e}")
|
||
break
|
||
|
||
# Save audio in the background so that it doesn't block streaming
|
||
task = asyncio.create_task(
|
||
save_wav_file(
|
||
audio_array,
|
||
result.get("prediction", 0),
|
||
result.get("probability", 0),
|
||
service_id,
|
||
sample_rate,
|
||
)
|
||
)
|
||
# Track task and remove when done
|
||
background_tasks.add(task)
|
||
task.add_done_callback(background_tasks.discard)
|
||
|
||
except WebSocketException as exc:
|
||
logger.error(f"WebSocket error: {exc}")
|
||
finally:
|
||
# Cancel any remaining background tasks
|
||
for task in background_tasks:
|
||
if not task.done():
|
||
task.cancel()
|
||
# Wait for all background tasks to complete or be cancelled
|
||
if background_tasks:
|
||
await asyncio.gather(*background_tasks, return_exceptions=True)
|
||
|
||
# Attempt a graceful close if it's not already closed
|
||
if ws.application_state == WebSocketState.CONNECTED:
|
||
try:
|
||
await ws.close()
|
||
except Exception as exc:
|
||
# Socket is probably already closed; log and ignore
|
||
logger.debug(f"WebSocket already closed: {exc}")
|