mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-19 08:28:10 +02:00
feat: add custom sarvam tts voice (#449)
* feat: add custom sarvam tts voice * chore: refactor registry and add deepgram multi --------- Co-authored-by: Abhishek Kumar <abhishek@a6k.me>
This commit is contained in:
parent
344c8220de
commit
951e73a645
9 changed files with 268 additions and 69 deletions
70
api/tests/test_deepgram_flux_service_factory.py
Normal file
70
api/tests/test_deepgram_flux_service_factory.py
Normal file
|
|
@ -0,0 +1,70 @@
|
|||
from types import SimpleNamespace
|
||||
from unittest.mock import patch
|
||||
|
||||
from pipecat.services.settings import NOT_GIVEN
|
||||
from pipecat.transcriptions.language import Language
|
||||
|
||||
from api.services.configuration.registry import (
|
||||
DeepgramSTTConfiguration,
|
||||
ServiceProviders,
|
||||
)
|
||||
from api.services.pipecat.audio_config import AudioConfig
|
||||
from api.services.pipecat.service_factory import create_stt_service
|
||||
|
||||
|
||||
def test_deepgram_stt_schema_includes_flux_multilingual_language_options():
|
||||
language_schema = DeepgramSTTConfiguration.model_json_schema()["properties"][
|
||||
"language"
|
||||
]
|
||||
|
||||
assert "flux-general-multi" in language_schema["model_options"]
|
||||
assert "multi" in language_schema["model_options"]["flux-general-multi"]
|
||||
assert "es" in language_schema["model_options"]["flux-general-multi"]
|
||||
|
||||
|
||||
def test_create_deepgram_flux_multi_uses_flux_service_with_language_hint():
|
||||
user_config = SimpleNamespace(
|
||||
stt=SimpleNamespace(
|
||||
provider=ServiceProviders.DEEPGRAM.value,
|
||||
api_key="test-key",
|
||||
model="flux-general-multi",
|
||||
language="es",
|
||||
)
|
||||
)
|
||||
audio_config = AudioConfig(
|
||||
transport_in_sample_rate=16000,
|
||||
transport_out_sample_rate=16000,
|
||||
)
|
||||
|
||||
with patch(
|
||||
"api.services.pipecat.service_factory.DeepgramFluxSTTService"
|
||||
) as mock_service:
|
||||
create_stt_service(user_config, audio_config)
|
||||
|
||||
kwargs = mock_service.call_args.kwargs
|
||||
assert kwargs["settings"].model == "flux-general-multi"
|
||||
assert kwargs["settings"].language_hints == [Language.ES]
|
||||
|
||||
|
||||
def test_create_deepgram_flux_multi_omits_auto_detect_language_hint():
|
||||
user_config = SimpleNamespace(
|
||||
stt=SimpleNamespace(
|
||||
provider=ServiceProviders.DEEPGRAM.value,
|
||||
api_key="test-key",
|
||||
model="flux-general-multi",
|
||||
language="multi",
|
||||
)
|
||||
)
|
||||
audio_config = AudioConfig(
|
||||
transport_in_sample_rate=16000,
|
||||
transport_out_sample_rate=16000,
|
||||
)
|
||||
|
||||
with patch(
|
||||
"api.services.pipecat.service_factory.DeepgramFluxSTTService"
|
||||
) as mock_service:
|
||||
create_stt_service(user_config, audio_config)
|
||||
|
||||
kwargs = mock_service.call_args.kwargs
|
||||
assert kwargs["settings"].model == "flux-general-multi"
|
||||
assert kwargs["settings"].language_hints is NOT_GIVEN
|
||||
|
|
@ -126,6 +126,13 @@ class TestSarvamTTSServiceFactory:
|
|||
assert config.language == "hi-IN"
|
||||
assert config.speed == 1.0
|
||||
|
||||
def test_sarvam_tts_voice_schema_allows_custom_model_specific_options(self):
|
||||
voice_schema = SarvamTTSConfiguration.model_json_schema()["properties"]["voice"]
|
||||
|
||||
assert voice_schema["allow_custom_input"] is True
|
||||
assert "bulbul:v2" in voice_schema["model_options"]
|
||||
assert "bulbul:v3" in voice_schema["model_options"]
|
||||
|
||||
def test_create_sarvam_tts_service_maps_speed_to_pace(self):
|
||||
user_config = SimpleNamespace(
|
||||
tts=SimpleNamespace(
|
||||
|
|
@ -152,3 +159,49 @@ class TestSarvamTTSServiceFactory:
|
|||
assert kwargs["settings"].voice == "anushka"
|
||||
assert kwargs["settings"].language == Language.HI
|
||||
assert kwargs["settings"].pace == 1.25
|
||||
|
||||
def test_create_sarvam_tts_service_normalizes_custom_voice_id(self):
|
||||
user_config = SimpleNamespace(
|
||||
tts=SimpleNamespace(
|
||||
provider=ServiceProviders.SARVAM.value,
|
||||
api_key="test-key",
|
||||
model="bulbul:v2",
|
||||
voice=" Rehan ",
|
||||
language="hi-IN",
|
||||
speed=1.0,
|
||||
)
|
||||
)
|
||||
audio_config = AudioConfig(
|
||||
transport_in_sample_rate=16000, transport_out_sample_rate=16000
|
||||
)
|
||||
|
||||
with patch(
|
||||
"api.services.pipecat.service_factory.SarvamTTSService"
|
||||
) as mock_service:
|
||||
create_tts_service(user_config, audio_config)
|
||||
|
||||
kwargs = mock_service.call_args.kwargs
|
||||
assert kwargs["settings"].voice == "rehan"
|
||||
|
||||
def test_create_sarvam_tts_service_defaults_blank_voice_id(self):
|
||||
user_config = SimpleNamespace(
|
||||
tts=SimpleNamespace(
|
||||
provider=ServiceProviders.SARVAM.value,
|
||||
api_key="test-key",
|
||||
model="bulbul:v2",
|
||||
voice=" ",
|
||||
language="hi-IN",
|
||||
speed=1.0,
|
||||
)
|
||||
)
|
||||
audio_config = AudioConfig(
|
||||
transport_in_sample_rate=16000, transport_out_sample_rate=16000
|
||||
)
|
||||
|
||||
with patch(
|
||||
"api.services.pipecat.service_factory.SarvamTTSService"
|
||||
) as mock_service:
|
||||
create_tts_service(user_config, audio_config)
|
||||
|
||||
kwargs = mock_service.call_args.kwargs
|
||||
assert kwargs["settings"].voice == "anushka"
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue