mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-07-01 08:59:46 +02:00
feat: add custom sarvam tts voice (#449)
* feat: add custom sarvam tts voice * chore: refactor registry and add deepgram multi --------- Co-authored-by: Abhishek Kumar <abhishek@a6k.me>
This commit is contained in:
parent
344c8220de
commit
951e73a645
9 changed files with 268 additions and 69 deletions
|
|
@ -9,7 +9,13 @@ from .azure import (
|
||||||
AZURE_SPEECH_TTS_LANGUAGES,
|
AZURE_SPEECH_TTS_LANGUAGES,
|
||||||
AZURE_SPEECH_TTS_VOICES,
|
AZURE_SPEECH_TTS_VOICES,
|
||||||
)
|
)
|
||||||
from .deepgram import DEEPGRAM_LANGUAGES, DEEPGRAM_STT_MODELS
|
from .deepgram import (
|
||||||
|
DEEPGRAM_FLUX_MODELS,
|
||||||
|
DEEPGRAM_FLUX_MULTILINGUAL_LANGUAGE_OPTIONS,
|
||||||
|
DEEPGRAM_FLUX_MULTILINGUAL_LANGUAGES,
|
||||||
|
DEEPGRAM_LANGUAGES,
|
||||||
|
DEEPGRAM_STT_MODELS,
|
||||||
|
)
|
||||||
from .gladia import GLADIA_STT_LANGUAGES, GLADIA_STT_MODELS
|
from .gladia import GLADIA_STT_LANGUAGES, GLADIA_STT_MODELS
|
||||||
from .google import (
|
from .google import (
|
||||||
GOOGLE_MODELS,
|
GOOGLE_MODELS,
|
||||||
|
|
@ -35,6 +41,11 @@ from .sarvam import (
|
||||||
SARVAM_V2_VOICES,
|
SARVAM_V2_VOICES,
|
||||||
SARVAM_V3_VOICES,
|
SARVAM_V3_VOICES,
|
||||||
)
|
)
|
||||||
|
from .smallest import (
|
||||||
|
SMALLEST_TTS_LANGUAGES,
|
||||||
|
SMALLEST_TTS_MODELS,
|
||||||
|
SMALLEST_TTS_VOICES,
|
||||||
|
)
|
||||||
from .speechmatics import SPEECHMATICS_STT_LANGUAGES
|
from .speechmatics import SPEECHMATICS_STT_LANGUAGES
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
|
@ -47,6 +58,9 @@ __all__ = [
|
||||||
"AZURE_SPEECH_STT_LANGUAGES",
|
"AZURE_SPEECH_STT_LANGUAGES",
|
||||||
"AZURE_SPEECH_TTS_LANGUAGES",
|
"AZURE_SPEECH_TTS_LANGUAGES",
|
||||||
"AZURE_SPEECH_TTS_VOICES",
|
"AZURE_SPEECH_TTS_VOICES",
|
||||||
|
"DEEPGRAM_FLUX_MODELS",
|
||||||
|
"DEEPGRAM_FLUX_MULTILINGUAL_LANGUAGES",
|
||||||
|
"DEEPGRAM_FLUX_MULTILINGUAL_LANGUAGE_OPTIONS",
|
||||||
"DEEPGRAM_LANGUAGES",
|
"DEEPGRAM_LANGUAGES",
|
||||||
"DEEPGRAM_STT_MODELS",
|
"DEEPGRAM_STT_MODELS",
|
||||||
"GLADIA_STT_LANGUAGES",
|
"GLADIA_STT_LANGUAGES",
|
||||||
|
|
@ -71,5 +85,8 @@ __all__ = [
|
||||||
"SARVAM_TTS_MODELS",
|
"SARVAM_TTS_MODELS",
|
||||||
"SARVAM_V2_VOICES",
|
"SARVAM_V2_VOICES",
|
||||||
"SARVAM_V3_VOICES",
|
"SARVAM_V3_VOICES",
|
||||||
|
"SMALLEST_TTS_LANGUAGES",
|
||||||
|
"SMALLEST_TTS_MODELS",
|
||||||
|
"SMALLEST_TTS_VOICES",
|
||||||
"SPEECHMATICS_STT_LANGUAGES",
|
"SPEECHMATICS_STT_LANGUAGES",
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,21 @@
|
||||||
DEEPGRAM_STT_MODELS = ("nova-3-general", "flux-general-en", "flux-general-multi")
|
DEEPGRAM_FLUX_MODELS = ("flux-general-en", "flux-general-multi")
|
||||||
|
DEEPGRAM_FLUX_MULTILINGUAL_LANGUAGES = (
|
||||||
|
"de",
|
||||||
|
"en",
|
||||||
|
"es",
|
||||||
|
"fr",
|
||||||
|
"hi",
|
||||||
|
"it",
|
||||||
|
"ja",
|
||||||
|
"nl",
|
||||||
|
"pt",
|
||||||
|
"ru",
|
||||||
|
)
|
||||||
|
DEEPGRAM_FLUX_MULTILINGUAL_LANGUAGE_OPTIONS = (
|
||||||
|
"multi",
|
||||||
|
*DEEPGRAM_FLUX_MULTILINGUAL_LANGUAGES,
|
||||||
|
)
|
||||||
|
DEEPGRAM_STT_MODELS = ("nova-3-general", *DEEPGRAM_FLUX_MODELS)
|
||||||
DEEPGRAM_LANGUAGES = (
|
DEEPGRAM_LANGUAGES = (
|
||||||
"multi",
|
"multi",
|
||||||
"ar",
|
"ar",
|
||||||
|
|
|
||||||
36
api/services/configuration/options/smallest.py
Normal file
36
api/services/configuration/options/smallest.py
Normal file
|
|
@ -0,0 +1,36 @@
|
||||||
|
SMALLEST_TTS_MODELS = ("lightning_v3.1", "lightning_v3.1_pro")
|
||||||
|
SMALLEST_TTS_VOICES = (
|
||||||
|
"sophia",
|
||||||
|
"avery",
|
||||||
|
"liam",
|
||||||
|
"lucas",
|
||||||
|
"olivia",
|
||||||
|
"ryan",
|
||||||
|
"freya",
|
||||||
|
"william",
|
||||||
|
"devansh",
|
||||||
|
"arjun",
|
||||||
|
"niharika",
|
||||||
|
"maya",
|
||||||
|
"dhruv",
|
||||||
|
"mia",
|
||||||
|
"maithili",
|
||||||
|
)
|
||||||
|
SMALLEST_TTS_LANGUAGES = (
|
||||||
|
"en",
|
||||||
|
"hi",
|
||||||
|
"fr",
|
||||||
|
"de",
|
||||||
|
"es",
|
||||||
|
"it",
|
||||||
|
"nl",
|
||||||
|
"pl",
|
||||||
|
"ru",
|
||||||
|
"ar",
|
||||||
|
"bn",
|
||||||
|
"gu",
|
||||||
|
"he",
|
||||||
|
"kn",
|
||||||
|
"mr",
|
||||||
|
"ta",
|
||||||
|
)
|
||||||
|
|
@ -14,6 +14,7 @@ from api.services.configuration.options import (
|
||||||
AZURE_SPEECH_STT_LANGUAGES,
|
AZURE_SPEECH_STT_LANGUAGES,
|
||||||
AZURE_SPEECH_TTS_LANGUAGES,
|
AZURE_SPEECH_TTS_LANGUAGES,
|
||||||
AZURE_SPEECH_TTS_VOICES,
|
AZURE_SPEECH_TTS_VOICES,
|
||||||
|
DEEPGRAM_FLUX_MULTILINGUAL_LANGUAGE_OPTIONS,
|
||||||
DEEPGRAM_LANGUAGES,
|
DEEPGRAM_LANGUAGES,
|
||||||
DEEPGRAM_STT_MODELS,
|
DEEPGRAM_STT_MODELS,
|
||||||
GLADIA_STT_LANGUAGES,
|
GLADIA_STT_LANGUAGES,
|
||||||
|
|
@ -38,6 +39,9 @@ from api.services.configuration.options import (
|
||||||
SARVAM_TTS_MODELS,
|
SARVAM_TTS_MODELS,
|
||||||
SARVAM_V2_VOICES,
|
SARVAM_V2_VOICES,
|
||||||
SARVAM_V3_VOICES,
|
SARVAM_V3_VOICES,
|
||||||
|
SMALLEST_TTS_LANGUAGES,
|
||||||
|
SMALLEST_TTS_MODELS,
|
||||||
|
SMALLEST_TTS_VOICES,
|
||||||
SPEECHMATICS_STT_LANGUAGES,
|
SPEECHMATICS_STT_LANGUAGES,
|
||||||
)
|
)
|
||||||
from api.services.configuration.options.google import GOOGLE_VERTEX_MODELS
|
from api.services.configuration.options.google import GOOGLE_VERTEX_MODELS
|
||||||
|
|
@ -987,9 +991,10 @@ class SarvamTTSConfiguration(BaseTTSConfiguration):
|
||||||
)
|
)
|
||||||
voice: str = Field(
|
voice: str = Field(
|
||||||
default="anushka",
|
default="anushka",
|
||||||
description="Sarvam voice name; must match the selected model's voice list.",
|
description="Sarvam voice name or custom voice ID.",
|
||||||
json_schema_extra={
|
json_schema_extra={
|
||||||
"examples": SARVAM_V2_VOICES,
|
"examples": SARVAM_V2_VOICES,
|
||||||
|
"allow_custom_input": True,
|
||||||
"model_options": {
|
"model_options": {
|
||||||
"bulbul:v2": SARVAM_V2_VOICES,
|
"bulbul:v2": SARVAM_V2_VOICES,
|
||||||
"bulbul:v3": SARVAM_V3_VOICES,
|
"bulbul:v3": SARVAM_V3_VOICES,
|
||||||
|
|
@ -1172,43 +1177,6 @@ SMALLEST_PROVIDER_MODEL_CONFIG = provider_model_config(
|
||||||
provider_docs_url="https://smallest.ai/docs",
|
provider_docs_url="https://smallest.ai/docs",
|
||||||
)
|
)
|
||||||
|
|
||||||
SMALLEST_TTS_MODELS = ["lightning_v3.1", "lightning_v3.1_pro"]
|
|
||||||
SMALLEST_TTS_VOICES = [
|
|
||||||
"sophia",
|
|
||||||
"avery",
|
|
||||||
"liam",
|
|
||||||
"lucas",
|
|
||||||
"olivia",
|
|
||||||
"ryan",
|
|
||||||
"freya",
|
|
||||||
"william",
|
|
||||||
"devansh",
|
|
||||||
"arjun",
|
|
||||||
"niharika",
|
|
||||||
"maya",
|
|
||||||
"dhruv",
|
|
||||||
"mia",
|
|
||||||
"maithili",
|
|
||||||
]
|
|
||||||
SMALLEST_TTS_LANGUAGES = [
|
|
||||||
"en",
|
|
||||||
"hi",
|
|
||||||
"fr",
|
|
||||||
"de",
|
|
||||||
"es",
|
|
||||||
"it",
|
|
||||||
"nl",
|
|
||||||
"pl",
|
|
||||||
"ru",
|
|
||||||
"ar",
|
|
||||||
"bn",
|
|
||||||
"gu",
|
|
||||||
"he",
|
|
||||||
"kn",
|
|
||||||
"mr",
|
|
||||||
"ta",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@register_tts
|
@register_tts
|
||||||
class SmallestAITTSConfiguration(BaseTTSConfiguration):
|
class SmallestAITTSConfiguration(BaseTTSConfiguration):
|
||||||
|
|
@ -1273,12 +1241,16 @@ class DeepgramSTTConfiguration(BaseSTTConfiguration):
|
||||||
)
|
)
|
||||||
language: str = Field(
|
language: str = Field(
|
||||||
default="multi",
|
default="multi",
|
||||||
description="Language code; 'multi' enables auto-detect (Nova-3 only).",
|
description=(
|
||||||
|
"Language code. 'multi' enables Nova-3 auto-detect and omits "
|
||||||
|
"language hints for Flux multilingual auto-detect."
|
||||||
|
),
|
||||||
json_schema_extra={
|
json_schema_extra={
|
||||||
"examples": DEEPGRAM_LANGUAGES,
|
"examples": DEEPGRAM_LANGUAGES,
|
||||||
"model_options": {
|
"model_options": {
|
||||||
"nova-3-general": DEEPGRAM_LANGUAGES,
|
"nova-3-general": DEEPGRAM_LANGUAGES,
|
||||||
"flux-general-en": ("en",),
|
"flux-general-en": ("en",),
|
||||||
|
"flux-general-multi": DEEPGRAM_FLUX_MULTILINGUAL_LANGUAGE_OPTIONS,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,7 @@ from loguru import logger
|
||||||
|
|
||||||
from api.db import db_client
|
from api.db import db_client
|
||||||
from api.enums import WorkflowRunMode
|
from api.enums import WorkflowRunMode
|
||||||
|
from api.services.configuration.options import DEEPGRAM_FLUX_MODELS
|
||||||
from api.services.configuration.registry import ServiceProviders
|
from api.services.configuration.registry import ServiceProviders
|
||||||
from api.services.integrations import (
|
from api.services.integrations import (
|
||||||
IntegrationRuntimeContext,
|
IntegrationRuntimeContext,
|
||||||
|
|
@ -626,7 +627,7 @@ async def _run_pipeline(
|
||||||
# Other models use configurable turn detection strategy
|
# Other models use configurable turn detection strategy
|
||||||
is_deepgram_flux = (
|
is_deepgram_flux = (
|
||||||
user_config.stt.provider == ServiceProviders.DEEPGRAM.value
|
user_config.stt.provider == ServiceProviders.DEEPGRAM.value
|
||||||
and user_config.stt.model == "flux-general-en"
|
and user_config.stt.model in DEEPGRAM_FLUX_MODELS
|
||||||
)
|
)
|
||||||
|
|
||||||
if is_deepgram_flux:
|
if is_deepgram_flux:
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,7 @@ from fastapi import HTTPException
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
from api.constants import MPS_API_URL
|
from api.constants import MPS_API_URL
|
||||||
|
from api.services.configuration.options import DEEPGRAM_FLUX_MODELS
|
||||||
from api.services.configuration.registry import ServiceProviders
|
from api.services.configuration.registry import ServiceProviders
|
||||||
from api.services.pipecat.minimax_tts import MiniMaxOwnedSessionTTSService
|
from api.services.pipecat.minimax_tts import MiniMaxOwnedSessionTTSService
|
||||||
from api.utils.url_security import validate_user_configured_service_url
|
from api.utils.url_security import validate_user_configured_service_url
|
||||||
|
|
@ -78,6 +79,20 @@ if TYPE_CHECKING:
|
||||||
from api.services.pipecat.audio_config import AudioConfig
|
from api.services.pipecat.audio_config import AudioConfig
|
||||||
|
|
||||||
|
|
||||||
|
DEEPGRAM_FLUX_LANGUAGE_HINTS = {
|
||||||
|
"de": Language.DE,
|
||||||
|
"en": Language.EN,
|
||||||
|
"es": Language.ES,
|
||||||
|
"fr": Language.FR,
|
||||||
|
"hi": Language.HI,
|
||||||
|
"it": Language.IT,
|
||||||
|
"ja": Language.JA,
|
||||||
|
"nl": Language.NL,
|
||||||
|
"pt": Language.PT,
|
||||||
|
"ru": Language.RU,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def _validate_runtime_service_url(url: str, field_name: str) -> None:
|
def _validate_runtime_service_url(url: str, field_name: str) -> None:
|
||||||
try:
|
try:
|
||||||
validate_user_configured_service_url(
|
validate_user_configured_service_url(
|
||||||
|
|
@ -104,17 +119,23 @@ def create_stt_service(
|
||||||
f"Creating STT service: provider={user_config.stt.provider}, model={user_config.stt.model}"
|
f"Creating STT service: provider={user_config.stt.provider}, model={user_config.stt.model}"
|
||||||
)
|
)
|
||||||
if user_config.stt.provider == ServiceProviders.DEEPGRAM.value:
|
if user_config.stt.provider == ServiceProviders.DEEPGRAM.value:
|
||||||
# Check if using Flux model (English-only, no language selection)
|
if user_config.stt.model in DEEPGRAM_FLUX_MODELS:
|
||||||
if user_config.stt.model == "flux-general-en":
|
settings_kwargs = {
|
||||||
|
"model": user_config.stt.model,
|
||||||
|
"eot_timeout_ms": 3000,
|
||||||
|
"eot_threshold": 0.7,
|
||||||
|
"eager_eot_threshold": 0.5,
|
||||||
|
"keyterm": keyterms or [],
|
||||||
|
}
|
||||||
|
if user_config.stt.model == "flux-general-multi":
|
||||||
|
language = getattr(user_config.stt, "language", None)
|
||||||
|
language_hint = DEEPGRAM_FLUX_LANGUAGE_HINTS.get(language)
|
||||||
|
if language_hint:
|
||||||
|
settings_kwargs["language_hints"] = [language_hint]
|
||||||
|
|
||||||
return DeepgramFluxSTTService(
|
return DeepgramFluxSTTService(
|
||||||
api_key=user_config.stt.api_key,
|
api_key=user_config.stt.api_key,
|
||||||
settings=DeepgramFluxSTTSettings(
|
settings=DeepgramFluxSTTSettings(**settings_kwargs),
|
||||||
model=user_config.stt.model,
|
|
||||||
eot_timeout_ms=3000,
|
|
||||||
eot_threshold=0.7,
|
|
||||||
eager_eot_threshold=0.5,
|
|
||||||
keyterm=keyterms or [],
|
|
||||||
),
|
|
||||||
should_interrupt=False, # Let UserAggregator take care of sending InterruptionFrame
|
should_interrupt=False, # Let UserAggregator take care of sending InterruptionFrame
|
||||||
sample_rate=audio_config.transport_in_sample_rate,
|
sample_rate=audio_config.transport_in_sample_rate,
|
||||||
)
|
)
|
||||||
|
|
@ -534,7 +555,9 @@ def create_tts_service(
|
||||||
language = getattr(user_config.tts, "language", None)
|
language = getattr(user_config.tts, "language", None)
|
||||||
pipecat_language = language_mapping.get(language, Language.HI)
|
pipecat_language = language_mapping.get(language, Language.HI)
|
||||||
|
|
||||||
voice = getattr(user_config.tts, "voice", None) or "anushka"
|
voice = (
|
||||||
|
getattr(user_config.tts, "voice", None) or ""
|
||||||
|
).strip().lower() or "anushka"
|
||||||
speed = getattr(user_config.tts, "speed", None)
|
speed = getattr(user_config.tts, "speed", None)
|
||||||
settings_kwargs = {
|
settings_kwargs = {
|
||||||
"model": user_config.tts.model,
|
"model": user_config.tts.model,
|
||||||
|
|
|
||||||
70
api/tests/test_deepgram_flux_service_factory.py
Normal file
70
api/tests/test_deepgram_flux_service_factory.py
Normal file
|
|
@ -0,0 +1,70 @@
|
||||||
|
from types import SimpleNamespace
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
from pipecat.services.settings import NOT_GIVEN
|
||||||
|
from pipecat.transcriptions.language import Language
|
||||||
|
|
||||||
|
from api.services.configuration.registry import (
|
||||||
|
DeepgramSTTConfiguration,
|
||||||
|
ServiceProviders,
|
||||||
|
)
|
||||||
|
from api.services.pipecat.audio_config import AudioConfig
|
||||||
|
from api.services.pipecat.service_factory import create_stt_service
|
||||||
|
|
||||||
|
|
||||||
|
def test_deepgram_stt_schema_includes_flux_multilingual_language_options():
|
||||||
|
language_schema = DeepgramSTTConfiguration.model_json_schema()["properties"][
|
||||||
|
"language"
|
||||||
|
]
|
||||||
|
|
||||||
|
assert "flux-general-multi" in language_schema["model_options"]
|
||||||
|
assert "multi" in language_schema["model_options"]["flux-general-multi"]
|
||||||
|
assert "es" in language_schema["model_options"]["flux-general-multi"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_deepgram_flux_multi_uses_flux_service_with_language_hint():
|
||||||
|
user_config = SimpleNamespace(
|
||||||
|
stt=SimpleNamespace(
|
||||||
|
provider=ServiceProviders.DEEPGRAM.value,
|
||||||
|
api_key="test-key",
|
||||||
|
model="flux-general-multi",
|
||||||
|
language="es",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
audio_config = AudioConfig(
|
||||||
|
transport_in_sample_rate=16000,
|
||||||
|
transport_out_sample_rate=16000,
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"api.services.pipecat.service_factory.DeepgramFluxSTTService"
|
||||||
|
) as mock_service:
|
||||||
|
create_stt_service(user_config, audio_config)
|
||||||
|
|
||||||
|
kwargs = mock_service.call_args.kwargs
|
||||||
|
assert kwargs["settings"].model == "flux-general-multi"
|
||||||
|
assert kwargs["settings"].language_hints == [Language.ES]
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_deepgram_flux_multi_omits_auto_detect_language_hint():
|
||||||
|
user_config = SimpleNamespace(
|
||||||
|
stt=SimpleNamespace(
|
||||||
|
provider=ServiceProviders.DEEPGRAM.value,
|
||||||
|
api_key="test-key",
|
||||||
|
model="flux-general-multi",
|
||||||
|
language="multi",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
audio_config = AudioConfig(
|
||||||
|
transport_in_sample_rate=16000,
|
||||||
|
transport_out_sample_rate=16000,
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"api.services.pipecat.service_factory.DeepgramFluxSTTService"
|
||||||
|
) as mock_service:
|
||||||
|
create_stt_service(user_config, audio_config)
|
||||||
|
|
||||||
|
kwargs = mock_service.call_args.kwargs
|
||||||
|
assert kwargs["settings"].model == "flux-general-multi"
|
||||||
|
assert kwargs["settings"].language_hints is NOT_GIVEN
|
||||||
|
|
@ -126,6 +126,13 @@ class TestSarvamTTSServiceFactory:
|
||||||
assert config.language == "hi-IN"
|
assert config.language == "hi-IN"
|
||||||
assert config.speed == 1.0
|
assert config.speed == 1.0
|
||||||
|
|
||||||
|
def test_sarvam_tts_voice_schema_allows_custom_model_specific_options(self):
|
||||||
|
voice_schema = SarvamTTSConfiguration.model_json_schema()["properties"]["voice"]
|
||||||
|
|
||||||
|
assert voice_schema["allow_custom_input"] is True
|
||||||
|
assert "bulbul:v2" in voice_schema["model_options"]
|
||||||
|
assert "bulbul:v3" in voice_schema["model_options"]
|
||||||
|
|
||||||
def test_create_sarvam_tts_service_maps_speed_to_pace(self):
|
def test_create_sarvam_tts_service_maps_speed_to_pace(self):
|
||||||
user_config = SimpleNamespace(
|
user_config = SimpleNamespace(
|
||||||
tts=SimpleNamespace(
|
tts=SimpleNamespace(
|
||||||
|
|
@ -152,3 +159,49 @@ class TestSarvamTTSServiceFactory:
|
||||||
assert kwargs["settings"].voice == "anushka"
|
assert kwargs["settings"].voice == "anushka"
|
||||||
assert kwargs["settings"].language == Language.HI
|
assert kwargs["settings"].language == Language.HI
|
||||||
assert kwargs["settings"].pace == 1.25
|
assert kwargs["settings"].pace == 1.25
|
||||||
|
|
||||||
|
def test_create_sarvam_tts_service_normalizes_custom_voice_id(self):
|
||||||
|
user_config = SimpleNamespace(
|
||||||
|
tts=SimpleNamespace(
|
||||||
|
provider=ServiceProviders.SARVAM.value,
|
||||||
|
api_key="test-key",
|
||||||
|
model="bulbul:v2",
|
||||||
|
voice=" Rehan ",
|
||||||
|
language="hi-IN",
|
||||||
|
speed=1.0,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
audio_config = AudioConfig(
|
||||||
|
transport_in_sample_rate=16000, transport_out_sample_rate=16000
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"api.services.pipecat.service_factory.SarvamTTSService"
|
||||||
|
) as mock_service:
|
||||||
|
create_tts_service(user_config, audio_config)
|
||||||
|
|
||||||
|
kwargs = mock_service.call_args.kwargs
|
||||||
|
assert kwargs["settings"].voice == "rehan"
|
||||||
|
|
||||||
|
def test_create_sarvam_tts_service_defaults_blank_voice_id(self):
|
||||||
|
user_config = SimpleNamespace(
|
||||||
|
tts=SimpleNamespace(
|
||||||
|
provider=ServiceProviders.SARVAM.value,
|
||||||
|
api_key="test-key",
|
||||||
|
model="bulbul:v2",
|
||||||
|
voice=" ",
|
||||||
|
language="hi-IN",
|
||||||
|
speed=1.0,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
audio_config = AudioConfig(
|
||||||
|
transport_in_sample_rate=16000, transport_out_sample_rate=16000
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"api.services.pipecat.service_factory.SarvamTTSService"
|
||||||
|
) as mock_service:
|
||||||
|
create_tts_service(user_config, audio_config)
|
||||||
|
|
||||||
|
kwargs = mock_service.call_args.kwargs
|
||||||
|
assert kwargs["settings"].voice == "anushka"
|
||||||
|
|
|
||||||
|
|
@ -130,6 +130,19 @@ function getGlobalSummary(
|
||||||
return model ? `${providerLabel} / ${model}` : providerLabel || provider;
|
return model ? `${providerLabel} / ${model}` : providerLabel || provider;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function getSchemaDropdownOptions(
|
||||||
|
schema: SchemaProperty | undefined,
|
||||||
|
modelValue?: string,
|
||||||
|
): string[] | undefined {
|
||||||
|
let dropdownOptions = schema?.enum || schema?.examples;
|
||||||
|
|
||||||
|
if (schema?.model_options && modelValue && schema.model_options[modelValue]) {
|
||||||
|
dropdownOptions = schema.model_options[modelValue];
|
||||||
|
}
|
||||||
|
|
||||||
|
return dropdownOptions;
|
||||||
|
}
|
||||||
|
|
||||||
export function ServiceConfigurationForm({
|
export function ServiceConfigurationForm({
|
||||||
mode,
|
mode,
|
||||||
currentOverrides,
|
currentOverrides,
|
||||||
|
|
@ -344,10 +357,12 @@ export function ServiceConfigurationForm({
|
||||||
? providerSchema.$defs[(schema as SchemaProperty).$ref!.split('/').pop() || '']
|
? providerSchema.$defs[(schema as SchemaProperty).$ref!.split('/').pop() || '']
|
||||||
: schema as SchemaProperty;
|
: schema as SchemaProperty;
|
||||||
|
|
||||||
if (!actualSchema?.allow_custom_input || !actualSchema?.examples) return;
|
if (!actualSchema?.allow_custom_input) return;
|
||||||
|
|
||||||
const savedValue = src?.[field] as string | undefined;
|
const savedValue = src?.[field] as string | undefined;
|
||||||
if (savedValue && !actualSchema.examples.includes(savedValue)) {
|
const modelValue = src?.model as string | undefined;
|
||||||
|
const dropdownOptions = getSchemaDropdownOptions(actualSchema, modelValue);
|
||||||
|
if (savedValue && dropdownOptions && !dropdownOptions.includes(savedValue)) {
|
||||||
detectedCustomInput[`${service}_${field}`] = true;
|
detectedCustomInput[`${service}_${field}`] = true;
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
@ -381,10 +396,11 @@ export function ServiceConfigurationForm({
|
||||||
|
|
||||||
const validVoices = modelOptions[ttsModel as string];
|
const validVoices = modelOptions[ttsModel as string];
|
||||||
const currentVoice = getValues("tts_voice") as string;
|
const currentVoice = getValues("tts_voice") as string;
|
||||||
if (validVoices && currentVoice && !validVoices.includes(currentVoice)) {
|
const isCustomVoice = !!isCustomInput.tts_voice;
|
||||||
|
if (validVoices && currentVoice && !validVoices.includes(currentVoice) && !isCustomVoice) {
|
||||||
setValue("tts_voice", validVoices[0], { shouldDirty: true });
|
setValue("tts_voice", validVoices[0], { shouldDirty: true });
|
||||||
}
|
}
|
||||||
}, [ttsModel, serviceProviders.tts, setValue, getValues, schemas]);
|
}, [ttsModel, serviceProviders.tts, setValue, getValues, schemas, isCustomInput.tts_voice]);
|
||||||
|
|
||||||
// Reset language when STT model changes if the provider has model-dependent language options
|
// Reset language when STT model changes if the provider has model-dependent language options
|
||||||
const sttModel = watch("stt_model");
|
const sttModel = watch("stt_model");
|
||||||
|
|
@ -676,10 +692,13 @@ export function ServiceConfigurationForm({
|
||||||
const actualSchema = schema.$ref && providerSchema.$defs
|
const actualSchema = schema.$ref && providerSchema.$defs
|
||||||
? providerSchema.$defs[schema.$ref.split('/').pop() || '']
|
? providerSchema.$defs[schema.$ref.split('/').pop() || '']
|
||||||
: schema;
|
: schema;
|
||||||
|
const dropdownOptions = getSchemaDropdownOptions(
|
||||||
|
actualSchema,
|
||||||
|
watch(`${service}_model`) as string | undefined,
|
||||||
|
);
|
||||||
|
|
||||||
if (service === "tts" && field === "voice" && !actualSchema?.allow_custom_input) {
|
if (service === "tts" && field === "voice" && !actualSchema?.allow_custom_input) {
|
||||||
const hasVoiceOptions = actualSchema?.enum || actualSchema?.examples;
|
if (!dropdownOptions) {
|
||||||
if (!hasVoiceOptions) {
|
|
||||||
return (
|
return (
|
||||||
<VoiceSelector
|
<VoiceSelector
|
||||||
provider={serviceProviders.tts}
|
provider={serviceProviders.tts}
|
||||||
|
|
@ -693,10 +712,10 @@ export function ServiceConfigurationForm({
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (actualSchema?.allow_custom_input && actualSchema?.examples) {
|
if (actualSchema?.allow_custom_input && dropdownOptions && dropdownOptions.length > 0) {
|
||||||
const fieldKey = `${service}_${field}`;
|
const fieldKey = `${service}_${field}`;
|
||||||
const currentValue = watch(fieldKey) as string || "";
|
const currentValue = watch(fieldKey) as string || "";
|
||||||
const options = actualSchema.examples;
|
const options = dropdownOptions;
|
||||||
|
|
||||||
if (isCustomInput[fieldKey]) {
|
if (isCustomInput[fieldKey]) {
|
||||||
return (
|
return (
|
||||||
|
|
@ -764,15 +783,6 @@ export function ServiceConfigurationForm({
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
let dropdownOptions = actualSchema?.enum || actualSchema?.examples;
|
|
||||||
|
|
||||||
if (actualSchema?.model_options) {
|
|
||||||
const modelValue = watch(`${service}_model`) as string;
|
|
||||||
if (modelValue && actualSchema.model_options[modelValue]) {
|
|
||||||
dropdownOptions = actualSchema.model_options[modelValue];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (dropdownOptions && dropdownOptions.length > 0) {
|
if (dropdownOptions && dropdownOptions.length > 0) {
|
||||||
const getDisplayName = (value: string) => {
|
const getDisplayName = (value: string) => {
|
||||||
if (field === "language") {
|
if (field === "language") {
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue