diff --git a/evals/stt/README.md b/evals/stt/README.md new file mode 100644 index 0000000..b176039 --- /dev/null +++ b/evals/stt/README.md @@ -0,0 +1,100 @@ +# STT Evaluation Benchmark + +Benchmark for comparing Speech-to-Text providers with focus on: +- **Speaker diarization** - identifying who said what +- **Keyterm boosting** - improving recognition of specific terms (Deepgram) + +## Providers + +| Provider | Diarization | Keyterm Boost | Notes | +|----------|-------------|---------------|-------| +| Deepgram | Yes | Yes | `diarize=true`, `keyterm` param | +| Speechmatics | Yes | No | `diarization: "speaker"` config | + +## Setup + +```bash +# Install dependencies (httpx is required) +pip install httpx + +# Set API keys +export DEEPGRAM_API_KEY="your-key" +export SPEECHMATICS_API_KEY="your-key" +``` + +## Usage + +Run from the project root directory: + +```bash +# Test both providers with diarization +python -m evals.stt.benchmark audio/multi_speaker.m4a --diarize + +# Test only Deepgram +python -m evals.stt.benchmark audio/multi_speaker.m4a --diarize --providers deepgram + +# Test with keyterm boosting (Deepgram only) +python -m evals.stt.benchmark audio/multi_speaker.m4a --diarize --keyterms "Dograh" "Pipecat" + +# Show word-level timings +python -m evals.stt.benchmark audio/multi_speaker.m4a --diarize --show-words + +# Save results to JSON +python -m evals.stt.benchmark audio/multi_speaker.m4a --diarize --save +``` + +## CLI Options + +| Option | Description | +|--------|-------------| +| `audio_file` | Path to audio file (relative to evals/stt/ or absolute) | +| `--providers` | Providers to test: `deepgram`, `speechmatics` (default: both) | +| `--diarize` | Enable speaker diarization | +| `--keyterms` | Keywords to boost (Deepgram only) | +| `--language` | Language code (default: en) | +| `--show-words` | Show individual word timings | +| `--save` | Save results to JSON in `results/` | + +## Directory Structure + +``` +evals/stt/ +├── audio/ # Audio test files +│ └── multi_speaker.m4a +├── results/ # Saved benchmark results (JSON) +├── providers/ # STT provider implementations +│ ├── base.py # Base classes +│ ├── deepgram_provider.py +│ └── speechmatics_provider.py +├── benchmark.py # Main runner script +└── README.md +``` + +## Output Example + +``` +Provider: DEEPGRAM +Duration: 45.32s +Speakers detected: 2 - ['0', '1'] + +Transcript: +Hello, welcome to the demo... + +--- Speaker Segments --- +[0.0s] Speaker 0: Hello, welcome to the demo. +[2.5s] Speaker 1: Thanks for having me. +... +``` + +## Adding New Providers + +1. Create a new file in `providers/` (e.g., `whisper_provider.py`) +2. Implement the `STTProvider` abstract class +3. Add to `providers/__init__.py` +4. Add to `benchmark.py` provider choices + +## API Documentation + +- Deepgram Diarization: https://developers.deepgram.com/docs/diarization +- Deepgram Keyterms: https://developers.deepgram.com/docs/keyterm +- Speechmatics Diarization: https://docs.speechmatics.com/features/diarization diff --git a/evals/stt/__init__.py b/evals/stt/__init__.py new file mode 100644 index 0000000..8c12ae8 --- /dev/null +++ b/evals/stt/__init__.py @@ -0,0 +1 @@ +# STT Evaluation Benchmark diff --git a/evals/stt/audio/multi_speaker.m4a b/evals/stt/audio/multi_speaker.m4a new file mode 100644 index 0000000..d1ab1f5 Binary files /dev/null and b/evals/stt/audio/multi_speaker.m4a differ diff --git a/evals/stt/benchmark.py b/evals/stt/benchmark.py new file mode 100644 index 0000000..c1b5fed --- /dev/null +++ b/evals/stt/benchmark.py @@ -0,0 +1,230 @@ +#!/usr/bin/env python3 +"""STT Benchmark Runner. + +Compare speech-to-text transcription across providers with focus on: +- Speaker diarization accuracy +- Keyword/keyterm recognition +- Transcription quality + +Usage: + python -m evals.stt.benchmark audio/multi_speaker.m4a --diarize + python -m evals.stt.benchmark audio/multi_speaker.m4a --diarize --providers deepgram + python -m evals.stt.benchmark audio/multi_speaker.m4a --diarize --keyterms "Dograh" "Pipecat" +""" + +import argparse +import asyncio +import json +import sys +from datetime import datetime +from pathlib import Path +from typing import Any + +from evals.stt.providers import DeepgramProvider, SpeechmaticsProvider, STTProvider, TranscriptionResult + + +def get_provider(name: str) -> STTProvider: + """Get provider instance by name.""" + providers = { + "deepgram": DeepgramProvider, + "speechmatics": SpeechmaticsProvider, + } + if name not in providers: + raise ValueError(f"Unknown provider: {name}. Available: {list(providers.keys())}") + return providers[name]() + + +async def run_transcription( + provider: STTProvider, + audio_path: Path, + diarize: bool = False, + keyterms: list[str] | None = None, + **kwargs: Any, +) -> TranscriptionResult: + """Run transcription with a provider.""" + print(f"\n{'='*60}") + print(f"Provider: {provider.name.upper()}") + print(f"{'='*60}") + + try: + result = await provider.transcribe( + audio_path, + diarize=diarize, + keyterms=keyterms, + **kwargs, + ) + return result + except Exception as e: + print(f"Error with {provider.name}: {e}") + raise + + +def print_result(result: TranscriptionResult, show_words: bool = False) -> None: + """Print transcription result.""" + print(f"\nDuration: {result.duration:.2f}s") + print(f"Speakers detected: {len(result.speakers)} - {result.speakers}") + print(f"\nTranscript:\n{result.transcript}") + + if result.speakers: + print(f"\n--- Speaker Segments ---") + for segment in result.get_speaker_segments(): + speaker = segment["speaker"] or "?" + text = segment["text"] + start = segment["start"] + print(f"[{start:.1f}s] Speaker {speaker}: {text}") + + if show_words: + print(f"\n--- Words ---") + for word in result.words[:50]: # First 50 words + speaker_info = f" (S{word.speaker})" if word.speaker else "" + print(f" {word.start:.2f}s: {word.word}{speaker_info} [{word.confidence:.2f}]") + if len(result.words) > 50: + print(f" ... and {len(result.words) - 50} more words") + + +def save_results( + results: list[TranscriptionResult], + output_dir: Path, + audio_name: str, +) -> Path: + """Save results to JSON file.""" + output_dir.mkdir(parents=True, exist_ok=True) + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + output_file = output_dir / f"{audio_name}_{timestamp}.json" + + output_data = { + "timestamp": timestamp, + "audio_file": audio_name, + "results": [r.to_dict() for r in results], + } + + with open(output_file, "w") as f: + json.dump(output_data, f, indent=2) + + print(f"\nResults saved to: {output_file}") + return output_file + + +def compare_results(results: list[TranscriptionResult]) -> None: + """Compare results across providers.""" + if len(results) < 2: + return + + print(f"\n{'='*60}") + print("COMPARISON SUMMARY") + print(f"{'='*60}") + + print(f"\n{'Provider':<15} {'Duration':<10} {'Speakers':<10} {'Words':<10}") + print("-" * 45) + for r in results: + print(f"{r.provider:<15} {r.duration:<10.2f} {len(r.speakers):<10} {len(r.words):<10}") + + # Compare speaker counts + speaker_counts = {r.provider: len(r.speakers) for r in results} + if len(set(speaker_counts.values())) > 1: + print(f"\nNote: Providers detected different speaker counts: {speaker_counts}") + + +async def main() -> int: + parser = argparse.ArgumentParser( + description="STT Benchmark - Compare transcription providers", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + python -m evals.stt.benchmark audio/multi_speaker.m4a --diarize + python -m evals.stt.benchmark audio/multi_speaker.m4a --diarize --providers deepgram + python -m evals.stt.benchmark audio/multi_speaker.m4a --keyterms "Dograh" "API" + """, + ) + parser.add_argument( + "audio_file", + type=str, + help="Path to audio file (relative to evals/stt/ or absolute)", + ) + parser.add_argument( + "--providers", + nargs="+", + default=["deepgram", "speechmatics"], + choices=["deepgram", "speechmatics"], + help="Providers to test (default: all)", + ) + parser.add_argument( + "--diarize", + action="store_true", + help="Enable speaker diarization", + ) + parser.add_argument( + "--keyterms", + nargs="+", + help="Keywords to boost (Deepgram only)", + ) + parser.add_argument( + "--language", + default="en", + help="Language code (default: en)", + ) + parser.add_argument( + "--show-words", + action="store_true", + help="Show individual word timings", + ) + parser.add_argument( + "--save", + action="store_true", + help="Save results to JSON file", + ) + parser.add_argument( + "--output-dir", + type=str, + default="results", + help="Output directory for results (default: results)", + ) + + args = parser.parse_args() + + # Resolve audio path + script_dir = Path(__file__).parent + audio_path = Path(args.audio_file) + if not audio_path.is_absolute(): + audio_path = script_dir / audio_path + + if not audio_path.exists(): + print(f"Error: Audio file not found: {audio_path}") + return 1 + + print(f"Audio file: {audio_path}") + print(f"Providers: {args.providers}") + print(f"Diarization: {args.diarize}") + if args.keyterms: + print(f"Keyterms: {args.keyterms}") + + results: list[TranscriptionResult] = [] + + for provider_name in args.providers: + try: + provider = get_provider(provider_name) + result = await run_transcription( + provider, + audio_path, + diarize=args.diarize, + keyterms=args.keyterms, + language=args.language, + ) + print_result(result, show_words=args.show_words) + results.append(result) + except Exception as e: + print(f"\nFailed to run {provider_name}: {e}") + continue + + if len(results) > 1: + compare_results(results) + + if args.save and results: + output_dir = script_dir / args.output_dir + save_results(results, output_dir, audio_path.stem) + + return 0 + + +if __name__ == "__main__": + sys.exit(asyncio.run(main())) diff --git a/evals/stt/providers/__init__.py b/evals/stt/providers/__init__.py new file mode 100644 index 0000000..d73a2b8 --- /dev/null +++ b/evals/stt/providers/__init__.py @@ -0,0 +1,11 @@ +from .base import STTProvider, TranscriptionResult, Word +from .deepgram_provider import DeepgramProvider +from .speechmatics_provider import SpeechmaticsProvider + +__all__ = [ + "STTProvider", + "TranscriptionResult", + "Word", + "DeepgramProvider", + "SpeechmaticsProvider", +] diff --git a/evals/stt/providers/base.py b/evals/stt/providers/base.py new file mode 100644 index 0000000..cdb4af9 --- /dev/null +++ b/evals/stt/providers/base.py @@ -0,0 +1,123 @@ +"""Base classes for STT providers.""" + +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any + + +@dataclass +class Word: + """Represents a transcribed word with metadata.""" + + word: str + start: float + end: float + confidence: float + speaker: str | None = None + speaker_confidence: float | None = None + + def to_dict(self) -> dict[str, Any]: + return { + "word": self.word, + "start": self.start, + "end": self.end, + "confidence": self.confidence, + "speaker": self.speaker, + "speaker_confidence": self.speaker_confidence, + } + + +@dataclass +class TranscriptionResult: + """Result from STT transcription.""" + + provider: str + transcript: str + words: list[Word] + speakers: list[str] + duration: float + raw_response: dict[str, Any] = field(default_factory=dict) + params: dict[str, Any] = field(default_factory=dict) + + def to_dict(self) -> dict[str, Any]: + return { + "provider": self.provider, + "transcript": self.transcript, + "words": [w.to_dict() for w in self.words], + "speakers": self.speakers, + "duration": self.duration, + "params": self.params, + } + + def get_speaker_segments(self) -> list[dict[str, Any]]: + """Get transcript segmented by speaker.""" + if not self.words: + return [] + + segments = [] + current_speaker = None + current_text = [] + segment_start = 0.0 + + for word in self.words: + if word.speaker != current_speaker: + if current_text: + segments.append( + { + "speaker": current_speaker, + "text": " ".join(current_text), + "start": segment_start, + "end": self.words[len(segments) - 1].end + if segments + else word.start, + } + ) + current_speaker = word.speaker + current_text = [word.word] + segment_start = word.start + else: + current_text.append(word.word) + + if current_text: + segments.append( + { + "speaker": current_speaker, + "text": " ".join(current_text), + "start": segment_start, + "end": self.words[-1].end if self.words else 0.0, + } + ) + + return segments + + +class STTProvider(ABC): + """Abstract base class for STT providers.""" + + @property + @abstractmethod + def name(self) -> str: + """Provider name.""" + pass + + @abstractmethod + async def transcribe( + self, + audio_path: Path, + diarize: bool = False, + keyterms: list[str] | None = None, + **kwargs: Any, + ) -> TranscriptionResult: + """Transcribe audio file. + + Args: + audio_path: Path to the audio file + diarize: Enable speaker diarization + keyterms: List of keywords to boost (if supported) + **kwargs: Provider-specific parameters + + Returns: + TranscriptionResult with transcript and metadata + """ + pass diff --git a/evals/stt/providers/deepgram_provider.py b/evals/stt/providers/deepgram_provider.py new file mode 100644 index 0000000..c3b1ebd --- /dev/null +++ b/evals/stt/providers/deepgram_provider.py @@ -0,0 +1,174 @@ +"""Deepgram STT provider.""" + +import os +from pathlib import Path +from typing import Any + +import httpx + +from .base import STTProvider, TranscriptionResult, Word + + +class DeepgramProvider(STTProvider): + """Deepgram Speech-to-Text provider. + + API Docs: https://developers.deepgram.com/docs/ + + Supports: + - Speaker diarization via `diarize=true` + - Keyterm boosting via `keyterm` parameter (Nova-3 and Flux models) + """ + + API_URL = "https://api.deepgram.com/v1/listen" + + def __init__(self, api_key: str | None = None): + self.api_key = api_key or os.getenv("DEEPGRAM_API_KEY") + if not self.api_key: + raise ValueError( + "Deepgram API key required. Set DEEPGRAM_API_KEY env var or pass api_key." + ) + + @property + def name(self) -> str: + return "deepgram" + + async def transcribe( + self, + audio_path: Path, + diarize: bool = False, + keyterms: list[str] | None = None, + model: str = "nova-3", + language: str = "en", + punctuate: bool = True, + **kwargs: Any, + ) -> TranscriptionResult: + """Transcribe audio using Deepgram API. + + Args: + audio_path: Path to audio file + diarize: Enable speaker diarization + keyterms: List of keywords to boost recognition + model: Deepgram model (nova-3, nova-2, etc.) + language: Language code + punctuate: Add punctuation + **kwargs: Additional Deepgram parameters + + Returns: + TranscriptionResult with transcript and speaker info + """ + params: dict[str, Any] = { + "model": model, + "language": language, + "punctuate": str(punctuate).lower(), + } + + if diarize: + params["diarize"] = "true" + + # Add keyterms (Deepgram uses repeated keyterm params) + if keyterms: + params["keyterm"] = keyterms + + # Add any extra kwargs + params.update(kwargs) + + # Read audio file + audio_data = audio_path.read_bytes() + + # Determine content type + suffix = audio_path.suffix.lower() + content_types = { + ".wav": "audio/wav", + ".mp3": "audio/mpeg", + ".m4a": "audio/mp4", + ".flac": "audio/flac", + ".ogg": "audio/ogg", + ".webm": "audio/webm", + } + content_type = content_types.get(suffix, "audio/wav") + + headers = { + "Authorization": f"Token {self.api_key}", + "Content-Type": content_type, + } + + async with httpx.AsyncClient(timeout=120.0) as client: + response = await client.post( + self.API_URL, + params=params, + headers=headers, + content=audio_data, + ) + response.raise_for_status() + data = response.json() + + return self._parse_response(data, params) + + def _parse_response( + self, data: dict[str, Any], params: dict[str, Any] + ) -> TranscriptionResult: + """Parse Deepgram API response.""" + results = data.get("results", {}) + channels = results.get("channels", []) + + if not channels: + return TranscriptionResult( + provider=self.name, + transcript="", + words=[], + speakers=[], + duration=0.0, + raw_response=data, + params=params, + ) + + # Get first channel, first alternative + channel = channels[0] + alternatives = channel.get("alternatives", []) + if not alternatives: + return TranscriptionResult( + provider=self.name, + transcript="", + words=[], + speakers=[], + duration=0.0, + raw_response=data, + params=params, + ) + + alt = alternatives[0] + transcript = alt.get("transcript", "") + + # Parse words with speaker info + words = [] + speakers_set: set[str] = set() + + for w in alt.get("words", []): + speaker = str(w.get("speaker", "")) if "speaker" in w else None + if speaker: + speakers_set.add(speaker) + + words.append( + Word( + word=w.get("word", ""), + start=w.get("start", 0.0), + end=w.get("end", 0.0), + confidence=w.get("confidence", 0.0), + speaker=speaker, + speaker_confidence=w.get("speaker_confidence"), + ) + ) + + # Get duration from metadata + metadata = results.get("metadata", {}) + duration = metadata.get("duration", 0.0) + + return TranscriptionResult( + provider=self.name, + transcript=transcript, + words=words, + speakers=sorted(speakers_set), + duration=duration, + raw_response=data, + params=params, + ) diff --git a/evals/stt/providers/speechmatics_provider.py b/evals/stt/providers/speechmatics_provider.py new file mode 100644 index 0000000..8f6ee63 --- /dev/null +++ b/evals/stt/providers/speechmatics_provider.py @@ -0,0 +1,210 @@ +"""Speechmatics STT provider.""" + +import os +from pathlib import Path +from typing import Any + +import httpx + +from .base import STTProvider, TranscriptionResult, Word + + +class SpeechmaticsProvider(STTProvider): + """Speechmatics Speech-to-Text provider. + + API Docs: https://docs.speechmatics.com/ + + Supports: + - Speaker diarization via `diarization: "speaker"` config + - Speaker sensitivity tuning + """ + + # EU and US endpoints available + API_URL = "https://asr.api.speechmatics.com/v2/jobs" + + def __init__(self, api_key: str | None = None, region: str = "eu1"): + self.api_key = api_key or os.getenv("SPEECHMATICS_API_KEY") + if not self.api_key: + raise ValueError( + "Speechmatics API key required. Set SPEECHMATICS_API_KEY env var or pass api_key." + ) + # Set region-specific endpoint + if region == "eu1": + self.api_url = "https://eu1.asr.api.speechmatics.com/v2/jobs" + else: + self.api_url = "https://asr.api.speechmatics.com/v2/jobs" + + @property + def name(self) -> str: + return "speechmatics" + + async def transcribe( + self, + audio_path: Path, + diarize: bool = False, + keyterms: list[str] | None = None, + language: str = "en", + operating_point: str = "enhanced", + speaker_sensitivity: float | None = None, + **kwargs: Any, + ) -> TranscriptionResult: + """Transcribe audio using Speechmatics API. + + Args: + audio_path: Path to audio file + diarize: Enable speaker diarization + keyterms: Not directly supported by Speechmatics (ignored) + language: Language code + operating_point: "standard" or "enhanced" + speaker_sensitivity: 0.0-1.0, higher = more speakers detected + **kwargs: Additional config parameters + + Returns: + TranscriptionResult with transcript and speaker info + """ + # Build transcription config + transcription_config: dict[str, Any] = { + "language": language, + "operating_point": operating_point, + } + + if diarize: + transcription_config["diarization"] = "speaker" + if speaker_sensitivity is not None: + transcription_config["speaker_diarization_config"] = { + "speaker_sensitivity": speaker_sensitivity + } + + # Add any extra config + transcription_config.update(kwargs) + + config = { + "type": "transcription", + "transcription_config": transcription_config, + } + + # Store params for result + params = { + "diarize": diarize, + "language": language, + "operating_point": operating_point, + "speaker_sensitivity": speaker_sensitivity, + } + + headers = { + "Authorization": f"Bearer {self.api_key}", + } + + # Create job with multipart form + async with httpx.AsyncClient(timeout=300.0) as client: + # Submit job + with open(audio_path, "rb") as f: + files = { + "data_file": (audio_path.name, f, "audio/mpeg"), + "config": (None, str(config).replace("'", '"'), "application/json"), + } + response = await client.post( + self.api_url, + headers=headers, + files=files, + ) + response.raise_for_status() + job_data = response.json() + + job_id = job_data.get("id") + if not job_id: + raise ValueError(f"No job ID in response: {job_data}") + + # Poll for completion + result_data = await self._wait_for_job(client, job_id, headers) + + return self._parse_response(result_data, params) + + async def _wait_for_job( + self, client: httpx.AsyncClient, job_id: str, headers: dict[str, str] + ) -> dict[str, Any]: + """Poll job status until complete.""" + import asyncio + + job_url = f"{self.api_url}/{job_id}" + transcript_url = f"{job_url}/transcript?format=json-v2" + + max_attempts = 120 # 10 minutes with 5s intervals + for _ in range(max_attempts): + # Check job status + status_response = await client.get(job_url, headers=headers) + status_response.raise_for_status() + status_data = status_response.json() + + job_status = status_data.get("job", {}).get("status") + + if job_status == "done": + # Get transcript + transcript_response = await client.get(transcript_url, headers=headers) + transcript_response.raise_for_status() + return transcript_response.json() + elif job_status == "rejected": + raise ValueError(f"Job rejected: {status_data}") + elif job_status == "deleted": + raise ValueError(f"Job deleted: {status_data}") + + await asyncio.sleep(5) + + raise TimeoutError(f"Job {job_id} did not complete in time") + + def _parse_response( + self, data: dict[str, Any], params: dict[str, Any] + ) -> TranscriptionResult: + """Parse Speechmatics API response.""" + results = data.get("results", []) + + words = [] + speakers_set: set[str] = set() + transcript_parts = [] + + for item in results: + item_type = item.get("type") + alternatives = item.get("alternatives", []) + + if not alternatives: + continue + + alt = alternatives[0] + content = alt.get("content", "") + speaker = alt.get("speaker") + + if speaker: + speakers_set.add(speaker) + + if item_type == "word": + words.append( + Word( + word=content, + start=item.get("start_time", 0.0), + end=item.get("end_time", 0.0), + confidence=alt.get("confidence", 0.0), + speaker=speaker, + speaker_confidence=None, # Not provided by Speechmatics + ) + ) + transcript_parts.append(content) + elif item_type == "punctuation": + # Append punctuation to last word in transcript + if transcript_parts: + transcript_parts[-1] += content + + # Get metadata + metadata = data.get("metadata", {}) + duration = metadata.get("duration", 0.0) + + transcript = " ".join(transcript_parts) + + return TranscriptionResult( + provider=self.name, + transcript=transcript, + words=words, + speakers=sorted(speakers_set), + duration=duration, + raw_response=data, + params=params, + )