feat: add stt evals

This commit is contained in:
Abhishek Kumar 2026-01-06 13:18:13 +05:30
parent d41f696f3f
commit 6d589b7452
8 changed files with 849 additions and 0 deletions

100
evals/stt/README.md Normal file
View file

@ -0,0 +1,100 @@
# STT Evaluation Benchmark
Benchmark for comparing Speech-to-Text providers with focus on:
- **Speaker diarization** - identifying who said what
- **Keyterm boosting** - improving recognition of specific terms (Deepgram)
## Providers
| Provider | Diarization | Keyterm Boost | Notes |
|----------|-------------|---------------|-------|
| Deepgram | Yes | Yes | `diarize=true`, `keyterm` param |
| Speechmatics | Yes | No | `diarization: "speaker"` config |
## Setup
```bash
# Install dependencies (httpx is required)
pip install httpx
# Set API keys
export DEEPGRAM_API_KEY="your-key"
export SPEECHMATICS_API_KEY="your-key"
```
## Usage
Run from the project root directory:
```bash
# Test both providers with diarization
python -m evals.stt.benchmark audio/multi_speaker.m4a --diarize
# Test only Deepgram
python -m evals.stt.benchmark audio/multi_speaker.m4a --diarize --providers deepgram
# Test with keyterm boosting (Deepgram only)
python -m evals.stt.benchmark audio/multi_speaker.m4a --diarize --keyterms "Dograh" "Pipecat"
# Show word-level timings
python -m evals.stt.benchmark audio/multi_speaker.m4a --diarize --show-words
# Save results to JSON
python -m evals.stt.benchmark audio/multi_speaker.m4a --diarize --save
```
## CLI Options
| Option | Description |
|--------|-------------|
| `audio_file` | Path to audio file (relative to evals/stt/ or absolute) |
| `--providers` | Providers to test: `deepgram`, `speechmatics` (default: both) |
| `--diarize` | Enable speaker diarization |
| `--keyterms` | Keywords to boost (Deepgram only) |
| `--language` | Language code (default: en) |
| `--show-words` | Show individual word timings |
| `--save` | Save results to JSON in `results/` |
## Directory Structure
```
evals/stt/
├── audio/ # Audio test files
│ └── multi_speaker.m4a
├── results/ # Saved benchmark results (JSON)
├── providers/ # STT provider implementations
│ ├── base.py # Base classes
│ ├── deepgram_provider.py
│ └── speechmatics_provider.py
├── benchmark.py # Main runner script
└── README.md
```
## Output Example
```
Provider: DEEPGRAM
Duration: 45.32s
Speakers detected: 2 - ['0', '1']
Transcript:
Hello, welcome to the demo...
--- Speaker Segments ---
[0.0s] Speaker 0: Hello, welcome to the demo.
[2.5s] Speaker 1: Thanks for having me.
...
```
## Adding New Providers
1. Create a new file in `providers/` (e.g., `whisper_provider.py`)
2. Implement the `STTProvider` abstract class
3. Add to `providers/__init__.py`
4. Add to `benchmark.py` provider choices
## API Documentation
- Deepgram Diarization: https://developers.deepgram.com/docs/diarization
- Deepgram Keyterms: https://developers.deepgram.com/docs/keyterm
- Speechmatics Diarization: https://docs.speechmatics.com/features/diarization

1
evals/stt/__init__.py Normal file
View file

@ -0,0 +1 @@
# STT Evaluation Benchmark

Binary file not shown.

230
evals/stt/benchmark.py Normal file
View file

