mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-07 07:55:16 +02:00
add smart turn as provider
This commit is contained in:
parent
6d589b7452
commit
e301579c31
9 changed files with 960 additions and 178 deletions
|
|
@ -1,27 +1,29 @@
|
|||
# STT Evaluation Benchmark
|
||||
|
||||
Benchmark for comparing Speech-to-Text providers with focus on:
|
||||
Benchmark for comparing Speech-to-Text providers using **WebSocket streaming** with focus on:
|
||||
- **Speaker diarization** - identifying who said what
|
||||
- **Keyterm boosting** - improving recognition of specific terms (Deepgram)
|
||||
|
||||
## Providers
|
||||
|
||||
| Provider | Diarization | Keyterm Boost | Notes |
|
||||
|----------|-------------|---------------|-------|
|
||||
| Deepgram | Yes | Yes | `diarize=true`, `keyterm` param |
|
||||
| Speechmatics | Yes | No | `diarization: "speaker"` config |
|
||||
| Provider | Diarization | Keyterm Boost | Streaming |
|
||||
|----------|-------------|---------------|-----------|
|
||||
| Deepgram | Yes | Yes | WebSocket (v1/v2) |
|
||||
| Speechmatics | Yes | Additional vocab | WebSocket RT |
|
||||
|
||||
## Setup
|
||||
|
||||
```bash
|
||||
# Install dependencies (httpx is required)
|
||||
pip install httpx
|
||||
# Install dependencies
|
||||
pip install websockets
|
||||
|
||||
# Set API keys
|
||||
export DEEPGRAM_API_KEY="your-key"
|
||||
export SPEECHMATICS_API_KEY="your-key"
|
||||
```
|
||||
|
||||
**Note:** Requires `ffmpeg` installed for audio conversion to PCM16.
|
||||
|
||||
## Usage
|
||||
|
||||
Run from the project root directory:
|
||||
|
|
@ -33,9 +35,12 @@ python -m evals.stt.benchmark audio/multi_speaker.m4a --diarize
|
|||
# Test only Deepgram
|
||||
python -m evals.stt.benchmark audio/multi_speaker.m4a --diarize --providers deepgram
|
||||
|
||||
# Test with keyterm boosting (Deepgram only)
|
||||
# Test with keyterm boosting (Deepgram)
|
||||
python -m evals.stt.benchmark audio/multi_speaker.m4a --diarize --keyterms "Dograh" "Pipecat"
|
||||
|
||||
# Use different sample rate (default: 8000 Hz)
|
||||
python -m evals.stt.benchmark audio/multi_speaker.m4a --diarize --sample-rate 16000
|
||||
|
||||
# Show word-level timings
|
||||
python -m evals.stt.benchmark audio/multi_speaker.m4a --diarize --show-words
|
||||
|
||||
|
|
@ -50,8 +55,9 @@ python -m evals.stt.benchmark audio/multi_speaker.m4a --diarize --save
|
|||
| `audio_file` | Path to audio file (relative to evals/stt/ or absolute) |
|
||||
| `--providers` | Providers to test: `deepgram`, `speechmatics` (default: both) |
|
||||
| `--diarize` | Enable speaker diarization |
|
||||
| `--keyterms` | Keywords to boost (Deepgram only) |
|
||||
| `--keyterms` | Keywords to boost (Deepgram) / additional vocab (Speechmatics) |
|
||||
| `--language` | Language code (default: en) |
|
||||
| `--sample-rate` | Audio sample rate for streaming (default: 8000) |
|
||||
| `--show-words` | Show individual word timings |
|
||||
| `--save` | Save results to JSON in `results/` |
|
||||
|
||||
|
|
@ -64,16 +70,33 @@ evals/stt/
|
|||
├── results/ # Saved benchmark results (JSON)
|
||||
├── providers/ # STT provider implementations
|
||||
│ ├── base.py # Base classes
|
||||
│ ├── deepgram_provider.py
|
||||
│ └── speechmatics_provider.py
|
||||
│ ├── deepgram_provider.py # WebSocket streaming
|
||||
│ └── speechmatics_provider.py # WebSocket streaming
|
||||
├── audio_streamer.py # PCM16 audio file streamer
|
||||
├── benchmark.py # Main runner script
|
||||
└── README.md
|
||||
```
|
||||
|
||||
## How It Works
|
||||
|
||||
1. **Audio Conversion**: The `AudioStreamer` converts any audio file to raw PCM16 using ffmpeg
|
||||
2. **WebSocket Connection**: Providers connect to their respective WebSocket APIs
|
||||
3. **Streaming**: Audio is sent in chunks (configurable sample rate, default 8kHz)
|
||||
4. **Result Collection**: Transcripts and speaker info are collected from WebSocket responses
|
||||
5. **Comparison**: Results are parsed into a common format for comparison
|
||||
|
||||
## Output Example
|
||||
|
||||
```
|
||||
Audio file: /path/to/audio/multi_speaker.m4a
|
||||
Providers: ['deepgram', 'speechmatics']
|
||||
Diarization: True
|
||||
Sample rate: 8000 Hz
|
||||
|
||||
============================================================
|
||||
Provider: DEEPGRAM
|
||||
============================================================
|
||||
|
||||
Duration: 45.32s
|
||||
Speakers detected: 2 - ['0', '1']
|
||||
|
||||
|
|
@ -84,17 +107,29 @@ Hello, welcome to the demo...
|
|||
[0.0s] Speaker 0: Hello, welcome to the demo.
|
||||
[2.5s] Speaker 1: Thanks for having me.
|
||||
...
|
||||
|
||||
============================================================
|
||||
COMPARISON SUMMARY
|
||||
============================================================
|
||||
|
||||
Provider Duration Speakers Words
|
||||
---------------------------------------------
|
||||
deepgram 45.32 2 312
|
||||
speechmatics 45.32 2 308
|
||||
```
|
||||
|
||||
## Adding New Providers
|
||||
|
||||
1. Create a new file in `providers/` (e.g., `whisper_provider.py`)
|
||||
2. Implement the `STTProvider` abstract class
|
||||
3. Add to `providers/__init__.py`
|
||||
4. Add to `benchmark.py` provider choices
|
||||
2. Implement the `STTProvider` abstract class with WebSocket streaming
|
||||
3. Use `AudioStreamer` for PCM16 conversion
|
||||
4. Add to `providers/__init__.py`
|
||||
5. Add to `benchmark.py` provider choices
|
||||
|
||||
## API Documentation
|
||||
|
||||
- Deepgram Streaming: https://developers.deepgram.com/docs/live-streaming-audio
|
||||
- Deepgram Diarization: https://developers.deepgram.com/docs/diarization
|
||||
- Deepgram Keyterms: https://developers.deepgram.com/docs/keyterm
|
||||
- Speechmatics RT API: https://docs.speechmatics.com/rt-api-ref
|
||||
- Speechmatics Diarization: https://docs.speechmatics.com/features/diarization
|
||||
|
|
|
|||
BIN
evals/stt/audio/vad.m4a
Normal file
BIN
evals/stt/audio/vad.m4a
Normal file
Binary file not shown.
127
evals/stt/audio_streamer.py
Normal file
127
evals/stt/audio_streamer.py
Normal file
|
|
@ -0,0 +1,127 @@
|
|||
"""Audio file streamer - converts audio files to PCM16 streams."""
|
||||
|
||||
import asyncio
|
||||
import subprocess
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import AsyncIterator
|
||||
|
||||
|
||||
@dataclass
|
||||
class AudioConfig:
|
||||
"""Audio streaming configuration."""
|
||||
|
||||
sample_rate: int = 8000
|
||||
channels: int = 1
|
||||
sample_width: int = 2 # 16-bit = 2 bytes
|
||||
chunk_duration_ms: int = 80 # Send chunks every 80ms
|
||||
|
||||
@property
|
||||
def chunk_size(self) -> int:
|
||||
"""Bytes per chunk based on duration."""
|
||||
samples_per_chunk = int(self.sample_rate * self.chunk_duration_ms / 1000)
|
||||
return samples_per_chunk * self.channels * self.sample_width
|
||||
|
||||
|
||||
class AudioStreamer:
|
||||
"""Streams audio files as PCM16 chunks.
|
||||
|
||||
Converts any audio format to raw PCM16 using ffmpeg and streams
|
||||
in real-time chunks to simulate live audio.
|
||||
"""
|
||||
|
||||
def __init__(self, config: AudioConfig | None = None):
|
||||
self.config = config or AudioConfig()
|
||||
|
||||
def convert_to_pcm16(self, audio_path: Path) -> bytes:
|
||||
"""Convert audio file to raw PCM16 bytes using ffmpeg.
|
||||
|
||||
Args:
|
||||
audio_path: Path to input audio file
|
||||
|
||||
Returns:
|
||||
Raw PCM16 audio bytes
|
||||
"""
|
||||
cmd = [
|
||||
"ffmpeg",
|
||||
"-i",
|
||||
str(audio_path),
|
||||
"-f",
|
||||
"s16le", # signed 16-bit little-endian
|
||||
"-acodec",
|
||||
"pcm_s16le",
|
||||
"-ar",
|
||||
str(self.config.sample_rate),
|
||||
"-ac",
|
||||
str(self.config.channels),
|
||||
"-", # output to stdout
|
||||
]
|
||||
|
||||
result = subprocess.run(
|
||||
cmd,
|
||||
capture_output=True,
|
||||
check=True,
|
||||
)
|
||||
return result.stdout
|
||||
|
||||
async def stream_file(
|
||||
self,
|
||||
audio_path: Path,
|
||||
realtime: bool = True,
|
||||
) -> AsyncIterator[bytes]:
|
||||
"""Stream audio file as PCM16 chunks.
|
||||
|
||||
Args:
|
||||
audio_path: Path to audio file
|
||||
realtime: If True, add delays to simulate real-time streaming
|
||||
|
||||
Yields:
|
||||
PCM16 audio chunks
|
||||
"""
|
||||
# Convert entire file to PCM16
|
||||
pcm_data = self.convert_to_pcm16(audio_path)
|
||||
|
||||
chunk_size = self.config.chunk_size
|
||||
delay = self.config.chunk_duration_ms / 1000.0 if realtime else 0
|
||||
|
||||
# Stream in chunks
|
||||
for i in range(0, len(pcm_data), chunk_size):
|
||||
chunk = pcm_data[i : i + chunk_size]
|
||||
if chunk:
|
||||
yield chunk
|
||||
if realtime and delay > 0:
|
||||
await asyncio.sleep(delay)
|
||||
|
||||
async def stream_file_fast(self, audio_path: Path) -> AsyncIterator[bytes]:
|
||||
"""Stream audio file as fast as possible (no real-time delay).
|
||||
|
||||
Args:
|
||||
audio_path: Path to audio file
|
||||
|
||||
Yields:
|
||||
PCM16 audio chunks
|
||||
"""
|
||||
async for chunk in self.stream_file(audio_path, realtime=False):
|
||||
yield chunk
|
||||
|
||||
def get_duration(self, audio_path: Path) -> float:
|
||||
"""Get audio file duration in seconds.
|
||||
|
||||
Args:
|
||||
audio_path: Path to audio file
|
||||
|
||||
Returns:
|
||||
Duration in seconds
|
||||
"""
|
||||
cmd = [
|
||||
"ffprobe",
|
||||
"-v",
|
||||
"error",
|
||||
"-show_entries",
|
||||
"format=duration",
|
||||
"-of",
|
||||
"default=noprint_wrappers=1:nokey=1",
|
||||
str(audio_path),
|
||||
]
|
||||
result = subprocess.run(cmd, capture_output=True, text=True, check=True)
|
||||
return float(result.stdout.strip())
|
||||
|
|
@ -20,14 +20,23 @@ from datetime import datetime
|
|||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from evals.stt.providers import DeepgramProvider, SpeechmaticsProvider, STTProvider, TranscriptionResult
|
||||
from evals.stt.providers import (
|
||||
DeepgramProvider,
|
||||
DeepgramFluxProvider,
|
||||
SpeechmaticsProvider,
|
||||
LocalSmartTurnProvider,
|
||||
STTProvider,
|
||||
TranscriptionResult,
|
||||
)
|
||||
|
||||
|
||||
def get_provider(name: str) -> STTProvider:
|
||||
"""Get provider instance by name."""
|
||||
providers = {
|
||||
"deepgram": DeepgramProvider,
|
||||
"deepgram-flux": DeepgramFluxProvider,
|
||||
"speechmatics": SpeechmaticsProvider,
|
||||
"local-smart-turn": LocalSmartTurnProvider,
|
||||
}
|
||||
if name not in providers:
|
||||
raise ValueError(f"Unknown provider: {name}. Available: {list(providers.keys())}")
|
||||
|
|
@ -145,7 +154,7 @@ Examples:
|
|||
"--providers",
|
||||
nargs="+",
|
||||
default=["deepgram", "speechmatics"],
|
||||
choices=["deepgram", "speechmatics"],
|
||||
choices=["deepgram", "deepgram-flux", "speechmatics", "local-smart-turn"],
|
||||
help="Providers to test (default: all)",
|
||||
)
|
||||
parser.add_argument(
|
||||
|
|
@ -163,6 +172,12 @@ Examples:
|
|||
default="en",
|
||||
help="Language code (default: en)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sample-rate",
|
||||
type=int,
|
||||
default=8000,
|
||||
help="Audio sample rate for streaming (default: 8000)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--show-words",
|
||||
action="store_true",
|
||||
|
|
@ -195,6 +210,7 @@ Examples:
|
|||
print(f"Audio file: {audio_path}")
|
||||
print(f"Providers: {args.providers}")
|
||||
print(f"Diarization: {args.diarize}")
|
||||
print(f"Sample rate: {args.sample_rate} Hz")
|
||||
if args.keyterms:
|
||||
print(f"Keyterms: {args.keyterms}")
|
||||
|
||||
|
|
@ -209,6 +225,7 @@ Examples:
|
|||
diarize=args.diarize,
|
||||
keyterms=args.keyterms,
|
||||
language=args.language,
|
||||
sample_rate=args.sample_rate,
|
||||
)
|
||||
print_result(result, show_words=args.show_words)
|
||||
results.append(result)
|
||||
|
|
|
|||
|
|
@ -1,11 +1,15 @@
|
|||
from .base import STTProvider, TranscriptionResult, Word
|
||||
from .deepgram_provider import DeepgramProvider
|
||||
from .deepgram_flux_provider import DeepgramFluxProvider
|
||||
from .speechmatics_provider import SpeechmaticsProvider
|
||||
from .local_smart_turn_provider import LocalSmartTurnProvider
|
||||
|
||||
__all__ = [
|
||||
"STTProvider",
|
||||
"TranscriptionResult",
|
||||
"Word",
|
||||
"DeepgramProvider",
|
||||
"DeepgramFluxProvider",
|
||||
"SpeechmaticsProvider",
|
||||
"LocalSmartTurnProvider",
|
||||
]
|
||||
|
|
|
|||
225
evals/stt/providers/deepgram_flux_provider.py
Normal file
225
evals/stt/providers/deepgram_flux_provider.py
Normal file
|
|
@ -0,0 +1,225 @@
|
|||
"""Deepgram Flux STT provider with WebSocket streaming.
|
||||
|
||||
Flux is Deepgram's conversational AI model with built-in turn detection.
|
||||
It has a different API than Nova models - no language/punctuate/diarize params.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from urllib.parse import urlencode
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from ..audio_streamer import AudioConfig, AudioStreamer
|
||||
from .base import STTProvider, TranscriptionResult, Word
|
||||
|
||||
try:
|
||||
from websockets.asyncio.client import connect as websocket_connect
|
||||
except ImportError:
|
||||
raise ImportError("websockets required: pip install websockets")
|
||||
|
||||
|
||||
class DeepgramFluxProvider(STTProvider):
|
||||
"""Deepgram Flux Speech-to-Text provider with WebSocket streaming.
|
||||
|
||||
Flux is optimized for conversational AI with built-in turn detection.
|
||||
|
||||
Key differences from Nova:
|
||||
- Uses v2 API endpoint
|
||||
- Only supports English (flux-general-en)
|
||||
- No punctuate, diarize, or language params
|
||||
- Has turn detection events (StartOfTurn, EndOfTurn, EagerEndOfTurn)
|
||||
- Supports keyterm boosting
|
||||
|
||||
API Docs: https://developers.deepgram.com/docs/
|
||||
"""
|
||||
|
||||
WS_URL = "wss://api.deepgram.com/v2/listen"
|
||||
|
||||
def __init__(self, api_key: str | None = None):
|
||||
self.api_key = api_key or os.getenv("DEEPGRAM_API_KEY")
|
||||
if not self.api_key:
|
||||
raise ValueError(
|
||||
"Deepgram API key required. Set DEEPGRAM_API_KEY env var or pass api_key."
|
||||
)
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "deepgram-flux"
|
||||
|
||||
async def transcribe(
|
||||
self,
|
||||
audio_path: Path,
|
||||
diarize: bool = False, # Ignored - Flux doesn't support diarization
|
||||
keyterms: list[str] | None = None,
|
||||
model: str = "flux-general-en",
|
||||
sample_rate: int = 16000,
|
||||
eot_threshold: float | None = None,
|
||||
eot_timeout_ms: int | None = None,
|
||||
eager_eot_threshold: float | None = None,
|
||||
**kwargs: Any,
|
||||
) -> TranscriptionResult:
|
||||
"""Transcribe audio using Deepgram Flux WebSocket streaming.
|
||||
|
||||
Args:
|
||||
audio_path: Path to audio file
|
||||
diarize: IGNORED - Flux does not support diarization
|
||||
keyterms: List of keywords to boost recognition
|
||||
model: Flux model (default: flux-general-en)
|
||||
sample_rate: Audio sample rate (default: 16000 for Flux)
|
||||
eot_threshold: End-of-turn confidence threshold (0-1, default 0.7)
|
||||
eot_timeout_ms: Timeout in ms to force end of turn (default 5000)
|
||||
eager_eot_threshold: Threshold for eager end-of-turn events
|
||||
**kwargs: Additional Flux parameters
|
||||
|
||||
Returns:
|
||||
TranscriptionResult with transcript (no speaker info - Flux doesn't support diarization)
|
||||
"""
|
||||
if diarize:
|
||||
logger.warning("Flux does not support diarization - ignoring diarize=True")
|
||||
|
||||
# Build query params - Flux only supports specific params
|
||||
params: dict[str, Any] = {
|
||||
"model": model,
|
||||
"encoding": "linear16",
|
||||
"sample_rate": sample_rate,
|
||||
}
|
||||
|
||||
# Flux-specific turn detection params
|
||||
if eot_threshold is not None:
|
||||
params["eot_threshold"] = eot_threshold
|
||||
if eot_timeout_ms is not None:
|
||||
params["eot_timeout_ms"] = eot_timeout_ms
|
||||
if eager_eot_threshold is not None:
|
||||
params["eager_eot_threshold"] = eager_eot_threshold
|
||||
|
||||
# Build URL with params
|
||||
url_parts = [f"{k}={v}" for k, v in params.items()]
|
||||
|
||||
# Add keyterms (repeated params)
|
||||
if keyterms:
|
||||
for term in keyterms:
|
||||
url_parts.append(urlencode({"keyterm": term}))
|
||||
|
||||
ws_url = f"{self.WS_URL}?{'&'.join(url_parts)}"
|
||||
logger.debug(f"Flux WebSocket URL: {ws_url}")
|
||||
|
||||
# Setup audio streamer
|
||||
audio_config = AudioConfig(sample_rate=sample_rate)
|
||||
streamer = AudioStreamer(audio_config)
|
||||
|
||||
# Collect results
|
||||
all_transcripts: list[dict[str, Any]] = []
|
||||
final_transcript = ""
|
||||
duration = 0.0
|
||||
connected = asyncio.Event()
|
||||
|
||||
async with websocket_connect(
|
||||
ws_url,
|
||||
additional_headers={"Authorization": f"Token {self.api_key}"},
|
||||
) as ws:
|
||||
|
||||
async def send_audio():
|
||||
"""Send audio chunks to Deepgram Flux."""
|
||||
await connected.wait()
|
||||
|
||||
chunk_no = 0
|
||||
async for chunk in streamer.stream_file(audio_path):
|
||||
logger.debug(f"[deepgram-flux] Sent audio chunk {chunk_no}")
|
||||
await ws.send(chunk)
|
||||
chunk_no += 1
|
||||
|
||||
async def receive_messages():
|
||||
"""Receive and collect Flux messages."""
|
||||
nonlocal all_transcripts, final_transcript, duration
|
||||
|
||||
async for message in ws:
|
||||
if isinstance(message, str):
|
||||
data = json.loads(message)
|
||||
msg_type = data.get("type")
|
||||
logger.debug(f"[deepgram-flux] Received {msg_type}: {data}")
|
||||
|
||||
if msg_type == "Connected":
|
||||
logger.info("[deepgram-flux] Connected")
|
||||
connected.set()
|
||||
|
||||
elif msg_type == "TurnInfo":
|
||||
event = data.get("event")
|
||||
transcript = data.get("transcript", "")
|
||||
words = data.get("words", [])
|
||||
|
||||
if event == "EndOfTurn":
|
||||
if transcript:
|
||||
final_transcript += transcript + " "
|
||||
if words:
|
||||
all_transcripts.append({
|
||||
"transcript": transcript,
|
||||
"words": words,
|
||||
})
|
||||
# Get duration from last word
|
||||
if words:
|
||||
last_word = words[-1]
|
||||
duration = max(duration, last_word.get("end", 0))
|
||||
|
||||
elif event == "TurnResumed":
|
||||
logger.debug("TurnResumed")
|
||||
|
||||
elif msg_type == "Error":
|
||||
raise Exception(f"Deepgram Flux error: {data}")
|
||||
|
||||
# Run send and receive concurrently
|
||||
send_task = asyncio.create_task(send_audio())
|
||||
receive_task = asyncio.create_task(receive_messages())
|
||||
|
||||
await send_task
|
||||
|
||||
logger.debug("[deepgram-flux] Send task done")
|
||||
try:
|
||||
await asyncio.wait_for(receive_task, timeout=10.0)
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
|
||||
return self._parse_results(
|
||||
all_transcripts, final_transcript.strip(), duration, params, keyterms
|
||||
)
|
||||
|
||||
def _parse_results(
|
||||
self,
|
||||
transcripts: list[dict[str, Any]],
|
||||
final_transcript: str,
|
||||
duration: float,
|
||||
params: dict[str, Any],
|
||||
keyterms: list[str] | None,
|
||||
) -> TranscriptionResult:
|
||||
"""Parse collected Flux results into TranscriptionResult."""
|
||||
words = []
|
||||
|
||||
for turn in transcripts:
|
||||
for w in turn.get("words", []):
|
||||
words.append(
|
||||
Word(
|
||||
word=w.get("word", ""),
|
||||
start=w.get("start", 0.0),
|
||||
end=w.get("end", 0.0),
|
||||
confidence=w.get("confidence", 0.0),
|
||||
speaker=None, # Flux doesn't support diarization
|
||||
speaker_confidence=None,
|
||||
)
|
||||
)
|
||||
|
||||
stored_params = dict(params)
|
||||
if keyterms:
|
||||
stored_params["keyterms"] = keyterms
|
||||
|
||||
return TranscriptionResult(
|
||||
provider=self.name,
|
||||
transcript=final_transcript,
|
||||
words=words,
|
||||
speakers=[], # Flux doesn't support diarization
|
||||
duration=duration,
|
||||
raw_response={"transcripts": transcripts},
|
||||
params=stored_params,
|
||||
)
|
||||
|
|
@ -1,25 +1,38 @@
|
|||
"""Deepgram STT provider."""
|
||||
"""Deepgram STT provider with WebSocket streaming."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from urllib.parse import urlencode
|
||||
|
||||
import httpx
|
||||
|
||||
from ..audio_streamer import AudioConfig, AudioStreamer
|
||||
from .base import STTProvider, TranscriptionResult, Word
|
||||
from loguru import logger
|
||||
|
||||
try:
|
||||
from websockets.asyncio.client import connect as websocket_connect
|
||||
except ImportError:
|
||||
raise ImportError("websockets required: pip install websockets")
|
||||
|
||||
|
||||
class DeepgramProvider(STTProvider):
|
||||
"""Deepgram Speech-to-Text provider.
|
||||
"""Deepgram Nova Speech-to-Text provider with WebSocket streaming.
|
||||
|
||||
API Docs: https://developers.deepgram.com/docs/
|
||||
|
||||
Supports:
|
||||
- Speaker diarization via `diarize=true`
|
||||
- Keyterm boosting via `keyterm` parameter (Nova-3 and Flux models)
|
||||
- Keyterm boosting via `keyterm` parameter
|
||||
- Real-time streaming via WebSocket
|
||||
- Multiple languages
|
||||
- Punctuation
|
||||
|
||||
For Flux models, use DeepgramFluxProvider instead.
|
||||
"""
|
||||
|
||||
API_URL = "https://api.deepgram.com/v1/listen"
|
||||
WS_URL = "wss://api.deepgram.com/v1/listen"
|
||||
|
||||
def __init__(self, api_key: str | None = None):
|
||||
self.api_key = api_key or os.getenv("DEEPGRAM_API_KEY")
|
||||
|
|
@ -37,113 +50,151 @@ class DeepgramProvider(STTProvider):
|
|||
audio_path: Path,
|
||||
diarize: bool = False,
|
||||
keyterms: list[str] | None = None,
|
||||
model: str = "nova-3",
|
||||
model: str = "nova-3-general",
|
||||
language: str = "en",
|
||||
sample_rate: int = 8000,
|
||||
punctuate: bool = True,
|
||||
**kwargs: Any,
|
||||
) -> TranscriptionResult:
|
||||
"""Transcribe audio using Deepgram API.
|
||||
"""Transcribe audio using Deepgram Nova WebSocket streaming.
|
||||
|
||||
Args:
|
||||
audio_path: Path to audio file
|
||||
diarize: Enable speaker diarization
|
||||
keyterms: List of keywords to boost recognition
|
||||
model: Deepgram model (nova-3, nova-2, etc.)
|
||||
model: Deepgram Nova model (nova-3, nova-2, etc.)
|
||||
language: Language code
|
||||
sample_rate: Audio sample rate for streaming
|
||||
punctuate: Add punctuation
|
||||
**kwargs: Additional Deepgram parameters
|
||||
|
||||
Returns:
|
||||
TranscriptionResult with transcript and speaker info
|
||||
"""
|
||||
# Build query params
|
||||
params: dict[str, Any] = {
|
||||
"model": model,
|
||||
"language": language,
|
||||
"punctuate": str(punctuate).lower(),
|
||||
"encoding": "linear16",
|
||||
"sample_rate": sample_rate,
|
||||
"channels": 1,
|
||||
"interim_results": "true",
|
||||
"smart_format": "true",
|
||||
"profanity_filter": "true",
|
||||
"vad_events": "true"
|
||||
}
|
||||
|
||||
if diarize:
|
||||
params["diarize"] = "true"
|
||||
|
||||
# Add keyterms (Deepgram uses repeated keyterm params)
|
||||
# Build URL with params
|
||||
url_parts = [f"{k}={v}" for k, v in params.items()]
|
||||
|
||||
# Add keyterms (repeated params)
|
||||
if keyterms:
|
||||
params["keyterm"] = keyterms
|
||||
for term in keyterms:
|
||||
url_parts.append(urlencode({"keyterm": term}))
|
||||
|
||||
# Add any extra kwargs
|
||||
params.update(kwargs)
|
||||
# Add extra kwargs
|
||||
for k, v in kwargs.items():
|
||||
url_parts.append(f"{k}={v}")
|
||||
|
||||
# Read audio file
|
||||
audio_data = audio_path.read_bytes()
|
||||
ws_url = f"{self.WS_URL}?{'&'.join(url_parts)}"
|
||||
logger.debug(f"Deepgram WebSocket URL: {ws_url}")
|
||||
|
||||
# Determine content type
|
||||
suffix = audio_path.suffix.lower()
|
||||
content_types = {
|
||||
".wav": "audio/wav",
|
||||
".mp3": "audio/mpeg",
|
||||
".m4a": "audio/mp4",
|
||||
".flac": "audio/flac",
|
||||
".ogg": "audio/ogg",
|
||||
".webm": "audio/webm",
|
||||
}
|
||||
content_type = content_types.get(suffix, "audio/wav")
|
||||
# Setup audio streamer
|
||||
audio_config = AudioConfig(sample_rate=sample_rate)
|
||||
streamer = AudioStreamer(audio_config)
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Token {self.api_key}",
|
||||
"Content-Type": content_type,
|
||||
}
|
||||
# Collect results
|
||||
all_words: list[dict[str, Any]] = []
|
||||
final_transcript = ""
|
||||
duration = 0.0
|
||||
|
||||
async with httpx.AsyncClient(timeout=120.0) as client:
|
||||
response = await client.post(
|
||||
self.API_URL,
|
||||
params=params,
|
||||
headers=headers,
|
||||
content=audio_data,
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
try:
|
||||
async with websocket_connect(
|
||||
ws_url,
|
||||
additional_headers={"Authorization": f"Token {self.api_key}"},
|
||||
) as ws:
|
||||
# Create tasks for sending and receiving
|
||||
send_complete = asyncio.Event()
|
||||
|
||||
return self._parse_response(data, params)
|
||||
async def send_audio():
|
||||
"""Send audio chunks to Deepgram."""
|
||||
chunk_no = 0
|
||||
async for chunk in streamer.stream_file(audio_path):
|
||||
logger.debug(f"[deepgram] Sent audio chunk {chunk_no}")
|
||||
await ws.send(chunk)
|
||||
chunk_no += 1
|
||||
# Send close message
|
||||
logger.debug(f"[deepgram] Sending CloseStream after {chunk_no} chunks")
|
||||
await ws.send(json.dumps({"type": "CloseStream"}))
|
||||
send_complete.set()
|
||||
|
||||
def _parse_response(
|
||||
self, data: dict[str, Any], params: dict[str, Any]
|
||||
async def receive_transcripts():
|
||||
"""Receive and collect transcription results."""
|
||||
nonlocal all_words, final_transcript, duration
|
||||
|
||||
async for message in ws:
|
||||
if isinstance(message, str):
|
||||
data = json.loads(message)
|
||||
msg_type = data.get("type")
|
||||
logger.debug(f"[deepgram] Received {msg_type}: {data}")
|
||||
|
||||
if msg_type == "Results":
|
||||
# Nova-style response
|
||||
channel = data.get("channel", {})
|
||||
alternatives = channel.get("alternatives", [])
|
||||
if alternatives:
|
||||
alt = alternatives[0]
|
||||
words = alt.get("words", [])
|
||||
all_words.extend(words)
|
||||
|
||||
# Check if final
|
||||
if data.get("is_final"):
|
||||
final_transcript += alt.get("transcript", "") + " "
|
||||
duration = max(
|
||||
duration, data.get("duration", 0) + data.get("start", 0)
|
||||
)
|
||||
|
||||
elif msg_type == "Metadata":
|
||||
# Get duration from metadata
|
||||
duration = data.get("duration", duration)
|
||||
|
||||
elif msg_type == "Error":
|
||||
raise Exception(f"Deepgram error: {data}")
|
||||
|
||||
# Run send and receive concurrently
|
||||
send_task = asyncio.create_task(send_audio())
|
||||
receive_task = asyncio.create_task(receive_transcripts())
|
||||
|
||||
# Wait for send to complete, then wait a bit for final results
|
||||
await send_task
|
||||
try:
|
||||
await asyncio.wait_for(receive_task, timeout=5.0)
|
||||
except asyncio.TimeoutError:
|
||||
pass # Normal - websocket closes after final results
|
||||
except Exception as e:
|
||||
logger.debug(e)
|
||||
|
||||
return self._parse_results(
|
||||
all_words, final_transcript.strip(), duration, params, keyterms
|
||||
)
|
||||
|
||||
def _parse_results(
|
||||
self,
|
||||
raw_words: list[dict[str, Any]],
|
||||
transcript: str,
|
||||
duration: float,
|
||||
params: dict[str, Any],
|
||||
keyterms: list[str] | None,
|
||||
) -> TranscriptionResult:
|
||||
"""Parse Deepgram API response."""
|
||||
results = data.get("results", {})
|
||||
channels = results.get("channels", [])
|
||||
|
||||
if not channels:
|
||||
return TranscriptionResult(
|
||||
provider=self.name,
|
||||
transcript="",
|
||||
words=[],
|
||||
speakers=[],
|
||||
duration=0.0,
|
||||
raw_response=data,
|
||||
params=params,
|
||||
)
|
||||
|
||||
# Get first channel, first alternative
|
||||
channel = channels[0]
|
||||
alternatives = channel.get("alternatives", [])
|
||||
if not alternatives:
|
||||
return TranscriptionResult(
|
||||
provider=self.name,
|
||||
transcript="",
|
||||
words=[],
|
||||
speakers=[],
|
||||
duration=0.0,
|
||||
raw_response=data,
|
||||
params=params,
|
||||
)
|
||||
|
||||
alt = alternatives[0]
|
||||
transcript = alt.get("transcript", "")
|
||||
|
||||
# Parse words with speaker info
|
||||
"""Parse collected results into TranscriptionResult."""
|
||||
words = []
|
||||
speakers_set: set[str] = set()
|
||||
|
||||
for w in alt.get("words", []):
|
||||
for w in raw_words:
|
||||
speaker = str(w.get("speaker", "")) if "speaker" in w else None
|
||||
if speaker:
|
||||
speakers_set.add(speaker)
|
||||
|
|
@ -159,9 +210,9 @@ class DeepgramProvider(STTProvider):
|
|||
)
|
||||
)
|
||||
|
||||
# Get duration from metadata
|
||||
metadata = results.get("metadata", {})
|
||||
duration = metadata.get("duration", 0.0)
|
||||
stored_params = dict(params)
|
||||
if keyterms:
|
||||
stored_params["keyterms"] = keyterms
|
||||
|
||||
return TranscriptionResult(
|
||||
provider=self.name,
|
||||
|
|
@ -169,6 +220,6 @@ class DeepgramProvider(STTProvider):
|
|||
words=words,
|
||||
speakers=sorted(speakers_set),
|
||||
duration=duration,
|
||||
raw_response=data,
|
||||
params=params,
|
||||
raw_response={"words": raw_words},
|
||||
params=stored_params,
|
||||
)
|
||||
|
|
|
|||
285
evals/stt/providers/local_smart_turn_provider.py
Normal file
285
evals/stt/providers/local_smart_turn_provider.py
Normal file
|
|
@ -0,0 +1,285 @@
|
|||
"""Local Smart Turn provider for benchmarking end-of-turn detection.
|
||||
|
||||
Uses the pipecat smart-turn-v3 ONNX model for local ML-based turn detection.
|
||||
This is NOT an STT provider - it only detects when a speaker has finished talking.
|
||||
"""
|
||||
|
||||
import os
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
from loguru import logger
|
||||
|
||||
from ..audio_streamer import AudioConfig, AudioStreamer
|
||||
from .base import STTProvider, TranscriptionResult, Word
|
||||
|
||||
try:
|
||||
import onnxruntime as ort
|
||||
from transformers import WhisperFeatureExtractor
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"onnxruntime and transformers required: pip install onnxruntime transformers"
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TurnEvent:
|
||||
"""Represents a detected turn event."""
|
||||
timestamp: float # Time in audio when turn was detected
|
||||
probability: float # Model confidence
|
||||
prediction: int # 1=complete, 0=incomplete
|
||||
inference_time_ms: float
|
||||
|
||||
|
||||
class LocalSmartTurnProvider(STTProvider):
|
||||
"""Local Smart Turn provider for end-of-turn detection benchmarking.
|
||||
|
||||
Uses the smart-turn-v3 ONNX model to detect when speakers finish talking.
|
||||
This is useful for comparing turn detection accuracy against cloud services
|
||||
like Deepgram Flux's built-in turn detection.
|
||||
|
||||
NOTE: This provider does NOT produce transcripts - only turn detection events.
|
||||
"""
|
||||
|
||||
# Smart turn model requires 16kHz audio
|
||||
REQUIRED_SAMPLE_RATE = 16000
|
||||
# Model analyzes 8 seconds of audio
|
||||
WINDOW_SECONDS = 8
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_path: str | None = None,
|
||||
cpu_count: int = 1,
|
||||
):
|
||||
"""Initialize the local smart turn provider.
|
||||
|
||||
Args:
|
||||
model_path: Path to ONNX model file. If None, uses bundled model.
|
||||
cpu_count: Number of CPUs for inference (default: 1)
|
||||
"""
|
||||
self.model_path = model_path
|
||||
self.cpu_count = cpu_count
|
||||
self._session = None
|
||||
self._feature_extractor = None
|
||||
|
||||
def _load_model(self):
|
||||
"""Lazy load the ONNX model."""
|
||||
if self._session is not None:
|
||||
return
|
||||
|
||||
model_path = self.model_path
|
||||
|
||||
if not model_path:
|
||||
# Try to load bundled model from pipecat
|
||||
model_name = "smart-turn-v3.1-cpu.onnx"
|
||||
package_path = "pipecat.audio.turn.smart_turn.data"
|
||||
|
||||
try:
|
||||
import importlib_resources as impresources
|
||||
model_path = str(impresources.files(package_path).joinpath(model_name))
|
||||
except Exception:
|
||||
from importlib import resources as impresources
|
||||
try:
|
||||
with impresources.path(package_path, model_name) as f:
|
||||
model_path = str(f)
|
||||
except Exception:
|
||||
model_path = str(impresources.files(package_path).joinpath(model_name))
|
||||
|
||||
logger.info(f"[local-smart-turn] Loading model from {model_path}")
|
||||
|
||||
# Configure ONNX runtime
|
||||
so = ort.SessionOptions()
|
||||
so.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL
|
||||
so.inter_op_num_threads = 1
|
||||
so.intra_op_num_threads = self.cpu_count
|
||||
so.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
|
||||
|
||||
self._feature_extractor = WhisperFeatureExtractor(chunk_length=8)
|
||||
self._session = ort.InferenceSession(model_path, sess_options=so)
|
||||
|
||||
logger.info("[local-smart-turn] Model loaded")
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "local-smart-turn"
|
||||
|
||||
def _predict_endpoint(self, audio_array: np.ndarray) -> dict[str, Any]:
|
||||
"""Predict end-of-turn using the ONNX model.
|
||||
|
||||
Args:
|
||||
audio_array: Audio samples as float32 numpy array (16kHz)
|
||||
|
||||
Returns:
|
||||
Dict with prediction (0/1) and probability
|
||||
"""
|
||||
# Truncate to last 8 seconds or pad to 8 seconds
|
||||
max_samples = self.WINDOW_SECONDS * self.REQUIRED_SAMPLE_RATE
|
||||
if len(audio_array) > max_samples:
|
||||
audio_array = audio_array[-max_samples:]
|
||||
elif len(audio_array) < max_samples:
|
||||
padding = max_samples - len(audio_array)
|
||||
audio_array = np.pad(audio_array, (padding, 0), mode="constant", constant_values=0)
|
||||
|
||||
# Process using Whisper's feature extractor
|
||||
inputs = self._feature_extractor(
|
||||
audio_array,
|
||||
sampling_rate=self.REQUIRED_SAMPLE_RATE,
|
||||
return_tensors="np",
|
||||
padding="max_length",
|
||||
max_length=self.WINDOW_SECONDS * self.REQUIRED_SAMPLE_RATE,
|
||||
truncation=True,
|
||||
do_normalize=True,
|
||||
)
|
||||
|
||||
# Extract features for ONNX
|
||||
input_features = inputs.input_features.squeeze(0).astype(np.float32)
|
||||
input_features = np.expand_dims(input_features, axis=0)
|
||||
|
||||
# Run inference
|
||||
start_time = time.perf_counter()
|
||||
outputs = self._session.run(None, {"input_features": input_features})
|
||||
inference_time = (time.perf_counter() - start_time) * 1000
|
||||
|
||||
# Extract probability (model returns sigmoid probabilities)
|
||||
probability = outputs[0][0].item()
|
||||
prediction = 1 if probability > 0.5 else 0
|
||||
|
||||
return {
|
||||
"prediction": prediction,
|
||||
"probability": probability,
|
||||
"inference_time_ms": inference_time,
|
||||
}
|
||||
|
||||
async def transcribe(
|
||||
self,
|
||||
audio_path: Path,
|
||||
diarize: bool = False, # Ignored - not applicable
|
||||
keyterms: list[str] | None = None, # Ignored - not applicable
|
||||
sample_rate: int = 16000, # Must be 16kHz for smart turn
|
||||
analysis_interval_ms: int = 500, # How often to check for turn completion
|
||||
**kwargs: Any,
|
||||
) -> TranscriptionResult:
|
||||
"""Analyze audio for turn detection events.
|
||||
|
||||
NOTE: This does NOT produce transcripts. It detects when speakers
|
||||
finish talking using ML-based turn detection.
|
||||
|
||||
Args:
|
||||
audio_path: Path to audio file
|
||||
diarize: Ignored (not applicable for turn detection)
|
||||
keyterms: Ignored (not applicable for turn detection)
|
||||
sample_rate: Must be 16000 Hz for smart turn model
|
||||
analysis_interval_ms: How often to run turn detection (ms)
|
||||
**kwargs: Additional parameters (ignored)
|
||||
|
||||
Returns:
|
||||
TranscriptionResult with turn detection events in raw_response
|
||||
"""
|
||||
if sample_rate != self.REQUIRED_SAMPLE_RATE:
|
||||
logger.warning(
|
||||
f"[local-smart-turn] Sample rate must be {self.REQUIRED_SAMPLE_RATE}Hz, "
|
||||
f"overriding {sample_rate}Hz"
|
||||
)
|
||||
sample_rate = self.REQUIRED_SAMPLE_RATE
|
||||
|
||||
# Load model if not already loaded
|
||||
self._load_model()
|
||||
|
||||
# Setup audio streamer at 16kHz
|
||||
audio_config = AudioConfig(sample_rate=sample_rate)
|
||||
streamer = AudioStreamer(audio_config)
|
||||
|
||||
# Get audio duration
|
||||
duration = streamer.get_duration(audio_path)
|
||||
logger.info(f"[local-smart-turn] Processing {audio_path} ({duration:.2f}s)")
|
||||
|
||||
# Collect all audio first (smart turn needs to analyze segments)
|
||||
pcm_data = streamer.convert_to_pcm16(audio_path)
|
||||
|
||||
# Convert to float32 for model
|
||||
audio_int16 = np.frombuffer(pcm_data, dtype=np.int16)
|
||||
audio_float32 = audio_int16.astype(np.float32) / 32768.0
|
||||
|
||||
# Analyze at intervals
|
||||
turn_events: list[TurnEvent] = []
|
||||
samples_per_interval = int(sample_rate * analysis_interval_ms / 1000)
|
||||
window_samples = self.WINDOW_SECONDS * sample_rate
|
||||
|
||||
chunk_no = 0
|
||||
for end_sample in range(samples_per_interval, len(audio_float32), samples_per_interval):
|
||||
# Get window of audio ending at current position
|
||||
start_sample = max(0, end_sample - window_samples)
|
||||
audio_window = audio_float32[start_sample:end_sample]
|
||||
|
||||
current_time = end_sample / sample_rate
|
||||
logger.debug(f"[local-smart-turn] Analyzing chunk {chunk_no} at {current_time:.2f}s")
|
||||
|
||||
result = self._predict_endpoint(audio_window)
|
||||
|
||||
turn_events.append(TurnEvent(
|
||||
timestamp=current_time,
|
||||
probability=result["probability"],
|
||||
prediction=result["prediction"],
|
||||
inference_time_ms=result["inference_time_ms"],
|
||||
))
|
||||
|
||||
if result["prediction"] == 1:
|
||||
logger.info(
|
||||
f"[local-smart-turn] Turn complete at {current_time:.2f}s "
|
||||
f"(prob={result['probability']:.3f})"
|
||||
f"(inf time ms={result["inference_time_ms"]})"
|
||||
)
|
||||
|
||||
chunk_no += 1
|
||||
|
||||
# Create result
|
||||
# Convert turn events to word-like format for compatibility
|
||||
words = []
|
||||
for event in turn_events:
|
||||
if event.prediction == 1:
|
||||
words.append(Word(
|
||||
word=f"[END_OF_TURN prob={event.probability:.2f}]",
|
||||
start=event.timestamp - 0.1,
|
||||
end=event.timestamp,
|
||||
confidence=event.probability,
|
||||
speaker=None,
|
||||
speaker_confidence=None,
|
||||
))
|
||||
|
||||
# Count completed turns
|
||||
completed_turns = sum(1 for e in turn_events if e.prediction == 1)
|
||||
|
||||
params = {
|
||||
"sample_rate": sample_rate,
|
||||
"analysis_interval_ms": analysis_interval_ms,
|
||||
"window_seconds": self.WINDOW_SECONDS,
|
||||
}
|
||||
|
||||
return TranscriptionResult(
|
||||
provider=self.name,
|
||||
transcript=f"[Turn detection only - {completed_turns} turns detected]",
|
||||
words=words,
|
||||
speakers=[], # Not applicable
|
||||
duration=duration,
|
||||
raw_response={
|
||||
"turn_events": [
|
||||
{
|
||||
"timestamp": e.timestamp,
|
||||
"probability": e.probability,
|
||||
"prediction": e.prediction,
|
||||
"inference_time_ms": e.inference_time_ms,
|
||||
}
|
||||
for e in turn_events
|
||||
],
|
||||
"completed_turns": completed_turns,
|
||||
"total_analyses": len(turn_events),
|
||||
"avg_inference_time_ms": (
|
||||
sum(e.inference_time_ms for e in turn_events) / len(turn_events)
|
||||
if turn_events else 0
|
||||
),
|
||||
},
|
||||
params=params,
|
||||
)
|
||||
|
|
@ -1,38 +1,41 @@
|
|||
"""Speechmatics STT provider."""
|
||||
"""Speechmatics STT provider with WebSocket streaming."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
from loguru import logger
|
||||
|
||||
from ..audio_streamer import AudioConfig, AudioStreamer
|
||||
from .base import STTProvider, TranscriptionResult, Word
|
||||
|
||||
try:
|
||||
from websockets.asyncio.client import connect as websocket_connect
|
||||
except ImportError:
|
||||
raise ImportError("websockets required: pip install websockets")
|
||||
|
||||
|
||||
class SpeechmaticsProvider(STTProvider):
|
||||
"""Speechmatics Speech-to-Text provider.
|
||||
"""Speechmatics Speech-to-Text provider with WebSocket streaming.
|
||||
|
||||
API Docs: https://docs.speechmatics.com/
|
||||
|
||||
Supports:
|
||||
- Speaker diarization via `diarization: "speaker"` config
|
||||
- Speaker sensitivity tuning
|
||||
- Real-time streaming via WebSocket
|
||||
"""
|
||||
|
||||
# EU and US endpoints available
|
||||
API_URL = "https://asr.api.speechmatics.com/v2/jobs"
|
||||
|
||||
def __init__(self, api_key: str | None = None, region: str = "eu1"):
|
||||
def __init__(self, api_key: str | None = None, region: str = "eu2"):
|
||||
self.api_key = api_key or os.getenv("SPEECHMATICS_API_KEY")
|
||||
if not self.api_key:
|
||||
raise ValueError(
|
||||
"Speechmatics API key required. Set SPEECHMATICS_API_KEY env var or pass api_key."
|
||||
)
|
||||
# Set region-specific endpoint
|
||||
if region == "eu1":
|
||||
self.api_url = "https://eu1.asr.api.speechmatics.com/v2/jobs"
|
||||
else:
|
||||
self.api_url = "https://asr.api.speechmatics.com/v2/jobs"
|
||||
self.ws_url = f"wss://{region}.rt.speechmatics.com/v2"
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
|
|
@ -45,27 +48,32 @@ class SpeechmaticsProvider(STTProvider):
|
|||
keyterms: list[str] | None = None,
|
||||
language: str = "en",
|
||||
operating_point: str = "enhanced",
|
||||
sample_rate: int = 8000,
|
||||
speaker_sensitivity: float | None = None,
|
||||
max_speakers: int | None = None,
|
||||
**kwargs: Any,
|
||||
) -> TranscriptionResult:
|
||||
"""Transcribe audio using Speechmatics API.
|
||||
"""Transcribe audio using Speechmatics WebSocket streaming.
|
||||
|
||||
Args:
|
||||
audio_path: Path to audio file
|
||||
diarize: Enable speaker diarization
|
||||
keyterms: Not directly supported by Speechmatics (ignored)
|
||||
keyterms: Additional vocabulary (limited support)
|
||||
language: Language code
|
||||
operating_point: "standard" or "enhanced"
|
||||
sample_rate: Audio sample rate for streaming
|
||||
speaker_sensitivity: 0.0-1.0, higher = more speakers detected
|
||||
max_speakers: Maximum number of speakers to detect
|
||||
**kwargs: Additional config parameters
|
||||
|
||||
Returns:
|
||||
TranscriptionResult with transcript and speaker info
|
||||
"""
|
||||
# Build transcription config
|
||||
# Build transcription config for StartRecognition message
|
||||
transcription_config: dict[str, Any] = {
|
||||
"language": language,
|
||||
"operating_point": operating_point,
|
||||
"enable_partials": False,
|
||||
}
|
||||
|
||||
if diarize:
|
||||
|
|
@ -74,13 +82,20 @@ class SpeechmaticsProvider(STTProvider):
|
|||
transcription_config["speaker_diarization_config"] = {
|
||||
"speaker_sensitivity": speaker_sensitivity
|
||||
}
|
||||
if max_speakers is not None:
|
||||
if "speaker_diarization_config" not in transcription_config:
|
||||
transcription_config["speaker_diarization_config"] = {}
|
||||
transcription_config["speaker_diarization_config"]["max_speakers"] = max_speakers
|
||||
|
||||
# Add any extra config
|
||||
transcription_config.update(kwargs)
|
||||
# Add additional vocabulary if provided
|
||||
if keyterms:
|
||||
transcription_config["additional_vocab"] = [{"content": term} for term in keyterms]
|
||||
|
||||
config = {
|
||||
"type": "transcription",
|
||||
"transcription_config": transcription_config,
|
||||
# Audio format config
|
||||
audio_format = {
|
||||
"type": "raw",
|
||||
"encoding": "pcm_s16le",
|
||||
"sample_rate": sample_rate,
|
||||
}
|
||||
|
||||
# Store params for result
|
||||
|
|
@ -88,79 +103,104 @@ class SpeechmaticsProvider(STTProvider):
|
|||
"diarize": diarize,
|
||||
"language": language,
|
||||
"operating_point": operating_point,
|
||||
"sample_rate": sample_rate,
|
||||
"speaker_sensitivity": speaker_sensitivity,
|
||||
"max_speakers": max_speakers,
|
||||
}
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
}
|
||||
# Setup audio streamer
|
||||
audio_config = AudioConfig(sample_rate=sample_rate)
|
||||
streamer = AudioStreamer(audio_config)
|
||||
|
||||
# Create job with multipart form
|
||||
async with httpx.AsyncClient(timeout=300.0) as client:
|
||||
# Submit job
|
||||
with open(audio_path, "rb") as f:
|
||||
files = {
|
||||
"data_file": (audio_path.name, f, "audio/mpeg"),
|
||||
"config": (None, str(config).replace("'", '"'), "application/json"),
|
||||
}
|
||||
response = await client.post(
|
||||
self.api_url,
|
||||
headers=headers,
|
||||
files=files,
|
||||
)
|
||||
response.raise_for_status()
|
||||
job_data = response.json()
|
||||
# Collect results
|
||||
all_results: list[dict[str, Any]] = []
|
||||
recognition_started = asyncio.Event()
|
||||
transcription_complete = asyncio.Event()
|
||||
|
||||
job_id = job_data.get("id")
|
||||
if not job_id:
|
||||
raise ValueError(f"No job ID in response: {job_data}")
|
||||
async with websocket_connect(
|
||||
self.ws_url,
|
||||
additional_headers={"Authorization": f"Bearer {self.api_key}"},
|
||||
) as ws:
|
||||
# Send StartRecognition message
|
||||
start_msg = {
|
||||
"message": "StartRecognition",
|
||||
"transcription_config": transcription_config,
|
||||
"audio_format": audio_format,
|
||||
}
|
||||
await ws.send(json.dumps(start_msg))
|
||||
|
||||
# Poll for completion
|
||||
result_data = await self._wait_for_job(client, job_id, headers)
|
||||
async def send_audio():
|
||||
"""Send audio chunks after recognition starts."""
|
||||
await recognition_started.wait()
|
||||
|
||||
return self._parse_response(result_data, params)
|
||||
chunk_no = 0
|
||||
async for chunk in streamer.stream_file(audio_path):
|
||||
logger.debug(f"[speechmatics] Sent audio chunk {chunk_no}")
|
||||
await ws.send(chunk)
|
||||
chunk_no += 1
|
||||
|
||||
async def _wait_for_job(
|
||||
self, client: httpx.AsyncClient, job_id: str, headers: dict[str, str]
|
||||
) -> dict[str, Any]:
|
||||
"""Poll job status until complete."""
|
||||
import asyncio
|
||||
# Signal end of audio with last sequence number
|
||||
logger.debug(f"[speechmatics] Sending EndOfStream after {chunk_no} chunks")
|
||||
await ws.send(json.dumps({"message": "EndOfStream", "last_seq_no": chunk_no}))
|
||||
|
||||
job_url = f"{self.api_url}/{job_id}"
|
||||
transcript_url = f"{job_url}/transcript?format=json-v2"
|
||||
async def receive_messages():
|
||||
"""Receive and process messages."""
|
||||
nonlocal all_results
|
||||
|
||||
max_attempts = 120 # 10 minutes with 5s intervals
|
||||
for _ in range(max_attempts):
|
||||
# Check job status
|
||||
status_response = await client.get(job_url, headers=headers)
|
||||
status_response.raise_for_status()
|
||||
status_data = status_response.json()
|
||||
async for message in ws:
|
||||
if isinstance(message, str):
|
||||
data = json.loads(message)
|
||||
msg_type = data.get("message")
|
||||
logger.debug(f"[speechmatics] Received {msg_type}: {data}")
|
||||
|
||||
job_status = status_data.get("job", {}).get("status")
|
||||
if msg_type == "RecognitionStarted":
|
||||
logger.info("[speechmatics] Connected")
|
||||
recognition_started.set()
|
||||
|
||||
if job_status == "done":
|
||||
# Get transcript
|
||||
transcript_response = await client.get(transcript_url, headers=headers)
|
||||
transcript_response.raise_for_status()
|
||||
return transcript_response.json()
|
||||
elif job_status == "rejected":
|
||||
raise ValueError(f"Job rejected: {status_data}")
|
||||
elif job_status == "deleted":
|
||||
raise ValueError(f"Job deleted: {status_data}")
|
||||
elif msg_type == "AddTranscript":
|
||||
# Final transcript segment
|
||||
results = data.get("results", [])
|
||||
all_results.extend(results)
|
||||
|
||||
await asyncio.sleep(5)
|
||||
elif msg_type == "EndOfTranscript":
|
||||
transcription_complete.set()
|
||||
return
|
||||
|
||||
raise TimeoutError(f"Job {job_id} did not complete in time")
|
||||
elif msg_type == "Error":
|
||||
raise Exception(f"Speechmatics error: {data}")
|
||||
|
||||
def _parse_response(
|
||||
self, data: dict[str, Any], params: dict[str, Any]
|
||||
elif msg_type == "Warning":
|
||||
logger.warning(f"[speechmatics] Warning: {data.get('reason')}")
|
||||
|
||||
# Run send and receive concurrently
|
||||
send_task = asyncio.create_task(send_audio())
|
||||
receive_task = asyncio.create_task(receive_messages())
|
||||
|
||||
# Wait for completion
|
||||
await send_task
|
||||
try:
|
||||
await asyncio.wait_for(transcription_complete.wait(), timeout=30.0)
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
|
||||
receive_task.cancel()
|
||||
try:
|
||||
await receive_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
return self._parse_results(all_results, params)
|
||||
|
||||
def _parse_results(
|
||||
self,
|
||||
results: list[dict[str, Any]],
|
||||
params: dict[str, Any],
|
||||
) -> TranscriptionResult:
|
||||
"""Parse Speechmatics API response."""
|
||||
results = data.get("results", [])
|
||||
|
||||
"""Parse Speechmatics results."""
|
||||
words = []
|
||||
speakers_set: set[str] = set()
|
||||
transcript_parts = []
|
||||
duration = 0.0
|
||||
|
||||
for item in results:
|
||||
item_type = item.get("type")
|
||||
|
|
@ -176,27 +216,25 @@ class SpeechmaticsProvider(STTProvider):
|
|||
if speaker:
|
||||
speakers_set.add(speaker)
|
||||
|
||||
end_time = item.get("end_time", 0.0)
|
||||
duration = max(duration, end_time)
|
||||
|
||||
if item_type == "word":
|
||||
words.append(
|
||||
Word(
|
||||
word=content,
|
||||
start=item.get("start_time", 0.0),
|
||||
end=item.get("end_time", 0.0),
|
||||
end=end_time,
|
||||
confidence=alt.get("confidence", 0.0),
|
||||
speaker=speaker,
|
||||
speaker_confidence=None, # Not provided by Speechmatics
|
||||
speaker_confidence=None,
|
||||
)
|
||||
)
|
||||
transcript_parts.append(content)
|
||||
elif item_type == "punctuation":
|
||||
# Append punctuation to last word in transcript
|
||||
if transcript_parts:
|
||||
transcript_parts[-1] += content
|
||||
|
||||
# Get metadata
|
||||
metadata = data.get("metadata", {})
|
||||
duration = metadata.get("duration", 0.0)
|
||||
|
||||
transcript = " ".join(transcript_parts)
|
||||
|
||||
return TranscriptionResult(
|
||||
|
|
@ -205,6 +243,6 @@ class SpeechmaticsProvider(STTProvider):
|
|||
words=words,
|
||||
speakers=sorted(speakers_set),
|
||||
duration=duration,
|
||||
raw_response=data,
|
||||
raw_response={"results": results},
|
||||
params=params,
|
||||
)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue