mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-07 07:55:16 +02:00
feat: add Rime TTS
This commit is contained in:
parent
78e4abc796
commit
e255b33813
9 changed files with 79 additions and 8 deletions
|
|
@ -308,7 +308,7 @@ async def reactivate_api_key(
|
|||
|
||||
|
||||
# Voice Configuration Endpoints
|
||||
TTSProvider = Literal["elevenlabs", "deepgram", "sarvam", "cartesia", "dograh"]
|
||||
TTSProvider = Literal["elevenlabs", "deepgram", "sarvam", "cartesia", "dograh", "rime"]
|
||||
|
||||
|
||||
class VoiceInfo(BaseModel):
|
||||
|
|
@ -329,12 +329,16 @@ class VoicesResponse(BaseModel):
|
|||
@router.get("/configurations/voices/{provider}")
|
||||
async def get_voices(
|
||||
provider: TTSProvider,
|
||||
model: Optional[str] = None,
|
||||
language: Optional[str] = None,
|
||||
user: UserModel = Depends(get_user),
|
||||
) -> VoicesResponse:
|
||||
"""Get available voices for a TTS provider."""
|
||||
try:
|
||||
result = await mps_service_key_client.get_voices(
|
||||
provider=provider,
|
||||
model=model,
|
||||
language=language,
|
||||
organization_id=user.selected_organization_id,
|
||||
created_by=user.provider_id,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -51,6 +51,7 @@ class UserConfigurationValidator:
|
|||
ServiceProviders.GOOGLE_REALTIME.value: self._check_google_api_key,
|
||||
ServiceProviders.ASSEMBLYAI.value: self._check_assemblyai_api_key,
|
||||
ServiceProviders.GLADIA.value: self._check_gladia_api_key,
|
||||
ServiceProviders.RIME.value: self._check_rime_api_key,
|
||||
}
|
||||
|
||||
async def validate(
|
||||
|
|
@ -225,3 +226,6 @@ class UserConfigurationValidator:
|
|||
|
||||
def _check_gladia_api_key(self, model: str, api_key: str) -> bool:
|
||||
return True
|
||||
|
||||
def _check_rime_api_key(self, model: str, api_key: str) -> bool:
|
||||
return True
|
||||
|
|
|
|||
|
|
@ -31,6 +31,7 @@ class ServiceProviders(str, Enum):
|
|||
SPEACHES = "speaches"
|
||||
ASSEMBLYAI = "assemblyai"
|
||||
GLADIA = "gladia"
|
||||
RIME = "rime"
|
||||
OPENAI_REALTIME = "openai_realtime"
|
||||
GOOGLE_REALTIME = "google_realtime"
|
||||
|
||||
|
|
@ -49,6 +50,7 @@ class BaseServiceConfiguration(BaseModel):
|
|||
ServiceProviders.SPEACHES,
|
||||
ServiceProviders.ASSEMBLYAI,
|
||||
ServiceProviders.GLADIA,
|
||||
ServiceProviders.RIME,
|
||||
ServiceProviders.OPENAI_REALTIME,
|
||||
ServiceProviders.GOOGLE_REALTIME,
|
||||
# ServiceProviders.SARVAM,
|
||||
|
|
@ -588,6 +590,25 @@ class CambTTSConfiguration(BaseTTSConfiguration):
|
|||
language: str = Field(default="en-us", description="BCP-47 language code")
|
||||
|
||||
|
||||
RIME_TTS_MODELS = ["arcana", "mistv3", "mistv2", "mist"]
|
||||
|
||||
|
||||
@register_tts
|
||||
class RimeTTSConfiguration(BaseTTSConfiguration):
|
||||
provider: Literal[ServiceProviders.RIME] = ServiceProviders.RIME
|
||||
model: str = Field(
|
||||
default="arcana",
|
||||
json_schema_extra={"examples": RIME_TTS_MODELS, "allow_custom_input": True},
|
||||
)
|
||||
voice: str = Field(
|
||||
default="celeste",
|
||||
description="Rime voice ID",
|
||||
)
|
||||
speed: float = Field(
|
||||
default=1.0, ge=0.5, le=2.0, description="Speech speed multiplier"
|
||||
)
|
||||
|
||||
|
||||
SPEACHES_TTS_MODELS = ["hexgrad/Kokoro-82M"]
|
||||
|
||||
|
||||
|
|
@ -625,6 +646,7 @@ TTSConfig = Annotated[
|
|||
DograhTTSService,
|
||||
SarvamTTSConfiguration,
|
||||
CambTTSConfiguration,
|
||||
RimeTTSConfiguration,
|
||||
SpeachesTTSConfiguration,
|
||||
],
|
||||
Field(discriminator="provider"),
|
||||
|
|
|
|||
|
|
@ -442,6 +442,8 @@ class MPSServiceKeyClient:
|
|||
async def get_voices(
|
||||
self,
|
||||
provider: str,
|
||||
model: Optional[str] = None,
|
||||
language: Optional[str] = None,
|
||||
organization_id: Optional[int] = None,
|
||||
created_by: Optional[str] = None,
|
||||
) -> dict:
|
||||
|
|
@ -449,7 +451,9 @@ class MPSServiceKeyClient:
|
|||
Get available voices for a TTS provider from MPS.
|
||||
|
||||
Args:
|
||||
provider: TTS provider name (elevenlabs, deepgram, sarvam, cartesia)
|
||||
provider: TTS provider name (elevenlabs, deepgram, sarvam, cartesia, rime)
|
||||
model: Optional model ID to filter voices (e.g., "arcana", "mistv2")
|
||||
language: Optional language code to filter voices (e.g., "eng", "en")
|
||||
organization_id: Organization ID (for authenticated mode)
|
||||
created_by: User provider ID (for OSS mode)
|
||||
|
||||
|
|
@ -460,9 +464,15 @@ class MPSServiceKeyClient:
|
|||
HTTPException: If the API call fails
|
||||
"""
|
||||
async with httpx.AsyncClient(timeout=self.timeout) as client:
|
||||
params = {}
|
||||
if model:
|
||||
params["model"] = model
|
||||
if language:
|
||||
params["language"] = language
|
||||
response = await client.get(
|
||||
f"{self.base_url}/api/v1/voice-proxy/{provider}/voices",
|
||||
headers=self._get_headers(organization_id, created_by),
|
||||
params=params,
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
|
|
|
|||
|
|
@ -35,6 +35,7 @@ from pipecat.services.openai.stt import (
|
|||
)
|
||||
from pipecat.services.openai.tts import OpenAITTSService, OpenAITTSSettings
|
||||
from pipecat.services.openrouter.llm import OpenRouterLLMService, OpenRouterLLMSettings
|
||||
from pipecat.services.rime.tts import RimeTTSService, RimeTTSSettings
|
||||
from pipecat.services.sarvam.stt import SarvamSTTService, SarvamSTTSettings
|
||||
from pipecat.services.sarvam.tts import SarvamTTSService, SarvamTTSSettings
|
||||
from pipecat.services.speaches.llm import SpeachesLLMService, SpeachesLLMSettings
|
||||
|
|
@ -323,6 +324,21 @@ def create_tts_service(user_config, audio_config: "AudioConfig"):
|
|||
skip_aggregator_types=["recording_router"],
|
||||
silence_time_s=1.0,
|
||||
)
|
||||
elif user_config.tts.provider == ServiceProviders.RIME.value:
|
||||
speed = getattr(user_config.tts, "speed", None)
|
||||
settings_kwargs = {
|
||||
"voice": user_config.tts.voice,
|
||||
"model": user_config.tts.model,
|
||||
}
|
||||
if speed and speed != 1.0:
|
||||
settings_kwargs["speedAlpha"] = speed
|
||||
return RimeTTSService(
|
||||
api_key=user_config.tts.api_key,
|
||||
settings=RimeTTSSettings(**settings_kwargs),
|
||||
text_filters=[xml_function_tag_filter],
|
||||
skip_aggregator_types=["recording_router"],
|
||||
silence_time_s=1.0,
|
||||
)
|
||||
elif user_config.tts.provider == ServiceProviders.SARVAM.value:
|
||||
# Map Sarvam language code to pipecat Language enum for TTS
|
||||
language_mapping = {
|
||||
|
|
|
|||
|
|
@ -17,7 +17,9 @@ def test_create_speaches_stt_service_uses_http_base_url():
|
|||
)
|
||||
audio_config = SimpleNamespace(transport_in_sample_rate=16000)
|
||||
|
||||
with patch("api.services.pipecat.service_factory.SpeachesSTTService") as mock_service:
|
||||
with patch(
|
||||
"api.services.pipecat.service_factory.SpeachesSTTService"
|
||||
) as mock_service:
|
||||
create_stt_service(user_config, audio_config)
|
||||
|
||||
assert mock_service.call_count == 1
|
||||
|
|
|
|||
|
|
@ -2943,9 +2943,12 @@ export type GetVoicesApiV1UserConfigurationsVoicesProviderGetData = {
|
|||
'X-API-Key'?: string | null;
|
||||
};
|
||||
path: {
|
||||
provider: 'elevenlabs' | 'deepgram' | 'sarvam' | 'cartesia' | 'dograh';
|
||||
provider: 'elevenlabs' | 'deepgram' | 'sarvam' | 'cartesia' | 'dograh' | 'rime';
|
||||
};
|
||||
query?: {
|
||||
model?: string | null;
|
||||
language?: string | null;
|
||||
};
|
||||
query?: never;
|
||||
url: '/api/v1/user/configurations/voices/{provider}';
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -502,6 +502,7 @@ export default function ServiceConfiguration() {
|
|||
onChange={(voiceId) => {
|
||||
setValue(`${service}_${field}`, voiceId, { shouldDirty: true });
|
||||
}}
|
||||
model={watch("tts_model") as string || undefined}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -13,13 +13,15 @@ import { Popover, PopoverContent, PopoverTrigger } from "@/components/ui/popover
|
|||
import { cn } from "@/lib/utils";
|
||||
|
||||
// Providers that have MPS voice endpoints
|
||||
type TTSProviderWithVoices = "elevenlabs" | "deepgram" | "sarvam" | "cartesia" | "dograh";
|
||||
const MPS_VOICE_PROVIDERS: TTSProviderWithVoices[] = ["elevenlabs", "deepgram", "sarvam", "cartesia", "dograh"];
|
||||
type TTSProviderWithVoices = "elevenlabs" | "deepgram" | "sarvam" | "cartesia" | "dograh" | "rime";
|
||||
const MPS_VOICE_PROVIDERS: TTSProviderWithVoices[] = ["elevenlabs", "deepgram", "sarvam", "cartesia", "dograh", "rime"];
|
||||
|
||||
interface VoiceSelectorProps {
|
||||
provider: string;
|
||||
value: string;
|
||||
onChange: (voiceId: string) => void;
|
||||
model?: string;
|
||||
language?: string;
|
||||
className?: string;
|
||||
}
|
||||
|
||||
|
|
@ -27,6 +29,8 @@ export const VoiceSelector: React.FC<VoiceSelectorProps> = ({
|
|||
provider,
|
||||
value,
|
||||
onChange,
|
||||
model,
|
||||
language,
|
||||
className,
|
||||
}) => {
|
||||
const [isOpen, setIsOpen] = useState(false);
|
||||
|
|
@ -52,6 +56,7 @@ export const VoiceSelector: React.FC<VoiceSelectorProps> = ({
|
|||
sarvam: "sarvam",
|
||||
cartesia: "cartesia",
|
||||
dograh: "dograh",
|
||||
rime: "rime",
|
||||
};
|
||||
return providerMap[providerName.toLowerCase()] || null;
|
||||
}, []);
|
||||
|
|
@ -67,8 +72,12 @@ export const VoiceSelector: React.FC<VoiceSelectorProps> = ({
|
|||
setError(null);
|
||||
|
||||
try {
|
||||
const query: { model?: string; language?: string } = {};
|
||||
if (model) query.model = model;
|
||||
if (language) query.language = language;
|
||||
const response = await getVoicesApiV1UserConfigurationsVoicesProviderGet({
|
||||
path: { provider: providerKey },
|
||||
query: Object.keys(query).length > 0 ? query : undefined,
|
||||
});
|
||||
|
||||
if (response.data?.voices) {
|
||||
|
|
@ -81,7 +90,7 @@ export const VoiceSelector: React.FC<VoiceSelectorProps> = ({
|
|||
} finally {
|
||||
setIsLoading(false);
|
||||
}
|
||||
}, [provider, getProviderKey]);
|
||||
}, [provider, model, language, getProviderKey]);
|
||||
|
||||
useEffect(() => {
|
||||
if (provider) {
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue