mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-28 08:49:42 +02:00
feat: add cartesia ink 2 in STT models
This commit is contained in:
parent
557de72b9c
commit
327ec561d5
10 changed files with 309 additions and 17 deletions
|
|
@ -46,7 +46,7 @@ from api.services.pipecat.service_factory import (
|
|||
create_realtime_llm_service,
|
||||
create_stt_service,
|
||||
create_tts_service,
|
||||
stt_uses_flux_turns,
|
||||
stt_uses_external_turns,
|
||||
)
|
||||
from api.services.pipecat.tracing_config import (
|
||||
ensure_tracing,
|
||||
|
|
@ -93,6 +93,19 @@ from pipecat.utils.run_context import set_current_org_id, set_current_run_id
|
|||
# Setup tracing if enabled
|
||||
ensure_tracing()
|
||||
|
||||
DEFAULT_USER_TURN_STOP_TIMEOUT = 5.0
|
||||
EXTERNAL_TURN_USER_STOP_TIMEOUT = 30.0
|
||||
|
||||
|
||||
def _resolve_user_turn_stop_timeout(
|
||||
run_configs: dict, *, uses_external_turns: bool
|
||||
) -> float:
|
||||
if "user_turn_stop_timeout" in run_configs:
|
||||
return float(run_configs["user_turn_stop_timeout"])
|
||||
if uses_external_turns:
|
||||
return EXTERNAL_TURN_USER_STOP_TIMEOUT
|
||||
return DEFAULT_USER_TURN_STOP_TIMEOUT
|
||||
|
||||
|
||||
def _create_realtime_user_turn_config(provider: str):
|
||||
"""Return user turn strategies and optional local VAD for realtime providers."""
|
||||
|
|
@ -620,16 +633,18 @@ async def _run_pipeline(
|
|||
|
||||
# Configure turn strategies based on STT provider, model, and workflow configuration
|
||||
if is_realtime:
|
||||
uses_external_turns = False
|
||||
# Realtime services still need user-turn tracking even when the model
|
||||
# itself owns speech generation and interruption behavior.
|
||||
user_turn_strategies, user_vad_analyzer = _create_realtime_user_turn_config(
|
||||
user_config.realtime.provider
|
||||
)
|
||||
else:
|
||||
# Deepgram Flux and supported Dograh managed Flux languages emit their
|
||||
# own turn boundaries, so the aggregator follows those external signals.
|
||||
# Other models use configurable turn detection.
|
||||
if stt_uses_flux_turns(user_config):
|
||||
# Some STT services emit their own turn boundaries, so the aggregator
|
||||
# follows those external signals. Other models use configurable turn
|
||||
# detection.
|
||||
uses_external_turns = stt_uses_external_turns(user_config)
|
||||
if uses_external_turns:
|
||||
user_turn_strategies = UserTurnStrategies(
|
||||
start=[
|
||||
VADUserTurnStartStrategy(),
|
||||
|
|
@ -661,9 +676,15 @@ async def _run_pipeline(
|
|||
stop=[SpeechTimeoutUserTurnStopStrategy()],
|
||||
)
|
||||
|
||||
user_turn_stop_timeout = _resolve_user_turn_stop_timeout(
|
||||
run_configs,
|
||||
uses_external_turns=uses_external_turns,
|
||||
)
|
||||
|
||||
user_params = LLMUserAggregatorParams(
|
||||
user_turn_strategies=user_turn_strategies,
|
||||
user_mute_strategies=user_mute_strategies,
|
||||
user_turn_stop_timeout=user_turn_stop_timeout,
|
||||
user_idle_timeout=max_user_idle_timeout,
|
||||
vad_analyzer=user_vad_analyzer,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -21,12 +21,13 @@ from pipecat.services.aws.llm import AWSBedrockLLMService, AWSBedrockLLMSettings
|
|||
from pipecat.services.azure.llm import AzureLLMService, AzureLLMSettings
|
||||
from pipecat.services.azure.stt import AzureSTTService, AzureSTTSettings
|
||||
from pipecat.services.azure.tts import AzureTTSService, AzureTTSSettings
|
||||
from pipecat.services.cartesia.stt import CartesiaSTTService
|
||||
from pipecat.services.cartesia.stt import CartesiaSTTService, CartesiaSTTSettings
|
||||
from pipecat.services.cartesia.tts import (
|
||||
CartesiaTTSService,
|
||||
CartesiaTTSSettings,
|
||||
GenerationConfig,
|
||||
)
|
||||
from pipecat.services.cartesia.turns.stt import CartesiaTurnsSTTService
|
||||
from pipecat.services.deepgram.flux.stt import (
|
||||
DeepgramFluxSTTService,
|
||||
DeepgramFluxSTTSettings,
|
||||
|
|
@ -106,11 +107,13 @@ def dograh_stt_uses_flux_language(language: str | None) -> bool:
|
|||
return language in DEEPGRAM_FLUX_MULTILINGUAL_LANGUAGE_OPTIONS
|
||||
|
||||
|
||||
def stt_uses_flux_turns(user_config) -> bool:
|
||||
def stt_uses_external_turns(user_config) -> bool:
|
||||
if user_config.stt.provider == ServiceProviders.DEEPGRAM.value:
|
||||
return user_config.stt.model in DEEPGRAM_FLUX_MODELS
|
||||
if user_config.stt.provider == ServiceProviders.DOGRAH.value:
|
||||
return dograh_stt_uses_flux_language(getattr(user_config.stt, "language", None))
|
||||
if user_config.stt.provider == ServiceProviders.CARTESIA.value:
|
||||
return user_config.stt.model == "ink-2"
|
||||
return False
|
||||
|
||||
|
||||
|
|
@ -214,8 +217,20 @@ def create_stt_service(
|
|||
sample_rate=audio_config.transport_in_sample_rate,
|
||||
)
|
||||
elif user_config.stt.provider == ServiceProviders.CARTESIA.value:
|
||||
if user_config.stt.model == "ink-2":
|
||||
return CartesiaTurnsSTTService(
|
||||
api_key=user_config.stt.api_key,
|
||||
should_interrupt=False, # Let UserAggregator emit interruption frames.
|
||||
sample_rate=audio_config.transport_in_sample_rate,
|
||||
)
|
||||
|
||||
language = getattr(user_config.stt, "language", None) or "en"
|
||||
return CartesiaSTTService(
|
||||
api_key=user_config.stt.api_key,
|
||||
settings=CartesiaSTTSettings(
|
||||
model=user_config.stt.model,
|
||||
language=language,
|
||||
),
|
||||
sample_rate=audio_config.transport_in_sample_rate,
|
||||
)
|
||||
elif user_config.stt.provider == ServiceProviders.DOGRAH.value:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue