feat: add Rime TTS

This commit is contained in:
Abhishek Kumar 2026-04-07 14:05:47 +05:30
parent 78e4abc796
commit e255b33813
9 changed files with 79 additions and 8 deletions

View file

@ -308,7 +308,7 @@ async def reactivate_api_key(
# Voice Configuration Endpoints
TTSProvider = Literal["elevenlabs", "deepgram", "sarvam", "cartesia", "dograh"]
TTSProvider = Literal["elevenlabs", "deepgram", "sarvam", "cartesia", "dograh", "rime"]
class VoiceInfo(BaseModel):
@ -329,12 +329,16 @@ class VoicesResponse(BaseModel):
@router.get("/configurations/voices/{provider}")
async def get_voices(
provider: TTSProvider,
model: Optional[str] = None,
language: Optional[str] = None,
user: UserModel = Depends(get_user),
) -> VoicesResponse:
"""Get available voices for a TTS provider."""
try:
result = await mps_service_key_client.get_voices(
provider=provider,
model=model,
language=language,
organization_id=user.selected_organization_id,
created_by=user.provider_id,
)

View file

@ -51,6 +51,7 @@ class UserConfigurationValidator:
ServiceProviders.GOOGLE_REALTIME.value: self._check_google_api_key,
ServiceProviders.ASSEMBLYAI.value: self._check_assemblyai_api_key,
ServiceProviders.GLADIA.value: self._check_gladia_api_key,
ServiceProviders.RIME.value: self._check_rime_api_key,
}
async def validate(
@ -225,3 +226,6 @@ class UserConfigurationValidator:
def _check_gladia_api_key(self, model: str, api_key: str) -> bool:
return True
def _check_rime_api_key(self, model: str, api_key: str) -> bool:
return True

View file

@ -31,6 +31,7 @@ class ServiceProviders(str, Enum):
SPEACHES = "speaches"
ASSEMBLYAI = "assemblyai"
GLADIA = "gladia"
RIME = "rime"
OPENAI_REALTIME = "openai_realtime"
GOOGLE_REALTIME = "google_realtime"
@ -49,6 +50,7 @@ class BaseServiceConfiguration(BaseModel):
ServiceProviders.SPEACHES,
ServiceProviders.ASSEMBLYAI,
ServiceProviders.GLADIA,
ServiceProviders.RIME,
ServiceProviders.OPENAI_REALTIME,
ServiceProviders.GOOGLE_REALTIME,
# ServiceProviders.SARVAM,
@ -588,6 +590,25 @@ class CambTTSConfiguration(BaseTTSConfiguration):
language: str = Field(default="en-us", description="BCP-47 language code")
RIME_TTS_MODELS = ["arcana", "mistv3", "mistv2", "mist"]
@register_tts
class RimeTTSConfiguration(BaseTTSConfiguration):
provider: Literal[ServiceProviders.RIME] = ServiceProviders.RIME
model: str = Field(
default="arcana",
json_schema_extra={"examples": RIME_TTS_MODELS, "allow_custom_input": True},
)
voice: str = Field(
default="celeste",
description="Rime voice ID",
)
speed: float = Field(
default=1.0, ge=0.5, le=2.0, description="Speech speed multiplier"
)
SPEACHES_TTS_MODELS = ["hexgrad/Kokoro-82M"]
@ -625,6 +646,7 @@ TTSConfig = Annotated[
DograhTTSService,
SarvamTTSConfiguration,
CambTTSConfiguration,
RimeTTSConfiguration,
SpeachesTTSConfiguration,
],
Field(discriminator="provider"),

View file

@ -442,6 +442,8 @@ class MPSServiceKeyClient:
async def get_voices(
self,
provider: str,
model: Optional[str] = None,
language: Optional[str] = None,
organization_id: Optional[int] = None,
created_by: Optional[str] = None,
) -> dict:
@ -449,7 +451,9 @@ class MPSServiceKeyClient:
Get available voices for a TTS provider from MPS.
Args:
provider: TTS provider name (elevenlabs, deepgram, sarvam, cartesia)
provider: TTS provider name (elevenlabs, deepgram, sarvam, cartesia, rime)
model: Optional model ID to filter voices (e.g., "arcana", "mistv2")
language: Optional language code to filter voices (e.g., "eng", "en")
organization_id: Organization ID (for authenticated mode)
created_by: User provider ID (for OSS mode)
@ -460,9 +464,15 @@ class MPSServiceKeyClient:
HTTPException: If the API call fails
"""
async with httpx.AsyncClient(timeout=self.timeout) as client:
params = {}
if model:
params["model"] = model
if language:
params["language"] = language
response = await client.get(
f"{self.base_url}/api/v1/voice-proxy/{provider}/voices",
headers=self._get_headers(organization_id, created_by),
params=params,
)
if response.status_code == 200:

View file

@ -35,6 +35,7 @@ from pipecat.services.openai.stt import (
)
from pipecat.services.openai.tts import OpenAITTSService, OpenAITTSSettings
from pipecat.services.openrouter.llm import OpenRouterLLMService, OpenRouterLLMSettings
from pipecat.services.rime.tts import RimeTTSService, RimeTTSSettings
from pipecat.services.sarvam.stt import SarvamSTTService, SarvamSTTSettings
from pipecat.services.sarvam.tts import SarvamTTSService, SarvamTTSSettings
from pipecat.services.speaches.llm import SpeachesLLMService, SpeachesLLMSettings
@ -323,6 +324,21 @@ def create_tts_service(user_config, audio_config: "AudioConfig"):
skip_aggregator_types=["recording_router"],
silence_time_s=1.0,
)
elif user_config.tts.provider == ServiceProviders.RIME.value:
speed = getattr(user_config.tts, "speed", None)
settings_kwargs = {
"voice": user_config.tts.voice,
"model": user_config.tts.model,
}
if speed and speed != 1.0:
settings_kwargs["speedAlpha"] = speed
return RimeTTSService(
api_key=user_config.tts.api_key,
settings=RimeTTSSettings(**settings_kwargs),
text_filters=[xml_function_tag_filter],
skip_aggregator_types=["recording_router"],
silence_time_s=1.0,
)
elif user_config.tts.provider == ServiceProviders.SARVAM.value:
# Map Sarvam language code to pipecat Language enum for TTS
language_mapping = {

View file

@ -17,7 +17,9 @@ def test_create_speaches_stt_service_uses_http_base_url():
)
audio_config = SimpleNamespace(transport_in_sample_rate=16000)
with patch("api.services.pipecat.service_factory.SpeachesSTTService") as mock_service:
with patch(
"api.services.pipecat.service_factory.SpeachesSTTService"
) as mock_service:
create_stt_service(user_config, audio_config)
assert mock_service.call_count == 1