mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-28 08:49:42 +02:00
feat: add cartesia ink 2 in STT models
This commit is contained in:
parent
557de72b9c
commit
327ec561d5
10 changed files with 309 additions and 17 deletions
|
|
@ -9,6 +9,12 @@ from .azure import (
|
|||
AZURE_SPEECH_TTS_LANGUAGES,
|
||||
AZURE_SPEECH_TTS_VOICES,
|
||||
)
|
||||
from .cartesia import (
|
||||
CARTESIA_INK_2_STT_LANGUAGES,
|
||||
CARTESIA_INK_WHISPER_STT_LANGUAGES,
|
||||
CARTESIA_STT_LANGUAGES,
|
||||
CARTESIA_STT_MODELS,
|
||||
)
|
||||
from .deepgram import (
|
||||
DEEPGRAM_FLUX_MODELS,
|
||||
DEEPGRAM_FLUX_MULTILINGUAL_LANGUAGE_OPTIONS,
|
||||
|
|
@ -59,6 +65,10 @@ __all__ = [
|
|||
"AZURE_SPEECH_STT_LANGUAGES",
|
||||
"AZURE_SPEECH_TTS_LANGUAGES",
|
||||
"AZURE_SPEECH_TTS_VOICES",
|
||||
"CARTESIA_INK_2_STT_LANGUAGES",
|
||||
"CARTESIA_INK_WHISPER_STT_LANGUAGES",
|
||||
"CARTESIA_STT_LANGUAGES",
|
||||
"CARTESIA_STT_MODELS",
|
||||
"DEEPGRAM_FLUX_MODELS",
|
||||
"DEEPGRAM_FLUX_MULTILINGUAL_LANGUAGES",
|
||||
"DEEPGRAM_FLUX_MULTILINGUAL_LANGUAGE_OPTIONS",
|
||||
|
|
|
|||
105
api/services/configuration/options/cartesia.py
Normal file
105
api/services/configuration/options/cartesia.py
Normal file
|
|
@ -0,0 +1,105 @@
|
|||
CARTESIA_STT_MODELS = ["ink-2", "ink-whisper"]
|
||||
CARTESIA_INK_2_STT_LANGUAGES = ("en",)
|
||||
CARTESIA_INK_WHISPER_STT_LANGUAGES = (
|
||||
"en",
|
||||
"zh",
|
||||
"de",
|
||||
"es",
|
||||
"ru",
|
||||
"ko",
|
||||
"fr",
|
||||
"ja",
|
||||
"pt",
|
||||
"tr",
|
||||
"pl",
|
||||
"ca",
|
||||
"nl",
|
||||
"ar",
|
||||
"sv",
|
||||
"it",
|
||||
"id",
|
||||
"hi",
|
||||
"fi",
|
||||
"vi",
|
||||
"he",
|
||||
"uk",
|
||||
"el",
|
||||
"ms",
|
||||
"cs",
|
||||
"ro",
|
||||
"da",
|
||||
"hu",
|
||||
"ta",
|
||||
"no",
|
||||
"th",
|
||||
"ur",
|
||||
"hr",
|
||||
"bg",
|
||||
"lt",
|
||||
"la",
|
||||
"mi",
|
||||
"ml",
|
||||
"cy",
|
||||
"sk",
|
||||
"te",
|
||||
"fa",
|
||||
"lv",
|
||||
"bn",
|
||||
"sr",
|
||||
"az",
|
||||
"sl",
|
||||
"kn",
|
||||
"et",
|
||||
"mk",
|
||||
"br",
|
||||
"eu",
|
||||
"is",
|
||||
"hy",
|
||||
"ne",
|
||||
"mn",
|
||||
"bs",
|
||||
"kk",
|
||||
"sq",
|
||||
"sw",
|
||||
"gl",
|
||||
"mr",
|
||||
"pa",
|
||||
"si",
|
||||
"km",
|
||||
"sn",
|
||||
"yo",
|
||||
"so",
|
||||
"af",
|
||||
"oc",
|
||||
"ka",
|
||||
"be",
|
||||
"tg",
|
||||
"sd",
|
||||
"gu",
|
||||
"am",
|
||||
"yi",
|
||||
"lo",
|
||||
"uz",
|
||||
"fo",
|
||||
"ht",
|
||||
"ps",
|
||||
"tk",
|
||||
"nn",
|
||||
"mt",
|
||||
"sa",
|
||||
"lb",
|
||||
"my",
|
||||
"bo",
|
||||
"tl",
|
||||
"mg",
|
||||
"as",
|
||||
"tt",
|
||||
"haw",
|
||||
"ln",
|
||||
"ha",
|
||||
"ba",
|
||||
"jw",
|
||||
"su",
|
||||
"yue",
|
||||
)
|
||||
CARTESIA_STT_LANGUAGES = CARTESIA_INK_WHISPER_STT_LANGUAGES
|
||||
|
|
@ -14,6 +14,10 @@ from api.services.configuration.options import (
|
|||
AZURE_SPEECH_STT_LANGUAGES,
|
||||
AZURE_SPEECH_TTS_LANGUAGES,
|
||||
AZURE_SPEECH_TTS_VOICES,
|
||||
CARTESIA_INK_2_STT_LANGUAGES,
|
||||
CARTESIA_INK_WHISPER_STT_LANGUAGES,
|
||||
CARTESIA_STT_LANGUAGES,
|
||||
CARTESIA_STT_MODELS,
|
||||
DEEPGRAM_FLUX_MULTILINGUAL_LANGUAGE_OPTIONS,
|
||||
DEEPGRAM_FLUX_MULTILINGUAL_LANGUAGES,
|
||||
DEEPGRAM_LANGUAGES,
|
||||
|
|
@ -1323,9 +1327,6 @@ class DeepgramSTTConfiguration(BaseSTTConfiguration):
|
|||
)
|
||||
|
||||
|
||||
CARTESIA_STT_MODELS = ["ink-whisper"]
|
||||
|
||||
|
||||
@register_stt
|
||||
class CartesiaSTTConfiguration(BaseSTTConfiguration):
|
||||
model_config = CARTESIA_PROVIDER_MODEL_CONFIG
|
||||
|
|
@ -1335,6 +1336,17 @@ class CartesiaSTTConfiguration(BaseSTTConfiguration):
|
|||
description="Cartesia STT model.",
|
||||
json_schema_extra={"examples": CARTESIA_STT_MODELS},
|
||||
)
|
||||
language: str = Field(
|
||||
default="en",
|
||||
description="ISO 639-1 language code. ink-2 currently supports English only.",
|
||||
json_schema_extra={
|
||||
"examples": CARTESIA_STT_LANGUAGES,
|
||||
"model_options": {
|
||||
"ink-2": CARTESIA_INK_2_STT_LANGUAGES,
|
||||
"ink-whisper": CARTESIA_INK_WHISPER_STT_LANGUAGES,
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
OPENAI_STT_MODELS = ["gpt-4o-transcribe"]
|
||||
|
|
|
|||
|
|
@ -46,7 +46,7 @@ from api.services.pipecat.service_factory import (
|
|||
create_realtime_llm_service,
|
||||
create_stt_service,
|
||||
create_tts_service,
|
||||
stt_uses_flux_turns,
|
||||
stt_uses_external_turns,
|
||||
)
|
||||
from api.services.pipecat.tracing_config import (
|
||||
ensure_tracing,
|
||||
|
|
@ -93,6 +93,19 @@ from pipecat.utils.run_context import set_current_org_id, set_current_run_id
|
|||
# Setup tracing if enabled
|
||||
ensure_tracing()
|
||||
|
||||
DEFAULT_USER_TURN_STOP_TIMEOUT = 5.0
|
||||
EXTERNAL_TURN_USER_STOP_TIMEOUT = 30.0
|
||||
|
||||
|
||||
def _resolve_user_turn_stop_timeout(
|
||||
run_configs: dict, *, uses_external_turns: bool
|
||||
) -> float:
|
||||
if "user_turn_stop_timeout" in run_configs:
|
||||
return float(run_configs["user_turn_stop_timeout"])
|
||||
if uses_external_turns:
|
||||
return EXTERNAL_TURN_USER_STOP_TIMEOUT
|
||||
return DEFAULT_USER_TURN_STOP_TIMEOUT
|
||||
|
||||
|
||||
def _create_realtime_user_turn_config(provider: str):
|
||||
"""Return user turn strategies and optional local VAD for realtime providers."""
|
||||
|
|
@ -620,16 +633,18 @@ async def _run_pipeline(
|
|||
|
||||
# Configure turn strategies based on STT provider, model, and workflow configuration
|
||||
if is_realtime:
|
||||
uses_external_turns = False
|
||||
# Realtime services still need user-turn tracking even when the model
|
||||
# itself owns speech generation and interruption behavior.
|
||||
user_turn_strategies, user_vad_analyzer = _create_realtime_user_turn_config(
|
||||
user_config.realtime.provider
|
||||
)
|
||||
else:
|
||||
# Deepgram Flux and supported Dograh managed Flux languages emit their
|
||||
# own turn boundaries, so the aggregator follows those external signals.
|
||||
# Other models use configurable turn detection.
|
||||
if stt_uses_flux_turns(user_config):
|
||||
# Some STT services emit their own turn boundaries, so the aggregator
|
||||
# follows those external signals. Other models use configurable turn
|
||||
# detection.
|
||||
uses_external_turns = stt_uses_external_turns(user_config)
|
||||
if uses_external_turns:
|
||||
user_turn_strategies = UserTurnStrategies(
|
||||
start=[
|
||||
VADUserTurnStartStrategy(),
|
||||
|
|
@ -661,9 +676,15 @@ async def _run_pipeline(
|
|||
stop=[SpeechTimeoutUserTurnStopStrategy()],
|
||||
)
|
||||
|
||||
user_turn_stop_timeout = _resolve_user_turn_stop_timeout(
|
||||
run_configs,
|
||||
uses_external_turns=uses_external_turns,
|
||||
)
|
||||
|
||||
user_params = LLMUserAggregatorParams(
|
||||
user_turn_strategies=user_turn_strategies,
|
||||
user_mute_strategies=user_mute_strategies,
|
||||
user_turn_stop_timeout=user_turn_stop_timeout,
|
||||
user_idle_timeout=max_user_idle_timeout,
|
||||
vad_analyzer=user_vad_analyzer,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -21,12 +21,13 @@ from pipecat.services.aws.llm import AWSBedrockLLMService, AWSBedrockLLMSettings
|
|||
from pipecat.services.azure.llm import AzureLLMService, AzureLLMSettings
|
||||
from pipecat.services.azure.stt import AzureSTTService, AzureSTTSettings
|
||||
from pipecat.services.azure.tts import AzureTTSService, AzureTTSSettings
|
||||
from pipecat.services.cartesia.stt import CartesiaSTTService
|
||||
from pipecat.services.cartesia.stt import CartesiaSTTService, CartesiaSTTSettings
|
||||
from pipecat.services.cartesia.tts import (
|
||||
CartesiaTTSService,
|
||||
CartesiaTTSSettings,
|
||||
GenerationConfig,
|
||||
)
|
||||
from pipecat.services.cartesia.turns.stt import CartesiaTurnsSTTService
|
||||
from pipecat.services.deepgram.flux.stt import (
|
||||
DeepgramFluxSTTService,
|
||||
DeepgramFluxSTTSettings,
|
||||
|
|
@ -106,11 +107,13 @@ def dograh_stt_uses_flux_language(language: str | None) -> bool:
|
|||
return language in DEEPGRAM_FLUX_MULTILINGUAL_LANGUAGE_OPTIONS
|
||||
|
||||
|
||||
def stt_uses_flux_turns(user_config) -> bool:
|
||||
def stt_uses_external_turns(user_config) -> bool:
|
||||
if user_config.stt.provider == ServiceProviders.DEEPGRAM.value:
|
||||
return user_config.stt.model in DEEPGRAM_FLUX_MODELS
|
||||
if user_config.stt.provider == ServiceProviders.DOGRAH.value:
|
||||
return dograh_stt_uses_flux_language(getattr(user_config.stt, "language", None))
|
||||
if user_config.stt.provider == ServiceProviders.CARTESIA.value:
|
||||
return user_config.stt.model == "ink-2"
|
||||
return False
|
||||
|
||||
|
||||
|
|
@ -214,8 +217,20 @@ def create_stt_service(
|
|||
sample_rate=audio_config.transport_in_sample_rate,
|
||||
)
|
||||
elif user_config.stt.provider == ServiceProviders.CARTESIA.value:
|
||||
if user_config.stt.model == "ink-2":
|
||||
return CartesiaTurnsSTTService(
|
||||
api_key=user_config.stt.api_key,
|
||||
should_interrupt=False, # Let UserAggregator emit interruption frames.
|
||||
sample_rate=audio_config.transport_in_sample_rate,
|
||||
)
|
||||
|
||||
language = getattr(user_config.stt, "language", None) or "en"
|
||||
return CartesiaSTTService(
|
||||
api_key=user_config.stt.api_key,
|
||||
settings=CartesiaSTTSettings(
|
||||
model=user_config.stt.model,
|
||||
language=language,
|
||||
),
|
||||
sample_rate=audio_config.transport_in_sample_rate,
|
||||
)
|
||||
elif user_config.stt.provider == ServiceProviders.DOGRAH.value:
|
||||
|
|
|
|||
94
api/tests/test_cartesia_stt_service_factory.py
Normal file
94
api/tests/test_cartesia_stt_service_factory.py
Normal file
|
|
@ -0,0 +1,94 @@
|
|||
from types import SimpleNamespace
|
||||
from unittest.mock import patch
|
||||
|
||||
from api.services.configuration.options import (
|
||||
CARTESIA_INK_2_STT_LANGUAGES,
|
||||
CARTESIA_INK_WHISPER_STT_LANGUAGES,
|
||||
CARTESIA_STT_MODELS,
|
||||
)
|
||||
from api.services.configuration.registry import (
|
||||
CartesiaSTTConfiguration,
|
||||
ServiceProviders,
|
||||
)
|
||||
from api.services.pipecat.audio_config import AudioConfig
|
||||
from api.services.pipecat.service_factory import (
|
||||
create_stt_service,
|
||||
stt_uses_external_turns,
|
||||
)
|
||||
|
||||
|
||||
def _audio_config() -> AudioConfig:
|
||||
return AudioConfig(
|
||||
transport_in_sample_rate=16000,
|
||||
transport_out_sample_rate=16000,
|
||||
)
|
||||
|
||||
|
||||
def _cartesia_config(model: str, language: str = "en") -> SimpleNamespace:
|
||||
return SimpleNamespace(
|
||||
stt=SimpleNamespace(
|
||||
provider=ServiceProviders.CARTESIA.value,
|
||||
api_key="test-key",
|
||||
model=model,
|
||||
language=language,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def test_cartesia_stt_configuration_exposes_ink_2_and_ink_whisper_languages():
|
||||
config = CartesiaSTTConfiguration(api_key="test-key")
|
||||
language_schema = CartesiaSTTConfiguration.model_json_schema()["properties"][
|
||||
"language"
|
||||
]
|
||||
|
||||
assert config.provider == ServiceProviders.CARTESIA
|
||||
assert config.model == "ink-whisper"
|
||||
assert config.language == "en"
|
||||
assert CARTESIA_STT_MODELS == ["ink-2", "ink-whisper"]
|
||||
assert CARTESIA_INK_2_STT_LANGUAGES == ("en",)
|
||||
assert "es" in CARTESIA_INK_WHISPER_STT_LANGUAGES
|
||||
assert language_schema["model_options"]["ink-2"] == ["en"]
|
||||
assert "es" in language_schema["model_options"]["ink-whisper"]
|
||||
|
||||
|
||||
def test_cartesia_ink_2_uses_external_turns_and_turns_service():
|
||||
user_config = _cartesia_config("ink-2")
|
||||
|
||||
assert stt_uses_external_turns(user_config)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"api.services.pipecat.service_factory.CartesiaTurnsSTTService"
|
||||
) as turns_service,
|
||||
patch("api.services.pipecat.service_factory.CartesiaSTTService") as stt_service,
|
||||
):
|
||||
create_stt_service(user_config, _audio_config())
|
||||
|
||||
turns_service.assert_called_once()
|
||||
stt_service.assert_not_called()
|
||||
kwargs = turns_service.call_args.kwargs
|
||||
assert kwargs["api_key"] == "test-key"
|
||||
assert kwargs["sample_rate"] == 16000
|
||||
assert kwargs["should_interrupt"] is False
|
||||
|
||||
|
||||
def test_cartesia_ink_whisper_uses_manual_stt_service_with_model_and_language():
|
||||
user_config = _cartesia_config("ink-whisper", language="es")
|
||||
|
||||
assert not stt_uses_external_turns(user_config)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"api.services.pipecat.service_factory.CartesiaTurnsSTTService"
|
||||
) as turns_service,
|
||||
patch("api.services.pipecat.service_factory.CartesiaSTTService") as stt_service,
|
||||
):
|
||||
create_stt_service(user_config, _audio_config())
|
||||
|
||||
turns_service.assert_not_called()
|
||||
stt_service.assert_called_once()
|
||||
kwargs = stt_service.call_args.kwargs
|
||||
assert kwargs["api_key"] == "test-key"
|
||||
assert kwargs["sample_rate"] == 16000
|
||||
assert kwargs["settings"].model == "ink-whisper"
|
||||
assert kwargs["settings"].language == "es"
|
||||
|
|
@ -9,7 +9,7 @@ from api.services.pipecat.audio_config import AudioConfig
|
|||
from api.services.pipecat.service_factory import (
|
||||
create_stt_service,
|
||||
dograh_stt_uses_flux_language,
|
||||
stt_uses_flux_turns,
|
||||
stt_uses_external_turns,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -38,10 +38,10 @@ def test_dograh_flux_language_predicate_matches_multilingual_support():
|
|||
assert not dograh_stt_uses_flux_language("ar")
|
||||
|
||||
|
||||
def test_stt_uses_flux_turns_only_for_dograh_flux_supported_languages():
|
||||
assert stt_uses_flux_turns(_dograh_config("multi"))
|
||||
assert stt_uses_flux_turns(_dograh_config("es"))
|
||||
assert not stt_uses_flux_turns(_dograh_config("ar"))
|
||||
def test_stt_uses_external_turns_only_for_dograh_flux_supported_languages():
|
||||
assert stt_uses_external_turns(_dograh_config("multi"))
|
||||
assert stt_uses_external_turns(_dograh_config("es"))
|
||||
assert not stt_uses_external_turns(_dograh_config("ar"))
|
||||
|
||||
|
||||
def test_create_dograh_multi_uses_flux_service_without_language_hint():
|
||||
|
|
|
|||
|
|
@ -11,7 +11,12 @@ from pipecat.turns.user_stop import (
|
|||
)
|
||||
|
||||
from api.services.configuration.registry import ServiceProviders
|
||||
from api.services.pipecat.run_pipeline import _create_realtime_user_turn_config
|
||||
from api.services.pipecat.run_pipeline import (
|
||||
DEFAULT_USER_TURN_STOP_TIMEOUT,
|
||||
EXTERNAL_TURN_USER_STOP_TIMEOUT,
|
||||
_create_realtime_user_turn_config,
|
||||
_resolve_user_turn_stop_timeout,
|
||||
)
|
||||
|
||||
|
||||
def test_gemini_realtime_uses_local_vad_without_local_interruptions():
|
||||
|
|
@ -72,3 +77,27 @@ def test_unknown_realtime_providers_keep_local_vad():
|
|||
assert isinstance(strategies.start[0], VADUserTurnStartStrategy)
|
||||
assert len(strategies.stop) == 1
|
||||
assert isinstance(strategies.stop[0], SpeechTimeoutUserTurnStopStrategy)
|
||||
|
||||
|
||||
def test_external_turn_stt_uses_longer_stop_timeout():
|
||||
assert (
|
||||
_resolve_user_turn_stop_timeout({}, uses_external_turns=True)
|
||||
== EXTERNAL_TURN_USER_STOP_TIMEOUT
|
||||
)
|
||||
|
||||
|
||||
def test_standard_stt_keeps_default_stop_timeout():
|
||||
assert (
|
||||
_resolve_user_turn_stop_timeout({}, uses_external_turns=False)
|
||||
== DEFAULT_USER_TURN_STOP_TIMEOUT
|
||||
)
|
||||
|
||||
|
||||
def test_workflow_config_can_override_user_turn_stop_timeout():
|
||||
assert (
|
||||
_resolve_user_turn_stop_timeout(
|
||||
{"user_turn_stop_timeout": "12.5"},
|
||||
uses_external_turns=True,
|
||||
)
|
||||
== 12.5
|
||||
)
|
||||
|
|
|
|||
File diff suppressed because one or more lines are too long
|
|
@ -1042,6 +1042,12 @@ export type CartesiaSttConfiguration = {
|
|||
* Cartesia STT model.
|
||||
*/
|
||||
model?: string;
|
||||
/**
|
||||
* Language
|
||||
*
|
||||
* ISO 639-1 language code. ink-2 currently supports English only.
|
||||
*/
|
||||
language?: string;
|
||||
};
|
||||
|
||||
/**
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue