add smart turn as provider

This commit is contained in:
Abhishek Kumar 2026-01-06 15:26:07 +05:30
parent 6d589b7452
commit e301579c31
9 changed files with 960 additions and 178 deletions

View file

@ -1,27 +1,29 @@
# STT Evaluation Benchmark
Benchmark for comparing Speech-to-Text providers with focus on:
Benchmark for comparing Speech-to-Text providers using **WebSocket streaming** with focus on:
- **Speaker diarization** - identifying who said what
- **Keyterm boosting** - improving recognition of specific terms (Deepgram)
## Providers
| Provider | Diarization | Keyterm Boost | Notes |
|----------|-------------|---------------|-------|
| Deepgram | Yes | Yes | `diarize=true`, `keyterm` param |
| Speechmatics | Yes | No | `diarization: "speaker"` config |
| Provider | Diarization | Keyterm Boost | Streaming |
|----------|-------------|---------------|-----------|
| Deepgram | Yes | Yes | WebSocket (v1/v2) |
| Speechmatics | Yes | Additional vocab | WebSocket RT |
## Setup
```bash
# Install dependencies (httpx is required)
pip install httpx
# Install dependencies
pip install websockets
# Set API keys
export DEEPGRAM_API_KEY="your-key"
export SPEECHMATICS_API_KEY="your-key"
```
**Note:** Requires `ffmpeg` installed for audio conversion to PCM16.
## Usage
Run from the project root directory:
@ -33,9 +35,12 @@ python -m evals.stt.benchmark audio/multi_speaker.m4a --diarize
# Test only Deepgram
python -m evals.stt.benchmark audio/multi_speaker.m4a --diarize --providers deepgram
# Test with keyterm boosting (Deepgram only)
# Test with keyterm boosting (Deepgram)
python -m evals.stt.benchmark audio/multi_speaker.m4a --diarize --keyterms "Dograh" "Pipecat"
# Use different sample rate (default: 8000 Hz)
python -m evals.stt.benchmark audio/multi_speaker.m4a --diarize --sample-rate 16000
# Show word-level timings
python -m evals.stt.benchmark audio/multi_speaker.m4a --diarize --show-words
@ -50,8 +55,9 @@ python -m evals.stt.benchmark audio/multi_speaker.m4a --diarize --save
| `audio_file` | Path to audio file (relative to evals/stt/ or absolute) |
| `--providers` | Providers to test: `deepgram`, `speechmatics` (default: both) |
| `--diarize` | Enable speaker diarization |
| `--keyterms` | Keywords to boost (Deepgram only) |
| `--keyterms` | Keywords to boost (Deepgram) / additional vocab (Speechmatics) |
| `--language` | Language code (default: en) |
| `--sample-rate` | Audio sample rate for streaming (default: 8000) |
| `--show-words` | Show individual word timings |
| `--save` | Save results to JSON in `results/` |
@ -64,16 +70,33 @@ evals/stt/
├── results/ # Saved benchmark results (JSON)
├── providers/ # STT provider implementations
│ ├── base.py # Base classes
│ ├── deepgram_provider.py
│ └── speechmatics_provider.py
│ ├── deepgram_provider.py # WebSocket streaming
│ └── speechmatics_provider.py # WebSocket streaming
├── audio_streamer.py # PCM16 audio file streamer
├── benchmark.py # Main runner script
└── README.md
```
## How It Works
1. **Audio Conversion**: The `AudioStreamer` converts any audio file to raw PCM16 using ffmpeg
2. **WebSocket Connection**: Providers connect to their respective WebSocket APIs
3. **Streaming**: Audio is sent in chunks (configurable sample rate, default 8kHz)
4. **Result Collection**: Transcripts and speaker info are collected from WebSocket responses
5. **Comparison**: Results are parsed into a common format for comparison
## Output Example
```
Audio file: /path/to/audio/multi_speaker.m4a
Providers: ['deepgram', 'speechmatics']
Diarization: True
Sample rate: 8000 Hz
============================================================
Provider: DEEPGRAM
============================================================
Duration: 45.32s
Speakers detected: 2 - ['0', '1']
@ -84,17 +107,29 @@ Hello, welcome to the demo...
[0.0s] Speaker 0: Hello, welcome to the demo.
[2.5s] Speaker 1: Thanks for having me.
...
============================================================
COMPARISON SUMMARY
============================================================
Provider Duration Speakers Words
---------------------------------------------
deepgram 45.32 2 312
speechmatics 45.32 2 308
```
## Adding New Providers
1. Create a new file in `providers/` (e.g., `whisper_provider.py`)
2. Implement the `STTProvider` abstract class
3. Add to `providers/__init__.py`
4. Add to `benchmark.py` provider choices
2. Implement the `STTProvider` abstract class with WebSocket streaming
3. Use `AudioStreamer` for PCM16 conversion
4. Add to `providers/__init__.py`
5. Add to `benchmark.py` provider choices
## API Documentation
- Deepgram Streaming: https://developers.deepgram.com/docs/live-streaming-audio
- Deepgram Diarization: https://developers.deepgram.com/docs/diarization
- Deepgram Keyterms: https://developers.deepgram.com/docs/keyterm
- Speechmatics RT API: https://docs.speechmatics.com/rt-api-ref
- Speechmatics Diarization: https://docs.speechmatics.com/features/diarization

BIN
evals/stt/audio/vad.m4a Normal file

Binary file not shown.

127
evals/stt/audio_streamer.py Normal file
View file

@ -0,0 +1,127 @@
"""Audio file streamer - converts audio files to PCM16 streams."""
import asyncio
import subprocess
from dataclasses import dataclass
from pathlib import Path
from typing import AsyncIterator
@dataclass
class AudioConfig:
"""Audio streaming configuration."""
sample_rate: int = 8000
channels: int = 1
sample_width: int = 2 # 16-bit = 2 bytes
chunk_duration_ms: int = 80 # Send chunks every 80ms
@property
def chunk_size(self) -> int:
"""Bytes per chunk based on duration."""
samples_per_chunk = int(self.sample_rate * self.chunk_duration_ms / 1000)
return samples_per_chunk * self.channels * self.sample_width
class AudioStreamer:
"""Streams audio files as PCM16 chunks.
Converts any audio format to raw PCM16 using ffmpeg and streams
in real-time chunks to simulate live audio.
"""
def __init__(self, config: AudioConfig | None = None):
self.config = config or AudioConfig()
def convert_to_pcm16(self, audio_path: Path) -> bytes:
"""Convert audio file to raw PCM16 bytes using ffmpeg.
Args:
audio_path: Path to input audio file
Returns:
Raw PCM16 audio bytes
"""
cmd = [
"ffmpeg",
"-i",
str(audio_path),
"-f",
"s16le", # signed 16-bit little-endian
"-acodec",
"pcm_s16le",
"-ar",
str(self.config.sample_rate),
"-ac",
str(self.config.channels),
"-", # output to stdout
]
result = subprocess.run(
cmd,
capture_output=True,
check=True,
)
return result.stdout
async def stream_file(
self,
audio_path: Path,
realtime: bool = True,
) -> AsyncIterator[bytes]:
"""Stream audio file as PCM16 chunks.
Args:
audio_path: Path to audio file
realtime: If True, add delays to simulate real-time streaming
Yields:
PCM16 audio chunks
"""
# Convert entire file to PCM16
pcm_data = self.convert_to_pcm16(audio_path)
chunk_size = self.config.chunk_size
delay = self.config.chunk_duration_ms / 1000.0 if realtime else 0
# Stream in chunks
for i in range(0, len(pcm_data), chunk_size):
chunk = pcm_data[i : i + chunk_size]
if chunk:
yield chunk
if realtime and delay > 0:
await asyncio.sleep(delay)
async def stream_file_fast(self, audio_path: Path) -> AsyncIterator[bytes]:
"""Stream audio file as fast as possible (no real-time delay).
Args:
audio_path: Path to audio file
Yields:
PCM16 audio chunks
"""
async for chunk in self.stream_file(audio_path, realtime=False):
yield chunk
def get_duration(self, audio_path: Path) -> float:
"""Get audio file duration in seconds.
Args:
audio_path: Path to audio file
Returns:
Duration in seconds
"""
cmd = [
"ffprobe",
"-v",
"error",
"-show_entries",
"format=duration",
"-of",
"default=noprint_wrappers=1:nokey=1",
str(audio_path),
]
result = subprocess.run(cmd, capture_output=True, text=True, check=True)
return float(result.stdout.strip())

View file

@ -20,14 +20,23 @@ from datetime import datetime
from pathlib import Path
from typing import Any
from evals.stt.providers import DeepgramProvider, SpeechmaticsProvider, STTProvider, TranscriptionResult
from evals.stt.providers import (
DeepgramProvider,
DeepgramFluxProvider,
SpeechmaticsProvider,
LocalSmartTurnProvider,
STTProvider,
TranscriptionResult,
)
def get_provider(name: str) -> STTProvider:
"""Get provider instance by name."""
providers = {
"deepgram": DeepgramProvider,
"deepgram-flux": DeepgramFluxProvider,
"speechmatics": SpeechmaticsProvider,
"local-smart-turn": LocalSmartTurnProvider,
}
if name not in providers:
raise ValueError(f"Unknown provider: {name}. Available: {list(providers.keys())}")
@ -145,7 +154,7 @@ Examples:
"--providers",
nargs="+",
default=["deepgram", "speechmatics"],
choices=["deepgram", "speechmatics"],
choices=["deepgram", "deepgram-flux", "speechmatics", "local-smart-turn"],
help="Providers to test (default: all)",
)
parser.add_argument(
@ -163,6 +172,12 @@ Examples:
default="en",
help="Language code (default: en)",
)
parser.add_argument(
"--sample-rate",
type=int,
default=8000,
help="Audio sample rate for streaming (default: 8000)",
)
parser.add_argument(
"--show-words",
action="store_true",
@ -195,6 +210,7 @@ Examples:
print(f"Audio file: {audio_path}")
print(f"Providers: {args.providers}")
print(f"Diarization: {args.diarize}")
print(f"Sample rate: {args.sample_rate} Hz")
if args.keyterms:
print(f"Keyterms: {args.keyterms}")
@ -209,6 +225,7 @@ Examples:
diarize=args.diarize,
keyterms=args.keyterms,
language=args.language,
sample_rate=args.sample_rate,
)
print_result(result, show_words=args.show_words)
results.append(result)

View file

@ -1,11 +1,15 @@
from .base import STTProvider, TranscriptionResult, Word
from .deepgram_provider import DeepgramProvider
from .deepgram_flux_provider import DeepgramFluxProvider
from .speechmatics_provider import SpeechmaticsProvider
from .local_smart_turn_provider import LocalSmartTurnProvider
__all__ = [
"STTProvider",
"TranscriptionResult",
"Word",
"DeepgramProvider",
"DeepgramFluxProvider",
"SpeechmaticsProvider",
"LocalSmartTurnProvider",
]

View file

@ -0,0 +1,225 @@
"""Deepgram Flux STT provider with WebSocket streaming.
Flux is Deepgram's conversational AI model with built-in turn detection.
It has a different API than Nova models - no language/punctuate/diarize params.
"""
import asyncio
import json
import os
from pathlib import Path
from typing import Any
from urllib.parse import urlencode
from loguru import logger
from ..audio_streamer import AudioConfig, AudioStreamer
from .base import STTProvider, TranscriptionResult, Word
try:
from websockets.asyncio.client import connect as websocket_connect
except ImportError:
raise ImportError("websockets required: pip install websockets")
class DeepgramFluxProvider(STTProvider):
"""Deepgram Flux Speech-to-Text provider with WebSocket streaming.
Flux is optimized for conversational AI with built-in turn detection.
Key differences from Nova:
- Uses v2 API endpoint
- Only supports English (flux-general-en)
- No punctuate, diarize, or language params
- Has turn detection events (StartOfTurn, EndOfTurn, EagerEndOfTurn)
- Supports keyterm boosting
API Docs: https://developers.deepgram.com/docs/
"""
WS_URL = "wss://api.deepgram.com/v2/listen"
def __init__(self, api_key: str | None = None):
self.api_key = api_key or os.getenv("DEEPGRAM_API_KEY")
if not self.api_key:
raise ValueError(
"Deepgram API key required. Set DEEPGRAM_API_KEY env var or pass api_key."
)
@property
def name(self) -> str:
return "deepgram-flux"
async def transcribe(
self,
audio_path: Path,
diarize: bool = False, # Ignored - Flux doesn't support diarization
keyterms: list[str] | None = None,
model: str = "flux-general-en",
sample_rate: int = 16000,
eot_threshold: float | None = None,
eot_timeout_ms: int | None = None,
eager_eot_threshold: float | None = None,
**kwargs: Any,
) -> TranscriptionResult:
"""Transcribe audio using Deepgram Flux WebSocket streaming.
Args:
audio_path: Path to audio file
diarize: IGNORED - Flux does not support diarization
keyterms: List of keywords to boost recognition
model: Flux model (default: flux-general-en)
sample_rate: Audio sample rate (default: 16000 for Flux)
eot_threshold: End-of-turn confidence threshold (0-1, default 0.7)
eot_timeout_ms: Timeout in ms to force end of turn (default 5000)
eager_eot_threshold: Threshold for eager end-of-turn events
**kwargs: Additional Flux parameters
Returns:
TranscriptionResult with transcript (no speaker info - Flux doesn't support diarization)
"""
if diarize:
logger.warning("Flux does not support diarization - ignoring diarize=True")
# Build query params - Flux only supports specific params
params: dict[str, Any] = {
"model": model,
"encoding": "linear16",
"sample_rate": sample_rate,
}
# Flux-specific turn detection params
if eot_threshold is not None:
params["eot_threshold"] = eot_threshold
if eot_timeout_ms is not None:
params["eot_timeout_ms"] = eot_timeout_ms
if eager_eot_threshold is not None:
params["eager_eot_threshold"] = eager_eot_threshold
# Build URL with params
url_parts = [f"{k}={v}" for k, v in params.items()]
# Add keyterms (repeated params)
if keyterms:
for term in keyterms:
url_parts.append(urlencode({"keyterm": term}))
ws_url = f"{self.WS_URL}?{'&'.join(url_parts)}"
logger.debug(f"Flux WebSocket URL: {ws_url}")
# Setup audio streamer
audio_config = AudioConfig(sample_rate=sample_rate)
streamer = AudioStreamer(audio_config)
# Collect results
all_transcripts: list[dict[str, Any]] = []
final_transcript = ""
duration = 0.0
connected = asyncio.Event()
async with websocket_connect(
ws_url,
additional_headers={"Authorization": f"Token {self.api_key}"},
) as ws:
async def send_audio():
"""Send audio chunks to Deepgram Flux."""
await connected.wait()
chunk_no = 0
async for chunk in streamer.stream_file(audio_path):
logger.debug(f"[deepgram-flux] Sent audio chunk {chunk_no}")
await ws.send(chunk)
chunk_no += 1
async def receive_messages():
"""Receive and collect Flux messages."""
nonlocal all_transcripts, final_transcript, duration
async for message in ws:
if isinstance(message, str):
data = json.loads(message)
msg_type = data.get("type")
logger.debug(f"[deepgram-flux] Received {msg_type}: {data}")
if msg_type == "Connected":
logger.info("[deepgram-flux] Connected")
connected.set()
elif msg_type == "TurnInfo":
event = data.get("event")
transcript = data.get("transcript", "")
words = data.get("words", [])
if event == "EndOfTurn":
if transcript:
final_transcript += transcript + " "
if words:
all_transcripts.append({
"transcript": transcript,
"words": words,
})
# Get duration from last word
if words:
last_word = words[-1]
duration = max(duration, last_word.get("end", 0))
elif event == "TurnResumed":
logger.debug("TurnResumed")
elif msg_type == "Error":
raise Exception(f"Deepgram Flux error: {data}")
# Run send and receive concurrently
send_task = asyncio.create_task(send_audio())
receive_task = asyncio.create_task(receive_messages())
await send_task
logger.debug("[deepgram-flux] Send task done")
try:
await asyncio.wait_for(receive_task, timeout=10.0)
except asyncio.TimeoutError:
pass
return self._parse_results(
all_transcripts, final_transcript.strip(), duration, params, keyterms
)
def _parse_results(
self,
transcripts: list[dict[str, Any]],
final_transcript: str,
duration: float,
params: dict[str, Any],
keyterms: list[str] | None,
) -> TranscriptionResult:
"""Parse collected Flux results into TranscriptionResult."""
words = []
for turn in transcripts:
for w in turn.get("words", []):
words.append(
Word(
word=w.get("word", ""),
start=w.get("start", 0.0),
end=w.get("end", 0.0),
confidence=w.get("confidence", 0.0),
speaker=None, # Flux doesn't support diarization
speaker_confidence=None,
)
)
stored_params = dict(params)
if keyterms:
stored_params["keyterms"] = keyterms
return TranscriptionResult(
provider=self.name,
transcript=final_transcript,
words=words,
speakers=[], # Flux doesn't support diarization
duration=duration,
raw_response={"transcripts": transcripts},
params=stored_params,
)

View file

@ -1,25 +1,38 @@
"""Deepgram STT provider."""
"""Deepgram STT provider with WebSocket streaming."""
import asyncio
import json
import os
from pathlib import Path
from typing import Any
from urllib.parse import urlencode
import httpx
from ..audio_streamer import AudioConfig, AudioStreamer
from .base import STTProvider, TranscriptionResult, Word
from loguru import logger
try:
from websockets.asyncio.client import connect as websocket_connect
except ImportError:
raise ImportError("websockets required: pip install websockets")
class DeepgramProvider(STTProvider):
"""Deepgram Speech-to-Text provider.
"""Deepgram Nova Speech-to-Text provider with WebSocket streaming.
API Docs: https://developers.deepgram.com/docs/
Supports:
- Speaker diarization via `diarize=true`
- Keyterm boosting via `keyterm` parameter (Nova-3 and Flux models)
- Keyterm boosting via `keyterm` parameter
- Real-time streaming via WebSocket
- Multiple languages
- Punctuation
For Flux models, use DeepgramFluxProvider instead.
"""
API_URL = "https://api.deepgram.com/v1/listen"
WS_URL = "wss://api.deepgram.com/v1/listen"
def __init__(self, api_key: str | None = None):
self.api_key = api_key or os.getenv("DEEPGRAM_API_KEY")
@ -37,113 +50,151 @@ class DeepgramProvider(STTProvider):
audio_path: Path,
diarize: bool = False,
keyterms: list[str] | None = None,
model: str = "nova-3",
model: str = "nova-3-general",
language: str = "en",
sample_rate: int = 8000,
punctuate: bool = True,
**kwargs: Any,
) -> TranscriptionResult:
"""Transcribe audio using Deepgram API.
"""Transcribe audio using Deepgram Nova WebSocket streaming.
Args:
audio_path: Path to audio file
diarize: Enable speaker diarization
keyterms: List of keywords to boost recognition
model: Deepgram model (nova-3, nova-2, etc.)
model: Deepgram Nova model (nova-3, nova-2, etc.)
language: Language code
sample_rate: Audio sample rate for streaming
punctuate: Add punctuation
**kwargs: Additional Deepgram parameters
Returns:
TranscriptionResult with transcript and speaker info
"""
# Build query params
params: dict[str, Any] = {
"model": model,
"language": language,
"punctuate": str(punctuate).lower(),
"encoding": "linear16",
"sample_rate": sample_rate,
"channels": 1,
"interim_results": "true",
"smart_format": "true",
"profanity_filter": "true",
"vad_events": "true"
}
if diarize:
params["diarize"] = "true"
# Add keyterms (Deepgram uses repeated keyterm params)
# Build URL with params
url_parts = [f"{k}={v}" for k, v in params.items()]
# Add keyterms (repeated params)
if keyterms:
params["keyterm"] = keyterms
for term in keyterms:
url_parts.append(urlencode({"keyterm": term}))
# Add any extra kwargs
params.update(kwargs)
# Add extra kwargs
for k, v in kwargs.items():
url_parts.append(f"{k}={v}")
# Read audio file
audio_data = audio_path.read_bytes()
ws_url = f"{self.WS_URL}?{'&'.join(url_parts)}"
logger.debug(f"Deepgram WebSocket URL: {ws_url}")
# Determine content type
suffix = audio_path.suffix.lower()
content_types = {
".wav": "audio/wav",
".mp3": "audio/mpeg",
".m4a": "audio/mp4",
".flac": "audio/flac",
".ogg": "audio/ogg",
".webm": "audio/webm",
}
content_type = content_types.get(suffix, "audio/wav")
# Setup audio streamer
audio_config = AudioConfig(sample_rate=sample_rate)
streamer = AudioStreamer(audio_config)
headers = {
"Authorization": f"Token {self.api_key}",
"Content-Type": content_type,
}
# Collect results
all_words: list[dict[str, Any]] = []
final_transcript = ""
duration = 0.0
async with httpx.AsyncClient(timeout=120.0) as client:
response = await client.post(
self.API_URL,
params=params,
headers=headers,
content=audio_data,
)
response.raise_for_status()
data = response.json()
try:
async with websocket_connect(
ws_url,
additional_headers={"Authorization": f"Token {self.api_key}"},
) as ws:
# Create tasks for sending and receiving
send_complete = asyncio.Event()
return self._parse_response(data, params)
async def send_audio():
"""Send audio chunks to Deepgram."""
chunk_no = 0
async for chunk in streamer.stream_file(audio_path):
logger.debug(f"[deepgram] Sent audio chunk {chunk_no}")
await ws.send(chunk)
chunk_no += 1
# Send close message
logger.debug(f"[deepgram] Sending CloseStream after {chunk_no} chunks")
await ws.send(json.dumps({"type": "CloseStream"}))
send_complete.set()
def _parse_response(
self, data: dict[str, Any], params: dict[str, Any]
async def receive_transcripts():
"""Receive and collect transcription results."""
nonlocal all_words, final_transcript, duration
async for message in ws:
if isinstance(message, str):
data = json.loads(message)
msg_type = data.get("type")
logger.debug(f"[deepgram] Received {msg_type}: {data}")
if msg_type == "Results":
# Nova-style response
channel = data.get("channel", {})
alternatives = channel.get("alternatives", [])
if alternatives:
alt = alternatives[0]
words = alt.get("words", [])
all_words.extend(words)
# Check if final
if data.get("is_final"):
final_transcript += alt.get("transcript", "") + " "
duration = max(
duration, data.get("duration", 0) + data.get("start", 0)
)
elif msg_type == "Metadata":
# Get duration from metadata
duration = data.get("duration", duration)
elif msg_type == "Error":
raise Exception(f"Deepgram error: {data}")
# Run send and receive concurrently
send_task = asyncio.create_task(send_audio())
receive_task = asyncio.create_task(receive_transcripts())
# Wait for send to complete, then wait a bit for final results
await send_task
try:
await asyncio.wait_for(receive_task, timeout=5.0)
except asyncio.TimeoutError:
pass # Normal - websocket closes after final results
except Exception as e:
logger.debug(e)
return self._parse_results(
all_words, final_transcript.strip(), duration, params, keyterms
)
def _parse_results(
self,
raw_words: list[dict[str, Any]],
transcript: str,
duration: float,
params: dict[str, Any],
keyterms: list[str] | None,
) -> TranscriptionResult:
"""Parse Deepgram API response."""
results = data.get("results", {})
channels = results.get("channels", [])
if not channels:
return TranscriptionResult(
provider=self.name,
transcript="",
words=[],
speakers=[],
duration=0.0,
raw_response=data,
params=params,
)
# Get first channel, first alternative
channel = channels[0]
alternatives = channel.get("alternatives", [])
if not alternatives:
return TranscriptionResult(
provider=self.name,
transcript="",
words=[],
speakers=[],
duration=0.0,
raw_response=data,
params=params,
)
alt = alternatives[0]
transcript = alt.get("transcript", "")
# Parse words with speaker info
"""Parse collected results into TranscriptionResult."""
words = []
speakers_set: set[str] = set()
for w in alt.get("words", []):
for w in raw_words:
speaker = str(w.get("speaker", "")) if "speaker" in w else None
if speaker:
speakers_set.add(speaker)
@ -159,9 +210,9 @@ class DeepgramProvider(STTProvider):
)
)
# Get duration from metadata
metadata = results.get("metadata", {})
duration = metadata.get("duration", 0.0)
stored_params = dict(params)
if keyterms:
stored_params["keyterms"] = keyterms
return TranscriptionResult(
provider=self.name,
@ -169,6 +220,6 @@ class DeepgramProvider(STTProvider):
words=words,
speakers=sorted(speakers_set),
duration=duration,
raw_response=data,
params=params,
raw_response={"words": raw_words},
params=stored_params,
)

View file

@ -0,0 +1,285 @@
"""Local Smart Turn provider for benchmarking end-of-turn detection.
Uses the pipecat smart-turn-v3 ONNX model for local ML-based turn detection.
This is NOT an STT provider - it only detects when a speaker has finished talking.
"""
import os
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Any
import numpy as np
from loguru import logger
from ..audio_streamer import AudioConfig, AudioStreamer
from .base import STTProvider, TranscriptionResult, Word
try:
import onnxruntime as ort
from transformers import WhisperFeatureExtractor
except ImportError:
raise ImportError(
"onnxruntime and transformers required: pip install onnxruntime transformers"
)
@dataclass
class TurnEvent:
"""Represents a detected turn event."""
timestamp: float # Time in audio when turn was detected
probability: float # Model confidence
prediction: int # 1=complete, 0=incomplete
inference_time_ms: float
class LocalSmartTurnProvider(STTProvider):
"""Local Smart Turn provider for end-of-turn detection benchmarking.
Uses the smart-turn-v3 ONNX model to detect when speakers finish talking.
This is useful for comparing turn detection accuracy against cloud services
like Deepgram Flux's built-in turn detection.
NOTE: This provider does NOT produce transcripts - only turn detection events.
"""
# Smart turn model requires 16kHz audio
REQUIRED_SAMPLE_RATE = 16000
# Model analyzes 8 seconds of audio
WINDOW_SECONDS = 8
def __init__(
self,
model_path: str | None = None,
cpu_count: int = 1,
):
"""Initialize the local smart turn provider.
Args:
model_path: Path to ONNX model file. If None, uses bundled model.
cpu_count: Number of CPUs for inference (default: 1)
"""
self.model_path = model_path
self.cpu_count = cpu_count
self._session = None
self._feature_extractor = None
def _load_model(self):
"""Lazy load the ONNX model."""
if self._session is not None:
return
model_path = self.model_path
if not model_path:
# Try to load bundled model from pipecat
model_name = "smart-turn-v3.1-cpu.onnx"
package_path = "pipecat.audio.turn.smart_turn.data"
try:
import importlib_resources as impresources
model_path = str(impresources.files(package_path).joinpath(model_name))
except Exception:
from importlib import resources as impresources
try:
with impresources.path(package_path, model_name) as f:
model_path = str(f)
except Exception:
model_path = str(impresources.files(package_path).joinpath(model_name))
logger.info(f"[local-smart-turn] Loading model from {model_path}")
# Configure ONNX runtime
so = ort.SessionOptions()
so.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL
so.inter_op_num_threads = 1
so.intra_op_num_threads = self.cpu_count
so.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
self._feature_extractor = WhisperFeatureExtractor(chunk_length=8)
self._session = ort.InferenceSession(model_path, sess_options=so)
logger.info("[local-smart-turn] Model loaded")
@property
def name(self) -> str:
return "local-smart-turn"
def _predict_endpoint(self, audio_array: np.ndarray) -> dict[str, Any]:
"""Predict end-of-turn using the ONNX model.
Args:
audio_array: Audio samples as float32 numpy array (16kHz)
Returns:
Dict with prediction (0/1) and probability
"""
# Truncate to last 8 seconds or pad to 8 seconds
max_samples = self.WINDOW_SECONDS * self.REQUIRED_SAMPLE_RATE
if len(audio_array) > max_samples:
audio_array = audio_array[-max_samples:]
elif len(audio_array) < max_samples:
padding = max_samples - len(audio_array)
audio_array = np.pad(audio_array, (padding, 0), mode="constant", constant_values=0)
# Process using Whisper's feature extractor
inputs = self._feature_extractor(
audio_array,
sampling_rate=self.REQUIRED_SAMPLE_RATE,
return_tensors="np",
padding="max_length",
max_length=self.WINDOW_SECONDS * self.REQUIRED_SAMPLE_RATE,
truncation=True,
do_normalize=True,
)
# Extract features for ONNX
input_features = inputs.input_features.squeeze(0).astype(np.float32)
input_features = np.expand_dims(input_features, axis=0)
# Run inference
start_time = time.perf_counter()
outputs = self._session.run(None, {"input_features": input_features})
inference_time = (time.perf_counter() - start_time) * 1000
# Extract probability (model returns sigmoid probabilities)
probability = outputs[0][0].item()
prediction = 1 if probability > 0.5 else 0
return {
"prediction": prediction,
"probability": probability,
"inference_time_ms": inference_time,
}
async def transcribe(
self,
audio_path: Path,
diarize: bool = False, # Ignored - not applicable
keyterms: list[str] | None = None, # Ignored - not applicable
sample_rate: int = 16000, # Must be 16kHz for smart turn
analysis_interval_ms: int = 500, # How often to check for turn completion
**kwargs: Any,
) -> TranscriptionResult:
"""Analyze audio for turn detection events.
NOTE: This does NOT produce transcripts. It detects when speakers
finish talking using ML-based turn detection.
Args:
audio_path: Path to audio file
diarize: Ignored (not applicable for turn detection)
keyterms: Ignored (not applicable for turn detection)
sample_rate: Must be 16000 Hz for smart turn model
analysis_interval_ms: How often to run turn detection (ms)
**kwargs: Additional parameters (ignored)
Returns:
TranscriptionResult with turn detection events in raw_response
"""
if sample_rate != self.REQUIRED_SAMPLE_RATE:
logger.warning(
f"[local-smart-turn] Sample rate must be {self.REQUIRED_SAMPLE_RATE}Hz, "
f"overriding {sample_rate}Hz"
)
sample_rate = self.REQUIRED_SAMPLE_RATE
# Load model if not already loaded
self._load_model()
# Setup audio streamer at 16kHz
audio_config = AudioConfig(sample_rate=sample_rate)
streamer = AudioStreamer(audio_config)
# Get audio duration
duration = streamer.get_duration(audio_path)
logger.info(f"[local-smart-turn] Processing {audio_path} ({duration:.2f}s)")
# Collect all audio first (smart turn needs to analyze segments)
pcm_data = streamer.convert_to_pcm16(audio_path)
# Convert to float32 for model
audio_int16 = np.frombuffer(pcm_data, dtype=np.int16)
audio_float32 = audio_int16.astype(np.float32) / 32768.0
# Analyze at intervals
turn_events: list[TurnEvent] = []
samples_per_interval = int(sample_rate * analysis_interval_ms / 1000)
window_samples = self.WINDOW_SECONDS * sample_rate
chunk_no = 0
for end_sample in range(samples_per_interval, len(audio_float32), samples_per_interval):
# Get window of audio ending at current position
start_sample = max(0, end_sample - window_samples)
audio_window = audio_float32[start_sample:end_sample]
current_time = end_sample / sample_rate
logger.debug(f"[local-smart-turn] Analyzing chunk {chunk_no} at {current_time:.2f}s")
result = self._predict_endpoint(audio_window)
turn_events.append(TurnEvent(
timestamp=current_time,
probability=result["probability"],
prediction=result["prediction"],
inference_time_ms=result["inference_time_ms"],
))
if result["prediction"] == 1:
logger.info(
f"[local-smart-turn] Turn complete at {current_time:.2f}s "
f"(prob={result['probability']:.3f})"
f"(inf time ms={result["inference_time_ms"]})"
)
chunk_no += 1
# Create result
# Convert turn events to word-like format for compatibility
words = []
for event in turn_events:
if event.prediction == 1:
words.append(Word(
word=f"[END_OF_TURN prob={event.probability:.2f}]",
start=event.timestamp - 0.1,
end=event.timestamp,
confidence=event.probability,
speaker=None,
speaker_confidence=None,
))
# Count completed turns
completed_turns = sum(1 for e in turn_events if e.prediction == 1)
params = {
"sample_rate": sample_rate,
"analysis_interval_ms": analysis_interval_ms,
"window_seconds": self.WINDOW_SECONDS,
}
return TranscriptionResult(
provider=self.name,
transcript=f"[Turn detection only - {completed_turns} turns detected]",
words=words,
speakers=[], # Not applicable
duration=duration,
raw_response={
"turn_events": [
{
"timestamp": e.timestamp,
"probability": e.probability,
"prediction": e.prediction,
"inference_time_ms": e.inference_time_ms,
}
for e in turn_events
],
"completed_turns": completed_turns,
"total_analyses": len(turn_events),
"avg_inference_time_ms": (
sum(e.inference_time_ms for e in turn_events) / len(turn_events)
if turn_events else 0
),
},
params=params,
)

View file

@ -1,38 +1,41 @@
"""Speechmatics STT provider."""
"""Speechmatics STT provider with WebSocket streaming."""
import asyncio
import json
import os
from pathlib import Path
from typing import Any
import httpx
from loguru import logger
from ..audio_streamer import AudioConfig, AudioStreamer
from .base import STTProvider, TranscriptionResult, Word
try:
from websockets.asyncio.client import connect as websocket_connect
except ImportError:
raise ImportError("websockets required: pip install websockets")
class SpeechmaticsProvider(STTProvider):
"""Speechmatics Speech-to-Text provider.
"""Speechmatics Speech-to-Text provider with WebSocket streaming.
API Docs: https://docs.speechmatics.com/
Supports:
- Speaker diarization via `diarization: "speaker"` config
- Speaker sensitivity tuning
- Real-time streaming via WebSocket
"""
# EU and US endpoints available
API_URL = "https://asr.api.speechmatics.com/v2/jobs"
def __init__(self, api_key: str | None = None, region: str = "eu1"):
def __init__(self, api_key: str | None = None, region: str = "eu2"):
self.api_key = api_key or os.getenv("SPEECHMATICS_API_KEY")
if not self.api_key:
raise ValueError(
"Speechmatics API key required. Set SPEECHMATICS_API_KEY env var or pass api_key."
)
# Set region-specific endpoint
if region == "eu1":
self.api_url = "https://eu1.asr.api.speechmatics.com/v2/jobs"
else:
self.api_url = "https://asr.api.speechmatics.com/v2/jobs"
self.ws_url = f"wss://{region}.rt.speechmatics.com/v2"
@property
def name(self) -> str:
@ -45,27 +48,32 @@ class SpeechmaticsProvider(STTProvider):
keyterms: list[str] | None = None,
language: str = "en",
operating_point: str = "enhanced",
sample_rate: int = 8000,
speaker_sensitivity: float | None = None,
max_speakers: int | None = None,
**kwargs: Any,
) -> TranscriptionResult:
"""Transcribe audio using Speechmatics API.
"""Transcribe audio using Speechmatics WebSocket streaming.
Args:
audio_path: Path to audio file
diarize: Enable speaker diarization
keyterms: Not directly supported by Speechmatics (ignored)
keyterms: Additional vocabulary (limited support)
language: Language code
operating_point: "standard" or "enhanced"
sample_rate: Audio sample rate for streaming
speaker_sensitivity: 0.0-1.0, higher = more speakers detected
max_speakers: Maximum number of speakers to detect
**kwargs: Additional config parameters
Returns:
TranscriptionResult with transcript and speaker info
"""
# Build transcription config
# Build transcription config for StartRecognition message
transcription_config: dict[str, Any] = {
"language": language,
"operating_point": operating_point,
"enable_partials": False,
}
if diarize:
@ -74,13 +82,20 @@ class SpeechmaticsProvider(STTProvider):
transcription_config["speaker_diarization_config"] = {
"speaker_sensitivity": speaker_sensitivity
}
if max_speakers is not None:
if "speaker_diarization_config" not in transcription_config:
transcription_config["speaker_diarization_config"] = {}
transcription_config["speaker_diarization_config"]["max_speakers"] = max_speakers
# Add any extra config
transcription_config.update(kwargs)
# Add additional vocabulary if provided
if keyterms:
transcription_config["additional_vocab"] = [{"content": term} for term in keyterms]
config = {
"type": "transcription",
"transcription_config": transcription_config,
# Audio format config
audio_format = {
"type": "raw",
"encoding": "pcm_s16le",
"sample_rate": sample_rate,
}
# Store params for result
@ -88,79 +103,104 @@ class SpeechmaticsProvider(STTProvider):
"diarize": diarize,
"language": language,
"operating_point": operating_point,
"sample_rate": sample_rate,
"speaker_sensitivity": speaker_sensitivity,
"max_speakers": max_speakers,
}
headers = {
"Authorization": f"Bearer {self.api_key}",
}
# Setup audio streamer
audio_config = AudioConfig(sample_rate=sample_rate)
streamer = AudioStreamer(audio_config)
# Create job with multipart form
async with httpx.AsyncClient(timeout=300.0) as client:
# Submit job
with open(audio_path, "rb") as f:
files = {
"data_file": (audio_path.name, f, "audio/mpeg"),
"config": (None, str(config).replace("'", '"'), "application/json"),
}
response = await client.post(
self.api_url,
headers=headers,
files=files,
)
response.raise_for_status()
job_data = response.json()
# Collect results
all_results: list[dict[str, Any]] = []
recognition_started = asyncio.Event()
transcription_complete = asyncio.Event()
job_id = job_data.get("id")
if not job_id:
raise ValueError(f"No job ID in response: {job_data}")
async with websocket_connect(
self.ws_url,
additional_headers={"Authorization": f"Bearer {self.api_key}"},
) as ws:
# Send StartRecognition message
start_msg = {
"message": "StartRecognition",
"transcription_config": transcription_config,
"audio_format": audio_format,
}
await ws.send(json.dumps(start_msg))
# Poll for completion
result_data = await self._wait_for_job(client, job_id, headers)
async def send_audio():
"""Send audio chunks after recognition starts."""
await recognition_started.wait()
return self._parse_response(result_data, params)
chunk_no = 0
async for chunk in streamer.stream_file(audio_path):
logger.debug(f"[speechmatics] Sent audio chunk {chunk_no}")
await ws.send(chunk)
chunk_no += 1
async def _wait_for_job(
self, client: httpx.AsyncClient, job_id: str, headers: dict[str, str]
) -> dict[str, Any]:
"""Poll job status until complete."""
import asyncio
# Signal end of audio with last sequence number
logger.debug(f"[speechmatics] Sending EndOfStream after {chunk_no} chunks")
await ws.send(json.dumps({"message": "EndOfStream", "last_seq_no": chunk_no}))
job_url = f"{self.api_url}/{job_id}"
transcript_url = f"{job_url}/transcript?format=json-v2"
async def receive_messages():
"""Receive and process messages."""
nonlocal all_results
max_attempts = 120 # 10 minutes with 5s intervals
for _ in range(max_attempts):
# Check job status
status_response = await client.get(job_url, headers=headers)
status_response.raise_for_status()
status_data = status_response.json()
async for message in ws:
if isinstance(message, str):
data = json.loads(message)
msg_type = data.get("message")
logger.debug(f"[speechmatics] Received {msg_type}: {data}")
job_status = status_data.get("job", {}).get("status")
if msg_type == "RecognitionStarted":
logger.info("[speechmatics] Connected")
recognition_started.set()
if job_status == "done":
# Get transcript
transcript_response = await client.get(transcript_url, headers=headers)
transcript_response.raise_for_status()
return transcript_response.json()
elif job_status == "rejected":
raise ValueError(f"Job rejected: {status_data}")
elif job_status == "deleted":
raise ValueError(f"Job deleted: {status_data}")
elif msg_type == "AddTranscript":
# Final transcript segment
results = data.get("results", [])
all_results.extend(results)
await asyncio.sleep(5)
elif msg_type == "EndOfTranscript":
transcription_complete.set()
return
raise TimeoutError(f"Job {job_id} did not complete in time")
elif msg_type == "Error":
raise Exception(f"Speechmatics error: {data}")
def _parse_response(
self, data: dict[str, Any], params: dict[str, Any]
elif msg_type == "Warning":
logger.warning(f"[speechmatics] Warning: {data.get('reason')}")
# Run send and receive concurrently
send_task = asyncio.create_task(send_audio())
receive_task = asyncio.create_task(receive_messages())
# Wait for completion
await send_task
try:
await asyncio.wait_for(transcription_complete.wait(), timeout=30.0)
except asyncio.TimeoutError:
pass
receive_task.cancel()
try:
await receive_task
except asyncio.CancelledError:
pass
return self._parse_results(all_results, params)
def _parse_results(
self,
results: list[dict[str, Any]],
params: dict[str, Any],
) -> TranscriptionResult:
"""Parse Speechmatics API response."""
results = data.get("results", [])
"""Parse Speechmatics results."""
words = []
speakers_set: set[str] = set()
transcript_parts = []
duration = 0.0
for item in results:
item_type = item.get("type")
@ -176,27 +216,25 @@ class SpeechmaticsProvider(STTProvider):
if speaker:
speakers_set.add(speaker)
end_time = item.get("end_time", 0.0)
duration = max(duration, end_time)
if item_type == "word":
words.append(
Word(
word=content,
start=item.get("start_time", 0.0),
end=item.get("end_time", 0.0),
end=end_time,
confidence=alt.get("confidence", 0.0),
speaker=speaker,
speaker_confidence=None, # Not provided by Speechmatics
speaker_confidence=None,
)
)
transcript_parts.append(content)
elif item_type == "punctuation":
# Append punctuation to last word in transcript
if transcript_parts:
transcript_parts[-1] += content
# Get metadata
metadata = data.get("metadata", {})
duration = metadata.get("duration", 0.0)
transcript = " ".join(transcript_parts)
return TranscriptionResult(
@ -205,6 +243,6 @@ class SpeechmaticsProvider(STTProvider):
words=words,
speakers=sorted(speakers_set),
duration=duration,
raw_response=data,
raw_response={"results": results},
params=params,
)