Merge branch 'feat/add-stt-evals' into update-pipecat-0.99

This commit is contained in:
Abhishek Kumar 2026-01-19 17:24:14 +05:30
commit fcf718d4e9
12 changed files with 1631 additions and 0 deletions

135
evals/stt/README.md Normal file
View file

@ -0,0 +1,135 @@
# STT Evaluation Benchmark
Benchmark for comparing Speech-to-Text providers using **WebSocket streaming** with focus on:
- **Speaker diarization** - identifying who said what
- **Keyterm boosting** - improving recognition of specific terms (Deepgram)
## Providers
| Provider | Diarization | Keyterm Boost | Streaming |
|----------|-------------|---------------|-----------|
| Deepgram | Yes | Yes | WebSocket (v1/v2) |
| Speechmatics | Yes | Additional vocab | WebSocket RT |
## Setup
```bash
# Install dependencies
pip install websockets
# Set API keys
export DEEPGRAM_API_KEY="your-key"
export SPEECHMATICS_API_KEY="your-key"
```
**Note:** Requires `ffmpeg` installed for audio conversion to PCM16.
## Usage
Run from the project root directory:
```bash
# Test both providers with diarization
python -m evals.stt.benchmark audio/multi_speaker.m4a --diarize
# Test only Deepgram
python -m evals.stt.benchmark audio/multi_speaker.m4a --diarize --providers deepgram
# Test with keyterm boosting (Deepgram)
python -m evals.stt.benchmark audio/multi_speaker.m4a --diarize --keyterms "Dograh" "Pipecat"
# Use different sample rate (default: 8000 Hz)
python -m evals.stt.benchmark audio/multi_speaker.m4a --diarize --sample-rate 16000
# Show word-level timings
python -m evals.stt.benchmark audio/multi_speaker.m4a --diarize --show-words
# Save results to JSON
python -m evals.stt.benchmark audio/multi_speaker.m4a --diarize --save
```
## CLI Options
| Option | Description |
|--------|-------------|
| `audio_file` | Path to audio file (relative to evals/stt/ or absolute) |
| `--providers` | Providers to test: `deepgram`, `speechmatics` (default: both) |
| `--diarize` | Enable speaker diarization |
| `--keyterms` | Keywords to boost (Deepgram) / additional vocab (Speechmatics) |
| `--language` | Language code (default: en) |
| `--sample-rate` | Audio sample rate for streaming (default: 8000) |
| `--show-words` | Show individual word timings |
| `--save` | Save results to JSON in `results/` |
## Directory Structure
```
evals/stt/
├── audio/ # Audio test files
│ └── multi_speaker.m4a
├── results/ # Saved benchmark results (JSON)
├── providers/ # STT provider implementations
│ ├── base.py # Base classes
│ ├── deepgram_provider.py # WebSocket streaming
│ └── speechmatics_provider.py # WebSocket streaming
├── audio_streamer.py # PCM16 audio file streamer
├── benchmark.py # Main runner script
└── README.md
```
## How It Works
1. **Audio Conversion**: The `AudioStreamer` converts any audio file to raw PCM16 using ffmpeg
2. **WebSocket Connection**: Providers connect to their respective WebSocket APIs
3. **Streaming**: Audio is sent in chunks (configurable sample rate, default 8kHz)
4. **Result Collection**: Transcripts and speaker info are collected from WebSocket responses
5. **Comparison**: Results are parsed into a common format for comparison
## Output Example
```
Audio file: /path/to/audio/multi_speaker.m4a
Providers: ['deepgram', 'speechmatics']
Diarization: True
Sample rate: 8000 Hz
============================================================
Provider: DEEPGRAM
============================================================
Duration: 45.32s
Speakers detected: 2 - ['0', '1']
Transcript:
Hello, welcome to the demo...
--- Speaker Segments ---
[0.0s] Speaker 0: Hello, welcome to the demo.
[2.5s] Speaker 1: Thanks for having me.
...
============================================================
COMPARISON SUMMARY
============================================================
Provider Duration Speakers Words
---------------------------------------------
deepgram 45.32 2 312
speechmatics 45.32 2 308
```
## Adding New Providers
1. Create a new file in `providers/` (e.g., `whisper_provider.py`)
2. Implement the `STTProvider` abstract class with WebSocket streaming
3. Use `AudioStreamer` for PCM16 conversion
4. Add to `providers/__init__.py`
5. Add to `benchmark.py` provider choices
## API Documentation
- Deepgram Streaming: https://developers.deepgram.com/docs/live-streaming-audio
- Deepgram Diarization: https://developers.deepgram.com/docs/diarization
- Deepgram Keyterms: https://developers.deepgram.com/docs/keyterm
- Speechmatics RT API: https://docs.speechmatics.com/rt-api-ref
- Speechmatics Diarization: https://docs.speechmatics.com/features/diarization

1
evals/stt/__init__.py Normal file
View file

@ -0,0 +1 @@
# STT Evaluation Benchmark

Binary file not shown.

BIN
evals/stt/audio/vad.m4a Normal file

Binary file not shown.

127
evals/stt/audio_streamer.py Normal file
View file

