mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-07 07:55:16 +02:00
feat: add stt evals
This commit is contained in:
parent
d41f696f3f
commit
6d589b7452
8 changed files with 849 additions and 0 deletions
100
evals/stt/README.md
Normal file
100
evals/stt/README.md
Normal file
|
|
@ -0,0 +1,100 @@
|
|||
# STT Evaluation Benchmark
|
||||
|
||||
Benchmark for comparing Speech-to-Text providers with focus on:
|
||||
- **Speaker diarization** - identifying who said what
|
||||
- **Keyterm boosting** - improving recognition of specific terms (Deepgram)
|
||||
|
||||
## Providers
|
||||
|
||||
| Provider | Diarization | Keyterm Boost | Notes |
|
||||
|----------|-------------|---------------|-------|
|
||||
| Deepgram | Yes | Yes | `diarize=true`, `keyterm` param |
|
||||
| Speechmatics | Yes | No | `diarization: "speaker"` config |
|
||||
|
||||
## Setup
|
||||
|
||||
```bash
|
||||
# Install dependencies (httpx is required)
|
||||
pip install httpx
|
||||
|
||||
# Set API keys
|
||||
export DEEPGRAM_API_KEY="your-key"
|
||||
export SPEECHMATICS_API_KEY="your-key"
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
Run from the project root directory:
|
||||
|
||||
```bash
|
||||
# Test both providers with diarization
|
||||
python -m evals.stt.benchmark audio/multi_speaker.m4a --diarize
|
||||
|
||||
# Test only Deepgram
|
||||
python -m evals.stt.benchmark audio/multi_speaker.m4a --diarize --providers deepgram
|
||||
|
||||
# Test with keyterm boosting (Deepgram only)
|
||||
python -m evals.stt.benchmark audio/multi_speaker.m4a --diarize --keyterms "Dograh" "Pipecat"
|
||||
|
||||
# Show word-level timings
|
||||
python -m evals.stt.benchmark audio/multi_speaker.m4a --diarize --show-words
|
||||
|
||||
# Save results to JSON
|
||||
python -m evals.stt.benchmark audio/multi_speaker.m4a --diarize --save
|
||||
```
|
||||
|
||||
## CLI Options
|
||||
|
||||
| Option | Description |
|
||||
|--------|-------------|
|
||||
| `audio_file` | Path to audio file (relative to evals/stt/ or absolute) |
|
||||
| `--providers` | Providers to test: `deepgram`, `speechmatics` (default: both) |
|
||||
| `--diarize` | Enable speaker diarization |
|
||||
| `--keyterms` | Keywords to boost (Deepgram only) |
|
||||
| `--language` | Language code (default: en) |
|
||||
| `--show-words` | Show individual word timings |
|
||||
| `--save` | Save results to JSON in `results/` |
|
||||
|
||||
## Directory Structure
|
||||
|
||||
```
|
||||
evals/stt/
|
||||
├── audio/ # Audio test files
|
||||
│ └── multi_speaker.m4a
|
||||
├── results/ # Saved benchmark results (JSON)
|
||||
├── providers/ # STT provider implementations
|
||||
│ ├── base.py # Base classes
|
||||
│ ├── deepgram_provider.py
|
||||
│ └── speechmatics_provider.py
|
||||
├── benchmark.py # Main runner script
|
||||
└── README.md
|
||||
```
|
||||
|
||||
## Output Example
|
||||
|
||||
```
|
||||
Provider: DEEPGRAM
|
||||
Duration: 45.32s
|
||||
Speakers detected: 2 - ['0', '1']
|
||||
|
||||
Transcript:
|
||||
Hello, welcome to the demo...
|
||||
|
||||
--- Speaker Segments ---
|
||||
[0.0s] Speaker 0: Hello, welcome to the demo.
|
||||
[2.5s] Speaker 1: Thanks for having me.
|
||||
...
|
||||
```
|
||||
|
||||
## Adding New Providers
|
||||
|
||||
1. Create a new file in `providers/` (e.g., `whisper_provider.py`)
|
||||
2. Implement the `STTProvider` abstract class
|
||||
3. Add to `providers/__init__.py`
|
||||
4. Add to `benchmark.py` provider choices
|
||||
|
||||
## API Documentation
|
||||
|
||||
- Deepgram Diarization: https://developers.deepgram.com/docs/diarization
|
||||
- Deepgram Keyterms: https://developers.deepgram.com/docs/keyterm
|
||||
- Speechmatics Diarization: https://docs.speechmatics.com/features/diarization
|
||||
1
evals/stt/__init__.py
Normal file
1
evals/stt/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
|||
# STT Evaluation Benchmark
|
||||
BIN
evals/stt/audio/multi_speaker.m4a
Normal file
BIN
evals/stt/audio/multi_speaker.m4a
Normal file
Binary file not shown.
230
evals/stt/benchmark.py
Normal file
230
evals/stt/benchmark.py
Normal file
|
|
@ -0,0 +1,230 @@
|
|||
#!/usr/bin/env python3
|
||||
"""STT Benchmark Runner.
|
||||
|
||||
Compare speech-to-text transcription across providers with focus on:
|
||||
- Speaker diarization accuracy
|
||||
- Keyword/keyterm recognition
|
||||
- Transcription quality
|
||||
|
||||
Usage:
|
||||
python -m evals.stt.benchmark audio/multi_speaker.m4a --diarize
|
||||
python -m evals.stt.benchmark audio/multi_speaker.m4a --diarize --providers deepgram
|
||||
python -m evals.stt.benchmark audio/multi_speaker.m4a --diarize --keyterms "Dograh" "Pipecat"
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from evals.stt.providers import DeepgramProvider, SpeechmaticsProvider, STTProvider, TranscriptionResult
|
||||
|
||||
|
||||
def get_provider(name: str) -> STTProvider:
|
||||
"""Get provider instance by name."""
|
||||
providers = {
|
||||
"deepgram": DeepgramProvider,
|
||||
"speechmatics": SpeechmaticsProvider,
|
||||
}
|
||||
if name not in providers:
|
||||
raise ValueError(f"Unknown provider: {name}. Available: {list(providers.keys())}")
|
||||
return providers[name]()
|
||||
|
||||
|
||||
async def run_transcription(
|
||||
provider: STTProvider,
|
||||
audio_path: Path,
|
||||
diarize: bool = False,
|
||||
keyterms: list[str] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> TranscriptionResult:
|
||||
"""Run transcription with a provider."""
|
||||
print(f"\n{'='*60}")
|
||||
print(f"Provider: {provider.name.upper()}")
|
||||
print(f"{'='*60}")
|
||||
|
||||
try:
|
||||
result = await provider.transcribe(
|
||||
audio_path,
|
||||
diarize=diarize,
|
||||
keyterms=keyterms,
|
||||
**kwargs,
|
||||
)
|
||||
return result
|
||||
except Exception as e:
|
||||
print(f"Error with {provider.name}: {e}")
|
||||
raise
|
||||
|
||||
|
||||
def print_result(result: TranscriptionResult, show_words: bool = False) -> None:
|
||||
"""Print transcription result."""
|
||||
print(f"\nDuration: {result.duration:.2f}s")
|
||||
print(f"Speakers detected: {len(result.speakers)} - {result.speakers}")
|
||||
print(f"\nTranscript:\n{result.transcript}")
|
||||
|
||||
if result.speakers:
|
||||
print(f"\n--- Speaker Segments ---")
|
||||
for segment in result.get_speaker_segments():
|
||||
speaker = segment["speaker"] or "?"
|
||||
text = segment["text"]
|
||||
start = segment["start"]
|
||||
print(f"[{start:.1f}s] Speaker {speaker}: {text}")
|
||||
|
||||
if show_words:
|
||||
print(f"\n--- Words ---")
|
||||
for word in result.words[:50]: # First 50 words
|
||||
speaker_info = f" (S{word.speaker})" if word.speaker else ""
|
||||
print(f" {word.start:.2f}s: {word.word}{speaker_info} [{word.confidence:.2f}]")
|
||||
if len(result.words) > 50:
|
||||
print(f" ... and {len(result.words) - 50} more words")
|
||||
|
||||
|
||||
def save_results(
|
||||
results: list[TranscriptionResult],
|
||||
output_dir: Path,
|
||||
audio_name: str,
|
||||
) -> Path:
|
||||
"""Save results to JSON file."""
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
output_file = output_dir / f"{audio_name}_{timestamp}.json"
|
||||
|
||||
output_data = {
|
||||
"timestamp": timestamp,
|
||||
"audio_file": audio_name,
|
||||
"results": [r.to_dict() for r in results],
|
||||
}
|
||||
|
||||
with open(output_file, "w") as f:
|
||||
json.dump(output_data, f, indent=2)
|
||||
|
||||
print(f"\nResults saved to: {output_file}")
|
||||
return output_file
|
||||
|
||||
|
||||
def compare_results(results: list[TranscriptionResult]) -> None:
|
||||
"""Compare results across providers."""
|
||||
if len(results) < 2:
|
||||
return
|
||||
|
||||
print(f"\n{'='*60}")
|
||||
print("COMPARISON SUMMARY")
|
||||
print(f"{'='*60}")
|
||||
|
||||
print(f"\n{'Provider':<15} {'Duration':<10} {'Speakers':<10} {'Words':<10}")
|
||||
print("-" * 45)
|
||||
for r in results:
|
||||
print(f"{r.provider:<15} {r.duration:<10.2f} {len(r.speakers):<10} {len(r.words):<10}")
|
||||
|
||||
# Compare speaker counts
|
||||
speaker_counts = {r.provider: len(r.speakers) for r in results}
|
||||
if len(set(speaker_counts.values())) > 1:
|
||||
print(f"\nNote: Providers detected different speaker counts: {speaker_counts}")
|
||||
|
||||
|
||||
async def main() -> int:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="STT Benchmark - Compare transcription providers",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
Examples:
|
||||
python -m evals.stt.benchmark audio/multi_speaker.m4a --diarize
|
||||
python -m evals.stt.benchmark audio/multi_speaker.m4a --diarize --providers deepgram
|
||||
python -m evals.stt.benchmark audio/multi_speaker.m4a --keyterms "Dograh" "API"
|
||||
""",
|
||||
)
|
||||
parser.add_argument(
|
||||
"audio_file",
|
||||
type=str,
|
||||
help="Path to audio file (relative to evals/stt/ or absolute)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--providers",
|
||||
nargs="+",
|
||||
default=["deepgram", "speechmatics"],
|
||||
choices=["deepgram", "speechmatics"],
|
||||
help="Providers to test (default: all)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--diarize",
|
||||
action="store_true",
|
||||
help="Enable speaker diarization",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--keyterms",
|
||||
nargs="+",
|
||||
help="Keywords to boost (Deepgram only)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--language",
|
||||
default="en",
|
||||
help="Language code (default: en)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--show-words",
|
||||
action="store_true",
|
||||
help="Show individual word timings",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save",
|
||||
action="store_true",
|
||||
help="Save results to JSON file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-dir",
|
||||
type=str,
|
||||
default="results",
|
||||
help="Output directory for results (default: results)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Resolve audio path
|
||||
script_dir = Path(__file__).parent
|
||||
audio_path = Path(args.audio_file)
|
||||
if not audio_path.is_absolute():
|
||||
audio_path = script_dir / audio_path
|
||||
|
||||
if not audio_path.exists():
|
||||
print(f"Error: Audio file not found: {audio_path}")
|
||||
return 1
|
||||
|
||||
print(f"Audio file: {audio_path}")
|
||||
print(f"Providers: {args.providers}")
|
||||
print(f"Diarization: {args.diarize}")
|
||||
if args.keyterms:
|
||||
print(f"Keyterms: {args.keyterms}")
|
||||
|
||||
results: list[TranscriptionResult] = []
|
||||
|
||||
for provider_name in args.providers:
|
||||
try:
|
||||
provider = get_provider(provider_name)
|
||||
result = await run_transcription(
|
||||
provider,
|
||||
audio_path,
|
||||
diarize=args.diarize,
|
||||
keyterms=args.keyterms,
|
||||
language=args.language,
|
||||
)
|
||||
print_result(result, show_words=args.show_words)
|
||||
results.append(result)
|
||||
except Exception as e:
|
||||
print(f"\nFailed to run {provider_name}: {e}")
|
||||
continue
|
||||
|
||||
if len(results) > 1:
|
||||
compare_results(results)
|
||||
|
||||
if args.save and results:
|
||||
output_dir = script_dir / args.output_dir
|
||||
save_results(results, output_dir, audio_path.stem)
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(asyncio.run(main()))
|
||||
11
evals/stt/providers/__init__.py
Normal file
11
evals/stt/providers/__init__.py
Normal file
|
|
@ -0,0 +1,11 @@
|
|||
from .base import STTProvider, TranscriptionResult, Word
|
||||
from .deepgram_provider import DeepgramProvider
|
||||
from .speechmatics_provider import SpeechmaticsProvider
|
||||
|
||||
__all__ = [
|
||||
"STTProvider",
|
||||
"TranscriptionResult",
|
||||
"Word",
|
||||
"DeepgramProvider",
|
||||
"SpeechmaticsProvider",
|
||||
]
|
||||
123
evals/stt/providers/base.py
Normal file
123
evals/stt/providers/base.py
Normal file
|
|
@ -0,0 +1,123 @@
|
|||
"""Base classes for STT providers."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
|
||||
@dataclass
|
||||
class Word:
|
||||
"""Represents a transcribed word with metadata."""
|
||||
|
||||
word: str
|
||||
start: float
|
||||
end: float
|
||||
confidence: float
|
||||
speaker: str | None = None
|
||||
speaker_confidence: float | None = None
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"word": self.word,
|
||||
"start": self.start,
|
||||
"end": self.end,
|
||||
"confidence": self.confidence,
|
||||
"speaker": self.speaker,
|
||||
"speaker_confidence": self.speaker_confidence,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class TranscriptionResult:
|
||||
"""Result from STT transcription."""
|
||||
|
||||
provider: str
|
||||
transcript: str
|
||||
words: list[Word]
|
||||
speakers: list[str]
|
||||
duration: float
|
||||
raw_response: dict[str, Any] = field(default_factory=dict)
|
||||
params: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"provider": self.provider,
|
||||
"transcript": self.transcript,
|
||||
"words": [w.to_dict() for w in self.words],
|
||||
"speakers": self.speakers,
|
||||
"duration": self.duration,
|
||||
"params": self.params,
|
||||
}
|
||||
|
||||
def get_speaker_segments(self) -> list[dict[str, Any]]:
|
||||
"""Get transcript segmented by speaker."""
|
||||
if not self.words:
|
||||
return []
|
||||
|
||||
segments = []
|
||||
current_speaker = None
|
||||
current_text = []
|
||||
segment_start = 0.0
|
||||
|
||||
for word in self.words:
|
||||
if word.speaker != current_speaker:
|
||||
if current_text:
|
||||
segments.append(
|
||||
{
|
||||
"speaker": current_speaker,
|
||||
"text": " ".join(current_text),
|
||||
"start": segment_start,
|
||||
"end": self.words[len(segments) - 1].end
|
||||
if segments
|
||||
else word.start,
|
||||
}
|
||||
)
|
||||
current_speaker = word.speaker
|
||||
current_text = [word.word]
|
||||
segment_start = word.start
|
||||
else:
|
||||
current_text.append(word.word)
|
||||
|
||||
if current_text:
|
||||
segments.append(
|
||||
{
|
||||
"speaker": current_speaker,
|
||||
"text": " ".join(current_text),
|
||||
"start": segment_start,
|
||||
"end": self.words[-1].end if self.words else 0.0,
|
||||
}
|
||||
)
|
||||
|
||||
return segments
|
||||
|
||||
|
||||
class STTProvider(ABC):
|
||||
"""Abstract base class for STT providers."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def name(self) -> str:
|
||||
"""Provider name."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def transcribe(
|
||||
self,
|
||||
audio_path: Path,
|
||||
diarize: bool = False,
|
||||
keyterms: list[str] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> TranscriptionResult:
|
||||
"""Transcribe audio file.
|
||||
|
||||
Args:
|
||||
audio_path: Path to the audio file
|
||||
diarize: Enable speaker diarization
|
||||
keyterms: List of keywords to boost (if supported)
|
||||
**kwargs: Provider-specific parameters
|
||||
|
||||
Returns:
|
||||
TranscriptionResult with transcript and metadata
|
||||
"""
|
||||
pass
|
||||
174
evals/stt/providers/deepgram_provider.py
Normal file
174
evals/stt/providers/deepgram_provider.py
Normal file
|
|
@ -0,0 +1,174 @@
|
|||
"""Deepgram STT provider."""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
|
||||
from .base import STTProvider, TranscriptionResult, Word
|
||||
|
||||
|
||||
class DeepgramProvider(STTProvider):
|
||||
"""Deepgram Speech-to-Text provider.
|
||||
|
||||
API Docs: https://developers.deepgram.com/docs/
|
||||
|
||||
Supports:
|
||||
- Speaker diarization via `diarize=true`
|
||||
- Keyterm boosting via `keyterm` parameter (Nova-3 and Flux models)
|
||||
"""
|
||||
|
||||
API_URL = "https://api.deepgram.com/v1/listen"
|
||||
|
||||
def __init__(self, api_key: str | None = None):
|
||||
self.api_key = api_key or os.getenv("DEEPGRAM_API_KEY")
|
||||
if not self.api_key:
|
||||
raise ValueError(
|
||||
"Deepgram API key required. Set DEEPGRAM_API_KEY env var or pass api_key."
|
||||
)
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "deepgram"
|
||||
|
||||
async def transcribe(
|
||||
self,
|
||||
audio_path: Path,
|
||||
diarize: bool = False,
|
||||
keyterms: list[str] | None = None,
|
||||
model: str = "nova-3",
|
||||
language: str = "en",
|
||||
punctuate: bool = True,
|
||||
**kwargs: Any,
|
||||
) -> TranscriptionResult:
|
||||
"""Transcribe audio using Deepgram API.
|
||||
|
||||
Args:
|
||||
audio_path: Path to audio file
|
||||
diarize: Enable speaker diarization
|
||||
keyterms: List of keywords to boost recognition
|
||||
model: Deepgram model (nova-3, nova-2, etc.)
|
||||
language: Language code
|
||||
punctuate: Add punctuation
|
||||
**kwargs: Additional Deepgram parameters
|
||||
|
||||
Returns:
|
||||
TranscriptionResult with transcript and speaker info
|
||||
"""
|
||||
params: dict[str, Any] = {
|
||||
"model": model,
|
||||
"language": language,
|
||||
"punctuate": str(punctuate).lower(),
|
||||
}
|
||||
|
||||
if diarize:
|
||||
params["diarize"] = "true"
|
||||
|
||||
# Add keyterms (Deepgram uses repeated keyterm params)
|
||||
if keyterms:
|
||||
params["keyterm"] = keyterms
|
||||
|
||||
# Add any extra kwargs
|
||||
params.update(kwargs)
|
||||
|
||||
# Read audio file
|
||||
audio_data = audio_path.read_bytes()
|
||||
|
||||
# Determine content type
|
||||
suffix = audio_path.suffix.lower()
|
||||
content_types = {
|
||||
".wav": "audio/wav",
|
||||
".mp3": "audio/mpeg",
|
||||
".m4a": "audio/mp4",
|
||||
".flac": "audio/flac",
|
||||
".ogg": "audio/ogg",
|
||||
".webm": "audio/webm",
|
||||
}
|
||||
content_type = content_types.get(suffix, "audio/wav")
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Token {self.api_key}",
|
||||
"Content-Type": content_type,
|
||||
}
|
||||
|
||||
async with httpx.AsyncClient(timeout=120.0) as client:
|
||||
response = await client.post(
|
||||
self.API_URL,
|
||||
params=params,
|
||||
headers=headers,
|
||||
content=audio_data,
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
return self._parse_response(data, params)
|
||||
|
||||
def _parse_response(
|
||||
self, data: dict[str, Any], params: dict[str, Any]
|
||||
) -> TranscriptionResult:
|
||||
"""Parse Deepgram API response."""
|
||||
results = data.get("results", {})
|
||||
channels = results.get("channels", [])
|
||||
|
||||
if not channels:
|
||||
return TranscriptionResult(
|
||||
provider=self.name,
|
||||
transcript="",
|
||||
words=[],
|
||||
speakers=[],
|
||||
duration=0.0,
|
||||
raw_response=data,
|
||||
params=params,
|
||||
)
|
||||
|
||||
# Get first channel, first alternative
|
||||
channel = channels[0]
|
||||
alternatives = channel.get("alternatives", [])
|
||||
if not alternatives:
|
||||
return TranscriptionResult(
|
||||
provider=self.name,
|
||||
transcript="",
|
||||
words=[],
|
||||
speakers=[],
|
||||
duration=0.0,
|
||||
raw_response=data,
|
||||
params=params,
|
||||
)
|
||||
|
||||
alt = alternatives[0]
|
||||
transcript = alt.get("transcript", "")
|
||||
|
||||
# Parse words with speaker info
|
||||
words = []
|
||||
speakers_set: set[str] = set()
|
||||
|
||||
for w in alt.get("words", []):
|
||||
speaker = str(w.get("speaker", "")) if "speaker" in w else None
|
||||
if speaker:
|
||||
speakers_set.add(speaker)
|
||||
|
||||
words.append(
|
||||
Word(
|
||||
word=w.get("word", ""),
|
||||
start=w.get("start", 0.0),
|
||||
end=w.get("end", 0.0),
|
||||
confidence=w.get("confidence", 0.0),
|
||||
speaker=speaker,
|
||||
speaker_confidence=w.get("speaker_confidence"),
|
||||
)
|
||||
)
|
||||
|
||||
# Get duration from metadata
|
||||
metadata = results.get("metadata", {})
|
||||
duration = metadata.get("duration", 0.0)
|
||||
|
||||
return TranscriptionResult(
|
||||
provider=self.name,
|
||||
transcript=transcript,
|
||||
words=words,
|
||||
speakers=sorted(speakers_set),
|
||||
duration=duration,
|
||||
raw_response=data,
|
||||
params=params,
|
||||
)
|
||||
210
evals/stt/providers/speechmatics_provider.py
Normal file
210
evals/stt/providers/speechmatics_provider.py
Normal file
|
|
@ -0,0 +1,210 @@
|
|||
"""Speechmatics STT provider."""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
|
||||
from .base import STTProvider, TranscriptionResult, Word
|
||||
|
||||
|
||||
class SpeechmaticsProvider(STTProvider):
|
||||
"""Speechmatics Speech-to-Text provider.
|
||||
|
||||
API Docs: https://docs.speechmatics.com/
|
||||
|
||||
Supports:
|
||||
- Speaker diarization via `diarization: "speaker"` config
|
||||
- Speaker sensitivity tuning
|
||||
"""
|
||||
|
||||
# EU and US endpoints available
|
||||
API_URL = "https://asr.api.speechmatics.com/v2/jobs"
|
||||
|
||||
def __init__(self, api_key: str | None = None, region: str = "eu1"):
|
||||
self.api_key = api_key or os.getenv("SPEECHMATICS_API_KEY")
|
||||
if not self.api_key:
|
||||
raise ValueError(
|
||||
"Speechmatics API key required. Set SPEECHMATICS_API_KEY env var or pass api_key."
|
||||
)
|
||||
# Set region-specific endpoint
|
||||
if region == "eu1":
|
||||
self.api_url = "https://eu1.asr.api.speechmatics.com/v2/jobs"
|
||||
else:
|
||||
self.api_url = "https://asr.api.speechmatics.com/v2/jobs"
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "speechmatics"
|
||||
|
||||
async def transcribe(
|
||||
self,
|
||||
audio_path: Path,
|
||||
diarize: bool = False,
|
||||
keyterms: list[str] | None = None,
|
||||
language: str = "en",
|
||||
operating_point: str = "enhanced",
|
||||
speaker_sensitivity: float | None = None,
|
||||
**kwargs: Any,
|
||||
) -> TranscriptionResult:
|
||||
"""Transcribe audio using Speechmatics API.
|
||||
|
||||
Args:
|
||||
audio_path: Path to audio file
|
||||
diarize: Enable speaker diarization
|
||||
keyterms: Not directly supported by Speechmatics (ignored)
|
||||
language: Language code
|
||||
operating_point: "standard" or "enhanced"
|
||||
speaker_sensitivity: 0.0-1.0, higher = more speakers detected
|
||||
**kwargs: Additional config parameters
|
||||
|
||||
Returns:
|
||||
TranscriptionResult with transcript and speaker info
|
||||
"""
|
||||
# Build transcription config
|
||||
transcription_config: dict[str, Any] = {
|
||||
"language": language,
|
||||
"operating_point": operating_point,
|
||||
}
|
||||
|
||||
if diarize:
|
||||
transcription_config["diarization"] = "speaker"
|
||||
if speaker_sensitivity is not None:
|
||||
transcription_config["speaker_diarization_config"] = {
|
||||
"speaker_sensitivity": speaker_sensitivity
|
||||
}
|
||||
|
||||
# Add any extra config
|
||||
transcription_config.update(kwargs)
|
||||
|
||||
config = {
|
||||
"type": "transcription",
|
||||
"transcription_config": transcription_config,
|
||||
}
|
||||
|
||||
# Store params for result
|
||||
params = {
|
||||
"diarize": diarize,
|
||||
"language": language,
|
||||
"operating_point": operating_point,
|
||||
"speaker_sensitivity": speaker_sensitivity,
|
||||
}
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
}
|
||||
|
||||
# Create job with multipart form
|
||||
async with httpx.AsyncClient(timeout=300.0) as client:
|
||||
# Submit job
|
||||
with open(audio_path, "rb") as f:
|
||||
files = {
|
||||
"data_file": (audio_path.name, f, "audio/mpeg"),
|
||||
"config": (None, str(config).replace("'", '"'), "application/json"),
|
||||
}
|
||||
response = await client.post(
|
||||
self.api_url,
|
||||
headers=headers,
|
||||
files=files,
|
||||
)
|
||||
response.raise_for_status()
|
||||
job_data = response.json()
|
||||
|
||||
job_id = job_data.get("id")
|
||||
if not job_id:
|
||||
raise ValueError(f"No job ID in response: {job_data}")
|
||||
|
||||
# Poll for completion
|
||||
result_data = await self._wait_for_job(client, job_id, headers)
|
||||
|
||||
return self._parse_response(result_data, params)
|
||||
|
||||
async def _wait_for_job(
|
||||
self, client: httpx.AsyncClient, job_id: str, headers: dict[str, str]
|
||||
) -> dict[str, Any]:
|
||||
"""Poll job status until complete."""
|
||||
import asyncio
|
||||
|
||||
job_url = f"{self.api_url}/{job_id}"
|
||||
transcript_url = f"{job_url}/transcript?format=json-v2"
|
||||
|
||||
max_attempts = 120 # 10 minutes with 5s intervals
|
||||
for _ in range(max_attempts):
|
||||
# Check job status
|
||||
status_response = await client.get(job_url, headers=headers)
|
||||
status_response.raise_for_status()
|
||||
status_data = status_response.json()
|
||||
|
||||
job_status = status_data.get("job", {}).get("status")
|
||||
|
||||
if job_status == "done":
|
||||
# Get transcript
|
||||
transcript_response = await client.get(transcript_url, headers=headers)
|
||||
transcript_response.raise_for_status()
|
||||
return transcript_response.json()
|
||||
elif job_status == "rejected":
|
||||
raise ValueError(f"Job rejected: {status_data}")
|
||||
elif job_status == "deleted":
|
||||
raise ValueError(f"Job deleted: {status_data}")
|
||||
|
||||
await asyncio.sleep(5)
|
||||
|
||||
raise TimeoutError(f"Job {job_id} did not complete in time")
|
||||
|
||||
def _parse_response(
|
||||
self, data: dict[str, Any], params: dict[str, Any]
|
||||
) -> TranscriptionResult:
|
||||
"""Parse Speechmatics API response."""
|
||||
results = data.get("results", [])
|
||||
|
||||
words = []
|
||||
speakers_set: set[str] = set()
|
||||
transcript_parts = []
|
||||
|
||||
for item in results:
|
||||
item_type = item.get("type")
|
||||
alternatives = item.get("alternatives", [])
|
||||
|
||||
if not alternatives:
|
||||
continue
|
||||
|
||||
alt = alternatives[0]
|
||||
content = alt.get("content", "")
|
||||
speaker = alt.get("speaker")
|
||||
|
||||
if speaker:
|
||||
speakers_set.add(speaker)
|
||||
|
||||
if item_type == "word":
|
||||
words.append(
|
||||
Word(
|
||||
word=content,
|
||||
start=item.get("start_time", 0.0),
|
||||
end=item.get("end_time", 0.0),
|
||||
confidence=alt.get("confidence", 0.0),
|
||||
speaker=speaker,
|
||||
speaker_confidence=None, # Not provided by Speechmatics
|
||||
)
|
||||
)
|
||||
transcript_parts.append(content)
|
||||
elif item_type == "punctuation":
|
||||
# Append punctuation to last word in transcript
|
||||
if transcript_parts:
|
||||
transcript_parts[-1] += content
|
||||
|
||||
# Get metadata
|
||||
metadata = data.get("metadata", {})
|
||||
duration = metadata.get("duration", 0.0)
|
||||
|
||||
transcript = " ".join(transcript_parts)
|
||||
|
||||
return TranscriptionResult(
|
||||
provider=self.name,
|
||||
transcript=transcript,
|
||||
words=words,
|
||||
speakers=sorted(speakers_set),
|
||||
duration=duration,
|
||||
raw_response=data,
|
||||
params=params,
|
||||
)
|
||||
Loading…
Add table
Add a link
Reference in a new issue