@ -0,0 +1,230 @@
#!/usr/bin/env python3
"""STT Benchmark Runner.
Compare speech-to-text transcription across providers with focus on:
- Speaker diarization accuracy
- Keyword/keyterm recognition
- Transcription quality
Usage:
python -m evals.stt.benchmark audio/multi_speaker.m4a --diarize
python -m evals.stt.benchmark audio/multi_speaker.m4a --diarize --providers deepgram
python -m evals.stt.benchmark audio/multi_speaker.m4a --diarize --keyterms "Dograh" "Pipecat"
"""
import argparse
import asyncio
import json
import sys
from datetime import datetime
from pathlib import Path
from typing import Any
from evals.stt.providers import DeepgramProvider, SpeechmaticsProvider, STTProvider, TranscriptionResult
def get_provider(name: str) -> STTProvider:
"""Get provider instance by name."""
providers = {
"deepgram": DeepgramProvider,
"speechmatics": SpeechmaticsProvider,
}
if name not in providers:
raise ValueError(f"Unknown provider: {name}. Available: {list(providers.keys())}")
return providers[name]()
async def run_transcription(
provider: STTProvider,
audio_path: Path,
diarize: bool = False,
keyterms: list[str] | None = None,
**kwargs: Any,
) -> TranscriptionResult:
"""Run transcription with a provider."""
print(f"\n{'='*60}")
print(f"Provider: {provider.name.upper()}")
print(f"{'='*60}")
try:
result = await provider.transcribe(
audio_path,
diarize=diarize,
keyterms=keyterms,
**kwargs,
)
return result
except Exception as e:
print(f"Error with {provider.name}: {e}")
raise
def print_result(result: TranscriptionResult, show_words: bool = False) -> None:
"""Print transcription result."""
print(f"\nDuration: {result.duration:.2f}s")
print(f"Speakers detected: {len(result.speakers)} - {result.speakers}")
print(f"\nTranscript:\n{result.transcript}")
if result.speakers:
print(f"\n--- Speaker Segments ---")
for segment in result.get_speaker_segments():
speaker = segment["speaker"] or "?"
text = segment["text"]
start = segment["start"]
print(f"[{start:.1f}s] Speaker {speaker}: {text}")
if show_words:
print(f"\n--- Words ---")
for word in result.words[:50]: # First 50 words
speaker_info = f" (S{word.speaker})" if word.speaker else ""
print(f" {word.start:.2f}s: {word.word}{speaker_info} [{word.confidence:.2f}]")
if len(result.words) > 50:
print(f" ... and {len(result.words) - 50} more words")
def save_results(
results: list[TranscriptionResult],
output_dir: Path,
audio_name: str,
) -> Path:
"""Save results to JSON file."""
output_dir.mkdir(parents=True, exist_ok=True)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
output_file = output_dir / f"{audio_name}_{timestamp}.json"
output_data = {
"timestamp": timestamp,
"audio_file": audio_name,
"results": [r.to_dict() for r in results],
}
with open(output_file, "w") as f:
json.dump(output_data, f, indent=2)
print(f"\nResults saved to: {output_file}")
return output_file
def compare_results(results: list[TranscriptionResult]) -> None:
"""Compare results across providers."""
if len(results) < 2:
return
print(f"\n{'='*60}")
print("COMPARISON SUMMARY")
print(f"{'='*60}")
print(f"\n{'Provider':<15} {'Duration':<10} {'Speakers':<10} {'Words':<10}")
print("-" * 45)
for r in results:
print(f"{r.provider:<15} {r.duration:<10.2f} {len(r.speakers):<10} {len(r.words):<10}")
# Compare speaker counts
speaker_counts = {r.provider: len(r.speakers) for r in results}
if len(set(speaker_counts.values())) > 1:
print(f"\nNote: Providers detected different speaker counts: {speaker_counts}")
async def main() -> int:
parser = argparse.ArgumentParser(
description="STT Benchmark - Compare transcription providers",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
python -m evals.stt.benchmark audio/multi_speaker.m4a --diarize
python -m evals.stt.benchmark audio/multi_speaker.m4a --diarize --providers deepgram
python -m evals.stt.benchmark audio/multi_speaker.m4a --keyterms "Dograh" "API"
""",
)
parser.add_argument(
"audio_file",
type=str,
help="Path to audio file (relative to evals/stt/ or absolute)",
)
parser.add_argument(
"--providers",
nargs="+",
default=["deepgram", "speechmatics"],
choices=["deepgram", "speechmatics"],
help="Providers to test (default: all)",
)
parser.add_argument(
"--diarize",
action="store_true",
help="Enable speaker diarization",
)
parser.add_argument(
"--keyterms",
nargs="+",
help="Keywords to boost (Deepgram only)",
)
parser.add_argument(
"--language",
default="en",
help="Language code (default: en)",
)
parser.add_argument(
"--show-words",
action="store_true",
help="Show individual word timings",
)
parser.add_argument(
"--save",
action="store_true",
help="Save results to JSON file",
)
parser.add_argument(
"--output-dir",
type=str,
default="results",
help="Output directory for results (default: results)",
)
args = parser.parse_args()
# Resolve audio path
script_dir = Path(__file__).parent
audio_path = Path(args.audio_file)
if not audio_path.is_absolute():
audio_path = script_dir / audio_path
if not audio_path.exists():
print(f"Error: Audio file not found: {audio_path}")
return 1
print(f"Audio file: {audio_path}")
print(f"Providers: {args.providers}")
print(f"Diarization: {args.diarize}")
if args.keyterms:
print(f"Keyterms: {args.keyterms}")
results: list[TranscriptionResult] = []
for provider_name in args.providers:
try:
provider = get_provider(provider_name)
result = await run_transcription(
provider,
audio_path,
diarize=args.diarize,
keyterms=args.keyterms,
language=args.language,
)
print_result(result, show_words=args.show_words)
results.append(result)
except Exception as e:
print(f"\nFailed to run {provider_name}: {e}")
continue
if len(results) > 1:
compare_results(results)
if args.save and results:
output_dir = script_dir / args.output_dir
save_results(results, output_dir, audio_path.stem)
return 0
if __name__ == "__main__":
sys.exit(asyncio.run(main()))

View file

@ -0,0 +1,11 @@
from .base import STTProvider, TranscriptionResult, Word
from .deepgram_provider import DeepgramProvider
from .speechmatics_provider import SpeechmaticsProvider
__all__ = [
"STTProvider",
"TranscriptionResult",
"Word",
"DeepgramProvider",
"SpeechmaticsProvider",
]

123
evals/stt/providers/base.py Normal file
View file

@ -0,0 +1,123 @@
"""Base classes for STT providers."""
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
@dataclass
class Word:
"""Represents a transcribed word with metadata."""
word: str
start: float
end: float
confidence: float
speaker: str | None = None
speaker_confidence: float | None = None
def to_dict(self) -> dict[str, Any]:
return {
"word": self.word,
"start": self.start,
"end": self.end,
"confidence": self.confidence,
"speaker": self.speaker,
"speaker_confidence": self.speaker_confidence,
}
@dataclass
class TranscriptionResult:
"""Result from STT transcription."""
provider: str
transcript: str
words: list[Word]
speakers: list[str]
duration: float
raw_response: dict[str, Any] = field(default_factory=dict)
params: dict[str, Any] = field(default_factory=dict)
def to_dict(self) -> dict[str, Any]:
return {
"provider": self.provider,
"transcript": self.transcript,
"words": [w.to_dict() for w in self.words],
"speakers": self.speakers,
"duration": self.duration,
"params": self.params,
}
def get_speaker_segments(self) -> list[dict[str, Any]]:
"""Get transcript segmented by speaker."""
if not self.words:
return []
segments = []
current_speaker = None
current_text = []
segment_start = 0.0
for word in self.words:
if word.speaker != current_speaker:
if current_text:
segments.append(
{
"speaker": current_speaker,
"text": " ".join(current_text),
"start": segment_start,
"end": self.words[len(segments) - 1].end
if segments
else word.start,
}
)
current_speaker = word.speaker
current_text = [word.word]
segment_start = word.start
else:
current_text.append(word.word)
if current_text:
segments.append(
{
"speaker": current_speaker,
"text": " ".join(current_text),
"start": segment_start,
"end": self.words[-1].end if self.words else 0.0,
}
)
return segments
class STTProvider(ABC):
"""Abstract base class for STT providers."""
@property
@abstractmethod
def name(self) -> str:
"""Provider name."""
pass
@abstractmethod
async def transcribe(
self,
audio_path: Path,
diarize: bool = False,
keyterms: list[str] | None = None,
**kwargs: Any,
) -> TranscriptionResult:
"""Transcribe audio file.
Args:
audio_path: Path to the audio file
diarize: Enable speaker diarization
keyterms: List of keywords to boost (if supported)
**kwargs: Provider-specific parameters
Returns:
TranscriptionResult with transcript and metadata
"""
pass

View file

@ -0,0 +1,174 @@
"""Deepgram STT provider."""
import os
from pathlib import Path
from typing import Any
import httpx
from .base import STTProvider, TranscriptionResult, Word
class DeepgramProvider(STTProvider):
"""Deepgram Speech-to-Text provider.
API Docs: https://developers.deepgram.com/docs/
Supports:
- Speaker diarization via `diarize=true`
- Keyterm boosting via `keyterm` parameter (Nova-3 and Flux models)
"""
API_URL = "https://api.deepgram.com/v1/listen"
def __init__(self, api_key: str | None = None):
self.api_key = api_key or os.getenv("DEEPGRAM_API_KEY")
if not self.api_key:
raise ValueError(
"Deepgram API key required. Set DEEPGRAM_API_KEY env var or pass api_key."
)
@property
def name(self) -> str:
return "deepgram"
async def transcribe(
self,
audio_path: Path,
diarize: bool = False,
keyterms: list[str] | None = None,
model: str = "nova-3",
language: str = "en",
punctuate: bool = True,
**kwargs: Any,
) -> TranscriptionResult:
"""Transcribe audio using Deepgram API.
Args:
audio_path: Path to audio file
diarize: Enable speaker diarization
keyterms: List of keywords to boost recognition
model: Deepgram model (nova-3, nova-2, etc.)
language: Language code
punctuate: Add punctuation
**kwargs: Additional Deepgram parameters
Returns:
TranscriptionResult with transcript and speaker info
"""
params: dict[str, Any] = {
"model": model,
"language": language,
"punctuate": str(punctuate).lower(),
}
if diarize:
params["diarize"] = "true"
# Add keyterms (Deepgram uses repeated keyterm params)
if keyterms:
params["keyterm"] = keyterms
# Add any extra kwargs
params.update(kwargs)
# Read audio file
audio_data = audio_path.read_bytes()
# Determine content type
suffix = audio_path.suffix.lower()
content_types = {
".wav": "audio/wav",
".mp3": "audio/mpeg",
".m4a": "audio/mp4",
".flac": "audio/flac",
".ogg": "audio/ogg",
".webm": "audio/webm",
}
content_type = content_types.get(suffix, "audio/wav")
headers = {
"Authorization": f"Token {self.api_key}",
"Content-Type": content_type,
}
async with httpx.AsyncClient(timeout=120.0) as client:
response = await client.post(
self.API_URL,
params=params,
headers=headers,
content=audio_data,
)
response.raise_for_status()
data = response.json()
return self._parse_response(data, params)
def _parse_response(
self, data: dict[str, Any], params: dict[str, Any]
) -> TranscriptionResult:
"""Parse Deepgram API response."""
results = data.get("results", {})
channels = results.get("channels", [])
if not channels:
return TranscriptionResult(
provider=self.name,
transcript="",
words=[],
speakers=[],
duration=0.0,
raw_response=data,
params=params,
)
# Get first channel, first alternative
channel = channels[0]
alternatives = channel.get("alternatives", [])
if not alternatives:
return TranscriptionResult(
provider=self.name,
transcript="",
words=[],
speakers=[],
duration=0.0,
raw_response=data,
params=params,
)
alt = alternatives[0]
transcript = alt.get("transcript", "")
# Parse words with speaker info
words = []
speakers_set: set[str] = set()
for w in alt.get("words", []):
speaker = str(w.get("speaker", "")) if "speaker" in w else None
if speaker:
speakers_set.add(speaker)
words.append(
Word(
word=w.get("word", ""),
start=w.get("start", 0.0),
end=w.get("end", 0.0),
confidence=w.get("confidence", 0.0),
speaker=speaker,
speaker_confidence=w.get("speaker_confidence"),
)
)
# Get duration from metadata
metadata = results.get("metadata", {})
duration = metadata.get("duration", 0.0)
return TranscriptionResult(
provider=self.name,
transcript=transcript,
words=words,
speakers=sorted(speakers_set),
duration=duration,
raw_response=data,
params=params,
)

View file

@ -0,0 +1,210 @@
"""Speechmatics STT provider."""
import os
from pathlib import Path
from typing import Any
import httpx
from .base import STTProvider, TranscriptionResult, Word
class SpeechmaticsProvider(STTProvider):
"""Speechmatics Speech-to-Text provider.
API Docs: https://docs.speechmatics.com/
Supports:
- Speaker diarization via `diarization: "speaker"` config
- Speaker sensitivity tuning
"""
# EU and US endpoints available
API_URL = "https://asr.api.speechmatics.com/v2/jobs"
def __init__(self, api_key: str | None = None, region: str = "eu1"):
self.api_key = api_key or os.getenv("SPEECHMATICS_API_KEY")
if not self.api_key:
raise ValueError(
"Speechmatics API key required. Set SPEECHMATICS_API_KEY env var or pass api_key."
)
# Set region-specific endpoint
if region == "eu1":
self.api_url = "https://eu1.asr.api.speechmatics.com/v2/jobs"
else:
self.api_url = "https://asr.api.speechmatics.com/v2/jobs"
@property
def name(self) -> str:
return "speechmatics"
async def transcribe(
self,
audio_path: Path,
diarize: bool = False,
keyterms: list[str] | None = None,
language: str = "en",
operating_point: str = "enhanced",
speaker_sensitivity: float | None = None,
**kwargs: Any,
) -> TranscriptionResult:
"""Transcribe audio using Speechmatics API.
Args:
audio_path: Path to audio file
diarize: Enable speaker diarization
keyterms: Not directly supported by Speechmatics (ignored)
language: Language code
operating_point: "standard" or "enhanced"
speaker_sensitivity: 0.0-1.0, higher = more speakers detected
**kwargs: Additional config parameters
Returns:
TranscriptionResult with transcript and speaker info
"""
# Build transcription config
transcription_config: dict[str, Any] = {
"language": language,
"operating_point": operating_point,
}
if diarize:
transcription_config["diarization"] = "speaker"
if speaker_sensitivity is not None:
transcription_config["speaker_diarization_config"] = {
"speaker_sensitivity": speaker_sensitivity
}
# Add any extra config
transcription_config.update(kwargs)
config = {
"type": "transcription",
"transcription_config": transcription_config,
}
# Store params for result
params = {
"diarize": diarize,
"language": language,
"operating_point": operating_point,
"speaker_sensitivity": speaker_sensitivity,
}
headers = {
"Authorization": f"Bearer {self.api_key}",
}
# Create job with multipart form
async with httpx.AsyncClient(timeout=300.0) as client:
# Submit job
with open(audio_path, "rb") as f:
files = {
"data_file": (audio_path.name, f, "audio/mpeg"),
"config": (None, str(config).replace("'", '"'), "application/json"),
}
response = await client.post(
self.api_url,
headers=headers,
files=files,
)
response.raise_for_status()
job_data = response.json()
job_id = job_data.get("id")
if not job_id:
raise ValueError(f"No job ID in response: {job_data}")
# Poll for completion
result_data = await self._wait_for_job(client, job_id, headers)
return self._parse_response(result_data, params)
async def _wait_for_job(
self, client: httpx.AsyncClient, job_id: str, headers: dict[str, str]
) -> dict[str, Any]:
"""Poll job status until complete."""
import asyncio
job_url = f"{self.api_url}/{job_id}"
transcript_url = f"{job_url}/transcript?format=json-v2"
max_attempts = 120 # 10 minutes with 5s intervals
for _ in range(max_attempts):
# Check job status
status_response = await client.get(job_url, headers=headers)
status_response.raise_for_status()
status_data = status_response.json()
job_status = status_data.get("job", {}).get("status")
if job_status == "done":
# Get transcript
transcript_response = await client.get(transcript_url, headers=headers)
transcript_response.raise_for_status()
return transcript_response.json()
elif job_status == "rejected":
raise ValueError(f"Job rejected: {status_data}")
elif job_status == "deleted":
raise ValueError(f"Job deleted: {status_data}")
await asyncio.sleep(5)
raise TimeoutError(f"Job {job_id} did not complete in time")
def _parse_response(
self, data: dict[str, Any], params: dict[str, Any]
) -> TranscriptionResult:
"""Parse Speechmatics API response."""
results = data.get("results", [])
words = []
speakers_set: set[str] = set()
transcript_parts = []
for item in results:
item_type = item.get("type")
alternatives = item.get("alternatives", [])
if not alternatives:
continue
alt = alternatives[0]
content = alt.get("content", "")
speaker = alt.get("speaker")
if speaker:
speakers_set.add(speaker)
if item_type == "word":
words.append(
Word(
word=content,
start=item.get("start_time", 0.0),
end=item.get("end_time", 0.0),
confidence=alt.get("confidence", 0.0),
speaker=speaker,
speaker_confidence=None, # Not provided by Speechmatics
)
)
transcript_parts.append(content)
elif item_type == "punctuation":
# Append punctuation to last word in transcript
if transcript_parts:
transcript_parts[-1] += content
# Get metadata
metadata = data.get("metadata", {})
duration = metadata.get("duration", 0.0)
transcript = " ".join(transcript_parts)
return TranscriptionResult(
provider=self.name,
transcript=transcript,
words=words,
speakers=sorted(speakers_set),
duration=duration,
raw_response=data,
params=params,
)