@ -0,0 +1,127 @@
"""Audio file streamer - converts audio files to PCM16 streams."""
import asyncio
import subprocess
from dataclasses import dataclass
from pathlib import Path
from typing import AsyncIterator
@dataclass
class AudioConfig:
"""Audio streaming configuration."""
sample_rate: int = 8000
channels: int = 1
sample_width: int = 2 # 16-bit = 2 bytes
chunk_duration_ms: int = 80 # Send chunks every 80ms
@property
def chunk_size(self) -> int:
"""Bytes per chunk based on duration."""
samples_per_chunk = int(self.sample_rate * self.chunk_duration_ms / 1000)
return samples_per_chunk * self.channels * self.sample_width
class AudioStreamer:
"""Streams audio files as PCM16 chunks.
Converts any audio format to raw PCM16 using ffmpeg and streams
in real-time chunks to simulate live audio.
"""
def __init__(self, config: AudioConfig | None = None):
self.config = config or AudioConfig()
def convert_to_pcm16(self, audio_path: Path) -> bytes:
"""Convert audio file to raw PCM16 bytes using ffmpeg.
Args:
audio_path: Path to input audio file
Returns:
Raw PCM16 audio bytes
"""
cmd = [
"ffmpeg",
"-i",
str(audio_path),
"-f",
"s16le", # signed 16-bit little-endian
"-acodec",
"pcm_s16le",
"-ar",
str(self.config.sample_rate),
"-ac",
str(self.config.channels),
"-", # output to stdout
]
result = subprocess.run(
cmd,
capture_output=True,
check=True,
)
return result.stdout
async def stream_file(
self,
audio_path: Path,
realtime: bool = True,
) -> AsyncIterator[bytes]:
"""Stream audio file as PCM16 chunks.
Args:
audio_path: Path to audio file
realtime: If True, add delays to simulate real-time streaming
Yields:
PCM16 audio chunks
"""
# Convert entire file to PCM16
pcm_data = self.convert_to_pcm16(audio_path)
chunk_size = self.config.chunk_size
delay = self.config.chunk_duration_ms / 1000.0 if realtime else 0
# Stream in chunks
for i in range(0, len(pcm_data), chunk_size):
chunk = pcm_data[i : i + chunk_size]
if chunk:
yield chunk
if realtime and delay > 0:
await asyncio.sleep(delay)
async def stream_file_fast(self, audio_path: Path) -> AsyncIterator[bytes]:
"""Stream audio file as fast as possible (no real-time delay).
Args:
audio_path: Path to audio file
Yields:
PCM16 audio chunks
"""
async for chunk in self.stream_file(audio_path, realtime=False):
yield chunk
def get_duration(self, audio_path: Path) -> float:
"""Get audio file duration in seconds.
Args:
audio_path: Path to audio file
Returns:
Duration in seconds
"""
cmd = [
"ffprobe",
"-v",
"error",
"-show_entries",
"format=duration",
"-of",
"default=noprint_wrappers=1:nokey=1",
str(audio_path),
]
result = subprocess.run(cmd, capture_output=True, text=True, check=True)
return float(result.stdout.strip())

247
evals/stt/benchmark.py Normal file
View file

