diff --git a/evals/stt/README.md b/evals/stt/README.md new file mode 100644 index 0000000..6e2802b --- /dev/null +++ b/evals/stt/README.md @@ -0,0 +1,135 @@ +# STT Evaluation Benchmark + +Benchmark for comparing Speech-to-Text providers using **WebSocket streaming** with focus on: +- **Speaker diarization** - identifying who said what +- **Keyterm boosting** - improving recognition of specific terms (Deepgram) + +## Providers + +| Provider | Diarization | Keyterm Boost | Streaming | +|----------|-------------|---------------|-----------| +| Deepgram | Yes | Yes | WebSocket (v1/v2) | +| Speechmatics | Yes | Additional vocab | WebSocket RT | + +## Setup + +```bash +# Install dependencies +pip install websockets + +# Set API keys +export DEEPGRAM_API_KEY="your-key" +export SPEECHMATICS_API_KEY="your-key" +``` + +**Note:** Requires `ffmpeg` installed for audio conversion to PCM16. + +## Usage + +Run from the project root directory: + +```bash +# Test both providers with diarization +python -m evals.stt.benchmark audio/multi_speaker.m4a --diarize + +# Test only Deepgram +python -m evals.stt.benchmark audio/multi_speaker.m4a --diarize --providers deepgram + +# Test with keyterm boosting (Deepgram) +python -m evals.stt.benchmark audio/multi_speaker.m4a --diarize --keyterms "Dograh" "Pipecat" + +# Use different sample rate (default: 8000 Hz) +python -m evals.stt.benchmark audio/multi_speaker.m4a --diarize --sample-rate 16000 + +# Show word-level timings +python -m evals.stt.benchmark audio/multi_speaker.m4a --diarize --show-words + +# Save results to JSON +python -m evals.stt.benchmark audio/multi_speaker.m4a --diarize --save +``` + +## CLI Options + +| Option | Description | +|--------|-------------| +| `audio_file` | Path to audio file (relative to evals/stt/ or absolute) | +| `--providers` | Providers to test: `deepgram`, `speechmatics` (default: both) | +| `--diarize` | Enable speaker diarization | +| `--keyterms` | Keywords to boost (Deepgram) / additional vocab (Speechmatics) | +| `--language` | Language code (default: en) | +| `--sample-rate` | Audio sample rate for streaming (default: 8000) | +| `--show-words` | Show individual word timings | +| `--save` | Save results to JSON in `results/` | + +## Directory Structure + +``` +evals/stt/ +├── audio/ # Audio test files +│ └── multi_speaker.m4a +├── results/ # Saved benchmark results (JSON) +├── providers/ # STT provider implementations +│ ├── base.py # Base classes +│ ├── deepgram_provider.py # WebSocket streaming +│ └── speechmatics_provider.py # WebSocket streaming +├── audio_streamer.py # PCM16 audio file streamer +├── benchmark.py # Main runner script +└── README.md +``` + +## How It Works + +1. **Audio Conversion**: The `AudioStreamer` converts any audio file to raw PCM16 using ffmpeg +2. **WebSocket Connection**: Providers connect to their respective WebSocket APIs +3. **Streaming**: Audio is sent in chunks (configurable sample rate, default 8kHz) +4. **Result Collection**: Transcripts and speaker info are collected from WebSocket responses +5. **Comparison**: Results are parsed into a common format for comparison + +## Output Example + +``` +Audio file: /path/to/audio/multi_speaker.m4a +Providers: ['deepgram', 'speechmatics'] +Diarization: True +Sample rate: 8000 Hz + +============================================================ +Provider: DEEPGRAM +============================================================ + +Duration: 45.32s +Speakers detected: 2 - ['0', '1'] + +Transcript: +Hello, welcome to the demo... + +--- Speaker Segments --- +[0.0s] Speaker 0: Hello, welcome to the demo. +[2.5s] Speaker 1: Thanks for having me. +... + +============================================================ +COMPARISON SUMMARY +============================================================ + +Provider Duration Speakers Words +--------------------------------------------- +deepgram 45.32 2 312 +speechmatics 45.32 2 308 +``` + +## Adding New Providers + +1. Create a new file in `providers/` (e.g., `whisper_provider.py`) +2. Implement the `STTProvider` abstract class with WebSocket streaming +3. Use `AudioStreamer` for PCM16 conversion +4. Add to `providers/__init__.py` +5. Add to `benchmark.py` provider choices + +## API Documentation + +- Deepgram Streaming: https://developers.deepgram.com/docs/live-streaming-audio +- Deepgram Diarization: https://developers.deepgram.com/docs/diarization +- Deepgram Keyterms: https://developers.deepgram.com/docs/keyterm +- Speechmatics RT API: https://docs.speechmatics.com/rt-api-ref +- Speechmatics Diarization: https://docs.speechmatics.com/features/diarization diff --git a/evals/stt/__init__.py b/evals/stt/__init__.py new file mode 100644 index 0000000..8c12ae8 --- /dev/null +++ b/evals/stt/__init__.py @@ -0,0 +1 @@ +# STT Evaluation Benchmark diff --git a/evals/stt/audio/multi_speaker.m4a b/evals/stt/audio/multi_speaker.m4a new file mode 100644 index 0000000..d1ab1f5 Binary files /dev/null and b/evals/stt/audio/multi_speaker.m4a differ diff --git a/evals/stt/audio/vad.m4a b/evals/stt/audio/vad.m4a new file mode 100644 index 0000000..e0488c7 Binary files /dev/null and b/evals/stt/audio/vad.m4a differ diff --git a/evals/stt/audio_streamer.py b/evals/stt/audio_streamer.py new file mode 100644 index 0000000..2eebc2d --- /dev/null +++ b/evals/stt/audio_streamer.py @@ -0,0 +1,127 @@ +"""Audio file streamer - converts audio files to PCM16 streams.""" + +import asyncio +import subprocess +from dataclasses import dataclass +from pathlib import Path +from typing import AsyncIterator + + +@dataclass +class AudioConfig: + """Audio streaming configuration.""" + + sample_rate: int = 8000 + channels: int = 1 + sample_width: int = 2 # 16-bit = 2 bytes + chunk_duration_ms: int = 80 # Send chunks every 80ms + + @property + def chunk_size(self) -> int: + """Bytes per chunk based on duration.""" + samples_per_chunk = int(self.sample_rate * self.chunk_duration_ms / 1000) + return samples_per_chunk * self.channels * self.sample_width + + +class AudioStreamer: + """Streams audio files as PCM16 chunks. + + Converts any audio format to raw PCM16 using ffmpeg and streams + in real-time chunks to simulate live audio. + """ + + def __init__(self, config: AudioConfig | None = None): + self.config = config or AudioConfig() + + def convert_to_pcm16(self, audio_path: Path) -> bytes: + """Convert audio file to raw PCM16 bytes using ffmpeg. + + Args: + audio_path: Path to input audio file + + Returns: + Raw PCM16 audio bytes + """ + cmd = [ + "ffmpeg", + "-i", + str(audio_path), + "-f", + "s16le", # signed 16-bit little-endian + "-acodec", + "pcm_s16le", + "-ar", + str(self.config.sample_rate), + "-ac", + str(self.config.channels), + "-", # output to stdout + ] + + result = subprocess.run( + cmd, + capture_output=True, + check=True, + ) + return result.stdout + + async def stream_file( + self, + audio_path: Path, + realtime: bool = True, + ) -> AsyncIterator[bytes]: + """Stream audio file as PCM16 chunks. + + Args: + audio_path: Path to audio file + realtime: If True, add delays to simulate real-time streaming + + Yields: + PCM16 audio chunks + """ + # Convert entire file to PCM16 + pcm_data = self.convert_to_pcm16(audio_path) + + chunk_size = self.config.chunk_size + delay = self.config.chunk_duration_ms / 1000.0 if realtime else 0 + + # Stream in chunks + for i in range(0, len(pcm_data), chunk_size): + chunk = pcm_data[i : i + chunk_size] + if chunk: + yield chunk + if realtime and delay > 0: + await asyncio.sleep(delay) + + async def stream_file_fast(self, audio_path: Path) -> AsyncIterator[bytes]: + """Stream audio file as fast as possible (no real-time delay). + + Args: + audio_path: Path to audio file + + Yields: + PCM16 audio chunks + """ + async for chunk in self.stream_file(audio_path, realtime=False): + yield chunk + + def get_duration(self, audio_path: Path) -> float: + """Get audio file duration in seconds. + + Args: + audio_path: Path to audio file + + Returns: + Duration in seconds + """ + cmd = [ + "ffprobe", + "-v", + "error", + "-show_entries", + "format=duration", + "-of", + "default=noprint_wrappers=1:nokey=1", + str(audio_path), + ] + result = subprocess.run(cmd, capture_output=True, text=True, check=True) + return float(result.stdout.strip()) diff --git a/evals/stt/benchmark.py b/evals/stt/benchmark.py new file mode 100644 index 0000000..7740cee --- /dev/null +++ b/evals/stt/benchmark.py @@ -0,0 +1,247 @@ +#!/usr/bin/env python3 +"""STT Benchmark Runner. + +Compare speech-to-text transcription across providers with focus on: +- Speaker diarization accuracy +- Keyword/keyterm recognition +- Transcription quality + +Usage: + python -m evals.stt.benchmark audio/multi_speaker.m4a --diarize + python -m evals.stt.benchmark audio/multi_speaker.m4a --diarize --providers deepgram + python -m evals.stt.benchmark audio/multi_speaker.m4a --diarize --keyterms "Dograh" "Pipecat" +""" + +import argparse +import asyncio +import json +import sys +from datetime import datetime +from pathlib import Path +from typing import Any + +from evals.stt.providers import ( + DeepgramProvider, + DeepgramFluxProvider, + SpeechmaticsProvider, + LocalSmartTurnProvider, + STTProvider, + TranscriptionResult, +) + + +def get_provider(name: str) -> STTProvider: + """Get provider instance by name.""" + providers = { + "deepgram": DeepgramProvider, + "deepgram-flux": DeepgramFluxProvider, + "speechmatics": SpeechmaticsProvider, + "local-smart-turn": LocalSmartTurnProvider, + } + if name not in providers: + raise ValueError(f"Unknown provider: {name}. Available: {list(providers.keys())}") + return providers[name]() + + +async def run_transcription( + provider: STTProvider, + audio_path: Path, + diarize: bool = False, + keyterms: list[str] | None = None, + **kwargs: Any, +) -> TranscriptionResult: + """Run transcription with a provider.""" + print(f"\n{'='*60}") + print(f"Provider: {provider.name.upper()}") + print(f"{'='*60}") + + try: + result = await provider.transcribe( + audio_path, + diarize=diarize, + keyterms=keyterms, + **kwargs, + ) + return result + except Exception as e: + print(f"Error with {provider.name}: {e}") + raise + + +def print_result(result: TranscriptionResult, show_words: bool = False) -> None: + """Print transcription result.""" + print(f"\nDuration: {result.duration:.2f}s") + print(f"Speakers detected: {len(result.speakers)} - {result.speakers}") + print(f"\nTranscript:\n{result.transcript}") + + if result.speakers: + print(f"\n--- Speaker Segments ---") + for segment in result.get_speaker_segments(): + speaker = segment["speaker"] or "?" + text = segment["text"] + start = segment["start"] + print(f"[{start:.1f}s] Speaker {speaker}: {text}") + + if show_words: + print(f"\n--- Words ---") + for word in result.words[:50]: # First 50 words + speaker_info = f" (S{word.speaker})" if word.speaker else "" + print(f" {word.start:.2f}s: {word.word}{speaker_info} [{word.confidence:.2f}]") + if len(result.words) > 50: + print(f" ... and {len(result.words) - 50} more words") + + +def save_results( + results: list[TranscriptionResult], + output_dir: Path, + audio_name: str, +) -> Path: + """Save results to JSON file.""" + output_dir.mkdir(parents=True, exist_ok=True) + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + output_file = output_dir / f"{audio_name}_{timestamp}.json" + + output_data = { + "timestamp": timestamp, + "audio_file": audio_name, + "results": [r.to_dict() for r in results], + } + + with open(output_file, "w") as f: + json.dump(output_data, f, indent=2) + + print(f"\nResults saved to: {output_file}") + return output_file + + +def compare_results(results: list[TranscriptionResult]) -> None: + """Compare results across providers.""" + if len(results) < 2: + return + + print(f"\n{'='*60}") + print("COMPARISON SUMMARY") + print(f"{'='*60}") + + print(f"\n{'Provider':<15} {'Duration':<10} {'Speakers':<10} {'Words':<10}") + print("-" * 45) + for r in results: + print(f"{r.provider:<15} {r.duration:<10.2f} {len(r.speakers):<10} {len(r.words):<10}") + + # Compare speaker counts + speaker_counts = {r.provider: len(r.speakers) for r in results} + if len(set(speaker_counts.values())) > 1: + print(f"\nNote: Providers detected different speaker counts: {speaker_counts}") + + +async def main() -> int: + parser = argparse.ArgumentParser( + description="STT Benchmark - Compare transcription providers", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + python -m evals.stt.benchmark audio/multi_speaker.m4a --diarize + python -m evals.stt.benchmark audio/multi_speaker.m4a --diarize --providers deepgram + python -m evals.stt.benchmark audio/multi_speaker.m4a --keyterms "Dograh" "API" + """, + ) + parser.add_argument( + "audio_file", + type=str, + help="Path to audio file (relative to evals/stt/ or absolute)", + ) + parser.add_argument( + "--providers", + nargs="+", + default=["deepgram", "speechmatics"], + choices=["deepgram", "deepgram-flux", "speechmatics", "local-smart-turn"], + help="Providers to test (default: all)", + ) + parser.add_argument( + "--diarize", + action="store_true", + help="Enable speaker diarization", + ) + parser.add_argument( + "--keyterms", + nargs="+", + help="Keywords to boost (Deepgram only)", + ) + parser.add_argument( + "--language", + default="en", + help="Language code (default: en)", + ) + parser.add_argument( + "--sample-rate", + type=int, + default=8000, + help="Audio sample rate for streaming (default: 8000)", + ) + parser.add_argument( + "--show-words", + action="store_true", + help="Show individual word timings", + ) + parser.add_argument( + "--save", + action="store_true", + help="Save results to JSON file", + ) + parser.add_argument( + "--output-dir", + type=str, + default="results", + help="Output directory for results (default: results)", + ) + + args = parser.parse_args() + + # Resolve audio path + script_dir = Path(__file__).parent + audio_path = Path(args.audio_file) + if not audio_path.is_absolute(): + audio_path = script_dir / audio_path + + if not audio_path.exists(): + print(f"Error: Audio file not found: {audio_path}") + return 1 + + print(f"Audio file: {audio_path}") + print(f"Providers: {args.providers}") + print(f"Diarization: {args.diarize}") + print(f"Sample rate: {args.sample_rate} Hz") + if args.keyterms: + print(f"Keyterms: {args.keyterms}") + + results: list[TranscriptionResult] = [] + + for provider_name in args.providers: + try: + provider = get_provider(provider_name) + result = await run_transcription( + provider, + audio_path, + diarize=args.diarize, + keyterms=args.keyterms, + language=args.language, + sample_rate=args.sample_rate, + ) + print_result(result, show_words=args.show_words) + results.append(result) + except Exception as e: + print(f"\nFailed to run {provider_name}: {e}") + continue + + if len(results) > 1: + compare_results(results) + + if args.save and results: + output_dir = script_dir / args.output_dir + save_results(results, output_dir, audio_path.stem) + + return 0 + + +if __name__ == "__main__": + sys.exit(asyncio.run(main())) diff --git a/evals/stt/providers/__init__.py b/evals/stt/providers/__init__.py new file mode 100644 index 0000000..12bf3ab --- /dev/null +++ b/evals/stt/providers/__init__.py @@ -0,0 +1,15 @@ +from .base import STTProvider, TranscriptionResult, Word +from .deepgram_provider import DeepgramProvider +from .deepgram_flux_provider import DeepgramFluxProvider +from .speechmatics_provider import SpeechmaticsProvider +from .local_smart_turn_provider import LocalSmartTurnProvider + +__all__ = [ + "STTProvider", + "TranscriptionResult", + "Word", + "DeepgramProvider", + "DeepgramFluxProvider", + "SpeechmaticsProvider", + "LocalSmartTurnProvider", +] diff --git a/evals/stt/providers/base.py b/evals/stt/providers/base.py new file mode 100644 index 0000000..cdb4af9 --- /dev/null +++ b/evals/stt/providers/base.py @@ -0,0 +1,123 @@ +"""Base classes for STT providers.""" + +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any + + +@dataclass +class Word: + """Represents a transcribed word with metadata.""" + + word: str + start: float + end: float + confidence: float + speaker: str | None = None + speaker_confidence: float | None = None + + def to_dict(self) -> dict[str, Any]: + return { + "word": self.word, + "start": self.start, + "end": self.end, + "confidence": self.confidence, + "speaker": self.speaker, + "speaker_confidence": self.speaker_confidence, + } + + +@dataclass +class TranscriptionResult: + """Result from STT transcription.""" + + provider: str + transcript: str + words: list[Word] + speakers: list[str] + duration: float + raw_response: dict[str, Any] = field(default_factory=dict) + params: dict[str, Any] = field(default_factory=dict) + + def to_dict(self) -> dict[str, Any]: + return { + "provider": self.provider, + "transcript": self.transcript, + "words": [w.to_dict() for w in self.words], + "speakers": self.speakers, + "duration": self.duration, + "params": self.params, + } + + def get_speaker_segments(self) -> list[dict[str, Any]]: + """Get transcript segmented by speaker.""" + if not self.words: + return [] + + segments = [] + current_speaker = None + current_text = [] + segment_start = 0.0 + + for word in self.words: + if word.speaker != current_speaker: + if current_text: + segments.append( + { + "speaker": current_speaker, + "text": " ".join(current_text), + "start": segment_start, + "end": self.words[len(segments) - 1].end + if segments + else word.start, + } + ) + current_speaker = word.speaker + current_text = [word.word] + segment_start = word.start + else: + current_text.append(word.word) + + if current_text: + segments.append( + { + "speaker": current_speaker, + "text": " ".join(current_text), + "start": segment_start, + "end": self.words[-1].end if self.words else 0.0, + } + ) + + return segments + + +class STTProvider(ABC): + """Abstract base class for STT providers.""" + + @property + @abstractmethod + def name(self) -> str: + """Provider name.""" + pass + + @abstractmethod + async def transcribe( + self, + audio_path: Path, + diarize: bool = False, + keyterms: list[str] | None = None, + **kwargs: Any, + ) -> TranscriptionResult: + """Transcribe audio file. + + Args: + audio_path: Path to the audio file + diarize: Enable speaker diarization + keyterms: List of keywords to boost (if supported) + **kwargs: Provider-specific parameters + + Returns: + TranscriptionResult with transcript and metadata + """ + pass diff --git a/evals/stt/providers/deepgram_flux_provider.py b/evals/stt/providers/deepgram_flux_provider.py new file mode 100644 index 0000000..bd7d16d --- /dev/null +++ b/evals/stt/providers/deepgram_flux_provider.py @@ -0,0 +1,225 @@ +"""Deepgram Flux STT provider with WebSocket streaming. + +Flux is Deepgram's conversational AI model with built-in turn detection. +It has a different API than Nova models - no language/punctuate/diarize params. +""" + +import asyncio +import json +import os +from pathlib import Path +from typing import Any +from urllib.parse import urlencode + +from loguru import logger + +from ..audio_streamer import AudioConfig, AudioStreamer +from .base import STTProvider, TranscriptionResult, Word + +try: + from websockets.asyncio.client import connect as websocket_connect +except ImportError: + raise ImportError("websockets required: pip install websockets") + + +class DeepgramFluxProvider(STTProvider): + """Deepgram Flux Speech-to-Text provider with WebSocket streaming. + + Flux is optimized for conversational AI with built-in turn detection. + + Key differences from Nova: + - Uses v2 API endpoint + - Only supports English (flux-general-en) + - No punctuate, diarize, or language params + - Has turn detection events (StartOfTurn, EndOfTurn, EagerEndOfTurn) + - Supports keyterm boosting + + API Docs: https://developers.deepgram.com/docs/ + """ + + WS_URL = "wss://api.deepgram.com/v2/listen" + + def __init__(self, api_key: str | None = None): + self.api_key = api_key or os.getenv("DEEPGRAM_API_KEY") + if not self.api_key: + raise ValueError( + "Deepgram API key required. Set DEEPGRAM_API_KEY env var or pass api_key." + ) + + @property + def name(self) -> str: + return "deepgram-flux" + + async def transcribe( + self, + audio_path: Path, + diarize: bool = False, # Ignored - Flux doesn't support diarization + keyterms: list[str] | None = None, + model: str = "flux-general-en", + sample_rate: int = 16000, + eot_threshold: float | None = None, + eot_timeout_ms: int | None = None, + eager_eot_threshold: float | None = None, + **kwargs: Any, + ) -> TranscriptionResult: + """Transcribe audio using Deepgram Flux WebSocket streaming. + + Args: + audio_path: Path to audio file + diarize: IGNORED - Flux does not support diarization + keyterms: List of keywords to boost recognition + model: Flux model (default: flux-general-en) + sample_rate: Audio sample rate (default: 16000 for Flux) + eot_threshold: End-of-turn confidence threshold (0-1, default 0.7) + eot_timeout_ms: Timeout in ms to force end of turn (default 5000) + eager_eot_threshold: Threshold for eager end-of-turn events + **kwargs: Additional Flux parameters + + Returns: + TranscriptionResult with transcript (no speaker info - Flux doesn't support diarization) + """ + if diarize: + logger.warning("Flux does not support diarization - ignoring diarize=True") + + # Build query params - Flux only supports specific params + params: dict[str, Any] = { + "model": model, + "encoding": "linear16", + "sample_rate": sample_rate, + } + + # Flux-specific turn detection params + if eot_threshold is not None: + params["eot_threshold"] = eot_threshold + if eot_timeout_ms is not None: + params["eot_timeout_ms"] = eot_timeout_ms + if eager_eot_threshold is not None: + params["eager_eot_threshold"] = eager_eot_threshold + + # Build URL with params + url_parts = [f"{k}={v}" for k, v in params.items()] + + # Add keyterms (repeated params) + if keyterms: + for term in keyterms: + url_parts.append(urlencode({"keyterm": term})) + + ws_url = f"{self.WS_URL}?{'&'.join(url_parts)}" + logger.debug(f"Flux WebSocket URL: {ws_url}") + + # Setup audio streamer + audio_config = AudioConfig(sample_rate=sample_rate) + streamer = AudioStreamer(audio_config) + + # Collect results + all_transcripts: list[dict[str, Any]] = [] + final_transcript = "" + duration = 0.0 + connected = asyncio.Event() + + async with websocket_connect( + ws_url, + additional_headers={"Authorization": f"Token {self.api_key}"}, + ) as ws: + + async def send_audio(): + """Send audio chunks to Deepgram Flux.""" + await connected.wait() + + chunk_no = 0 + async for chunk in streamer.stream_file(audio_path): + logger.debug(f"[deepgram-flux] Sent audio chunk {chunk_no}") + await ws.send(chunk) + chunk_no += 1 + + async def receive_messages(): + """Receive and collect Flux messages.""" + nonlocal all_transcripts, final_transcript, duration + + async for message in ws: + if isinstance(message, str): + data = json.loads(message) + msg_type = data.get("type") + logger.debug(f"[deepgram-flux] Received {msg_type}: {data}") + + if msg_type == "Connected": + logger.info("[deepgram-flux] Connected") + connected.set() + + elif msg_type == "TurnInfo": + event = data.get("event") + transcript = data.get("transcript", "") + words = data.get("words", []) + + if event == "EndOfTurn": + if transcript: + final_transcript += transcript + " " + if words: + all_transcripts.append({ + "transcript": transcript, + "words": words, + }) + # Get duration from last word + if words: + last_word = words[-1] + duration = max(duration, last_word.get("end", 0)) + + elif event == "TurnResumed": + logger.debug("TurnResumed") + + elif msg_type == "Error": + raise Exception(f"Deepgram Flux error: {data}") + + # Run send and receive concurrently + send_task = asyncio.create_task(send_audio()) + receive_task = asyncio.create_task(receive_messages()) + + await send_task + + logger.debug("[deepgram-flux] Send task done") + try: + await asyncio.wait_for(receive_task, timeout=10.0) + except asyncio.TimeoutError: + pass + + return self._parse_results( + all_transcripts, final_transcript.strip(), duration, params, keyterms + ) + + def _parse_results( + self, + transcripts: list[dict[str, Any]], + final_transcript: str, + duration: float, + params: dict[str, Any], + keyterms: list[str] | None, + ) -> TranscriptionResult: + """Parse collected Flux results into TranscriptionResult.""" + words = [] + + for turn in transcripts: + for w in turn.get("words", []): + words.append( + Word( + word=w.get("word", ""), + start=w.get("start", 0.0), + end=w.get("end", 0.0), + confidence=w.get("confidence", 0.0), + speaker=None, # Flux doesn't support diarization + speaker_confidence=None, + ) + ) + + stored_params = dict(params) + if keyterms: + stored_params["keyterms"] = keyterms + + return TranscriptionResult( + provider=self.name, + transcript=final_transcript, + words=words, + speakers=[], # Flux doesn't support diarization + duration=duration, + raw_response={"transcripts": transcripts}, + params=stored_params, + ) diff --git a/evals/stt/providers/deepgram_provider.py b/evals/stt/providers/deepgram_provider.py new file mode 100644 index 0000000..8bb59eb --- /dev/null +++ b/evals/stt/providers/deepgram_provider.py @@ -0,0 +1,225 @@ +"""Deepgram STT provider with WebSocket streaming.""" + +import asyncio +import json +import os +from pathlib import Path +from typing import Any +from urllib.parse import urlencode + +from ..audio_streamer import AudioConfig, AudioStreamer +from .base import STTProvider, TranscriptionResult, Word +from loguru import logger + +try: + from websockets.asyncio.client import connect as websocket_connect +except ImportError: + raise ImportError("websockets required: pip install websockets") + + +class DeepgramProvider(STTProvider): + """Deepgram Nova Speech-to-Text provider with WebSocket streaming. + + API Docs: https://developers.deepgram.com/docs/ + + Supports: + - Speaker diarization via `diarize=true` + - Keyterm boosting via `keyterm` parameter + - Real-time streaming via WebSocket + - Multiple languages + - Punctuation + + For Flux models, use DeepgramFluxProvider instead. + """ + + WS_URL = "wss://api.deepgram.com/v1/listen" + + def __init__(self, api_key: str | None = None): + self.api_key = api_key or os.getenv("DEEPGRAM_API_KEY") + if not self.api_key: + raise ValueError( + "Deepgram API key required. Set DEEPGRAM_API_KEY env var or pass api_key." + ) + + @property + def name(self) -> str: + return "deepgram" + + async def transcribe( + self, + audio_path: Path, + diarize: bool = False, + keyterms: list[str] | None = None, + model: str = "nova-3-general", + language: str = "en", + sample_rate: int = 8000, + punctuate: bool = True, + **kwargs: Any, + ) -> TranscriptionResult: + """Transcribe audio using Deepgram Nova WebSocket streaming. + + Args: + audio_path: Path to audio file + diarize: Enable speaker diarization + keyterms: List of keywords to boost recognition + model: Deepgram Nova model (nova-3, nova-2, etc.) + language: Language code + sample_rate: Audio sample rate for streaming + punctuate: Add punctuation + **kwargs: Additional Deepgram parameters + + Returns: + TranscriptionResult with transcript and speaker info + """ + # Build query params + params: dict[str, Any] = { + "model": model, + "language": language, + "punctuate": str(punctuate).lower(), + "encoding": "linear16", + "sample_rate": sample_rate, + "channels": 1, + "interim_results": "true", + "smart_format": "true", + "profanity_filter": "true", + "vad_events": "true" + } + + if diarize: + params["diarize"] = "true" + + # Build URL with params + url_parts = [f"{k}={v}" for k, v in params.items()] + + # Add keyterms (repeated params) + if keyterms: + for term in keyterms: + url_parts.append(urlencode({"keyterm": term})) + + # Add extra kwargs + for k, v in kwargs.items(): + url_parts.append(f"{k}={v}") + + ws_url = f"{self.WS_URL}?{'&'.join(url_parts)}" + logger.debug(f"Deepgram WebSocket URL: {ws_url}") + + # Setup audio streamer + audio_config = AudioConfig(sample_rate=sample_rate) + streamer = AudioStreamer(audio_config) + + # Collect results + all_words: list[dict[str, Any]] = [] + final_transcript = "" + duration = 0.0 + + try: + async with websocket_connect( + ws_url, + additional_headers={"Authorization": f"Token {self.api_key}"}, + ) as ws: + # Create tasks for sending and receiving + send_complete = asyncio.Event() + + async def send_audio(): + """Send audio chunks to Deepgram.""" + chunk_no = 0 + async for chunk in streamer.stream_file(audio_path): + logger.debug(f"[deepgram] Sent audio chunk {chunk_no}") + await ws.send(chunk) + chunk_no += 1 + # Send close message + logger.debug(f"[deepgram] Sending CloseStream after {chunk_no} chunks") + await ws.send(json.dumps({"type": "CloseStream"})) + send_complete.set() + + async def receive_transcripts(): + """Receive and collect transcription results.""" + nonlocal all_words, final_transcript, duration + + async for message in ws: + if isinstance(message, str): + data = json.loads(message) + msg_type = data.get("type") + logger.debug(f"[deepgram] Received {msg_type}: {data}") + + if msg_type == "Results": + # Nova-style response + channel = data.get("channel", {}) + alternatives = channel.get("alternatives", []) + if alternatives: + alt = alternatives[0] + words = alt.get("words", []) + all_words.extend(words) + + # Check if final + if data.get("is_final"): + final_transcript += alt.get("transcript", "") + " " + duration = max( + duration, data.get("duration", 0) + data.get("start", 0) + ) + + elif msg_type == "Metadata": + # Get duration from metadata + duration = data.get("duration", duration) + + elif msg_type == "Error": + raise Exception(f"Deepgram error: {data}") + + # Run send and receive concurrently + send_task = asyncio.create_task(send_audio()) + receive_task = asyncio.create_task(receive_transcripts()) + + # Wait for send to complete, then wait a bit for final results + await send_task + try: + await asyncio.wait_for(receive_task, timeout=5.0) + except asyncio.TimeoutError: + pass # Normal - websocket closes after final results + except Exception as e: + logger.debug(e) + + return self._parse_results( + all_words, final_transcript.strip(), duration, params, keyterms + ) + + def _parse_results( + self, + raw_words: list[dict[str, Any]], + transcript: str, + duration: float, + params: dict[str, Any], + keyterms: list[str] | None, + ) -> TranscriptionResult: + """Parse collected results into TranscriptionResult.""" + words = [] + speakers_set: set[str] = set() + + for w in raw_words: + speaker = str(w.get("speaker", "")) if "speaker" in w else None + if speaker: + speakers_set.add(speaker) + + words.append( + Word( + word=w.get("word", ""), + start=w.get("start", 0.0), + end=w.get("end", 0.0), + confidence=w.get("confidence", 0.0), + speaker=speaker, + speaker_confidence=w.get("speaker_confidence"), + ) + ) + + stored_params = dict(params) + if keyterms: + stored_params["keyterms"] = keyterms + + return TranscriptionResult( + provider=self.name, + transcript=transcript, + words=words, + speakers=sorted(speakers_set), + duration=duration, + raw_response={"words": raw_words}, + params=stored_params, + ) diff --git a/evals/stt/providers/local_smart_turn_provider.py b/evals/stt/providers/local_smart_turn_provider.py new file mode 100644 index 0000000..3116ca0 --- /dev/null +++ b/evals/stt/providers/local_smart_turn_provider.py @@ -0,0 +1,285 @@ +"""Local Smart Turn provider for benchmarking end-of-turn detection. + +Uses the pipecat smart-turn-v3 ONNX model for local ML-based turn detection. +This is NOT an STT provider - it only detects when a speaker has finished talking. +""" + +import os +import time +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +import numpy as np +from loguru import logger + +from ..audio_streamer import AudioConfig, AudioStreamer +from .base import STTProvider, TranscriptionResult, Word + +try: + import onnxruntime as ort + from transformers import WhisperFeatureExtractor +except ImportError: + raise ImportError( + "onnxruntime and transformers required: pip install onnxruntime transformers" + ) + + +@dataclass +class TurnEvent: + """Represents a detected turn event.""" + timestamp: float # Time in audio when turn was detected + probability: float # Model confidence + prediction: int # 1=complete, 0=incomplete + inference_time_ms: float + + +class LocalSmartTurnProvider(STTProvider): + """Local Smart Turn provider for end-of-turn detection benchmarking. + + Uses the smart-turn-v3 ONNX model to detect when speakers finish talking. + This is useful for comparing turn detection accuracy against cloud services + like Deepgram Flux's built-in turn detection. + + NOTE: This provider does NOT produce transcripts - only turn detection events. + """ + + # Smart turn model requires 16kHz audio + REQUIRED_SAMPLE_RATE = 16000 + # Model analyzes 8 seconds of audio + WINDOW_SECONDS = 8 + + def __init__( + self, + model_path: str | None = None, + cpu_count: int = 1, + ): + """Initialize the local smart turn provider. + + Args: + model_path: Path to ONNX model file. If None, uses bundled model. + cpu_count: Number of CPUs for inference (default: 1) + """ + self.model_path = model_path + self.cpu_count = cpu_count + self._session = None + self._feature_extractor = None + + def _load_model(self): + """Lazy load the ONNX model.""" + if self._session is not None: + return + + model_path = self.model_path + + if not model_path: + # Try to load bundled model from pipecat + model_name = "smart-turn-v3.1-cpu.onnx" + package_path = "pipecat.audio.turn.smart_turn.data" + + try: + import importlib_resources as impresources + model_path = str(impresources.files(package_path).joinpath(model_name)) + except Exception: + from importlib import resources as impresources + try: + with impresources.path(package_path, model_name) as f: + model_path = str(f) + except Exception: + model_path = str(impresources.files(package_path).joinpath(model_name)) + + logger.info(f"[local-smart-turn] Loading model from {model_path}") + + # Configure ONNX runtime + so = ort.SessionOptions() + so.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL + so.inter_op_num_threads = 1 + so.intra_op_num_threads = self.cpu_count + so.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL + + self._feature_extractor = WhisperFeatureExtractor(chunk_length=8) + self._session = ort.InferenceSession(model_path, sess_options=so) + + logger.info("[local-smart-turn] Model loaded") + + @property + def name(self) -> str: + return "local-smart-turn" + + def _predict_endpoint(self, audio_array: np.ndarray) -> dict[str, Any]: + """Predict end-of-turn using the ONNX model. + + Args: + audio_array: Audio samples as float32 numpy array (16kHz) + + Returns: + Dict with prediction (0/1) and probability + """ + # Truncate to last 8 seconds or pad to 8 seconds + max_samples = self.WINDOW_SECONDS * self.REQUIRED_SAMPLE_RATE + if len(audio_array) > max_samples: + audio_array = audio_array[-max_samples:] + elif len(audio_array) < max_samples: + padding = max_samples - len(audio_array) + audio_array = np.pad(audio_array, (padding, 0), mode="constant", constant_values=0) + + # Process using Whisper's feature extractor + inputs = self._feature_extractor( + audio_array, + sampling_rate=self.REQUIRED_SAMPLE_RATE, + return_tensors="np", + padding="max_length", + max_length=self.WINDOW_SECONDS * self.REQUIRED_SAMPLE_RATE, + truncation=True, + do_normalize=True, + ) + + # Extract features for ONNX + input_features = inputs.input_features.squeeze(0).astype(np.float32) + input_features = np.expand_dims(input_features, axis=0) + + # Run inference + start_time = time.perf_counter() + outputs = self._session.run(None, {"input_features": input_features}) + inference_time = (time.perf_counter() - start_time) * 1000 + + # Extract probability (model returns sigmoid probabilities) + probability = outputs[0][0].item() + prediction = 1 if probability > 0.5 else 0 + + return { + "prediction": prediction, + "probability": probability, + "inference_time_ms": inference_time, + } + + async def transcribe( + self, + audio_path: Path, + diarize: bool = False, # Ignored - not applicable + keyterms: list[str] | None = None, # Ignored - not applicable + sample_rate: int = 16000, # Must be 16kHz for smart turn + analysis_interval_ms: int = 500, # How often to check for turn completion + **kwargs: Any, + ) -> TranscriptionResult: + """Analyze audio for turn detection events. + + NOTE: This does NOT produce transcripts. It detects when speakers + finish talking using ML-based turn detection. + + Args: + audio_path: Path to audio file + diarize: Ignored (not applicable for turn detection) + keyterms: Ignored (not applicable for turn detection) + sample_rate: Must be 16000 Hz for smart turn model + analysis_interval_ms: How often to run turn detection (ms) + **kwargs: Additional parameters (ignored) + + Returns: + TranscriptionResult with turn detection events in raw_response + """ + if sample_rate != self.REQUIRED_SAMPLE_RATE: + logger.warning( + f"[local-smart-turn] Sample rate must be {self.REQUIRED_SAMPLE_RATE}Hz, " + f"overriding {sample_rate}Hz" + ) + sample_rate = self.REQUIRED_SAMPLE_RATE + + # Load model if not already loaded + self._load_model() + + # Setup audio streamer at 16kHz + audio_config = AudioConfig(sample_rate=sample_rate) + streamer = AudioStreamer(audio_config) + + # Get audio duration + duration = streamer.get_duration(audio_path) + logger.info(f"[local-smart-turn] Processing {audio_path} ({duration:.2f}s)") + + # Collect all audio first (smart turn needs to analyze segments) + pcm_data = streamer.convert_to_pcm16(audio_path) + + # Convert to float32 for model + audio_int16 = np.frombuffer(pcm_data, dtype=np.int16) + audio_float32 = audio_int16.astype(np.float32) / 32768.0 + + # Analyze at intervals + turn_events: list[TurnEvent] = [] + samples_per_interval = int(sample_rate * analysis_interval_ms / 1000) + window_samples = self.WINDOW_SECONDS * sample_rate + + chunk_no = 0 + for end_sample in range(samples_per_interval, len(audio_float32), samples_per_interval): + # Get window of audio ending at current position + start_sample = max(0, end_sample - window_samples) + audio_window = audio_float32[start_sample:end_sample] + + current_time = end_sample / sample_rate + logger.debug(f"[local-smart-turn] Analyzing chunk {chunk_no} at {current_time:.2f}s") + + result = self._predict_endpoint(audio_window) + + turn_events.append(TurnEvent( + timestamp=current_time, + probability=result["probability"], + prediction=result["prediction"], + inference_time_ms=result["inference_time_ms"], + )) + + if result["prediction"] == 1: + logger.info( + f"[local-smart-turn] Turn complete at {current_time:.2f}s " + f"(prob={result['probability']:.3f})" + f"(inf time ms={result["inference_time_ms"]})" + ) + + chunk_no += 1 + + # Create result + # Convert turn events to word-like format for compatibility + words = [] + for event in turn_events: + if event.prediction == 1: + words.append(Word( + word=f"[END_OF_TURN prob={event.probability:.2f}]", + start=event.timestamp - 0.1, + end=event.timestamp, + confidence=event.probability, + speaker=None, + speaker_confidence=None, + )) + + # Count completed turns + completed_turns = sum(1 for e in turn_events if e.prediction == 1) + + params = { + "sample_rate": sample_rate, + "analysis_interval_ms": analysis_interval_ms, + "window_seconds": self.WINDOW_SECONDS, + } + + return TranscriptionResult( + provider=self.name, + transcript=f"[Turn detection only - {completed_turns} turns detected]", + words=words, + speakers=[], # Not applicable + duration=duration, + raw_response={ + "turn_events": [ + { + "timestamp": e.timestamp, + "probability": e.probability, + "prediction": e.prediction, + "inference_time_ms": e.inference_time_ms, + } + for e in turn_events + ], + "completed_turns": completed_turns, + "total_analyses": len(turn_events), + "avg_inference_time_ms": ( + sum(e.inference_time_ms for e in turn_events) / len(turn_events) + if turn_events else 0 + ), + }, + params=params, + ) diff --git a/evals/stt/providers/speechmatics_provider.py b/evals/stt/providers/speechmatics_provider.py new file mode 100644 index 0000000..3f84a5c --- /dev/null +++ b/evals/stt/providers/speechmatics_provider.py @@ -0,0 +1,248 @@ +"""Speechmatics STT provider with WebSocket streaming.""" + +import asyncio +import json +import os +from pathlib import Path +from typing import Any + +from loguru import logger + +from ..audio_streamer import AudioConfig, AudioStreamer +from .base import STTProvider, TranscriptionResult, Word + +try: + from websockets.asyncio.client import connect as websocket_connect +except ImportError: + raise ImportError("websockets required: pip install websockets") + + +class SpeechmaticsProvider(STTProvider): + """Speechmatics Speech-to-Text provider with WebSocket streaming. + + API Docs: https://docs.speechmatics.com/ + + Supports: + - Speaker diarization via `diarization: "speaker"` config + - Speaker sensitivity tuning + - Real-time streaming via WebSocket + """ + + def __init__(self, api_key: str | None = None, region: str = "eu2"): + self.api_key = api_key or os.getenv("SPEECHMATICS_API_KEY") + if not self.api_key: + raise ValueError( + "Speechmatics API key required. Set SPEECHMATICS_API_KEY env var or pass api_key." + ) + # Set region-specific endpoint + self.ws_url = f"wss://{region}.rt.speechmatics.com/v2" + + @property + def name(self) -> str: + return "speechmatics" + + async def transcribe( + self, + audio_path: Path, + diarize: bool = False, + keyterms: list[str] | None = None, + language: str = "en", + operating_point: str = "enhanced", + sample_rate: int = 8000, + speaker_sensitivity: float | None = None, + max_speakers: int | None = None, + **kwargs: Any, + ) -> TranscriptionResult: + """Transcribe audio using Speechmatics WebSocket streaming. + + Args: + audio_path: Path to audio file + diarize: Enable speaker diarization + keyterms: Additional vocabulary (limited support) + language: Language code + operating_point: "standard" or "enhanced" + sample_rate: Audio sample rate for streaming + speaker_sensitivity: 0.0-1.0, higher = more speakers detected + max_speakers: Maximum number of speakers to detect + **kwargs: Additional config parameters + + Returns: + TranscriptionResult with transcript and speaker info + """ + # Build transcription config for StartRecognition message + transcription_config: dict[str, Any] = { + "language": language, + "operating_point": operating_point, + "enable_partials": False, + } + + if diarize: + transcription_config["diarization"] = "speaker" + if speaker_sensitivity is not None: + transcription_config["speaker_diarization_config"] = { + "speaker_sensitivity": speaker_sensitivity + } + if max_speakers is not None: + if "speaker_diarization_config" not in transcription_config: + transcription_config["speaker_diarization_config"] = {} + transcription_config["speaker_diarization_config"]["max_speakers"] = max_speakers + + # Add additional vocabulary if provided + if keyterms: + transcription_config["additional_vocab"] = [{"content": term} for term in keyterms] + + # Audio format config + audio_format = { + "type": "raw", + "encoding": "pcm_s16le", + "sample_rate": sample_rate, + } + + # Store params for result + params = { + "diarize": diarize, + "language": language, + "operating_point": operating_point, + "sample_rate": sample_rate, + "speaker_sensitivity": speaker_sensitivity, + "max_speakers": max_speakers, + } + + # Setup audio streamer + audio_config = AudioConfig(sample_rate=sample_rate) + streamer = AudioStreamer(audio_config) + + # Collect results + all_results: list[dict[str, Any]] = [] + recognition_started = asyncio.Event() + transcription_complete = asyncio.Event() + + async with websocket_connect( + self.ws_url, + additional_headers={"Authorization": f"Bearer {self.api_key}"}, + ) as ws: + # Send StartRecognition message + start_msg = { + "message": "StartRecognition", + "transcription_config": transcription_config, + "audio_format": audio_format, + } + await ws.send(json.dumps(start_msg)) + + async def send_audio(): + """Send audio chunks after recognition starts.""" + await recognition_started.wait() + + chunk_no = 0 + async for chunk in streamer.stream_file(audio_path): + logger.debug(f"[speechmatics] Sent audio chunk {chunk_no}") + await ws.send(chunk) + chunk_no += 1 + + # Signal end of audio with last sequence number + logger.debug(f"[speechmatics] Sending EndOfStream after {chunk_no} chunks") + await ws.send(json.dumps({"message": "EndOfStream", "last_seq_no": chunk_no})) + + async def receive_messages(): + """Receive and process messages.""" + nonlocal all_results + + async for message in ws: + if isinstance(message, str): + data = json.loads(message) + msg_type = data.get("message") + logger.debug(f"[speechmatics] Received {msg_type}: {data}") + + if msg_type == "RecognitionStarted": + logger.info("[speechmatics] Connected") + recognition_started.set() + + elif msg_type == "AddTranscript": + # Final transcript segment + results = data.get("results", []) + all_results.extend(results) + + elif msg_type == "EndOfTranscript": + transcription_complete.set() + return + + elif msg_type == "Error": + raise Exception(f"Speechmatics error: {data}") + + elif msg_type == "Warning": + logger.warning(f"[speechmatics] Warning: {data.get('reason')}") + + # Run send and receive concurrently + send_task = asyncio.create_task(send_audio()) + receive_task = asyncio.create_task(receive_messages()) + + # Wait for completion + await send_task + try: + await asyncio.wait_for(transcription_complete.wait(), timeout=30.0) + except asyncio.TimeoutError: + pass + + receive_task.cancel() + try: + await receive_task + except asyncio.CancelledError: + pass + + return self._parse_results(all_results, params) + + def _parse_results( + self, + results: list[dict[str, Any]], + params: dict[str, Any], + ) -> TranscriptionResult: + """Parse Speechmatics results.""" + words = [] + speakers_set: set[str] = set() + transcript_parts = [] + duration = 0.0 + + for item in results: + item_type = item.get("type") + alternatives = item.get("alternatives", []) + + if not alternatives: + continue + + alt = alternatives[0] + content = alt.get("content", "") + speaker = alt.get("speaker") + + if speaker: + speakers_set.add(speaker) + + end_time = item.get("end_time", 0.0) + duration = max(duration, end_time) + + if item_type == "word": + words.append( + Word( + word=content, + start=item.get("start_time", 0.0), + end=end_time, + confidence=alt.get("confidence", 0.0), + speaker=speaker, + speaker_confidence=None, + ) + ) + transcript_parts.append(content) + elif item_type == "punctuation": + if transcript_parts: + transcript_parts[-1] += content + + transcript = " ".join(transcript_parts) + + return TranscriptionResult( + provider=self.name, + transcript=transcript, + words=words, + speakers=sorted(speakers_set), + duration=duration, + raw_response={"results": results}, + params=params, + )