mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-07 07:55:16 +02:00
Merge branch 'feat/add-stt-evals' into update-pipecat-0.99
This commit is contained in:
commit
fcf718d4e9
12 changed files with 1631 additions and 0 deletions
135
evals/stt/README.md
Normal file
135
evals/stt/README.md
Normal file
|
|
@ -0,0 +1,135 @@
|
|||
# STT Evaluation Benchmark
|
||||
|
||||
Benchmark for comparing Speech-to-Text providers using **WebSocket streaming** with focus on:
|
||||
- **Speaker diarization** - identifying who said what
|
||||
- **Keyterm boosting** - improving recognition of specific terms (Deepgram)
|
||||
|
||||
## Providers
|
||||
|
||||
| Provider | Diarization | Keyterm Boost | Streaming |
|
||||
|----------|-------------|---------------|-----------|
|
||||
| Deepgram | Yes | Yes | WebSocket (v1/v2) |
|
||||
| Speechmatics | Yes | Additional vocab | WebSocket RT |
|
||||
|
||||
## Setup
|
||||
|
||||
```bash
|
||||
# Install dependencies
|
||||
pip install websockets
|
||||
|
||||
# Set API keys
|
||||
export DEEPGRAM_API_KEY="your-key"
|
||||
export SPEECHMATICS_API_KEY="your-key"
|
||||
```
|
||||
|
||||
**Note:** Requires `ffmpeg` installed for audio conversion to PCM16.
|
||||
|
||||
## Usage
|
||||
|
||||
Run from the project root directory:
|
||||
|
||||
```bash
|
||||
# Test both providers with diarization
|
||||
python -m evals.stt.benchmark audio/multi_speaker.m4a --diarize
|
||||
|
||||
# Test only Deepgram
|
||||
python -m evals.stt.benchmark audio/multi_speaker.m4a --diarize --providers deepgram
|
||||
|
||||
# Test with keyterm boosting (Deepgram)
|
||||
python -m evals.stt.benchmark audio/multi_speaker.m4a --diarize --keyterms "Dograh" "Pipecat"
|
||||
|
||||
# Use different sample rate (default: 8000 Hz)
|
||||
python -m evals.stt.benchmark audio/multi_speaker.m4a --diarize --sample-rate 16000
|
||||
|
||||
# Show word-level timings
|
||||
python -m evals.stt.benchmark audio/multi_speaker.m4a --diarize --show-words
|
||||
|
||||
# Save results to JSON
|
||||
python -m evals.stt.benchmark audio/multi_speaker.m4a --diarize --save
|
||||
```
|
||||
|
||||
## CLI Options
|
||||
|
||||
| Option | Description |
|
||||
|--------|-------------|
|
||||
| `audio_file` | Path to audio file (relative to evals/stt/ or absolute) |
|
||||
| `--providers` | Providers to test: `deepgram`, `speechmatics` (default: both) |
|
||||
| `--diarize` | Enable speaker diarization |
|
||||
| `--keyterms` | Keywords to boost (Deepgram) / additional vocab (Speechmatics) |
|
||||
| `--language` | Language code (default: en) |
|
||||
| `--sample-rate` | Audio sample rate for streaming (default: 8000) |
|
||||
| `--show-words` | Show individual word timings |
|
||||
| `--save` | Save results to JSON in `results/` |
|
||||
|
||||
## Directory Structure
|
||||
|
||||
```
|
||||
evals/stt/
|
||||
├── audio/ # Audio test files
|
||||
│ └── multi_speaker.m4a
|
||||
├── results/ # Saved benchmark results (JSON)
|
||||
├── providers/ # STT provider implementations
|
||||
│ ├── base.py # Base classes
|
||||
│ ├── deepgram_provider.py # WebSocket streaming
|
||||
│ └── speechmatics_provider.py # WebSocket streaming
|
||||
├── audio_streamer.py # PCM16 audio file streamer
|
||||
├── benchmark.py # Main runner script
|
||||
└── README.md
|
||||
```
|
||||
|
||||
## How It Works
|
||||
|
||||
1. **Audio Conversion**: The `AudioStreamer` converts any audio file to raw PCM16 using ffmpeg
|
||||
2. **WebSocket Connection**: Providers connect to their respective WebSocket APIs
|
||||
3. **Streaming**: Audio is sent in chunks (configurable sample rate, default 8kHz)
|
||||
4. **Result Collection**: Transcripts and speaker info are collected from WebSocket responses
|
||||
5. **Comparison**: Results are parsed into a common format for comparison
|
||||
|
||||
## Output Example
|
||||
|
||||
```
|
||||
Audio file: /path/to/audio/multi_speaker.m4a
|
||||
Providers: ['deepgram', 'speechmatics']
|
||||
Diarization: True
|
||||
Sample rate: 8000 Hz
|
||||
|
||||
============================================================
|
||||
Provider: DEEPGRAM
|
||||
============================================================
|
||||
|
||||
Duration: 45.32s
|
||||
Speakers detected: 2 - ['0', '1']
|
||||
|
||||
Transcript:
|
||||
Hello, welcome to the demo...
|
||||
|
||||
--- Speaker Segments ---
|
||||
[0.0s] Speaker 0: Hello, welcome to the demo.
|
||||
[2.5s] Speaker 1: Thanks for having me.
|
||||
...
|
||||
|
||||
============================================================
|
||||
COMPARISON SUMMARY
|
||||
============================================================
|
||||
|
||||
Provider Duration Speakers Words
|
||||
---------------------------------------------
|
||||
deepgram 45.32 2 312
|
||||
speechmatics 45.32 2 308
|
||||
```
|
||||
|
||||
## Adding New Providers
|
||||
|
||||
1. Create a new file in `providers/` (e.g., `whisper_provider.py`)
|
||||
2. Implement the `STTProvider` abstract class with WebSocket streaming
|
||||
3. Use `AudioStreamer` for PCM16 conversion
|
||||
4. Add to `providers/__init__.py`
|
||||
5. Add to `benchmark.py` provider choices
|
||||
|
||||
## API Documentation
|
||||
|
||||
- Deepgram Streaming: https://developers.deepgram.com/docs/live-streaming-audio
|
||||
- Deepgram Diarization: https://developers.deepgram.com/docs/diarization
|
||||
- Deepgram Keyterms: https://developers.deepgram.com/docs/keyterm
|
||||
- Speechmatics RT API: https://docs.speechmatics.com/rt-api-ref
|
||||
- Speechmatics Diarization: https://docs.speechmatics.com/features/diarization
|
||||
1
evals/stt/__init__.py
Normal file
1
evals/stt/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
|||
# STT Evaluation Benchmark
|
||||
BIN
evals/stt/audio/multi_speaker.m4a
Normal file
BIN
evals/stt/audio/multi_speaker.m4a
Normal file
Binary file not shown.
BIN
evals/stt/audio/vad.m4a
Normal file
BIN
evals/stt/audio/vad.m4a
Normal file
Binary file not shown.
127
evals/stt/audio_streamer.py
Normal file
127
evals/stt/audio_streamer.py
Normal file
|
|
@ -0,0 +1,127 @@
|
|||
"""Audio file streamer - converts audio files to PCM16 streams."""
|
||||
|
||||
import asyncio
|
||||
import subprocess
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import AsyncIterator
|
||||
|
||||
|
||||
@dataclass
|
||||
class AudioConfig:
|
||||
"""Audio streaming configuration."""
|
||||
|
||||
sample_rate: int = 8000
|
||||
channels: int = 1
|
||||
sample_width: int = 2 # 16-bit = 2 bytes
|
||||
chunk_duration_ms: int = 80 # Send chunks every 80ms
|
||||
|
||||
@property
|
||||
def chunk_size(self) -> int:
|
||||
"""Bytes per chunk based on duration."""
|
||||
samples_per_chunk = int(self.sample_rate * self.chunk_duration_ms / 1000)
|
||||
return samples_per_chunk * self.channels * self.sample_width
|
||||
|
||||
|
||||
class AudioStreamer:
|
||||
"""Streams audio files as PCM16 chunks.
|
||||
|
||||
Converts any audio format to raw PCM16 using ffmpeg and streams
|
||||
in real-time chunks to simulate live audio.
|
||||
"""
|
||||
|
||||
def __init__(self, config: AudioConfig | None = None):
|
||||
self.config = config or AudioConfig()
|
||||
|
||||
def convert_to_pcm16(self, audio_path: Path) -> bytes:
|
||||
"""Convert audio file to raw PCM16 bytes using ffmpeg.
|
||||
|
||||
Args:
|
||||
audio_path: Path to input audio file
|
||||
|
||||
Returns:
|
||||
Raw PCM16 audio bytes
|
||||
"""
|
||||
cmd = [
|
||||
"ffmpeg",
|
||||
"-i",
|
||||
str(audio_path),
|
||||
"-f",
|
||||
"s16le", # signed 16-bit little-endian
|
||||
"-acodec",
|
||||
"pcm_s16le",
|
||||
"-ar",
|
||||
str(self.config.sample_rate),
|
||||
"-ac",
|
||||
str(self.config.channels),
|
||||
"-", # output to stdout
|
||||
]
|
||||
|
||||
result = subprocess.run(
|
||||
cmd,
|
||||
capture_output=True,
|
||||
check=True,
|
||||
)
|
||||
return result.stdout
|
||||
|
||||
async def stream_file(
|
||||
self,
|
||||
audio_path: Path,
|
||||
realtime: bool = True,
|
||||
) -> AsyncIterator[bytes]:
|
||||
"""Stream audio file as PCM16 chunks.
|
||||
|
||||
Args:
|
||||
audio_path: Path to audio file
|
||||
realtime: If True, add delays to simulate real-time streaming
|
||||
|
||||
Yields:
|
||||
PCM16 audio chunks
|
||||
"""
|
||||
# Convert entire file to PCM16
|
||||
pcm_data = self.convert_to_pcm16(audio_path)
|
||||
|
||||
chunk_size = self.config.chunk_size
|
||||
delay = self.config.chunk_duration_ms / 1000.0 if realtime else 0
|
||||
|
||||
# Stream in chunks
|
||||
for i in range(0, len(pcm_data), chunk_size):
|
||||
chunk = pcm_data[i : i + chunk_size]
|
||||
if chunk:
|
||||
yield chunk
|
||||
if realtime and delay > 0:
|
||||
await asyncio.sleep(delay)
|
||||
|
||||
async def stream_file_fast(self, audio_path: Path) -> AsyncIterator[bytes]:
|
||||
"""Stream audio file as fast as possible (no real-time delay).
|
||||
|
||||
Args:
|
||||
audio_path: Path to audio file
|
||||
|
||||
Yields:
|
||||
PCM16 audio chunks
|
||||
"""
|
||||
async for chunk in self.stream_file(audio_path, realtime=False):
|
||||
yield chunk
|
||||
|
||||
def get_duration(self, audio_path: Path) -> float:
|
||||
"""Get audio file duration in seconds.
|
||||
|
||||
Args:
|
||||
audio_path: Path to audio file
|
||||
|
||||
Returns:
|
||||
Duration in seconds
|
||||
"""
|
||||
cmd = [
|
||||
"ffprobe",
|
||||
"-v",
|
||||
"error",
|
||||
"-show_entries",
|
||||
"format=duration",
|
||||
"-of",
|
||||
"default=noprint_wrappers=1:nokey=1",
|
||||
str(audio_path),
|
||||
]
|
||||
result = subprocess.run(cmd, capture_output=True, text=True, check=True)
|
||||
return float(result.stdout.strip())
|
||||
247
evals/stt/benchmark.py
Normal file
247
evals/stt/benchmark.py
Normal file
|
|
@ -0,0 +1,247 @@
|
|||
#!/usr/bin/env python3
|
||||
"""STT Benchmark Runner.
|
||||
|
||||
Compare speech-to-text transcription across providers with focus on:
|
||||
- Speaker diarization accuracy
|
||||
- Keyword/keyterm recognition
|
||||
- Transcription quality
|
||||
|
||||
Usage:
|
||||
python -m evals.stt.benchmark audio/multi_speaker.m4a --diarize
|
||||
python -m evals.stt.benchmark audio/multi_speaker.m4a --diarize --providers deepgram
|
||||
python -m evals.stt.benchmark audio/multi_speaker.m4a --diarize --keyterms "Dograh" "Pipecat"
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from evals.stt.providers import (
|
||||
DeepgramProvider,
|
||||
DeepgramFluxProvider,
|
||||
SpeechmaticsProvider,
|
||||
LocalSmartTurnProvider,
|
||||
STTProvider,
|
||||
TranscriptionResult,
|
||||
)
|
||||
|
||||
|
||||
def get_provider(name: str) -> STTProvider:
|
||||
"""Get provider instance by name."""
|
||||
providers = {
|
||||
"deepgram": DeepgramProvider,
|
||||
"deepgram-flux": DeepgramFluxProvider,
|
||||
"speechmatics": SpeechmaticsProvider,
|
||||
"local-smart-turn": LocalSmartTurnProvider,
|
||||
}
|
||||
if name not in providers:
|
||||
raise ValueError(f"Unknown provider: {name}. Available: {list(providers.keys())}")
|
||||
return providers[name]()
|
||||
|
||||
|
||||
async def run_transcription(
|
||||
provider: STTProvider,
|
||||
audio_path: Path,
|
||||
diarize: bool = False,
|
||||
keyterms: list[str] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> TranscriptionResult:
|
||||
"""Run transcription with a provider."""
|
||||
print(f"\n{'='*60}")
|
||||
print(f"Provider: {provider.name.upper()}")
|
||||
print(f"{'='*60}")
|
||||
|
||||
try:
|
||||
result = await provider.transcribe(
|
||||
audio_path,
|
||||
diarize=diarize,
|
||||
keyterms=keyterms,
|
||||
**kwargs,
|
||||
)
|
||||
return result
|
||||
except Exception as e:
|
||||
print(f"Error with {provider.name}: {e}")
|
||||
raise
|
||||
|
||||
|
||||
def print_result(result: TranscriptionResult, show_words: bool = False) -> None:
|
||||
"""Print transcription result."""
|
||||
print(f"\nDuration: {result.duration:.2f}s")
|
||||
print(f"Speakers detected: {len(result.speakers)} - {result.speakers}")
|
||||
print(f"\nTranscript:\n{result.transcript}")
|
||||
|
||||
if result.speakers:
|
||||
print(f"\n--- Speaker Segments ---")
|
||||
for segment in result.get_speaker_segments():
|
||||
speaker = segment["speaker"] or "?"
|
||||
text = segment["text"]
|
||||
start = segment["start"]
|
||||
print(f"[{start:.1f}s] Speaker {speaker}: {text}")
|
||||
|
||||
if show_words:
|
||||
print(f"\n--- Words ---")
|
||||
for word in result.words[:50]: # First 50 words
|
||||
speaker_info = f" (S{word.speaker})" if word.speaker else ""
|
||||
print(f" {word.start:.2f}s: {word.word}{speaker_info} [{word.confidence:.2f}]")
|
||||
if len(result.words) > 50:
|
||||
print(f" ... and {len(result.words) - 50} more words")
|
||||
|
||||
|
||||
def save_results(
|
||||
results: list[TranscriptionResult],
|
||||
output_dir: Path,
|
||||
audio_name: str,
|
||||
) -> Path:
|
||||
"""Save results to JSON file."""
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
output_file = output_dir / f"{audio_name}_{timestamp}.json"
|
||||
|
||||
output_data = {
|
||||
"timestamp": timestamp,
|
||||
"audio_file": audio_name,
|
||||
"results": [r.to_dict() for r in results],
|
||||
}
|
||||
|
||||
with open(output_file, "w") as f:
|
||||
json.dump(output_data, f, indent=2)
|
||||
|
||||
print(f"\nResults saved to: {output_file}")
|
||||
return output_file
|
||||
|
||||
|
||||
def compare_results(results: list[TranscriptionResult]) -> None:
|
||||
"""Compare results across providers."""
|
||||
if len(results) < 2:
|
||||
return
|
||||
|
||||
print(f"\n{'='*60}")
|
||||
print("COMPARISON SUMMARY")
|
||||
print(f"{'='*60}")
|
||||
|
||||
print(f"\n{'Provider':<15} {'Duration':<10} {'Speakers':<10} {'Words':<10}")
|
||||
print("-" * 45)
|
||||
for r in results:
|
||||
print(f"{r.provider:<15} {r.duration:<10.2f} {len(r.speakers):<10} {len(r.words):<10}")
|
||||
|
||||
# Compare speaker counts
|
||||
speaker_counts = {r.provider: len(r.speakers) for r in results}
|
||||
if len(set(speaker_counts.values())) > 1:
|
||||
print(f"\nNote: Providers detected different speaker counts: {speaker_counts}")
|
||||
|
||||
|
||||
async def main() -> int:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="STT Benchmark - Compare transcription providers",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
Examples:
|
||||
python -m evals.stt.benchmark audio/multi_speaker.m4a --diarize
|
||||
python -m evals.stt.benchmark audio/multi_speaker.m4a --diarize --providers deepgram
|
||||
python -m evals.stt.benchmark audio/multi_speaker.m4a --keyterms "Dograh" "API"
|
||||
""",
|
||||
)
|
||||
parser.add_argument(
|
||||
"audio_file",
|
||||
type=str,
|
||||
help="Path to audio file (relative to evals/stt/ or absolute)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--providers",
|
||||
nargs="+",
|
||||
default=["deepgram", "speechmatics"],
|
||||
choices=["deepgram", "deepgram-flux", "speechmatics", "local-smart-turn"],
|
||||
help="Providers to test (default: all)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--diarize",
|
||||
action="store_true",
|
||||
help="Enable speaker diarization",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--keyterms",
|
||||
nargs="+",
|
||||
help="Keywords to boost (Deepgram only)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--language",
|
||||
default="en",
|
||||
help="Language code (default: en)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sample-rate",
|
||||
type=int,
|
||||
default=8000,
|
||||
help="Audio sample rate for streaming (default: 8000)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--show-words",
|
||||
action="store_true",
|
||||
help="Show individual word timings",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save",
|
||||
action="store_true",
|
||||
help="Save results to JSON file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-dir",
|
||||
type=str,
|
||||
default="results",
|
||||
help="Output directory for results (default: results)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Resolve audio path
|
||||
script_dir = Path(__file__).parent
|
||||
audio_path = Path(args.audio_file)
|
||||
if not audio_path.is_absolute():
|
||||
audio_path = script_dir / audio_path
|
||||
|
||||
if not audio_path.exists():
|
||||
print(f"Error: Audio file not found: {audio_path}")
|
||||
return 1
|
||||
|
||||
print(f"Audio file: {audio_path}")
|
||||
print(f"Providers: {args.providers}")
|
||||
print(f"Diarization: {args.diarize}")
|
||||
print(f"Sample rate: {args.sample_rate} Hz")
|
||||
if args.keyterms:
|
||||
print(f"Keyterms: {args.keyterms}")
|
||||
|
||||
results: list[TranscriptionResult] = []
|
||||
|
||||
for provider_name in args.providers:
|
||||
try:
|
||||
provider = get_provider(provider_name)
|
||||
result = await run_transcription(
|
||||
provider,
|
||||
audio_path,
|
||||
diarize=args.diarize,
|
||||
keyterms=args.keyterms,
|
||||
language=args.language,
|
||||
sample_rate=args.sample_rate,
|
||||
)
|
||||
print_result(result, show_words=args.show_words)
|
||||
results.append(result)
|
||||
except Exception as e:
|
||||
print(f"\nFailed to run {provider_name}: {e}")
|
||||
continue
|
||||
|
||||
if len(results) > 1:
|
||||
compare_results(results)
|
||||
|
||||
if args.save and results:
|
||||
output_dir = script_dir / args.output_dir
|
||||
save_results(results, output_dir, audio_path.stem)
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(asyncio.run(main()))
|
||||
15
evals/stt/providers/__init__.py
Normal file
15
evals/stt/providers/__init__.py
Normal file
|
|
@ -0,0 +1,15 @@
|
|||
from .base import STTProvider, TranscriptionResult, Word
|
||||
from .deepgram_provider import DeepgramProvider
|
||||
from .deepgram_flux_provider import DeepgramFluxProvider
|
||||
from .speechmatics_provider import SpeechmaticsProvider
|
||||
from .local_smart_turn_provider import LocalSmartTurnProvider
|
||||
|
||||
__all__ = [
|
||||
"STTProvider",
|
||||
"TranscriptionResult",
|
||||
"Word",
|
||||
"DeepgramProvider",
|
||||
"DeepgramFluxProvider",
|
||||
"SpeechmaticsProvider",
|
||||
"LocalSmartTurnProvider",
|
||||
]
|
||||
123
evals/stt/providers/base.py
Normal file
123
evals/stt/providers/base.py
Normal file
|
|
@ -0,0 +1,123 @@
|
|||
"""Base classes for STT providers."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
|
||||
@dataclass
|
||||
class Word:
|
||||
"""Represents a transcribed word with metadata."""
|
||||
|
||||
word: str
|
||||
start: float
|
||||
end: float
|
||||
confidence: float
|
||||
speaker: str | None = None
|
||||
speaker_confidence: float | None = None
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"word": self.word,
|
||||
"start": self.start,
|
||||
"end": self.end,
|
||||
"confidence": self.confidence,
|
||||
"speaker": self.speaker,
|
||||
"speaker_confidence": self.speaker_confidence,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class TranscriptionResult:
|
||||
"""Result from STT transcription."""
|
||||
|
||||
provider: str
|
||||
transcript: str
|
||||
words: list[Word]
|
||||
speakers: list[str]
|
||||
duration: float
|
||||
raw_response: dict[str, Any] = field(default_factory=dict)
|
||||
params: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"provider": self.provider,
|
||||
"transcript": self.transcript,
|
||||
"words": [w.to_dict() for w in self.words],
|
||||
"speakers": self.speakers,
|
||||
"duration": self.duration,
|
||||
"params": self.params,
|
||||
}
|
||||
|
||||
def get_speaker_segments(self) -> list[dict[str, Any]]:
|
||||
"""Get transcript segmented by speaker."""
|
||||
if not self.words:
|
||||
return []
|
||||
|
||||
segments = []
|
||||
current_speaker = None
|
||||
current_text = []
|
||||
segment_start = 0.0
|
||||
|
||||
for word in self.words:
|
||||
if word.speaker != current_speaker:
|
||||
if current_text:
|
||||
segments.append(
|
||||
{
|
||||
"speaker": current_speaker,
|
||||
"text": " ".join(current_text),
|
||||
"start": segment_start,
|
||||
"end": self.words[len(segments) - 1].end
|
||||
if segments
|
||||
else word.start,
|
||||
}
|
||||
)
|
||||
current_speaker = word.speaker
|
||||
current_text = [word.word]
|
||||
segment_start = word.start
|
||||
else:
|
||||
current_text.append(word.word)
|
||||
|
||||
if current_text:
|
||||
segments.append(
|
||||
{
|
||||
"speaker": current_speaker,
|
||||
"text": " ".join(current_text),
|
||||
"start": segment_start,
|
||||
"end": self.words[-1].end if self.words else 0.0,
|
||||
}
|
||||
)
|
||||
|
||||
return segments
|
||||
|
||||
|
||||
class STTProvider(ABC):
|
||||
"""Abstract base class for STT providers."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def name(self) -> str:
|
||||
"""Provider name."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def transcribe(
|
||||
self,
|
||||
audio_path: Path,
|
||||
diarize: bool = False,
|
||||
keyterms: list[str] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> TranscriptionResult:
|
||||
"""Transcribe audio file.
|
||||
|
||||
Args:
|
||||
audio_path: Path to the audio file
|
||||
diarize: Enable speaker diarization
|
||||
keyterms: List of keywords to boost (if supported)
|
||||
**kwargs: Provider-specific parameters
|
||||
|
||||
Returns:
|
||||
TranscriptionResult with transcript and metadata
|
||||
"""
|
||||
pass
|
||||
225
evals/stt/providers/deepgram_flux_provider.py
Normal file
225
evals/stt/providers/deepgram_flux_provider.py
Normal file
|
|
@ -0,0 +1,225 @@
|
|||
"""Deepgram Flux STT provider with WebSocket streaming.
|
||||
|
||||
Flux is Deepgram's conversational AI model with built-in turn detection.
|
||||
It has a different API than Nova models - no language/punctuate/diarize params.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from urllib.parse import urlencode
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from ..audio_streamer import AudioConfig, AudioStreamer
|
||||
from .base import STTProvider, TranscriptionResult, Word
|
||||
|
||||
try:
|
||||
from websockets.asyncio.client import connect as websocket_connect
|
||||
except ImportError:
|
||||
raise ImportError("websockets required: pip install websockets")
|
||||
|
||||
|
||||
class DeepgramFluxProvider(STTProvider):
|
||||
"""Deepgram Flux Speech-to-Text provider with WebSocket streaming.
|
||||
|
||||
Flux is optimized for conversational AI with built-in turn detection.
|
||||
|
||||
Key differences from Nova:
|
||||
- Uses v2 API endpoint
|
||||
- Only supports English (flux-general-en)
|
||||
- No punctuate, diarize, or language params
|
||||
- Has turn detection events (StartOfTurn, EndOfTurn, EagerEndOfTurn)
|
||||
- Supports keyterm boosting
|
||||
|
||||
API Docs: https://developers.deepgram.com/docs/
|
||||
"""
|
||||
|
||||
WS_URL = "wss://api.deepgram.com/v2/listen"
|
||||
|
||||
def __init__(self, api_key: str | None = None):
|
||||
self.api_key = api_key or os.getenv("DEEPGRAM_API_KEY")
|
||||
if not self.api_key:
|
||||
raise ValueError(
|
||||
"Deepgram API key required. Set DEEPGRAM_API_KEY env var or pass api_key."
|
||||
)
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "deepgram-flux"
|
||||
|
||||
async def transcribe(
|
||||
self,
|
||||
audio_path: Path,
|
||||
diarize: bool = False, # Ignored - Flux doesn't support diarization
|
||||
keyterms: list[str] | None = None,
|
||||
model: str = "flux-general-en",
|
||||
sample_rate: int = 16000,
|
||||
eot_threshold: float | None = None,
|
||||
eot_timeout_ms: int | None = None,
|
||||
eager_eot_threshold: float | None = None,
|
||||
**kwargs: Any,
|
||||
) -> TranscriptionResult:
|
||||
"""Transcribe audio using Deepgram Flux WebSocket streaming.
|
||||
|
||||
Args:
|
||||
audio_path: Path to audio file
|
||||
diarize: IGNORED - Flux does not support diarization
|
||||
keyterms: List of keywords to boost recognition
|
||||
model: Flux model (default: flux-general-en)
|
||||
sample_rate: Audio sample rate (default: 16000 for Flux)
|
||||
eot_threshold: End-of-turn confidence threshold (0-1, default 0.7)
|
||||
eot_timeout_ms: Timeout in ms to force end of turn (default 5000)
|
||||
eager_eot_threshold: Threshold for eager end-of-turn events
|
||||
**kwargs: Additional Flux parameters
|
||||
|
||||
Returns:
|
||||
TranscriptionResult with transcript (no speaker info - Flux doesn't support diarization)
|
||||
"""
|
||||
if diarize:
|
||||
logger.warning("Flux does not support diarization - ignoring diarize=True")
|
||||
|
||||
# Build query params - Flux only supports specific params
|
||||
params: dict[str, Any] = {
|
||||
"model": model,
|
||||
"encoding": "linear16",
|
||||
"sample_rate": sample_rate,
|
||||
}
|
||||
|
||||
# Flux-specific turn detection params
|
||||
if eot_threshold is not None:
|
||||
params["eot_threshold"] = eot_threshold
|
||||
if eot_timeout_ms is not None:
|
||||
params["eot_timeout_ms"] = eot_timeout_ms
|
||||
if eager_eot_threshold is not None:
|
||||
params["eager_eot_threshold"] = eager_eot_threshold
|
||||
|
||||
# Build URL with params
|
||||
url_parts = [f"{k}={v}" for k, v in params.items()]
|
||||
|
||||
# Add keyterms (repeated params)
|
||||
if keyterms:
|
||||
for term in keyterms:
|
||||
url_parts.append(urlencode({"keyterm": term}))
|
||||
|
||||
ws_url = f"{self.WS_URL}?{'&'.join(url_parts)}"
|
||||
logger.debug(f"Flux WebSocket URL: {ws_url}")
|
||||
|
||||
# Setup audio streamer
|
||||
audio_config = AudioConfig(sample_rate=sample_rate)
|
||||
streamer = AudioStreamer(audio_config)
|
||||
|
||||
# Collect results
|
||||
all_transcripts: list[dict[str, Any]] = []
|
||||
final_transcript = ""
|
||||
duration = 0.0
|
||||
connected = asyncio.Event()
|
||||
|
||||
async with websocket_connect(
|
||||
ws_url,
|
||||
additional_headers={"Authorization": f"Token {self.api_key}"},
|
||||
) as ws:
|
||||
|
||||
async def send_audio():
|
||||
"""Send audio chunks to Deepgram Flux."""
|
||||
await connected.wait()
|
||||
|
||||
chunk_no = 0
|
||||
async for chunk in streamer.stream_file(audio_path):
|
||||
logger.debug(f"[deepgram-flux] Sent audio chunk {chunk_no}")
|
||||
await ws.send(chunk)
|
||||
chunk_no += 1
|
||||
|
||||
async def receive_messages():
|
||||
"""Receive and collect Flux messages."""
|
||||
nonlocal all_transcripts, final_transcript, duration
|
||||
|
||||
async for message in ws:
|
||||
if isinstance(message, str):
|
||||
data = json.loads(message)
|
||||
msg_type = data.get("type")
|
||||
logger.debug(f"[deepgram-flux] Received {msg_type}: {data}")
|
||||
|
||||
if msg_type == "Connected":
|
||||
logger.info("[deepgram-flux] Connected")
|
||||
connected.set()
|
||||
|
||||
elif msg_type == "TurnInfo":
|
||||
event = data.get("event")
|
||||
transcript = data.get("transcript", "")
|
||||
words = data.get("words", [])
|
||||
|
||||
if event == "EndOfTurn":
|
||||
if transcript:
|
||||
final_transcript += transcript + " "
|
||||
if words:
|
||||
all_transcripts.append({
|
||||
"transcript": transcript,
|
||||
"words": words,
|
||||
})
|
||||
# Get duration from last word
|
||||
if words:
|
||||
last_word = words[-1]
|
||||
duration = max(duration, last_word.get("end", 0))
|
||||
|
||||
elif event == "TurnResumed":
|
||||
logger.debug("TurnResumed")
|
||||
|
||||
elif msg_type == "Error":
|
||||
raise Exception(f"Deepgram Flux error: {data}")
|
||||
|
||||
# Run send and receive concurrently
|
||||
send_task = asyncio.create_task(send_audio())
|
||||
receive_task = asyncio.create_task(receive_messages())
|
||||
|
||||
await send_task
|
||||
|
||||
logger.debug("[deepgram-flux] Send task done")
|
||||
try:
|
||||
await asyncio.wait_for(receive_task, timeout=10.0)
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
|
||||
return self._parse_results(
|
||||
all_transcripts, final_transcript.strip(), duration, params, keyterms
|
||||
)
|
||||
|
||||
def _parse_results(
|
||||
self,
|
||||
transcripts: list[dict[str, Any]],
|
||||
final_transcript: str,
|
||||
duration: float,
|
||||
params: dict[str, Any],
|
||||
keyterms: list[str] | None,
|
||||
) -> TranscriptionResult:
|
||||
"""Parse collected Flux results into TranscriptionResult."""
|
||||
words = []
|
||||
|
||||
for turn in transcripts:
|
||||
for w in turn.get("words", []):
|
||||
words.append(
|
||||
Word(
|
||||
word=w.get("word", ""),
|
||||
start=w.get("start", 0.0),
|
||||
end=w.get("end", 0.0),
|
||||
confidence=w.get("confidence", 0.0),
|
||||
speaker=None, # Flux doesn't support diarization
|
||||
speaker_confidence=None,
|
||||
)
|
||||
)
|
||||
|
||||
stored_params = dict(params)
|
||||
if keyterms:
|
||||
stored_params["keyterms"] = keyterms
|
||||
|
||||
return TranscriptionResult(
|
||||
provider=self.name,
|
||||
transcript=final_transcript,
|
||||
words=words,
|
||||
speakers=[], # Flux doesn't support diarization
|
||||
duration=duration,
|
||||
raw_response={"transcripts": transcripts},
|
||||
params=stored_params,
|
||||
)
|
||||
225
evals/stt/providers/deepgram_provider.py
Normal file
225
evals/stt/providers/deepgram_provider.py
Normal file
|
|
@ -0,0 +1,225 @@
|
|||
"""Deepgram STT provider with WebSocket streaming."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from urllib.parse import urlencode
|
||||
|
||||
from ..audio_streamer import AudioConfig, AudioStreamer
|
||||
from .base import STTProvider, TranscriptionResult, Word
|
||||
from loguru import logger
|
||||
|
||||
try:
|
||||
from websockets.asyncio.client import connect as websocket_connect
|
||||
except ImportError:
|
||||
raise ImportError("websockets required: pip install websockets")
|
||||
|
||||
|
||||
class DeepgramProvider(STTProvider):
|
||||
"""Deepgram Nova Speech-to-Text provider with WebSocket streaming.
|
||||
|
||||
API Docs: https://developers.deepgram.com/docs/
|
||||
|
||||
Supports:
|
||||
- Speaker diarization via `diarize=true`
|
||||
- Keyterm boosting via `keyterm` parameter
|
||||
- Real-time streaming via WebSocket
|
||||
- Multiple languages
|
||||
- Punctuation
|
||||
|
||||
For Flux models, use DeepgramFluxProvider instead.
|
||||
"""
|
||||
|
||||
WS_URL = "wss://api.deepgram.com/v1/listen"
|
||||
|
||||
def __init__(self, api_key: str | None = None):
|
||||
self.api_key = api_key or os.getenv("DEEPGRAM_API_KEY")
|
||||
if not self.api_key:
|
||||
raise ValueError(
|
||||
"Deepgram API key required. Set DEEPGRAM_API_KEY env var or pass api_key."
|
||||
)
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "deepgram"
|
||||
|
||||
async def transcribe(
|
||||
self,
|
||||
audio_path: Path,
|
||||
diarize: bool = False,
|
||||
keyterms: list[str] | None = None,
|
||||
model: str = "nova-3-general",
|
||||
language: str = "en",
|
||||
sample_rate: int = 8000,
|
||||
punctuate: bool = True,
|
||||
**kwargs: Any,
|
||||
) -> TranscriptionResult:
|
||||
"""Transcribe audio using Deepgram Nova WebSocket streaming.
|
||||
|
||||
Args:
|
||||
audio_path: Path to audio file
|
||||
diarize: Enable speaker diarization
|
||||
keyterms: List of keywords to boost recognition
|
||||
model: Deepgram Nova model (nova-3, nova-2, etc.)
|
||||
language: Language code
|
||||
sample_rate: Audio sample rate for streaming
|
||||
punctuate: Add punctuation
|
||||
**kwargs: Additional Deepgram parameters
|
||||
|
||||
Returns:
|
||||
TranscriptionResult with transcript and speaker info
|
||||
"""
|
||||
# Build query params
|
||||
params: dict[str, Any] = {
|
||||
"model": model,
|
||||
"language": language,
|
||||
"punctuate": str(punctuate).lower(),
|
||||
"encoding": "linear16",
|
||||
"sample_rate": sample_rate,
|
||||
"channels": 1,
|
||||
"interim_results": "true",
|
||||
"smart_format": "true",
|
||||
"profanity_filter": "true",
|
||||
"vad_events": "true"
|
||||
}
|
||||
|
||||
if diarize:
|
||||
params["diarize"] = "true"
|
||||
|
||||
# Build URL with params
|
||||
url_parts = [f"{k}={v}" for k, v in params.items()]
|
||||
|
||||
# Add keyterms (repeated params)
|
||||
if keyterms:
|
||||
for term in keyterms:
|
||||
url_parts.append(urlencode({"keyterm": term}))
|
||||
|
||||
# Add extra kwargs
|
||||
for k, v in kwargs.items():
|
||||
url_parts.append(f"{k}={v}")
|
||||
|
||||
ws_url = f"{self.WS_URL}?{'&'.join(url_parts)}"
|
||||
logger.debug(f"Deepgram WebSocket URL: {ws_url}")
|
||||
|
||||
# Setup audio streamer
|
||||
audio_config = AudioConfig(sample_rate=sample_rate)
|
||||
streamer = AudioStreamer(audio_config)
|
||||
|
||||
# Collect results
|
||||
all_words: list[dict[str, Any]] = []
|
||||
final_transcript = ""
|
||||
duration = 0.0
|
||||
|
||||
try:
|
||||
async with websocket_connect(
|
||||
ws_url,
|
||||
additional_headers={"Authorization": f"Token {self.api_key}"},
|
||||
) as ws:
|
||||
# Create tasks for sending and receiving
|
||||
send_complete = asyncio.Event()
|
||||
|
||||
async def send_audio():
|
||||
"""Send audio chunks to Deepgram."""
|
||||
chunk_no = 0
|
||||
async for chunk in streamer.stream_file(audio_path):
|
||||
logger.debug(f"[deepgram] Sent audio chunk {chunk_no}")
|
||||
await ws.send(chunk)
|
||||
chunk_no += 1
|
||||
# Send close message
|
||||
logger.debug(f"[deepgram] Sending CloseStream after {chunk_no} chunks")
|
||||
await ws.send(json.dumps({"type": "CloseStream"}))
|
||||
send_complete.set()
|
||||
|
||||
async def receive_transcripts():
|
||||
"""Receive and collect transcription results."""
|
||||
nonlocal all_words, final_transcript, duration
|
||||
|
||||
async for message in ws:
|
||||
if isinstance(message, str):
|
||||
data = json.loads(message)
|
||||
msg_type = data.get("type")
|
||||
logger.debug(f"[deepgram] Received {msg_type}: {data}")
|
||||
|
||||
if msg_type == "Results":
|
||||
# Nova-style response
|
||||
channel = data.get("channel", {})
|
||||
alternatives = channel.get("alternatives", [])
|
||||
if alternatives:
|
||||
alt = alternatives[0]
|
||||
words = alt.get("words", [])
|
||||
all_words.extend(words)
|
||||
|
||||
# Check if final
|
||||
if data.get("is_final"):
|
||||
final_transcript += alt.get("transcript", "") + " "
|
||||
duration = max(
|
||||
duration, data.get("duration", 0) + data.get("start", 0)
|
||||
)
|
||||
|
||||
elif msg_type == "Metadata":
|
||||
# Get duration from metadata
|
||||
duration = data.get("duration", duration)
|
||||
|
||||
elif msg_type == "Error":
|
||||
raise Exception(f"Deepgram error: {data}")
|
||||
|
||||
# Run send and receive concurrently
|
||||
send_task = asyncio.create_task(send_audio())
|
||||
receive_task = asyncio.create_task(receive_transcripts())
|
||||
|
||||
# Wait for send to complete, then wait a bit for final results
|
||||
await send_task
|
||||
try:
|
||||
await asyncio.wait_for(receive_task, timeout=5.0)
|
||||
except asyncio.TimeoutError:
|
||||
pass # Normal - websocket closes after final results
|
||||
except Exception as e:
|
||||
logger.debug(e)
|
||||
|
||||
return self._parse_results(
|
||||
all_words, final_transcript.strip(), duration, params, keyterms
|
||||
)
|
||||
|
||||
def _parse_results(
|
||||
self,
|
||||
raw_words: list[dict[str, Any]],
|
||||
transcript: str,
|
||||
duration: float,
|
||||
params: dict[str, Any],
|
||||
keyterms: list[str] | None,
|
||||
) -> TranscriptionResult:
|
||||
"""Parse collected results into TranscriptionResult."""
|
||||
words = []
|
||||
speakers_set: set[str] = set()
|
||||
|
||||
for w in raw_words:
|
||||
speaker = str(w.get("speaker", "")) if "speaker" in w else None
|
||||
if speaker:
|
||||
speakers_set.add(speaker)
|
||||
|
||||
words.append(
|
||||
Word(
|
||||
word=w.get("word", ""),
|
||||
start=w.get("start", 0.0),
|
||||
end=w.get("end", 0.0),
|
||||
confidence=w.get("confidence", 0.0),
|
||||
speaker=speaker,
|
||||
speaker_confidence=w.get("speaker_confidence"),
|
||||
)
|
||||
)
|
||||
|
||||
stored_params = dict(params)
|
||||
if keyterms:
|
||||
stored_params["keyterms"] = keyterms
|
||||
|
||||
return TranscriptionResult(
|
||||
provider=self.name,
|
||||
transcript=transcript,
|
||||
words=words,
|
||||
speakers=sorted(speakers_set),
|
||||
duration=duration,
|
||||
raw_response={"words": raw_words},
|
||||
params=stored_params,
|
||||
)
|
||||
285
evals/stt/providers/local_smart_turn_provider.py
Normal file
285
evals/stt/providers/local_smart_turn_provider.py
Normal file
|
|
@ -0,0 +1,285 @@
|
|||
"""Local Smart Turn provider for benchmarking end-of-turn detection.
|
||||
|
||||
Uses the pipecat smart-turn-v3 ONNX model for local ML-based turn detection.
|
||||
This is NOT an STT provider - it only detects when a speaker has finished talking.
|
||||
"""
|
||||
|
||||
import os
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
from loguru import logger
|
||||
|
||||
from ..audio_streamer import AudioConfig, AudioStreamer
|
||||
from .base import STTProvider, TranscriptionResult, Word
|
||||
|
||||
try:
|
||||
import onnxruntime as ort
|
||||
from transformers import WhisperFeatureExtractor
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"onnxruntime and transformers required: pip install onnxruntime transformers"
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TurnEvent:
|
||||
"""Represents a detected turn event."""
|
||||
timestamp: float # Time in audio when turn was detected
|
||||
probability: float # Model confidence
|
||||
prediction: int # 1=complete, 0=incomplete
|
||||
inference_time_ms: float
|
||||
|
||||
|
||||
class LocalSmartTurnProvider(STTProvider):
|
||||
"""Local Smart Turn provider for end-of-turn detection benchmarking.
|
||||
|
||||
Uses the smart-turn-v3 ONNX model to detect when speakers finish talking.
|
||||
This is useful for comparing turn detection accuracy against cloud services
|
||||
like Deepgram Flux's built-in turn detection.
|
||||
|
||||
NOTE: This provider does NOT produce transcripts - only turn detection events.
|
||||
"""
|
||||
|
||||
# Smart turn model requires 16kHz audio
|
||||
REQUIRED_SAMPLE_RATE = 16000
|
||||
# Model analyzes 8 seconds of audio
|
||||
WINDOW_SECONDS = 8
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_path: str | None = None,
|
||||
cpu_count: int = 1,
|
||||
):
|
||||
"""Initialize the local smart turn provider.
|
||||
|
||||
Args:
|
||||
model_path: Path to ONNX model file. If None, uses bundled model.
|
||||
cpu_count: Number of CPUs for inference (default: 1)
|
||||
"""
|
||||
self.model_path = model_path
|
||||
self.cpu_count = cpu_count
|
||||
self._session = None
|
||||
self._feature_extractor = None
|
||||
|
||||
def _load_model(self):
|
||||
"""Lazy load the ONNX model."""
|
||||
if self._session is not None:
|
||||
return
|
||||
|
||||
model_path = self.model_path
|
||||
|
||||
if not model_path:
|
||||
# Try to load bundled model from pipecat
|
||||
model_name = "smart-turn-v3.1-cpu.onnx"
|
||||
package_path = "pipecat.audio.turn.smart_turn.data"
|
||||
|
||||
try:
|
||||
import importlib_resources as impresources
|
||||
model_path = str(impresources.files(package_path).joinpath(model_name))
|
||||
except Exception:
|
||||
from importlib import resources as impresources
|
||||
try:
|
||||
with impresources.path(package_path, model_name) as f:
|
||||
model_path = str(f)
|
||||
except Exception:
|
||||
model_path = str(impresources.files(package_path).joinpath(model_name))
|
||||
|
||||
logger.info(f"[local-smart-turn] Loading model from {model_path}")
|
||||
|
||||
# Configure ONNX runtime
|
||||
so = ort.SessionOptions()
|
||||
so.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL
|
||||
so.inter_op_num_threads = 1
|
||||
so.intra_op_num_threads = self.cpu_count
|
||||
so.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
|
||||
|
||||
self._feature_extractor = WhisperFeatureExtractor(chunk_length=8)
|
||||
self._session = ort.InferenceSession(model_path, sess_options=so)
|
||||
|
||||
logger.info("[local-smart-turn] Model loaded")
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "local-smart-turn"
|
||||
|
||||
def _predict_endpoint(self, audio_array: np.ndarray) -> dict[str, Any]:
|
||||
"""Predict end-of-turn using the ONNX model.
|
||||
|
||||
Args:
|
||||
audio_array: Audio samples as float32 numpy array (16kHz)
|
||||
|
||||
Returns:
|
||||
Dict with prediction (0/1) and probability
|
||||
"""
|
||||
# Truncate to last 8 seconds or pad to 8 seconds
|
||||
max_samples = self.WINDOW_SECONDS * self.REQUIRED_SAMPLE_RATE
|
||||
if len(audio_array) > max_samples:
|
||||
audio_array = audio_array[-max_samples:]
|
||||
elif len(audio_array) < max_samples:
|
||||
padding = max_samples - len(audio_array)
|
||||
audio_array = np.pad(audio_array, (padding, 0), mode="constant", constant_values=0)
|
||||
|
||||
# Process using Whisper's feature extractor
|
||||
inputs = self._feature_extractor(
|
||||
audio_array,
|
||||
sampling_rate=self.REQUIRED_SAMPLE_RATE,
|
||||
return_tensors="np",
|
||||
padding="max_length",
|
||||
max_length=self.WINDOW_SECONDS * self.REQUIRED_SAMPLE_RATE,
|
||||
truncation=True,
|
||||
do_normalize=True,
|
||||
)
|
||||
|
||||
# Extract features for ONNX
|
||||
input_features = inputs.input_features.squeeze(0).astype(np.float32)
|
||||
input_features = np.expand_dims(input_features, axis=0)
|
||||
|
||||
# Run inference
|
||||
start_time = time.perf_counter()
|
||||
outputs = self._session.run(None, {"input_features": input_features})
|
||||
inference_time = (time.perf_counter() - start_time) * 1000
|
||||
|
||||
# Extract probability (model returns sigmoid probabilities)
|
||||
probability = outputs[0][0].item()
|
||||
prediction = 1 if probability > 0.5 else 0
|
||||
|
||||
return {
|
||||
"prediction": prediction,
|
||||
"probability": probability,
|
||||
"inference_time_ms": inference_time,
|
||||
}
|
||||
|
||||
async def transcribe(
|
||||
self,
|
||||
audio_path: Path,
|
||||
diarize: bool = False, # Ignored - not applicable
|
||||
keyterms: list[str] | None = None, # Ignored - not applicable
|
||||
sample_rate: int = 16000, # Must be 16kHz for smart turn
|
||||
analysis_interval_ms: int = 500, # How often to check for turn completion
|
||||
**kwargs: Any,
|
||||
) -> TranscriptionResult:
|
||||
"""Analyze audio for turn detection events.
|
||||
|
||||
NOTE: This does NOT produce transcripts. It detects when speakers
|
||||
finish talking using ML-based turn detection.
|
||||
|
||||
Args:
|
||||
audio_path: Path to audio file
|
||||
diarize: Ignored (not applicable for turn detection)
|
||||
keyterms: Ignored (not applicable for turn detection)
|
||||
sample_rate: Must be 16000 Hz for smart turn model
|
||||
analysis_interval_ms: How often to run turn detection (ms)
|
||||
**kwargs: Additional parameters (ignored)
|
||||
|
||||
Returns:
|
||||
TranscriptionResult with turn detection events in raw_response
|
||||
"""
|
||||
if sample_rate != self.REQUIRED_SAMPLE_RATE:
|
||||
logger.warning(
|
||||
f"[local-smart-turn] Sample rate must be {self.REQUIRED_SAMPLE_RATE}Hz, "
|
||||
f"overriding {sample_rate}Hz"
|
||||
)
|
||||
sample_rate = self.REQUIRED_SAMPLE_RATE
|
||||
|
||||
# Load model if not already loaded
|
||||
self._load_model()
|
||||
|
||||
# Setup audio streamer at 16kHz
|
||||
audio_config = AudioConfig(sample_rate=sample_rate)
|
||||
streamer = AudioStreamer(audio_config)
|
||||
|
||||
# Get audio duration
|
||||
duration = streamer.get_duration(audio_path)
|
||||
logger.info(f"[local-smart-turn] Processing {audio_path} ({duration:.2f}s)")
|
||||
|
||||
# Collect all audio first (smart turn needs to analyze segments)
|
||||
pcm_data = streamer.convert_to_pcm16(audio_path)
|
||||
|
||||
# Convert to float32 for model
|
||||
audio_int16 = np.frombuffer(pcm_data, dtype=np.int16)
|
||||
audio_float32 = audio_int16.astype(np.float32) / 32768.0
|
||||
|
||||
# Analyze at intervals
|
||||
turn_events: list[TurnEvent] = []
|
||||
samples_per_interval = int(sample_rate * analysis_interval_ms / 1000)
|
||||
window_samples = self.WINDOW_SECONDS * sample_rate
|
||||
|
||||
chunk_no = 0
|
||||
for end_sample in range(samples_per_interval, len(audio_float32), samples_per_interval):
|
||||
# Get window of audio ending at current position
|
||||
start_sample = max(0, end_sample - window_samples)
|
||||
audio_window = audio_float32[start_sample:end_sample]
|
||||
|
||||
current_time = end_sample / sample_rate
|
||||
logger.debug(f"[local-smart-turn] Analyzing chunk {chunk_no} at {current_time:.2f}s")
|
||||
|
||||
result = self._predict_endpoint(audio_window)
|
||||
|
||||
turn_events.append(TurnEvent(
|
||||
timestamp=current_time,
|
||||
probability=result["probability"],
|
||||
prediction=result["prediction"],
|
||||
inference_time_ms=result["inference_time_ms"],
|
||||
))
|
||||
|
||||
if result["prediction"] == 1:
|
||||
logger.info(
|
||||
f"[local-smart-turn] Turn complete at {current_time:.2f}s "
|
||||
f"(prob={result['probability']:.3f})"
|
||||
f"(inf time ms={result["inference_time_ms"]})"
|
||||
)
|
||||
|
||||
chunk_no += 1
|
||||
|
||||
# Create result
|
||||
# Convert turn events to word-like format for compatibility
|
||||
words = []
|
||||
for event in turn_events:
|
||||
if event.prediction == 1:
|
||||
words.append(Word(
|
||||
word=f"[END_OF_TURN prob={event.probability:.2f}]",
|
||||
start=event.timestamp - 0.1,
|
||||
end=event.timestamp,
|
||||
confidence=event.probability,
|
||||
speaker=None,
|
||||
speaker_confidence=None,
|
||||
))
|
||||
|
||||
# Count completed turns
|
||||
completed_turns = sum(1 for e in turn_events if e.prediction == 1)
|
||||
|
||||
params = {
|
||||
"sample_rate": sample_rate,
|
||||
"analysis_interval_ms": analysis_interval_ms,
|
||||
"window_seconds": self.WINDOW_SECONDS,
|
||||
}
|
||||
|
||||
return TranscriptionResult(
|
||||
provider=self.name,
|
||||
transcript=f"[Turn detection only - {completed_turns} turns detected]",
|
||||
words=words,
|
||||
speakers=[], # Not applicable
|
||||
duration=duration,
|
||||
raw_response={
|
||||
"turn_events": [
|
||||
{
|
||||
"timestamp": e.timestamp,
|
||||
"probability": e.probability,
|
||||
"prediction": e.prediction,
|
||||
"inference_time_ms": e.inference_time_ms,
|
||||
}
|
||||
for e in turn_events
|
||||
],
|
||||
"completed_turns": completed_turns,
|
||||
"total_analyses": len(turn_events),
|
||||
"avg_inference_time_ms": (
|
||||
sum(e.inference_time_ms for e in turn_events) / len(turn_events)
|
||||
if turn_events else 0
|
||||
),
|
||||
},
|
||||
params=params,
|
||||
)
|
||||
248
evals/stt/providers/speechmatics_provider.py
Normal file
248
evals/stt/providers/speechmatics_provider.py
Normal file
|
|
@ -0,0 +1,248 @@
|
|||
"""Speechmatics STT provider with WebSocket streaming."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from ..audio_streamer import AudioConfig, AudioStreamer
|
||||
from .base import STTProvider, TranscriptionResult, Word
|
||||
|
||||
try:
|
||||
from websockets.asyncio.client import connect as websocket_connect
|
||||
except ImportError:
|
||||
raise ImportError("websockets required: pip install websockets")
|
||||
|
||||
|
||||
class SpeechmaticsProvider(STTProvider):
|
||||
"""Speechmatics Speech-to-Text provider with WebSocket streaming.
|
||||
|
||||
API Docs: https://docs.speechmatics.com/
|
||||
|
||||
Supports:
|
||||
- Speaker diarization via `diarization: "speaker"` config
|
||||
- Speaker sensitivity tuning
|
||||
- Real-time streaming via WebSocket
|
||||
"""
|
||||
|
||||
def __init__(self, api_key: str | None = None, region: str = "eu2"):
|
||||
self.api_key = api_key or os.getenv("SPEECHMATICS_API_KEY")
|
||||
if not self.api_key:
|
||||
raise ValueError(
|
||||
"Speechmatics API key required. Set SPEECHMATICS_API_KEY env var or pass api_key."
|
||||
)
|
||||
# Set region-specific endpoint
|
||||
self.ws_url = f"wss://{region}.rt.speechmatics.com/v2"
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "speechmatics"
|
||||
|
||||
async def transcribe(
|
||||
self,
|
||||
audio_path: Path,
|
||||
diarize: bool = False,
|
||||
keyterms: list[str] | None = None,
|
||||
language: str = "en",
|
||||
operating_point: str = "enhanced",
|
||||
sample_rate: int = 8000,
|
||||
speaker_sensitivity: float | None = None,
|
||||
max_speakers: int | None = None,
|
||||
**kwargs: Any,
|
||||
) -> TranscriptionResult:
|
||||
"""Transcribe audio using Speechmatics WebSocket streaming.
|
||||
|
||||
Args:
|
||||
audio_path: Path to audio file
|
||||
diarize: Enable speaker diarization
|
||||
keyterms: Additional vocabulary (limited support)
|
||||
language: Language code
|
||||
operating_point: "standard" or "enhanced"
|
||||
sample_rate: Audio sample rate for streaming
|
||||
speaker_sensitivity: 0.0-1.0, higher = more speakers detected
|
||||
max_speakers: Maximum number of speakers to detect
|
||||
**kwargs: Additional config parameters
|
||||
|
||||
Returns:
|
||||
TranscriptionResult with transcript and speaker info
|
||||
"""
|
||||
# Build transcription config for StartRecognition message
|
||||
transcription_config: dict[str, Any] = {
|
||||
"language": language,
|
||||
"operating_point": operating_point,
|
||||
"enable_partials": False,
|
||||
}
|
||||
|
||||
if diarize:
|
||||
transcription_config["diarization"] = "speaker"
|
||||
if speaker_sensitivity is not None:
|
||||
transcription_config["speaker_diarization_config"] = {
|
||||
"speaker_sensitivity": speaker_sensitivity
|
||||
}
|
||||
if max_speakers is not None:
|
||||
if "speaker_diarization_config" not in transcription_config:
|
||||
transcription_config["speaker_diarization_config"] = {}
|
||||
transcription_config["speaker_diarization_config"]["max_speakers"] = max_speakers
|
||||
|
||||
# Add additional vocabulary if provided
|
||||
if keyterms:
|
||||
transcription_config["additional_vocab"] = [{"content": term} for term in keyterms]
|
||||
|
||||
# Audio format config
|
||||
audio_format = {
|
||||
"type": "raw",
|
||||
"encoding": "pcm_s16le",
|
||||
"sample_rate": sample_rate,
|
||||
}
|
||||
|
||||
# Store params for result
|
||||
params = {
|
||||
"diarize": diarize,
|
||||
"language": language,
|
||||
"operating_point": operating_point,
|
||||
"sample_rate": sample_rate,
|
||||
"speaker_sensitivity": speaker_sensitivity,
|
||||
"max_speakers": max_speakers,
|
||||
}
|
||||
|
||||
# Setup audio streamer
|
||||
audio_config = AudioConfig(sample_rate=sample_rate)
|
||||
streamer = AudioStreamer(audio_config)
|
||||
|
||||
# Collect results
|
||||
all_results: list[dict[str, Any]] = []
|
||||
recognition_started = asyncio.Event()
|
||||
transcription_complete = asyncio.Event()
|
||||
|
||||
async with websocket_connect(
|
||||
self.ws_url,
|
||||
additional_headers={"Authorization": f"Bearer {self.api_key}"},
|
||||
) as ws:
|
||||
# Send StartRecognition message
|
||||
start_msg = {
|
||||
"message": "StartRecognition",
|
||||
"transcription_config": transcription_config,
|
||||
"audio_format": audio_format,
|
||||
}
|
||||
await ws.send(json.dumps(start_msg))
|
||||
|
||||
async def send_audio():
|
||||
"""Send audio chunks after recognition starts."""
|
||||
await recognition_started.wait()
|
||||
|
||||
chunk_no = 0
|
||||
async for chunk in streamer.stream_file(audio_path):
|
||||
logger.debug(f"[speechmatics] Sent audio chunk {chunk_no}")
|
||||
await ws.send(chunk)
|
||||
chunk_no += 1
|
||||
|
||||
# Signal end of audio with last sequence number
|
||||
logger.debug(f"[speechmatics] Sending EndOfStream after {chunk_no} chunks")
|
||||
await ws.send(json.dumps({"message": "EndOfStream", "last_seq_no": chunk_no}))
|
||||
|
||||
async def receive_messages():
|
||||
"""Receive and process messages."""
|
||||
nonlocal all_results
|
||||
|
||||
async for message in ws:
|
||||
if isinstance(message, str):
|
||||
data = json.loads(message)
|
||||
msg_type = data.get("message")
|
||||
logger.debug(f"[speechmatics] Received {msg_type}: {data}")
|
||||
|
||||
if msg_type == "RecognitionStarted":
|
||||
logger.info("[speechmatics] Connected")
|
||||
recognition_started.set()
|
||||
|
||||
elif msg_type == "AddTranscript":
|
||||
# Final transcript segment
|
||||
results = data.get("results", [])
|
||||
all_results.extend(results)
|
||||
|
||||
elif msg_type == "EndOfTranscript":
|
||||
transcription_complete.set()
|
||||
return
|
||||
|
||||
elif msg_type == "Error":
|
||||
raise Exception(f"Speechmatics error: {data}")
|
||||
|
||||
elif msg_type == "Warning":
|
||||
logger.warning(f"[speechmatics] Warning: {data.get('reason')}")
|
||||
|
||||
# Run send and receive concurrently
|
||||
send_task = asyncio.create_task(send_audio())
|
||||
receive_task = asyncio.create_task(receive_messages())
|
||||
|
||||
# Wait for completion
|
||||
await send_task
|
||||
try:
|
||||
await asyncio.wait_for(transcription_complete.wait(), timeout=30.0)
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
|
||||
receive_task.cancel()
|
||||
try:
|
||||
await receive_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
return self._parse_results(all_results, params)
|
||||
|
||||
def _parse_results(
|
||||
self,
|
||||
results: list[dict[str, Any]],
|
||||
params: dict[str, Any],
|
||||
) -> TranscriptionResult:
|
||||
"""Parse Speechmatics results."""
|
||||
words = []
|
||||
speakers_set: set[str] = set()
|
||||
transcript_parts = []
|
||||
duration = 0.0
|
||||
|
||||
for item in results:
|
||||
item_type = item.get("type")
|
||||
alternatives = item.get("alternatives", [])
|
||||
|
||||
if not alternatives:
|
||||
continue
|
||||
|
||||
alt = alternatives[0]
|
||||
content = alt.get("content", "")
|
||||
speaker = alt.get("speaker")
|
||||
|
||||
if speaker:
|
||||
speakers_set.add(speaker)
|
||||
|
||||
end_time = item.get("end_time", 0.0)
|
||||
duration = max(duration, end_time)
|
||||
|
||||
if item_type == "word":
|
||||
words.append(
|
||||
Word(
|
||||
word=content,
|
||||
start=item.get("start_time", 0.0),
|
||||
end=end_time,
|
||||
confidence=alt.get("confidence", 0.0),
|
||||
speaker=speaker,
|
||||
speaker_confidence=None,
|
||||
)
|
||||
)
|
||||
transcript_parts.append(content)
|
||||
elif item_type == "punctuation":
|
||||
if transcript_parts:
|
||||
transcript_parts[-1] += content
|
||||
|
||||
transcript = " ".join(transcript_parts)
|
||||
|
||||
return TranscriptionResult(
|
||||
provider=self.name,
|
||||
transcript=transcript,
|
||||
words=words,
|
||||
speakers=sorted(speakers_set),
|
||||
duration=duration,
|
||||
raw_response={"results": results},
|
||||
params=params,
|
||||
)
|
||||
Loading…
Add table
Add a link
Reference in a new issue