@ -0,0 +1,247 @@
#!/usr/bin/env python3
"""STT Benchmark Runner.
Compare speech-to-text transcription across providers with focus on:
- Speaker diarization accuracy
- Keyword/keyterm recognition
- Transcription quality
Usage:
python -m evals.stt.benchmark audio/multi_speaker.m4a --diarize
python -m evals.stt.benchmark audio/multi_speaker.m4a --diarize --providers deepgram
python -m evals.stt.benchmark audio/multi_speaker.m4a --diarize --keyterms "Dograh" "Pipecat"
"""
import argparse
import asyncio
import json
import sys
from datetime import datetime
from pathlib import Path
from typing import Any
from evals.stt.providers import (
DeepgramProvider,
DeepgramFluxProvider,
SpeechmaticsProvider,
LocalSmartTurnProvider,
STTProvider,
TranscriptionResult,
)
def get_provider(name: str) -> STTProvider:
"""Get provider instance by name."""
providers = {
"deepgram": DeepgramProvider,
"deepgram-flux": DeepgramFluxProvider,
"speechmatics": SpeechmaticsProvider,
"local-smart-turn": LocalSmartTurnProvider,
}
if name not in providers:
raise ValueError(f"Unknown provider: {name}. Available: {list(providers.keys())}")
return providers[name]()
async def run_transcription(
provider: STTProvider,
audio_path: Path,
diarize: bool = False,
keyterms: list[str] | None = None,
**kwargs: Any,
) -> TranscriptionResult:
"""Run transcription with a provider."""
print(f"\n{'='*60}")
print(f"Provider: {provider.name.upper()}")
print(f"{'='*60}")
try:
result = await provider.transcribe(
audio_path,
diarize=diarize,
keyterms=keyterms,
**kwargs,
)
return result
except Exception as e:
print(f"Error with {provider.name}: {e}")
raise
def print_result(result: TranscriptionResult, show_words: bool = False) -> None:
"""Print transcription result."""
print(f"\nDuration: {result.duration:.2f}s")
print(f"Speakers detected: {len(result.speakers)} - {result.speakers}")
print(f"\nTranscript:\n{result.transcript}")
if result.speakers:
print(f"\n--- Speaker Segments ---")
for segment in result.get_speaker_segments():
speaker = segment["speaker"] or "?"
text = segment["text"]
start = segment["start"]
print(f"[{start:.1f}s] Speaker {speaker}: {text}")
if show_words:
print(f"\n--- Words ---")
for word in result.words[:50]: # First 50 words
speaker_info = f" (S{word.speaker})" if word.speaker else ""
print(f" {word.start:.2f}s: {word.word}{speaker_info} [{word.confidence:.2f}]")
if len(result.words) > 50:
print(f" ... and {len(result.words) - 50} more words")
def save_results(
results: list[TranscriptionResult],
output_dir: Path,
audio_name: str,
) -> Path:
"""Save results to JSON file."""
output_dir.mkdir(parents=True, exist_ok=True)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
output_file = output_dir / f"{audio_name}_{timestamp}.json"
output_data = {
"timestamp": timestamp,
"audio_file": audio_name,
"results": [r.to_dict() for r in results],
}
with open(output_file, "w") as f:
json.dump(output_data, f, indent=2)
print(f"\nResults saved to: {output_file}")
return output_file
def compare_results(results: list[TranscriptionResult]) -> None:
"""Compare results across providers."""
if len(results) < 2:
return
print(f"\n{'='*60}")
print("COMPARISON SUMMARY")
print(f"{'='*60}")
print(f"\n{'Provider':<15} {'Duration':<10} {'Speakers':<10} {'Words':<10}")
print("-" * 45)
for r in results:
print(f"{r.provider:<15} {r.duration:<10.2f} {len(r.speakers):<10} {len(r.words):<10}")
# Compare speaker counts
speaker_counts = {r.provider: len(r.speakers) for r in results}
if len(set(speaker_counts.values())) > 1:
print(f"\nNote: Providers detected different speaker counts: {speaker_counts}")
async def main() -> int:
parser = argparse.ArgumentParser(
description="STT Benchmark - Compare transcription providers",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
python -m evals.stt.benchmark audio/multi_speaker.m4a --diarize
python -m evals.stt.benchmark audio/multi_speaker.m4a --diarize --providers deepgram
python -m evals.stt.benchmark audio/multi_speaker.m4a --keyterms "Dograh" "API"
""",
)
parser.add_argument(
"audio_file",
type=str,
help="Path to audio file (relative to evals/stt/ or absolute)",
)
parser.add_argument(
"--providers",
nargs="+",
default=["deepgram", "speechmatics"],
choices=["deepgram", "deepgram-flux", "speechmatics", "local-smart-turn"],
help="Providers to test (default: all)",
)
parser.add_argument(
"--diarize",
action="store_true",
help="Enable speaker diarization",
)
parser.add_argument(
"--keyterms",
nargs="+",
help="Keywords to boost (Deepgram only)",
)
parser.add_argument(
"--language",
default="en",
help="Language code (default: en)",
)
parser.add_argument(
"--sample-rate",
type=int,
default=8000,
help="Audio sample rate for streaming (default: 8000)",
)
parser.add_argument(
"--show-words",
action="store_true",
help="Show individual word timings",
)
parser.add_argument(
"--save",
action="store_true",
help="Save results to JSON file",
)
parser.add_argument(
"--output-dir",
type=str,
default="results",
help="Output directory for results (default: results)",
)
args = parser.parse_args()
# Resolve audio path
script_dir = Path(__file__).parent
audio_path = Path(args.audio_file)
if not audio_path.is_absolute():
audio_path = script_dir / audio_path
if not audio_path.exists():
print(f"Error: Audio file not found: {audio_path}")
return 1
print(f"Audio file: {audio_path}")
print(f"Providers: {args.providers}")
print(f"Diarization: {args.diarize}")
print(f"Sample rate: {args.sample_rate} Hz")
if args.keyterms:
print(f"Keyterms: {args.keyterms}")
results: list[TranscriptionResult] = []
for provider_name in args.providers:
try:
provider = get_provider(provider_name)
result = await run_transcription(
provider,
audio_path,
diarize=args.diarize,
keyterms=args.keyterms,
language=args.language,
sample_rate=args.sample_rate,
)
print_result(result, show_words=args.show_words)
results.append(result)
except Exception as e:
print(f"\nFailed to run {provider_name}: {e}")
continue
if len(results) > 1:
compare_results(results)
if args.save and results:
output_dir = script_dir / args.output_dir
save_results(results, output_dir, audio_path.stem)
return 0
if __name__ == "__main__":
sys.exit(asyncio.run(main()))

View file

@ -0,0 +1,15 @@
from .base import STTProvider, TranscriptionResult, Word
from .deepgram_provider import DeepgramProvider
from .deepgram_flux_provider import DeepgramFluxProvider
from .speechmatics_provider import SpeechmaticsProvider
from .local_smart_turn_provider import LocalSmartTurnProvider
__all__ = [
"STTProvider",
"TranscriptionResult",
"Word",
"DeepgramProvider",
"DeepgramFluxProvider",
"SpeechmaticsProvider",
"LocalSmartTurnProvider",
]

123
evals/stt/providers/base.py Normal file
View file

@ -0,0 +1,123 @@
"""Base classes for STT providers."""
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
@dataclass
class Word:
"""Represents a transcribed word with metadata."""
word: str
start: float
end: float
confidence: float
speaker: str | None = None
speaker_confidence: float | None = None
def to_dict(self) -> dict[str, Any]:
return {
"word": self.word,
"start": self.start,
"end": self.end,
"confidence": self.confidence,
"speaker": self.speaker,
"speaker_confidence": self.speaker_confidence,
}
@dataclass
class TranscriptionResult:
"""Result from STT transcription."""
provider: str
transcript: str
words: list[Word]
speakers: list[str]
duration: float
raw_response: dict[str, Any] = field(default_factory=dict)
params: dict[str, Any] = field(default_factory=dict)
def to_dict(self) -> dict[str, Any]:
return {
"provider": self.provider,
"transcript": self.transcript,
"words": [w.to_dict() for w in self.words],
"speakers": self.speakers,
"duration": self.duration,
"params": self.params,
}
def get_speaker_segments(self) -> list[dict[str, Any]]:
"""Get transcript segmented by speaker."""
if not self.words:
return []
segments = []
current_speaker = None
current_text = []
segment_start = 0.0
for word in self.words:
if word.speaker != current_speaker:
if current_text:
segments.append(
{
"speaker": current_speaker,
"text": " ".join(current_text),
"start": segment_start,
"end": self.words[len(segments) - 1].end
if segments
else word.start,
}
)
current_speaker = word.speaker
current_text = [word.word]
segment_start = word.start
else:
current_text.append(word.word)
if current_text:
segments.append(
{
"speaker": current_speaker,
"text": " ".join(current_text),
"start": segment_start,
"end": self.words[-1].end if self.words else 0.0,
}
)
return segments
class STTProvider(ABC):
"""Abstract base class for STT providers."""
@property
@abstractmethod
def name(self) -> str:
"""Provider name."""
pass
@abstractmethod
async def transcribe(
self,
audio_path: Path,
diarize: bool = False,
keyterms: list[str] | None = None,
**kwargs: Any,
) -> TranscriptionResult:
"""Transcribe audio file.
Args:
audio_path: Path to the audio file
diarize: Enable speaker diarization
keyterms: List of keywords to boost (if supported)
**kwargs: Provider-specific parameters
Returns:
TranscriptionResult with transcript and metadata
"""
pass

View file

@ -0,0 +1,225 @@
"""Deepgram Flux STT provider with WebSocket streaming.
Flux is Deepgram's conversational AI model with built-in turn detection.
It has a different API than Nova models - no language/punctuate/diarize params.
"""
import asyncio
import json
import os
from pathlib import Path
from typing import Any
from urllib.parse import urlencode
from loguru import logger
from ..audio_streamer import AudioConfig, AudioStreamer
from .base import STTProvider, TranscriptionResult, Word
try:
from websockets.asyncio.client import connect as websocket_connect
except ImportError:
raise ImportError("websockets required: pip install websockets")
class DeepgramFluxProvider(STTProvider):
"""Deepgram Flux Speech-to-Text provider with WebSocket streaming.
Flux is optimized for conversational AI with built-in turn detection.
Key differences from Nova:
- Uses v2 API endpoint
- Only supports English (flux-general-en)
- No punctuate, diarize, or language params
- Has turn detection events (StartOfTurn, EndOfTurn, EagerEndOfTurn)
- Supports keyterm boosting
API Docs: https://developers.deepgram.com/docs/
"""
WS_URL = "wss://api.deepgram.com/v2/listen"
def __init__(self, api_key: str | None = None):
self.api_key = api_key or os.getenv("DEEPGRAM_API_KEY")
if not self.api_key:
raise ValueError(
"Deepgram API key required. Set DEEPGRAM_API_KEY env var or pass api_key."
)
@property
def name(self) -> str:
return "deepgram-flux"
async def transcribe(
self,
audio_path: Path,
diarize: bool = False, # Ignored - Flux doesn't support diarization
keyterms: list[str] | None = None,
model: str = "flux-general-en",
sample_rate: int = 16000,
eot_threshold: float | None = None,
eot_timeout_ms: int | None = None,
eager_eot_threshold: float | None = None,
**kwargs: Any,
) -> TranscriptionResult:
"""Transcribe audio using Deepgram Flux WebSocket streaming.
Args:
audio_path: Path to audio file
diarize: IGNORED - Flux does not support diarization
keyterms: List of keywords to boost recognition
model: Flux model (default: flux-general-en)
sample_rate: Audio sample rate (default: 16000 for Flux)
eot_threshold: End-of-turn confidence threshold (0-1, default 0.7)
eot_timeout_ms: Timeout in ms to force end of turn (default 5000)
eager_eot_threshold: Threshold for eager end-of-turn events
**kwargs: Additional Flux parameters
Returns:
TranscriptionResult with transcript (no speaker info - Flux doesn't support diarization)
"""
if diarize:
logger.warning("Flux does not support diarization - ignoring diarize=True")
# Build query params - Flux only supports specific params
params: dict[str, Any] = {
"model": model,
"encoding": "linear16",
"sample_rate": sample_rate,
}
# Flux-specific turn detection params
if eot_threshold is not None:
params["eot_threshold"] = eot_threshold
if eot_timeout_ms is not None:
params["eot_timeout_ms"] = eot_timeout_ms
if eager_eot_threshold is not None:
params["eager_eot_threshold"] = eager_eot_threshold
# Build URL with params
url_parts = [f"{k}={v}" for k, v in params.items()]
# Add keyterms (repeated params)
if keyterms:
for term in keyterms:
url_parts.append(urlencode({"keyterm": term}))
ws_url = f"{self.WS_URL}?{'&'.join(url_parts)}"
logger.debug(f"Flux WebSocket URL: {ws_url}")
# Setup audio streamer
audio_config = AudioConfig(sample_rate=sample_rate)
streamer = AudioStreamer(audio_config)
# Collect results
all_transcripts: list[dict[str, Any]] = []
final_transcript = ""
duration = 0.0
connected = asyncio.Event()
async with websocket_connect(
ws_url,
additional_headers={"Authorization": f"Token {self.api_key}"},
) as ws:
async def send_audio():
"""Send audio chunks to Deepgram Flux."""
await connected.wait()
chunk_no = 0
async for chunk in streamer.stream_file(audio_path):
logger.debug(f"[deepgram-flux] Sent audio chunk {chunk_no}")
await ws.send(chunk)
chunk_no += 1
async def receive_messages():
"""Receive and collect Flux messages."""
nonlocal all_transcripts, final_transcript, duration
async for message in ws:
if isinstance(message, str):
data = json.loads(message)
msg_type = data.get("type")
logger.debug(f"[deepgram-flux] Received {msg_type}: {data}")
if msg_type == "Connected":
logger.info("[deepgram-flux] Connected")
connected.set()
elif msg_type == "TurnInfo":
event = data.get("event")
transcript = data.get("transcript", "")
words = data.get("words", [])
if event == "EndOfTurn":
if transcript:
final_transcript += transcript + " "
if words:
all_transcripts.append({
"transcript": transcript,
"words": words,
})
# Get duration from last word
if words:
last_word = words[-1]
duration = max(duration, last_word.get("end", 0))
elif event == "TurnResumed":
logger.debug("TurnResumed")
elif msg_type == "Error":
raise Exception(f"Deepgram Flux error: {data}")
# Run send and receive concurrently
send_task = asyncio.create_task(send_audio())
receive_task = asyncio.create_task(receive_messages())
await send_task
logger.debug("[deepgram-flux] Send task done")
try:
await asyncio.wait_for(receive_task, timeout=10.0)
except asyncio.TimeoutError:
pass
return self._parse_results(
all_transcripts, final_transcript.strip(), duration, params, keyterms
)
def _parse_results(
self,
transcripts: list[dict[str, Any]],
final_transcript: str,
duration: float,
params: dict[str, Any],
keyterms: list[str] | None,
) -> TranscriptionResult:
"""Parse collected Flux results into TranscriptionResult."""
words = []
for turn in transcripts:
for w in turn.get("words", []):
words.append(
Word(
word=w.get("word", ""),
start=w.get("start", 0.0),
end=w.get("end", 0.0),
confidence=w.get("confidence", 0.0),
speaker=None, # Flux doesn't support diarization
speaker_confidence=None,
)
)
stored_params = dict(params)
if keyterms:
stored_params["keyterms"] = keyterms
return TranscriptionResult(
provider=self.name,
transcript=final_transcript,
words=words,
speakers=[], # Flux doesn't support diarization
duration=duration,
raw_response={"transcripts": transcripts},
params=stored_params,
)

View file

@ -0,0 +1,225 @@
"""Deepgram STT provider with WebSocket streaming."""
import asyncio
import json
import os
from pathlib import Path
from typing import Any
from urllib.parse import urlencode
from ..audio_streamer import AudioConfig, AudioStreamer
from .base import STTProvider, TranscriptionResult, Word
from loguru import logger
try:
from websockets.asyncio.client import connect as websocket_connect
except ImportError:
raise ImportError("websockets required: pip install websockets")
class DeepgramProvider(STTProvider):
"""Deepgram Nova Speech-to-Text provider with WebSocket streaming.
API Docs: https://developers.deepgram.com/docs/
Supports:
- Speaker diarization via `diarize=true`
- Keyterm boosting via `keyterm` parameter
- Real-time streaming via WebSocket
- Multiple languages
- Punctuation
For Flux models, use DeepgramFluxProvider instead.
"""
WS_URL = "wss://api.deepgram.com/v1/listen"
def __init__(self, api_key: str | None = None):
self.api_key = api_key or os.getenv("DEEPGRAM_API_KEY")
if not self.api_key:
raise ValueError(
"Deepgram API key required. Set DEEPGRAM_API_KEY env var or pass api_key."
)
@property
def name(self) -> str:
return "deepgram"
async def transcribe(
self,
audio_path: Path,
diarize: bool = False,
keyterms: list[str] | None = None,
model: str = "nova-3-general",
language: str = "en",
sample_rate: int = 8000,
punctuate: bool = True,
**kwargs: Any,
) -> TranscriptionResult:
"""Transcribe audio using Deepgram Nova WebSocket streaming.
Args:
audio_path: Path to audio file
diarize: Enable speaker diarization
keyterms: List of keywords to boost recognition
model: Deepgram Nova model (nova-3, nova-2, etc.)
language: Language code
sample_rate: Audio sample rate for streaming
punctuate: Add punctuation
**kwargs: Additional Deepgram parameters
Returns:
TranscriptionResult with transcript and speaker info
"""
# Build query params
params: dict[str, Any] = {
"model": model,
"language": language,
"punctuate": str(punctuate).lower(),
"encoding": "linear16",
"sample_rate": sample_rate,
"channels": 1,
"interim_results": "true",
"smart_format": "true",
"profanity_filter": "true",
"vad_events": "true"
}
if diarize:
params["diarize"] = "true"
# Build URL with params
url_parts = [f"{k}={v}" for k, v in params.items()]
# Add keyterms (repeated params)
if keyterms:
for term in keyterms:
url_parts.append(urlencode({"keyterm": term}))
# Add extra kwargs
for k, v in kwargs.items():
url_parts.append(f"{k}={v}")
ws_url = f"{self.WS_URL}?{'&'.join(url_parts)}"
logger.debug(f"Deepgram WebSocket URL: {ws_url}")
# Setup audio streamer
audio_config = AudioConfig(sample_rate=sample_rate)
streamer = AudioStreamer(audio_config)
# Collect results
all_words: list[dict[str, Any]] = []
final_transcript = ""
duration = 0.0
try:
async with websocket_connect(
ws_url,
additional_headers={"Authorization": f"Token {self.api_key}"},
) as ws:
# Create tasks for sending and receiving
send_complete = asyncio.Event()
async def send_audio():
"""Send audio chunks to Deepgram."""
chunk_no = 0
async for chunk in streamer.stream_file(audio_path):
logger.debug(f"[deepgram] Sent audio chunk {chunk_no}")
await ws.send(chunk)
chunk_no += 1
# Send close message
logger.debug(f"[deepgram] Sending CloseStream after {chunk_no} chunks")
await ws.send(json.dumps({"type": "CloseStream"}))
send_complete.set()
async def receive_transcripts():
"""Receive and collect transcription results."""
nonlocal all_words, final_transcript, duration
async for message in ws:
if isinstance(message, str):
data = json.loads(message)
msg_type = data.get("type")
logger.debug(f"[deepgram] Received {msg_type}: {data}")
if msg_type == "Results":
# Nova-style response
channel = data.get("channel", {})
alternatives = channel.get("alternatives", [])
if alternatives:
alt = alternatives[0]
words = alt.get("words", [])
all_words.extend(words)
# Check if final
if data.get("is_final"):
final_transcript += alt.get("transcript", "") + " "
duration = max(
duration, data.get("duration", 0) + data.get("start", 0)
)
elif msg_type == "Metadata":
# Get duration from metadata
duration = data.get("duration", duration)
elif msg_type == "Error":
raise Exception(f"Deepgram error: {data}")
# Run send and receive concurrently
send_task = asyncio.create_task(send_audio())
receive_task = asyncio.create_task(receive_transcripts())
# Wait for send to complete, then wait a bit for final results
await send_task
try:
await asyncio.wait_for(receive_task, timeout=5.0)
except asyncio.TimeoutError:
pass # Normal - websocket closes after final results
except Exception as e:
logger.debug(e)
return self._parse_results(
all_words, final_transcript.strip(), duration, params, keyterms
)
def _parse_results(
self,
raw_words: list[dict[str, Any]],
transcript: str,
duration: float,
params: dict[str, Any],
keyterms: list[str] | None,
) -> TranscriptionResult:
"""Parse collected results into TranscriptionResult."""
words = []
speakers_set: set[str] = set()
for w in raw_words:
speaker = str(w.get("speaker", "")) if "speaker" in w else None
if speaker:
speakers_set.add(speaker)
words.append(
Word(
word=w.get("word", ""),
start=w.get("start", 0.0),
end=w.get("end", 0.0),
confidence=w.get("confidence", 0.0),
speaker=speaker,
speaker_confidence=w.get("speaker_confidence"),
)
)
stored_params = dict(params)
if keyterms:
stored_params["keyterms"] = keyterms
return TranscriptionResult(
provider=self.name,
transcript=transcript,
words=words,
speakers=sorted(speakers_set),
duration=duration,
raw_response={"words": raw_words},
params=stored_params,
)

View file

@ -0,0 +1,285 @@
"""Local Smart Turn provider for benchmarking end-of-turn detection.
Uses the pipecat smart-turn-v3 ONNX model for local ML-based turn detection.
This is NOT an STT provider - it only detects when a speaker has finished talking.
"""
import os
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Any
import numpy as np
from loguru import logger
from ..audio_streamer import AudioConfig, AudioStreamer
from .base import STTProvider, TranscriptionResult, Word
try:
import onnxruntime as ort
from transformers import WhisperFeatureExtractor
except ImportError:
raise ImportError(
"onnxruntime and transformers required: pip install onnxruntime transformers"
)
@dataclass
class TurnEvent:
"""Represents a detected turn event."""
timestamp: float # Time in audio when turn was detected
probability: float # Model confidence
prediction: int # 1=complete, 0=incomplete
inference_time_ms: float
class LocalSmartTurnProvider(STTProvider):
"""Local Smart Turn provider for end-of-turn detection benchmarking.
Uses the smart-turn-v3 ONNX model to detect when speakers finish talking.
This is useful for comparing turn detection accuracy against cloud services
like Deepgram Flux's built-in turn detection.
NOTE: This provider does NOT produce transcripts - only turn detection events.
"""
# Smart turn model requires 16kHz audio
REQUIRED_SAMPLE_RATE = 16000
# Model analyzes 8 seconds of audio
WINDOW_SECONDS = 8
def __init__(
self,
model_path: str | None = None,
cpu_count: int = 1,
):
"""Initialize the local smart turn provider.
Args:
model_path: Path to ONNX model file. If None, uses bundled model.
cpu_count: Number of CPUs for inference (default: 1)
"""
self.model_path = model_path
self.cpu_count = cpu_count
self._session = None
self._feature_extractor = None
def _load_model(self):
"""Lazy load the ONNX model."""
if self._session is not None:
return
model_path = self.model_path
if not model_path:
# Try to load bundled model from pipecat
model_name = "smart-turn-v3.1-cpu.onnx"
package_path = "pipecat.audio.turn.smart_turn.data"
try:
import importlib_resources as impresources
model_path = str(impresources.files(package_path).joinpath(model_name))
except Exception:
from importlib import resources as impresources
try:
with impresources.path(package_path, model_name) as f:
model_path = str(f)
except Exception:
model_path = str(impresources.files(package_path).joinpath(model_name))
logger.info(f"[local-smart-turn] Loading model from {model_path}")
# Configure ONNX runtime
so = ort.SessionOptions()
so.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL
so.inter_op_num_threads = 1
so.intra_op_num_threads = self.cpu_count
so.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
self._feature_extractor = WhisperFeatureExtractor(chunk_length=8)
self._session = ort.InferenceSession(model_path, sess_options=so)
logger.info("[local-smart-turn] Model loaded")
@property
def name(self) -> str:
return "local-smart-turn"
def _predict_endpoint(self, audio_array: np.ndarray) -> dict[str, Any]:
"""Predict end-of-turn using the ONNX model.
Args:
audio_array: Audio samples as float32 numpy array (16kHz)
Returns:
Dict with prediction (0/1) and probability
"""
# Truncate to last 8 seconds or pad to 8 seconds
max_samples = self.WINDOW_SECONDS * self.REQUIRED_SAMPLE_RATE
if len(audio_array) > max_samples:
audio_array = audio_array[-max_samples:]
elif len(audio_array) < max_samples:
padding = max_samples - len(audio_array)
audio_array = np.pad(audio_array, (padding, 0), mode="constant", constant_values=0)
# Process using Whisper's feature extractor
inputs = self._feature_extractor(
audio_array,
sampling_rate=self.REQUIRED_SAMPLE_RATE,
return_tensors="np",
padding="max_length",
max_length=self.WINDOW_SECONDS * self.REQUIRED_SAMPLE_RATE,
truncation=True,
do_normalize=True,
)
# Extract features for ONNX
input_features = inputs.input_features.squeeze(0).astype(np.float32)
input_features = np.expand_dims(input_features, axis=0)
# Run inference
start_time = time.perf_counter()
outputs = self._session.run(None, {"input_features": input_features})
inference_time = (time.perf_counter() - start_time) * 1000
# Extract probability (model returns sigmoid probabilities)
probability = outputs[0][0].item()
prediction = 1 if probability > 0.5 else 0
return {
"prediction": prediction,
"probability": probability,
"inference_time_ms": inference_time,
}
async def transcribe(
self,
audio_path: Path,
diarize: bool = False, # Ignored - not applicable
keyterms: list[str] | None = None, # Ignored - not applicable
sample_rate: int = 16000, # Must be 16kHz for smart turn
analysis_interval_ms: int = 500, # How often to check for turn completion
**kwargs: Any,
) -> TranscriptionResult:
"""Analyze audio for turn detection events.
NOTE: This does NOT produce transcripts. It detects when speakers
finish talking using ML-based turn detection.
Args:
audio_path: Path to audio file
diarize: Ignored (not applicable for turn detection)
keyterms: Ignored (not applicable for turn detection)
sample_rate: Must be 16000 Hz for smart turn model
analysis_interval_ms: How often to run turn detection (ms)
**kwargs: Additional parameters (ignored)
Returns:
TranscriptionResult with turn detection events in raw_response
"""
if sample_rate != self.REQUIRED_SAMPLE_RATE:
logger.warning(
f"[local-smart-turn] Sample rate must be {self.REQUIRED_SAMPLE_RATE}Hz, "
f"overriding {sample_rate}Hz"
)
sample_rate = self.REQUIRED_SAMPLE_RATE
# Load model if not already loaded
self._load_model()
# Setup audio streamer at 16kHz
audio_config = AudioConfig(sample_rate=sample_rate)
streamer = AudioStreamer(audio_config)
# Get audio duration
duration = streamer.get_duration(audio_path)
logger.info(f"[local-smart-turn] Processing {audio_path} ({duration:.2f}s)")
# Collect all audio first (smart turn needs to analyze segments)
pcm_data = streamer.convert_to_pcm16(audio_path)
# Convert to float32 for model
audio_int16 = np.frombuffer(pcm_data, dtype=np.int16)
audio_float32 = audio_int16.astype(np.float32) / 32768.0
# Analyze at intervals
turn_events: list[TurnEvent] = []
samples_per_interval = int(sample_rate * analysis_interval_ms / 1000)
window_samples = self.WINDOW_SECONDS * sample_rate
chunk_no = 0
for end_sample in range(samples_per_interval, len(audio_float32), samples_per_interval):
# Get window of audio ending at current position
start_sample = max(0, end_sample - window_samples)
audio_window = audio_float32[start_sample:end_sample]
current_time = end_sample / sample_rate
logger.debug(f"[local-smart-turn] Analyzing chunk {chunk_no} at {current_time:.2f}s")
result = self._predict_endpoint(audio_window)
turn_events.append(TurnEvent(
timestamp=current_time,
probability=result["probability"],
prediction=result["prediction"],
inference_time_ms=result["inference_time_ms"],
))
if result["prediction"] == 1:
logger.info(
f"[local-smart-turn] Turn complete at {current_time:.2f}s "
f"(prob={result['probability']:.3f})"
f"(inf time ms={result["inference_time_ms"]})"
)
chunk_no += 1
# Create result
# Convert turn events to word-like format for compatibility
words = []
for event in turn_events:
if event.prediction == 1:
words.append(Word(
word=f"[END_OF_TURN prob={event.probability:.2f}]",
start=event.timestamp - 0.1,
end=event.timestamp,
confidence=event.probability,
speaker=None,
speaker_confidence=None,
))
# Count completed turns
completed_turns = sum(1 for e in turn_events if e.prediction == 1)
params = {
"sample_rate": sample_rate,
"analysis_interval_ms": analysis_interval_ms,
"window_seconds": self.WINDOW_SECONDS,
}
return TranscriptionResult(
provider=self.name,
transcript=f"[Turn detection only - {completed_turns} turns detected]",
words=words,
speakers=[], # Not applicable
duration=duration,
raw_response={
"turn_events": [
{
"timestamp": e.timestamp,
"probability": e.probability,
"prediction": e.prediction,
"inference_time_ms": e.inference_time_ms,
}
for e in turn_events
],
"completed_turns": completed_turns,
"total_analyses": len(turn_events),
"avg_inference_time_ms": (
sum(e.inference_time_ms for e in turn_events) / len(turn_events)
if turn_events else 0
),
},
params=params,
)

View file

@ -0,0 +1,248 @@
"""Speechmatics STT provider with WebSocket streaming."""
import asyncio
import json
import os
from pathlib import Path
from typing import Any
from loguru import logger
from ..audio_streamer import AudioConfig, AudioStreamer
from .base import STTProvider, TranscriptionResult, Word
try:
from websockets.asyncio.client import connect as websocket_connect
except ImportError:
raise ImportError("websockets required: pip install websockets")
class SpeechmaticsProvider(STTProvider):
"""Speechmatics Speech-to-Text provider with WebSocket streaming.
API Docs: https://docs.speechmatics.com/
Supports:
- Speaker diarization via `diarization: "speaker"` config
- Speaker sensitivity tuning
- Real-time streaming via WebSocket
"""
def __init__(self, api_key: str | None = None, region: str = "eu2"):
self.api_key = api_key or os.getenv("SPEECHMATICS_API_KEY")
if not self.api_key:
raise ValueError(
"Speechmatics API key required. Set SPEECHMATICS_API_KEY env var or pass api_key."
)
# Set region-specific endpoint
self.ws_url = f"wss://{region}.rt.speechmatics.com/v2"
@property
def name(self) -> str:
return "speechmatics"
async def transcribe(
self,
audio_path: Path,
diarize: bool = False,
keyterms: list[str] | None = None,
language: str = "en",
operating_point: str = "enhanced",
sample_rate: int = 8000,
speaker_sensitivity: float | None = None,
max_speakers: int | None = None,
**kwargs: Any,
) -> TranscriptionResult:
"""Transcribe audio using Speechmatics WebSocket streaming.
Args:
audio_path: Path to audio file
diarize: Enable speaker diarization
keyterms: Additional vocabulary (limited support)
language: Language code
operating_point: "standard" or "enhanced"
sample_rate: Audio sample rate for streaming
speaker_sensitivity: 0.0-1.0, higher = more speakers detected
max_speakers: Maximum number of speakers to detect
**kwargs: Additional config parameters
Returns:
TranscriptionResult with transcript and speaker info
"""
# Build transcription config for StartRecognition message
transcription_config: dict[str, Any] = {
"language": language,
"operating_point": operating_point,
"enable_partials": False,
}
if diarize:
transcription_config["diarization"] = "speaker"
if speaker_sensitivity is not None:
transcription_config["speaker_diarization_config"] = {
"speaker_sensitivity": speaker_sensitivity
}
if max_speakers is not None:
if "speaker_diarization_config" not in transcription_config:
transcription_config["speaker_diarization_config"] = {}
transcription_config["speaker_diarization_config"]["max_speakers"] = max_speakers
# Add additional vocabulary if provided
if keyterms:
transcription_config["additional_vocab"] = [{"content": term} for term in keyterms]
# Audio format config
audio_format = {
"type": "raw",
"encoding": "pcm_s16le",
"sample_rate": sample_rate,
}
# Store params for result
params = {
"diarize": diarize,
"language": language,
"operating_point": operating_point,
"sample_rate": sample_rate,
"speaker_sensitivity": speaker_sensitivity,
"max_speakers": max_speakers,
}
# Setup audio streamer
audio_config = AudioConfig(sample_rate=sample_rate)
streamer = AudioStreamer(audio_config)
# Collect results
all_results: list[dict[str, Any]] = []
recognition_started = asyncio.Event()
transcription_complete = asyncio.Event()
async with websocket_connect(
self.ws_url,
additional_headers={"Authorization": f"Bearer {self.api_key}"},
) as ws:
# Send StartRecognition message
start_msg = {
"message": "StartRecognition",
"transcription_config": transcription_config,
"audio_format": audio_format,
}
await ws.send(json.dumps(start_msg))
async def send_audio():
"""Send audio chunks after recognition starts."""
await recognition_started.wait()
chunk_no = 0
async for chunk in streamer.stream_file(audio_path):
logger.debug(f"[speechmatics] Sent audio chunk {chunk_no}")
await ws.send(chunk)
chunk_no += 1
# Signal end of audio with last sequence number
logger.debug(f"[speechmatics] Sending EndOfStream after {chunk_no} chunks")
await ws.send(json.dumps({"message": "EndOfStream", "last_seq_no": chunk_no}))
async def receive_messages():
"""Receive and process messages."""
nonlocal all_results
async for message in ws:
if isinstance(message, str):
data = json.loads(message)
msg_type = data.get("message")
logger.debug(f"[speechmatics] Received {msg_type}: {data}")
if msg_type == "RecognitionStarted":
logger.info("[speechmatics] Connected")
recognition_started.set()
elif msg_type == "AddTranscript":
# Final transcript segment
results = data.get("results", [])
all_results.extend(results)
elif msg_type == "EndOfTranscript":
transcription_complete.set()
return
elif msg_type == "Error":
raise Exception(f"Speechmatics error: {data}")
elif msg_type == "Warning":
logger.warning(f"[speechmatics] Warning: {data.get('reason')}")
# Run send and receive concurrently
send_task = asyncio.create_task(send_audio())
receive_task = asyncio.create_task(receive_messages())
# Wait for completion
await send_task
try:
await asyncio.wait_for(transcription_complete.wait(), timeout=30.0)
except asyncio.TimeoutError:
pass
receive_task.cancel()
try:
await receive_task
except asyncio.CancelledError:
pass
return self._parse_results(all_results, params)
def _parse_results(
self,
results: list[dict[str, Any]],
params: dict[str, Any],
) -> TranscriptionResult:
"""Parse Speechmatics results."""
words = []
speakers_set: set[str] = set()
transcript_parts = []
duration = 0.0
for item in results:
item_type = item.get("type")
alternatives = item.get("alternatives", [])
if not alternatives:
continue
alt = alternatives[0]
content = alt.get("content", "")
speaker = alt.get("speaker")
if speaker:
speakers_set.add(speaker)
end_time = item.get("end_time", 0.0)
duration = max(duration, end_time)
if item_type == "word":
words.append(
Word(
word=content,
start=item.get("start_time", 0.0),
end=end_time,
confidence=alt.get("confidence", 0.0),
speaker=speaker,
speaker_confidence=None,
)
)
transcript_parts.append(content)
elif item_type == "punctuation":
if transcript_parts:
transcript_parts[-1] += content
transcript = " ".join(transcript_parts)
return TranscriptionResult(
provider=self.name,
transcript=transcript,
words=words,
speakers=sorted(speakers_set),
duration=duration,
raw_response={"results": results},
params=params,
)