diff --git a/evals/stt/README.md b/evals/stt/README.md index b176039..6e2802b 100644 --- a/evals/stt/README.md +++ b/evals/stt/README.md @@ -1,27 +1,29 @@ # STT Evaluation Benchmark -Benchmark for comparing Speech-to-Text providers with focus on: +Benchmark for comparing Speech-to-Text providers using **WebSocket streaming** with focus on: - **Speaker diarization** - identifying who said what - **Keyterm boosting** - improving recognition of specific terms (Deepgram) ## Providers -| Provider | Diarization | Keyterm Boost | Notes | -|----------|-------------|---------------|-------| -| Deepgram | Yes | Yes | `diarize=true`, `keyterm` param | -| Speechmatics | Yes | No | `diarization: "speaker"` config | +| Provider | Diarization | Keyterm Boost | Streaming | +|----------|-------------|---------------|-----------| +| Deepgram | Yes | Yes | WebSocket (v1/v2) | +| Speechmatics | Yes | Additional vocab | WebSocket RT | ## Setup ```bash -# Install dependencies (httpx is required) -pip install httpx +# Install dependencies +pip install websockets # Set API keys export DEEPGRAM_API_KEY="your-key" export SPEECHMATICS_API_KEY="your-key" ``` +**Note:** Requires `ffmpeg` installed for audio conversion to PCM16. + ## Usage Run from the project root directory: @@ -33,9 +35,12 @@ python -m evals.stt.benchmark audio/multi_speaker.m4a --diarize # Test only Deepgram python -m evals.stt.benchmark audio/multi_speaker.m4a --diarize --providers deepgram -# Test with keyterm boosting (Deepgram only) +# Test with keyterm boosting (Deepgram) python -m evals.stt.benchmark audio/multi_speaker.m4a --diarize --keyterms "Dograh" "Pipecat" +# Use different sample rate (default: 8000 Hz) +python -m evals.stt.benchmark audio/multi_speaker.m4a --diarize --sample-rate 16000 + # Show word-level timings python -m evals.stt.benchmark audio/multi_speaker.m4a --diarize --show-words @@ -50,8 +55,9 @@ python -m evals.stt.benchmark audio/multi_speaker.m4a --diarize --save | `audio_file` | Path to audio file (relative to evals/stt/ or absolute) | | `--providers` | Providers to test: `deepgram`, `speechmatics` (default: both) | | `--diarize` | Enable speaker diarization | -| `--keyterms` | Keywords to boost (Deepgram only) | +| `--keyterms` | Keywords to boost (Deepgram) / additional vocab (Speechmatics) | | `--language` | Language code (default: en) | +| `--sample-rate` | Audio sample rate for streaming (default: 8000) | | `--show-words` | Show individual word timings | | `--save` | Save results to JSON in `results/` | @@ -64,16 +70,33 @@ evals/stt/ ├── results/ # Saved benchmark results (JSON) ├── providers/ # STT provider implementations │ ├── base.py # Base classes -│ ├── deepgram_provider.py -│ └── speechmatics_provider.py +│ ├── deepgram_provider.py # WebSocket streaming +│ └── speechmatics_provider.py # WebSocket streaming +├── audio_streamer.py # PCM16 audio file streamer ├── benchmark.py # Main runner script └── README.md ``` +## How It Works + +1. **Audio Conversion**: The `AudioStreamer` converts any audio file to raw PCM16 using ffmpeg +2. **WebSocket Connection**: Providers connect to their respective WebSocket APIs +3. **Streaming**: Audio is sent in chunks (configurable sample rate, default 8kHz) +4. **Result Collection**: Transcripts and speaker info are collected from WebSocket responses +5. **Comparison**: Results are parsed into a common format for comparison + ## Output Example ``` +Audio file: /path/to/audio/multi_speaker.m4a +Providers: ['deepgram', 'speechmatics'] +Diarization: True +Sample rate: 8000 Hz + +============================================================ Provider: DEEPGRAM +============================================================ + Duration: 45.32s Speakers detected: 2 - ['0', '1'] @@ -84,17 +107,29 @@ Hello, welcome to the demo... [0.0s] Speaker 0: Hello, welcome to the demo. [2.5s] Speaker 1: Thanks for having me. ... + +============================================================ +COMPARISON SUMMARY +============================================================ + +Provider Duration Speakers Words +--------------------------------------------- +deepgram 45.32 2 312 +speechmatics 45.32 2 308 ``` ## Adding New Providers 1. Create a new file in `providers/` (e.g., `whisper_provider.py`) -2. Implement the `STTProvider` abstract class -3. Add to `providers/__init__.py` -4. Add to `benchmark.py` provider choices +2. Implement the `STTProvider` abstract class with WebSocket streaming +3. Use `AudioStreamer` for PCM16 conversion +4. Add to `providers/__init__.py` +5. Add to `benchmark.py` provider choices ## API Documentation +- Deepgram Streaming: https://developers.deepgram.com/docs/live-streaming-audio - Deepgram Diarization: https://developers.deepgram.com/docs/diarization - Deepgram Keyterms: https://developers.deepgram.com/docs/keyterm +- Speechmatics RT API: https://docs.speechmatics.com/rt-api-ref - Speechmatics Diarization: https://docs.speechmatics.com/features/diarization diff --git a/evals/stt/audio/vad.m4a b/evals/stt/audio/vad.m4a new file mode 100644 index 0000000..e0488c7 Binary files /dev/null and b/evals/stt/audio/vad.m4a differ diff --git a/evals/stt/audio_streamer.py b/evals/stt/audio_streamer.py new file mode 100644 index 0000000..2eebc2d --- /dev/null +++ b/evals/stt/audio_streamer.py @@ -0,0 +1,127 @@ +"""Audio file streamer - converts audio files to PCM16 streams.""" + +import asyncio +import subprocess +from dataclasses import dataclass +from pathlib import Path +from typing import AsyncIterator + + +@dataclass +class AudioConfig: + """Audio streaming configuration.""" + + sample_rate: int = 8000 + channels: int = 1 + sample_width: int = 2 # 16-bit = 2 bytes + chunk_duration_ms: int = 80 # Send chunks every 80ms + + @property + def chunk_size(self) -> int: + """Bytes per chunk based on duration.""" + samples_per_chunk = int(self.sample_rate * self.chunk_duration_ms / 1000) + return samples_per_chunk * self.channels * self.sample_width + + +class AudioStreamer: + """Streams audio files as PCM16 chunks. + + Converts any audio format to raw PCM16 using ffmpeg and streams + in real-time chunks to simulate live audio. + """ + + def __init__(self, config: AudioConfig | None = None): + self.config = config or AudioConfig() + + def convert_to_pcm16(self, audio_path: Path) -> bytes: + """Convert audio file to raw PCM16 bytes using ffmpeg. + + Args: + audio_path: Path to input audio file + + Returns: + Raw PCM16 audio bytes + """ + cmd = [ + "ffmpeg", + "-i", + str(audio_path), + "-f", + "s16le", # signed 16-bit little-endian + "-acodec", + "pcm_s16le", + "-ar", + str(self.config.sample_rate), + "-ac", + str(self.config.channels), + "-", # output to stdout + ] + + result = subprocess.run( + cmd, + capture_output=True, + check=True, + ) + return result.stdout + + async def stream_file( + self, + audio_path: Path, + realtime: bool = True, + ) -> AsyncIterator[bytes]: + """Stream audio file as PCM16 chunks. + + Args: + audio_path: Path to audio file + realtime: If True, add delays to simulate real-time streaming + + Yields: + PCM16 audio chunks + """ + # Convert entire file to PCM16 + pcm_data = self.convert_to_pcm16(audio_path) + + chunk_size = self.config.chunk_size + delay = self.config.chunk_duration_ms / 1000.0 if realtime else 0 + + # Stream in chunks + for i in range(0, len(pcm_data), chunk_size): + chunk = pcm_data[i : i + chunk_size] + if chunk: + yield chunk + if realtime and delay > 0: + await asyncio.sleep(delay) + + async def stream_file_fast(self, audio_path: Path) -> AsyncIterator[bytes]: + """Stream audio file as fast as possible (no real-time delay). + + Args: + audio_path: Path to audio file + + Yields: + PCM16 audio chunks + """ + async for chunk in self.stream_file(audio_path, realtime=False): + yield chunk + + def get_duration(self, audio_path: Path) -> float: + """Get audio file duration in seconds. + + Args: + audio_path: Path to audio file + + Returns: + Duration in seconds + """ + cmd = [ + "ffprobe", + "-v", + "error", + "-show_entries", + "format=duration", + "-of", + "default=noprint_wrappers=1:nokey=1", + str(audio_path), + ] + result = subprocess.run(cmd, capture_output=True, text=True, check=True) + return float(result.stdout.strip()) diff --git a/evals/stt/benchmark.py b/evals/stt/benchmark.py index c1b5fed..7740cee 100644 --- a/evals/stt/benchmark.py +++ b/evals/stt/benchmark.py @@ -20,14 +20,23 @@ from datetime import datetime from pathlib import Path from typing import Any -from evals.stt.providers import DeepgramProvider, SpeechmaticsProvider, STTProvider, TranscriptionResult +from evals.stt.providers import ( + DeepgramProvider, + DeepgramFluxProvider, + SpeechmaticsProvider, + LocalSmartTurnProvider, + STTProvider, + TranscriptionResult, +) def get_provider(name: str) -> STTProvider: """Get provider instance by name.""" providers = { "deepgram": DeepgramProvider, + "deepgram-flux": DeepgramFluxProvider, "speechmatics": SpeechmaticsProvider, + "local-smart-turn": LocalSmartTurnProvider, } if name not in providers: raise ValueError(f"Unknown provider: {name}. Available: {list(providers.keys())}") @@ -145,7 +154,7 @@ Examples: "--providers", nargs="+", default=["deepgram", "speechmatics"], - choices=["deepgram", "speechmatics"], + choices=["deepgram", "deepgram-flux", "speechmatics", "local-smart-turn"], help="Providers to test (default: all)", ) parser.add_argument( @@ -163,6 +172,12 @@ Examples: default="en", help="Language code (default: en)", ) + parser.add_argument( + "--sample-rate", + type=int, + default=8000, + help="Audio sample rate for streaming (default: 8000)", + ) parser.add_argument( "--show-words", action="store_true", @@ -195,6 +210,7 @@ Examples: print(f"Audio file: {audio_path}") print(f"Providers: {args.providers}") print(f"Diarization: {args.diarize}") + print(f"Sample rate: {args.sample_rate} Hz") if args.keyterms: print(f"Keyterms: {args.keyterms}") @@ -209,6 +225,7 @@ Examples: diarize=args.diarize, keyterms=args.keyterms, language=args.language, + sample_rate=args.sample_rate, ) print_result(result, show_words=args.show_words) results.append(result) diff --git a/evals/stt/providers/__init__.py b/evals/stt/providers/__init__.py index d73a2b8..12bf3ab 100644 --- a/evals/stt/providers/__init__.py +++ b/evals/stt/providers/__init__.py @@ -1,11 +1,15 @@ from .base import STTProvider, TranscriptionResult, Word from .deepgram_provider import DeepgramProvider +from .deepgram_flux_provider import DeepgramFluxProvider from .speechmatics_provider import SpeechmaticsProvider +from .local_smart_turn_provider import LocalSmartTurnProvider __all__ = [ "STTProvider", "TranscriptionResult", "Word", "DeepgramProvider", + "DeepgramFluxProvider", "SpeechmaticsProvider", + "LocalSmartTurnProvider", ] diff --git a/evals/stt/providers/deepgram_flux_provider.py b/evals/stt/providers/deepgram_flux_provider.py new file mode 100644 index 0000000..bd7d16d --- /dev/null +++ b/evals/stt/providers/deepgram_flux_provider.py @@ -0,0 +1,225 @@ +"""Deepgram Flux STT provider with WebSocket streaming. + +Flux is Deepgram's conversational AI model with built-in turn detection. +It has a different API than Nova models - no language/punctuate/diarize params. +""" + +import asyncio +import json +import os +from pathlib import Path +from typing import Any +from urllib.parse import urlencode + +from loguru import logger + +from ..audio_streamer import AudioConfig, AudioStreamer +from .base import STTProvider, TranscriptionResult, Word + +try: + from websockets.asyncio.client import connect as websocket_connect +except ImportError: + raise ImportError("websockets required: pip install websockets") + + +class DeepgramFluxProvider(STTProvider): + """Deepgram Flux Speech-to-Text provider with WebSocket streaming. + + Flux is optimized for conversational AI with built-in turn detection. + + Key differences from Nova: + - Uses v2 API endpoint + - Only supports English (flux-general-en) + - No punctuate, diarize, or language params + - Has turn detection events (StartOfTurn, EndOfTurn, EagerEndOfTurn) + - Supports keyterm boosting + + API Docs: https://developers.deepgram.com/docs/ + """ + + WS_URL = "wss://api.deepgram.com/v2/listen" + + def __init__(self, api_key: str | None = None): + self.api_key = api_key or os.getenv("DEEPGRAM_API_KEY") + if not self.api_key: + raise ValueError( + "Deepgram API key required. Set DEEPGRAM_API_KEY env var or pass api_key." + ) + + @property + def name(self) -> str: + return "deepgram-flux" + + async def transcribe( + self, + audio_path: Path, + diarize: bool = False, # Ignored - Flux doesn't support diarization + keyterms: list[str] | None = None, + model: str = "flux-general-en", + sample_rate: int = 16000, + eot_threshold: float | None = None, + eot_timeout_ms: int | None = None, + eager_eot_threshold: float | None = None, + **kwargs: Any, + ) -> TranscriptionResult: + """Transcribe audio using Deepgram Flux WebSocket streaming. + + Args: + audio_path: Path to audio file + diarize: IGNORED - Flux does not support diarization + keyterms: List of keywords to boost recognition + model: Flux model (default: flux-general-en) + sample_rate: Audio sample rate (default: 16000 for Flux) + eot_threshold: End-of-turn confidence threshold (0-1, default 0.7) + eot_timeout_ms: Timeout in ms to force end of turn (default 5000) + eager_eot_threshold: Threshold for eager end-of-turn events + **kwargs: Additional Flux parameters + + Returns: + TranscriptionResult with transcript (no speaker info - Flux doesn't support diarization) + """ + if diarize: + logger.warning("Flux does not support diarization - ignoring diarize=True") + + # Build query params - Flux only supports specific params + params: dict[str, Any] = { + "model": model, + "encoding": "linear16", + "sample_rate": sample_rate, + } + + # Flux-specific turn detection params + if eot_threshold is not None: + params["eot_threshold"] = eot_threshold + if eot_timeout_ms is not None: + params["eot_timeout_ms"] = eot_timeout_ms + if eager_eot_threshold is not None: + params["eager_eot_threshold"] = eager_eot_threshold + + # Build URL with params + url_parts = [f"{k}={v}" for k, v in params.items()] + + # Add keyterms (repeated params) + if keyterms: + for term in keyterms: + url_parts.append(urlencode({"keyterm": term})) + + ws_url = f"{self.WS_URL}?{'&'.join(url_parts)}" + logger.debug(f"Flux WebSocket URL: {ws_url}") + + # Setup audio streamer + audio_config = AudioConfig(sample_rate=sample_rate) + streamer = AudioStreamer(audio_config) + + # Collect results + all_transcripts: list[dict[str, Any]] = [] + final_transcript = "" + duration = 0.0 + connected = asyncio.Event() + + async with websocket_connect( + ws_url, + additional_headers={"Authorization": f"Token {self.api_key}"}, + ) as ws: + + async def send_audio(): + """Send audio chunks to Deepgram Flux.""" + await connected.wait() + + chunk_no = 0 + async for chunk in streamer.stream_file(audio_path): + logger.debug(f"[deepgram-flux] Sent audio chunk {chunk_no}") + await ws.send(chunk) + chunk_no += 1 + + async def receive_messages(): + """Receive and collect Flux messages.""" + nonlocal all_transcripts, final_transcript, duration + + async for message in ws: + if isinstance(message, str): + data = json.loads(message) + msg_type = data.get("type") + logger.debug(f"[deepgram-flux] Received {msg_type}: {data}") + + if msg_type == "Connected": + logger.info("[deepgram-flux] Connected") + connected.set() + + elif msg_type == "TurnInfo": + event = data.get("event") + transcript = data.get("transcript", "") + words = data.get("words", []) + + if event == "EndOfTurn": + if transcript: + final_transcript += transcript + " " + if words: + all_transcripts.append({ + "transcript": transcript, + "words": words, + }) + # Get duration from last word + if words: + last_word = words[-1] + duration = max(duration, last_word.get("end", 0)) + + elif event == "TurnResumed": + logger.debug("TurnResumed") + + elif msg_type == "Error": + raise Exception(f"Deepgram Flux error: {data}") + + # Run send and receive concurrently + send_task = asyncio.create_task(send_audio()) + receive_task = asyncio.create_task(receive_messages()) + + await send_task + + logger.debug("[deepgram-flux] Send task done") + try: + await asyncio.wait_for(receive_task, timeout=10.0) + except asyncio.TimeoutError: + pass + + return self._parse_results( + all_transcripts, final_transcript.strip(), duration, params, keyterms + ) + + def _parse_results( + self, + transcripts: list[dict[str, Any]], + final_transcript: str, + duration: float, + params: dict[str, Any], + keyterms: list[str] | None, + ) -> TranscriptionResult: + """Parse collected Flux results into TranscriptionResult.""" + words = [] + + for turn in transcripts: + for w in turn.get("words", []): + words.append( + Word( + word=w.get("word", ""), + start=w.get("start", 0.0), + end=w.get("end", 0.0), + confidence=w.get("confidence", 0.0), + speaker=None, # Flux doesn't support diarization + speaker_confidence=None, + ) + ) + + stored_params = dict(params) + if keyterms: + stored_params["keyterms"] = keyterms + + return TranscriptionResult( + provider=self.name, + transcript=final_transcript, + words=words, + speakers=[], # Flux doesn't support diarization + duration=duration, + raw_response={"transcripts": transcripts}, + params=stored_params, + ) diff --git a/evals/stt/providers/deepgram_provider.py b/evals/stt/providers/deepgram_provider.py index c3b1ebd..8bb59eb 100644 --- a/evals/stt/providers/deepgram_provider.py +++ b/evals/stt/providers/deepgram_provider.py @@ -1,25 +1,38 @@ -"""Deepgram STT provider.""" +"""Deepgram STT provider with WebSocket streaming.""" +import asyncio +import json import os from pathlib import Path from typing import Any +from urllib.parse import urlencode -import httpx - +from ..audio_streamer import AudioConfig, AudioStreamer from .base import STTProvider, TranscriptionResult, Word +from loguru import logger + +try: + from websockets.asyncio.client import connect as websocket_connect +except ImportError: + raise ImportError("websockets required: pip install websockets") class DeepgramProvider(STTProvider): - """Deepgram Speech-to-Text provider. + """Deepgram Nova Speech-to-Text provider with WebSocket streaming. API Docs: https://developers.deepgram.com/docs/ Supports: - Speaker diarization via `diarize=true` - - Keyterm boosting via `keyterm` parameter (Nova-3 and Flux models) + - Keyterm boosting via `keyterm` parameter + - Real-time streaming via WebSocket + - Multiple languages + - Punctuation + + For Flux models, use DeepgramFluxProvider instead. """ - API_URL = "https://api.deepgram.com/v1/listen" + WS_URL = "wss://api.deepgram.com/v1/listen" def __init__(self, api_key: str | None = None): self.api_key = api_key or os.getenv("DEEPGRAM_API_KEY") @@ -37,113 +50,151 @@ class DeepgramProvider(STTProvider): audio_path: Path, diarize: bool = False, keyterms: list[str] | None = None, - model: str = "nova-3", + model: str = "nova-3-general", language: str = "en", + sample_rate: int = 8000, punctuate: bool = True, **kwargs: Any, ) -> TranscriptionResult: - """Transcribe audio using Deepgram API. + """Transcribe audio using Deepgram Nova WebSocket streaming. Args: audio_path: Path to audio file diarize: Enable speaker diarization keyterms: List of keywords to boost recognition - model: Deepgram model (nova-3, nova-2, etc.) + model: Deepgram Nova model (nova-3, nova-2, etc.) language: Language code + sample_rate: Audio sample rate for streaming punctuate: Add punctuation **kwargs: Additional Deepgram parameters Returns: TranscriptionResult with transcript and speaker info """ + # Build query params params: dict[str, Any] = { "model": model, "language": language, "punctuate": str(punctuate).lower(), + "encoding": "linear16", + "sample_rate": sample_rate, + "channels": 1, + "interim_results": "true", + "smart_format": "true", + "profanity_filter": "true", + "vad_events": "true" } if diarize: params["diarize"] = "true" - # Add keyterms (Deepgram uses repeated keyterm params) + # Build URL with params + url_parts = [f"{k}={v}" for k, v in params.items()] + + # Add keyterms (repeated params) if keyterms: - params["keyterm"] = keyterms + for term in keyterms: + url_parts.append(urlencode({"keyterm": term})) - # Add any extra kwargs - params.update(kwargs) + # Add extra kwargs + for k, v in kwargs.items(): + url_parts.append(f"{k}={v}") - # Read audio file - audio_data = audio_path.read_bytes() + ws_url = f"{self.WS_URL}?{'&'.join(url_parts)}" + logger.debug(f"Deepgram WebSocket URL: {ws_url}") - # Determine content type - suffix = audio_path.suffix.lower() - content_types = { - ".wav": "audio/wav", - ".mp3": "audio/mpeg", - ".m4a": "audio/mp4", - ".flac": "audio/flac", - ".ogg": "audio/ogg", - ".webm": "audio/webm", - } - content_type = content_types.get(suffix, "audio/wav") + # Setup audio streamer + audio_config = AudioConfig(sample_rate=sample_rate) + streamer = AudioStreamer(audio_config) - headers = { - "Authorization": f"Token {self.api_key}", - "Content-Type": content_type, - } + # Collect results + all_words: list[dict[str, Any]] = [] + final_transcript = "" + duration = 0.0 - async with httpx.AsyncClient(timeout=120.0) as client: - response = await client.post( - self.API_URL, - params=params, - headers=headers, - content=audio_data, - ) - response.raise_for_status() - data = response.json() + try: + async with websocket_connect( + ws_url, + additional_headers={"Authorization": f"Token {self.api_key}"}, + ) as ws: + # Create tasks for sending and receiving + send_complete = asyncio.Event() - return self._parse_response(data, params) + async def send_audio(): + """Send audio chunks to Deepgram.""" + chunk_no = 0 + async for chunk in streamer.stream_file(audio_path): + logger.debug(f"[deepgram] Sent audio chunk {chunk_no}") + await ws.send(chunk) + chunk_no += 1 + # Send close message + logger.debug(f"[deepgram] Sending CloseStream after {chunk_no} chunks") + await ws.send(json.dumps({"type": "CloseStream"})) + send_complete.set() - def _parse_response( - self, data: dict[str, Any], params: dict[str, Any] + async def receive_transcripts(): + """Receive and collect transcription results.""" + nonlocal all_words, final_transcript, duration + + async for message in ws: + if isinstance(message, str): + data = json.loads(message) + msg_type = data.get("type") + logger.debug(f"[deepgram] Received {msg_type}: {data}") + + if msg_type == "Results": + # Nova-style response + channel = data.get("channel", {}) + alternatives = channel.get("alternatives", []) + if alternatives: + alt = alternatives[0] + words = alt.get("words", []) + all_words.extend(words) + + # Check if final + if data.get("is_final"): + final_transcript += alt.get("transcript", "") + " " + duration = max( + duration, data.get("duration", 0) + data.get("start", 0) + ) + + elif msg_type == "Metadata": + # Get duration from metadata + duration = data.get("duration", duration) + + elif msg_type == "Error": + raise Exception(f"Deepgram error: {data}") + + # Run send and receive concurrently + send_task = asyncio.create_task(send_audio()) + receive_task = asyncio.create_task(receive_transcripts()) + + # Wait for send to complete, then wait a bit for final results + await send_task + try: + await asyncio.wait_for(receive_task, timeout=5.0) + except asyncio.TimeoutError: + pass # Normal - websocket closes after final results + except Exception as e: + logger.debug(e) + + return self._parse_results( + all_words, final_transcript.strip(), duration, params, keyterms + ) + + def _parse_results( + self, + raw_words: list[dict[str, Any]], + transcript: str, + duration: float, + params: dict[str, Any], + keyterms: list[str] | None, ) -> TranscriptionResult: - """Parse Deepgram API response.""" - results = data.get("results", {}) - channels = results.get("channels", []) - - if not channels: - return TranscriptionResult( - provider=self.name, - transcript="", - words=[], - speakers=[], - duration=0.0, - raw_response=data, - params=params, - ) - - # Get first channel, first alternative - channel = channels[0] - alternatives = channel.get("alternatives", []) - if not alternatives: - return TranscriptionResult( - provider=self.name, - transcript="", - words=[], - speakers=[], - duration=0.0, - raw_response=data, - params=params, - ) - - alt = alternatives[0] - transcript = alt.get("transcript", "") - - # Parse words with speaker info + """Parse collected results into TranscriptionResult.""" words = [] speakers_set: set[str] = set() - for w in alt.get("words", []): + for w in raw_words: speaker = str(w.get("speaker", "")) if "speaker" in w else None if speaker: speakers_set.add(speaker) @@ -159,9 +210,9 @@ class DeepgramProvider(STTProvider): ) ) - # Get duration from metadata - metadata = results.get("metadata", {}) - duration = metadata.get("duration", 0.0) + stored_params = dict(params) + if keyterms: + stored_params["keyterms"] = keyterms return TranscriptionResult( provider=self.name, @@ -169,6 +220,6 @@ class DeepgramProvider(STTProvider): words=words, speakers=sorted(speakers_set), duration=duration, - raw_response=data, - params=params, + raw_response={"words": raw_words}, + params=stored_params, ) diff --git a/evals/stt/providers/local_smart_turn_provider.py b/evals/stt/providers/local_smart_turn_provider.py new file mode 100644 index 0000000..3116ca0 --- /dev/null +++ b/evals/stt/providers/local_smart_turn_provider.py @@ -0,0 +1,285 @@ +"""Local Smart Turn provider for benchmarking end-of-turn detection. + +Uses the pipecat smart-turn-v3 ONNX model for local ML-based turn detection. +This is NOT an STT provider - it only detects when a speaker has finished talking. +""" + +import os +import time +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +import numpy as np +from loguru import logger + +from ..audio_streamer import AudioConfig, AudioStreamer +from .base import STTProvider, TranscriptionResult, Word + +try: + import onnxruntime as ort + from transformers import WhisperFeatureExtractor +except ImportError: + raise ImportError( + "onnxruntime and transformers required: pip install onnxruntime transformers" + ) + + +@dataclass +class TurnEvent: + """Represents a detected turn event.""" + timestamp: float # Time in audio when turn was detected + probability: float # Model confidence + prediction: int # 1=complete, 0=incomplete + inference_time_ms: float + + +class LocalSmartTurnProvider(STTProvider): + """Local Smart Turn provider for end-of-turn detection benchmarking. + + Uses the smart-turn-v3 ONNX model to detect when speakers finish talking. + This is useful for comparing turn detection accuracy against cloud services + like Deepgram Flux's built-in turn detection. + + NOTE: This provider does NOT produce transcripts - only turn detection events. + """ + + # Smart turn model requires 16kHz audio + REQUIRED_SAMPLE_RATE = 16000 + # Model analyzes 8 seconds of audio + WINDOW_SECONDS = 8 + + def __init__( + self, + model_path: str | None = None, + cpu_count: int = 1, + ): + """Initialize the local smart turn provider. + + Args: + model_path: Path to ONNX model file. If None, uses bundled model. + cpu_count: Number of CPUs for inference (default: 1) + """ + self.model_path = model_path + self.cpu_count = cpu_count + self._session = None + self._feature_extractor = None + + def _load_model(self): + """Lazy load the ONNX model.""" + if self._session is not None: + return + + model_path = self.model_path + + if not model_path: + # Try to load bundled model from pipecat + model_name = "smart-turn-v3.1-cpu.onnx" + package_path = "pipecat.audio.turn.smart_turn.data" + + try: + import importlib_resources as impresources + model_path = str(impresources.files(package_path).joinpath(model_name)) + except Exception: + from importlib import resources as impresources + try: + with impresources.path(package_path, model_name) as f: + model_path = str(f) + except Exception: + model_path = str(impresources.files(package_path).joinpath(model_name)) + + logger.info(f"[local-smart-turn] Loading model from {model_path}") + + # Configure ONNX runtime + so = ort.SessionOptions() + so.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL + so.inter_op_num_threads = 1 + so.intra_op_num_threads = self.cpu_count + so.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL + + self._feature_extractor = WhisperFeatureExtractor(chunk_length=8) + self._session = ort.InferenceSession(model_path, sess_options=so) + + logger.info("[local-smart-turn] Model loaded") + + @property + def name(self) -> str: + return "local-smart-turn" + + def _predict_endpoint(self, audio_array: np.ndarray) -> dict[str, Any]: + """Predict end-of-turn using the ONNX model. + + Args: + audio_array: Audio samples as float32 numpy array (16kHz) + + Returns: + Dict with prediction (0/1) and probability + """ + # Truncate to last 8 seconds or pad to 8 seconds + max_samples = self.WINDOW_SECONDS * self.REQUIRED_SAMPLE_RATE + if len(audio_array) > max_samples: + audio_array = audio_array[-max_samples:] + elif len(audio_array) < max_samples: + padding = max_samples - len(audio_array) + audio_array = np.pad(audio_array, (padding, 0), mode="constant", constant_values=0) + + # Process using Whisper's feature extractor + inputs = self._feature_extractor( + audio_array, + sampling_rate=self.REQUIRED_SAMPLE_RATE, + return_tensors="np", + padding="max_length", + max_length=self.WINDOW_SECONDS * self.REQUIRED_SAMPLE_RATE, + truncation=True, + do_normalize=True, + ) + + # Extract features for ONNX + input_features = inputs.input_features.squeeze(0).astype(np.float32) + input_features = np.expand_dims(input_features, axis=0) + + # Run inference + start_time = time.perf_counter() + outputs = self._session.run(None, {"input_features": input_features}) + inference_time = (time.perf_counter() - start_time) * 1000 + + # Extract probability (model returns sigmoid probabilities) + probability = outputs[0][0].item() + prediction = 1 if probability > 0.5 else 0 + + return { + "prediction": prediction, + "probability": probability, + "inference_time_ms": inference_time, + } + + async def transcribe( + self, + audio_path: Path, + diarize: bool = False, # Ignored - not applicable + keyterms: list[str] | None = None, # Ignored - not applicable + sample_rate: int = 16000, # Must be 16kHz for smart turn + analysis_interval_ms: int = 500, # How often to check for turn completion + **kwargs: Any, + ) -> TranscriptionResult: + """Analyze audio for turn detection events. + + NOTE: This does NOT produce transcripts. It detects when speakers + finish talking using ML-based turn detection. + + Args: + audio_path: Path to audio file + diarize: Ignored (not applicable for turn detection) + keyterms: Ignored (not applicable for turn detection) + sample_rate: Must be 16000 Hz for smart turn model + analysis_interval_ms: How often to run turn detection (ms) + **kwargs: Additional parameters (ignored) + + Returns: + TranscriptionResult with turn detection events in raw_response + """ + if sample_rate != self.REQUIRED_SAMPLE_RATE: + logger.warning( + f"[local-smart-turn] Sample rate must be {self.REQUIRED_SAMPLE_RATE}Hz, " + f"overriding {sample_rate}Hz" + ) + sample_rate = self.REQUIRED_SAMPLE_RATE + + # Load model if not already loaded + self._load_model() + + # Setup audio streamer at 16kHz + audio_config = AudioConfig(sample_rate=sample_rate) + streamer = AudioStreamer(audio_config) + + # Get audio duration + duration = streamer.get_duration(audio_path) + logger.info(f"[local-smart-turn] Processing {audio_path} ({duration:.2f}s)") + + # Collect all audio first (smart turn needs to analyze segments) + pcm_data = streamer.convert_to_pcm16(audio_path) + + # Convert to float32 for model + audio_int16 = np.frombuffer(pcm_data, dtype=np.int16) + audio_float32 = audio_int16.astype(np.float32) / 32768.0 + + # Analyze at intervals + turn_events: list[TurnEvent] = [] + samples_per_interval = int(sample_rate * analysis_interval_ms / 1000) + window_samples = self.WINDOW_SECONDS * sample_rate + + chunk_no = 0 + for end_sample in range(samples_per_interval, len(audio_float32), samples_per_interval): + # Get window of audio ending at current position + start_sample = max(0, end_sample - window_samples) + audio_window = audio_float32[start_sample:end_sample] + + current_time = end_sample / sample_rate + logger.debug(f"[local-smart-turn] Analyzing chunk {chunk_no} at {current_time:.2f}s") + + result = self._predict_endpoint(audio_window) + + turn_events.append(TurnEvent( + timestamp=current_time, + probability=result["probability"], + prediction=result["prediction"], + inference_time_ms=result["inference_time_ms"], + )) + + if result["prediction"] == 1: + logger.info( + f"[local-smart-turn] Turn complete at {current_time:.2f}s " + f"(prob={result['probability']:.3f})" + f"(inf time ms={result["inference_time_ms"]})" + ) + + chunk_no += 1 + + # Create result + # Convert turn events to word-like format for compatibility + words = [] + for event in turn_events: + if event.prediction == 1: + words.append(Word( + word=f"[END_OF_TURN prob={event.probability:.2f}]", + start=event.timestamp - 0.1, + end=event.timestamp, + confidence=event.probability, + speaker=None, + speaker_confidence=None, + )) + + # Count completed turns + completed_turns = sum(1 for e in turn_events if e.prediction == 1) + + params = { + "sample_rate": sample_rate, + "analysis_interval_ms": analysis_interval_ms, + "window_seconds": self.WINDOW_SECONDS, + } + + return TranscriptionResult( + provider=self.name, + transcript=f"[Turn detection only - {completed_turns} turns detected]", + words=words, + speakers=[], # Not applicable + duration=duration, + raw_response={ + "turn_events": [ + { + "timestamp": e.timestamp, + "probability": e.probability, + "prediction": e.prediction, + "inference_time_ms": e.inference_time_ms, + } + for e in turn_events + ], + "completed_turns": completed_turns, + "total_analyses": len(turn_events), + "avg_inference_time_ms": ( + sum(e.inference_time_ms for e in turn_events) / len(turn_events) + if turn_events else 0 + ), + }, + params=params, + ) diff --git a/evals/stt/providers/speechmatics_provider.py b/evals/stt/providers/speechmatics_provider.py index 8f6ee63..3f84a5c 100644 --- a/evals/stt/providers/speechmatics_provider.py +++ b/evals/stt/providers/speechmatics_provider.py @@ -1,38 +1,41 @@ -"""Speechmatics STT provider.""" +"""Speechmatics STT provider with WebSocket streaming.""" +import asyncio +import json import os from pathlib import Path from typing import Any -import httpx +from loguru import logger +from ..audio_streamer import AudioConfig, AudioStreamer from .base import STTProvider, TranscriptionResult, Word +try: + from websockets.asyncio.client import connect as websocket_connect +except ImportError: + raise ImportError("websockets required: pip install websockets") + class SpeechmaticsProvider(STTProvider): - """Speechmatics Speech-to-Text provider. + """Speechmatics Speech-to-Text provider with WebSocket streaming. API Docs: https://docs.speechmatics.com/ Supports: - Speaker diarization via `diarization: "speaker"` config - Speaker sensitivity tuning + - Real-time streaming via WebSocket """ - # EU and US endpoints available - API_URL = "https://asr.api.speechmatics.com/v2/jobs" - - def __init__(self, api_key: str | None = None, region: str = "eu1"): + def __init__(self, api_key: str | None = None, region: str = "eu2"): self.api_key = api_key or os.getenv("SPEECHMATICS_API_KEY") if not self.api_key: raise ValueError( "Speechmatics API key required. Set SPEECHMATICS_API_KEY env var or pass api_key." ) # Set region-specific endpoint - if region == "eu1": - self.api_url = "https://eu1.asr.api.speechmatics.com/v2/jobs" - else: - self.api_url = "https://asr.api.speechmatics.com/v2/jobs" + self.ws_url = f"wss://{region}.rt.speechmatics.com/v2" @property def name(self) -> str: @@ -45,27 +48,32 @@ class SpeechmaticsProvider(STTProvider): keyterms: list[str] | None = None, language: str = "en", operating_point: str = "enhanced", + sample_rate: int = 8000, speaker_sensitivity: float | None = None, + max_speakers: int | None = None, **kwargs: Any, ) -> TranscriptionResult: - """Transcribe audio using Speechmatics API. + """Transcribe audio using Speechmatics WebSocket streaming. Args: audio_path: Path to audio file diarize: Enable speaker diarization - keyterms: Not directly supported by Speechmatics (ignored) + keyterms: Additional vocabulary (limited support) language: Language code operating_point: "standard" or "enhanced" + sample_rate: Audio sample rate for streaming speaker_sensitivity: 0.0-1.0, higher = more speakers detected + max_speakers: Maximum number of speakers to detect **kwargs: Additional config parameters Returns: TranscriptionResult with transcript and speaker info """ - # Build transcription config + # Build transcription config for StartRecognition message transcription_config: dict[str, Any] = { "language": language, "operating_point": operating_point, + "enable_partials": False, } if diarize: @@ -74,13 +82,20 @@ class SpeechmaticsProvider(STTProvider): transcription_config["speaker_diarization_config"] = { "speaker_sensitivity": speaker_sensitivity } + if max_speakers is not None: + if "speaker_diarization_config" not in transcription_config: + transcription_config["speaker_diarization_config"] = {} + transcription_config["speaker_diarization_config"]["max_speakers"] = max_speakers - # Add any extra config - transcription_config.update(kwargs) + # Add additional vocabulary if provided + if keyterms: + transcription_config["additional_vocab"] = [{"content": term} for term in keyterms] - config = { - "type": "transcription", - "transcription_config": transcription_config, + # Audio format config + audio_format = { + "type": "raw", + "encoding": "pcm_s16le", + "sample_rate": sample_rate, } # Store params for result @@ -88,79 +103,104 @@ class SpeechmaticsProvider(STTProvider): "diarize": diarize, "language": language, "operating_point": operating_point, + "sample_rate": sample_rate, "speaker_sensitivity": speaker_sensitivity, + "max_speakers": max_speakers, } - headers = { - "Authorization": f"Bearer {self.api_key}", - } + # Setup audio streamer + audio_config = AudioConfig(sample_rate=sample_rate) + streamer = AudioStreamer(audio_config) - # Create job with multipart form - async with httpx.AsyncClient(timeout=300.0) as client: - # Submit job - with open(audio_path, "rb") as f: - files = { - "data_file": (audio_path.name, f, "audio/mpeg"), - "config": (None, str(config).replace("'", '"'), "application/json"), - } - response = await client.post( - self.api_url, - headers=headers, - files=files, - ) - response.raise_for_status() - job_data = response.json() + # Collect results + all_results: list[dict[str, Any]] = [] + recognition_started = asyncio.Event() + transcription_complete = asyncio.Event() - job_id = job_data.get("id") - if not job_id: - raise ValueError(f"No job ID in response: {job_data}") + async with websocket_connect( + self.ws_url, + additional_headers={"Authorization": f"Bearer {self.api_key}"}, + ) as ws: + # Send StartRecognition message + start_msg = { + "message": "StartRecognition", + "transcription_config": transcription_config, + "audio_format": audio_format, + } + await ws.send(json.dumps(start_msg)) - # Poll for completion - result_data = await self._wait_for_job(client, job_id, headers) + async def send_audio(): + """Send audio chunks after recognition starts.""" + await recognition_started.wait() - return self._parse_response(result_data, params) + chunk_no = 0 + async for chunk in streamer.stream_file(audio_path): + logger.debug(f"[speechmatics] Sent audio chunk {chunk_no}") + await ws.send(chunk) + chunk_no += 1 - async def _wait_for_job( - self, client: httpx.AsyncClient, job_id: str, headers: dict[str, str] - ) -> dict[str, Any]: - """Poll job status until complete.""" - import asyncio + # Signal end of audio with last sequence number + logger.debug(f"[speechmatics] Sending EndOfStream after {chunk_no} chunks") + await ws.send(json.dumps({"message": "EndOfStream", "last_seq_no": chunk_no})) - job_url = f"{self.api_url}/{job_id}" - transcript_url = f"{job_url}/transcript?format=json-v2" + async def receive_messages(): + """Receive and process messages.""" + nonlocal all_results - max_attempts = 120 # 10 minutes with 5s intervals - for _ in range(max_attempts): - # Check job status - status_response = await client.get(job_url, headers=headers) - status_response.raise_for_status() - status_data = status_response.json() + async for message in ws: + if isinstance(message, str): + data = json.loads(message) + msg_type = data.get("message") + logger.debug(f"[speechmatics] Received {msg_type}: {data}") - job_status = status_data.get("job", {}).get("status") + if msg_type == "RecognitionStarted": + logger.info("[speechmatics] Connected") + recognition_started.set() - if job_status == "done": - # Get transcript - transcript_response = await client.get(transcript_url, headers=headers) - transcript_response.raise_for_status() - return transcript_response.json() - elif job_status == "rejected": - raise ValueError(f"Job rejected: {status_data}") - elif job_status == "deleted": - raise ValueError(f"Job deleted: {status_data}") + elif msg_type == "AddTranscript": + # Final transcript segment + results = data.get("results", []) + all_results.extend(results) - await asyncio.sleep(5) + elif msg_type == "EndOfTranscript": + transcription_complete.set() + return - raise TimeoutError(f"Job {job_id} did not complete in time") + elif msg_type == "Error": + raise Exception(f"Speechmatics error: {data}") - def _parse_response( - self, data: dict[str, Any], params: dict[str, Any] + elif msg_type == "Warning": + logger.warning(f"[speechmatics] Warning: {data.get('reason')}") + + # Run send and receive concurrently + send_task = asyncio.create_task(send_audio()) + receive_task = asyncio.create_task(receive_messages()) + + # Wait for completion + await send_task + try: + await asyncio.wait_for(transcription_complete.wait(), timeout=30.0) + except asyncio.TimeoutError: + pass + + receive_task.cancel() + try: + await receive_task + except asyncio.CancelledError: + pass + + return self._parse_results(all_results, params) + + def _parse_results( + self, + results: list[dict[str, Any]], + params: dict[str, Any], ) -> TranscriptionResult: - """Parse Speechmatics API response.""" - results = data.get("results", []) - + """Parse Speechmatics results.""" words = [] speakers_set: set[str] = set() transcript_parts = [] + duration = 0.0 for item in results: item_type = item.get("type") @@ -176,27 +216,25 @@ class SpeechmaticsProvider(STTProvider): if speaker: speakers_set.add(speaker) + end_time = item.get("end_time", 0.0) + duration = max(duration, end_time) + if item_type == "word": words.append( Word( word=content, start=item.get("start_time", 0.0), - end=item.get("end_time", 0.0), + end=end_time, confidence=alt.get("confidence", 0.0), speaker=speaker, - speaker_confidence=None, # Not provided by Speechmatics + speaker_confidence=None, ) ) transcript_parts.append(content) elif item_type == "punctuation": - # Append punctuation to last word in transcript if transcript_parts: transcript_parts[-1] += content - # Get metadata - metadata = data.get("metadata", {}) - duration = metadata.get("duration", 0.0) - transcript = " ".join(transcript_parts) return TranscriptionResult( @@ -205,6 +243,6 @@ class SpeechmaticsProvider(STTProvider): words=words, speakers=sorted(speakers_set), duration=duration, - raw_response=data, + raw_response={"results": results}, params=params, )