mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-25 08:48:13 +02:00
fix: change type definition from enum to str for consistency
This commit is contained in:
parent
74b069354b
commit
e83f3a36d2
4 changed files with 147 additions and 150 deletions
|
|
@ -30,12 +30,9 @@ def create_stt_service(user_config):
|
|||
"""Create and return appropriate STT service based on user configuration"""
|
||||
if user_config.stt.provider == ServiceProviders.DEEPGRAM.value:
|
||||
# Use language from user config, defaulting to "multi" for multilingual support
|
||||
language = getattr(user_config.stt, "language", None)
|
||||
language_value = (
|
||||
language.value if hasattr(language, "value") else (language or "multi")
|
||||
)
|
||||
language = getattr(user_config.stt, "language", None) or "multi"
|
||||
live_options = LiveOptions(
|
||||
language=language_value, profanity_filter=False, endpointing=100
|
||||
language=language, profanity_filter=False, endpointing=100
|
||||
)
|
||||
return DeepgramSTTService(
|
||||
live_options=live_options,
|
||||
|
|
@ -45,7 +42,7 @@ def create_stt_service(user_config):
|
|||
elif user_config.stt.provider == ServiceProviders.OPENAI.value:
|
||||
return OpenAISTTService(
|
||||
api_key=user_config.stt.api_key,
|
||||
model=user_config.stt.model.value,
|
||||
model=user_config.stt.model,
|
||||
audio_passthrough=False, # Disable passthrough since audio is buffered separately
|
||||
)
|
||||
elif user_config.stt.provider == ServiceProviders.CARTESIA.value:
|
||||
|
|
@ -58,7 +55,7 @@ def create_stt_service(user_config):
|
|||
return DograhSTTService(
|
||||
base_url=base_url,
|
||||
api_key=user_config.stt.api_key,
|
||||
model=user_config.stt.model.value,
|
||||
model=user_config.stt.model,
|
||||
audio_passthrough=False, # Disable passthrough since audio is buffered separately
|
||||
)
|
||||
elif user_config.stt.provider == ServiceProviders.SARVAM.value:
|
||||
|
|
@ -78,12 +75,10 @@ def create_stt_service(user_config):
|
|||
"as-IN": Language.AS_IN,
|
||||
}
|
||||
language = getattr(user_config.stt, "language", None)
|
||||
language_value = language.value if hasattr(language, "value") else language
|
||||
pipecat_language = language_mapping.get(language_value, Language.HI_IN)
|
||||
|
||||
pipecat_language = language_mapping.get(language, Language.HI_IN)
|
||||
return SarvamSTTService(
|
||||
api_key=user_config.stt.api_key,
|
||||
model=user_config.stt.model.value,
|
||||
model=user_config.stt.model,
|
||||
params=SarvamSTTService.InputParams(language=pipecat_language),
|
||||
audio_passthrough=False,
|
||||
)
|
||||
|
|
@ -105,13 +100,13 @@ def create_tts_service(user_config, audio_config: "AudioConfig"):
|
|||
if user_config.tts.provider == ServiceProviders.DEEPGRAM.value:
|
||||
return DeepgramTTSService(
|
||||
api_key=user_config.tts.api_key,
|
||||
voice=user_config.tts.voice.value,
|
||||
voice=user_config.tts.voice,
|
||||
text_filters=[xml_function_tag_filter],
|
||||
)
|
||||
elif user_config.tts.provider == ServiceProviders.OPENAI.value:
|
||||
return OpenAITTSService(
|
||||
api_key=user_config.tts.api_key,
|
||||
model=user_config.tts.model.value,
|
||||
model=user_config.tts.model,
|
||||
text_filters=[xml_function_tag_filter],
|
||||
)
|
||||
elif user_config.tts.provider == ServiceProviders.ELEVENLABS.value:
|
||||
|
|
@ -120,12 +115,11 @@ def create_tts_service(user_config, audio_config: "AudioConfig"):
|
|||
voice_id = user_config.tts.voice.split(" - ")[1]
|
||||
except IndexError:
|
||||
voice_id = user_config.tts.voice
|
||||
|
||||
return ElevenLabsTTSService(
|
||||
reconnect_on_error=False,
|
||||
api_key=user_config.tts.api_key,
|
||||
voice_id=voice_id,
|
||||
model=user_config.tts.model.value,
|
||||
model=user_config.tts.model,
|
||||
params=ElevenLabsTTSService.InputParams(
|
||||
stability=0.8, speed=user_config.tts.speed, similarity_boost=0.75
|
||||
),
|
||||
|
|
@ -134,12 +128,11 @@ def create_tts_service(user_config, audio_config: "AudioConfig"):
|
|||
elif user_config.tts.provider == ServiceProviders.DOGRAH.value:
|
||||
# Convert HTTP URL to WebSocket URL for TTS
|
||||
base_url = MPS_API_URL.replace("http://", "ws://").replace("https://", "wss://")
|
||||
# Handle both enum and string values for model and voice
|
||||
return DograhTTSService(
|
||||
base_url=base_url,
|
||||
api_key=user_config.tts.api_key,
|
||||
model=user_config.tts.model.value,
|
||||
voice=user_config.tts.voice.value,
|
||||
model=user_config.tts.model,
|
||||
voice=user_config.tts.voice,
|
||||
text_filters=[xml_function_tag_filter],
|
||||
)
|
||||
elif user_config.tts.provider == ServiceProviders.SARVAM.value:
|
||||
|
|
@ -158,16 +151,13 @@ def create_tts_service(user_config, audio_config: "AudioConfig"):
|
|||
"te-IN": Language.TE,
|
||||
}
|
||||
language = getattr(user_config.tts, "language", None)
|
||||
language_value = language.value if hasattr(language, "value") else language
|
||||
pipecat_language = language_mapping.get(language_value, Language.HI)
|
||||
|
||||
voice = getattr(user_config.tts, "voice", None)
|
||||
voice_value = voice.value if hasattr(voice, "value") else (voice or "anushka")
|
||||
pipecat_language = language_mapping.get(language, Language.HI)
|
||||
|
||||
voice = getattr(user_config.tts, "voice", None) or "anushka"
|
||||
return SarvamTTSService(
|
||||
api_key=user_config.tts.api_key,
|
||||
model=user_config.tts.model.value,
|
||||
voice_id=voice_value,
|
||||
model=user_config.tts.model,
|
||||
voice_id=voice,
|
||||
params=SarvamTTSService.InputParams(language=pipecat_language),
|
||||
text_filters=[xml_function_tag_filter],
|
||||
)
|
||||
|
|
@ -179,17 +169,12 @@ def create_tts_service(user_config, audio_config: "AudioConfig"):
|
|||
|
||||
def create_llm_service(user_config):
|
||||
"""Create and return appropriate LLM service based on user configuration"""
|
||||
# Handle both enum and string values for model
|
||||
model_value = (
|
||||
user_config.llm.model.value
|
||||
if hasattr(user_config.llm.model, "value")
|
||||
else user_config.llm.model
|
||||
)
|
||||
model = user_config.llm.model
|
||||
if user_config.llm.provider == ServiceProviders.OPENAI.value:
|
||||
if "gpt-5" in model_value:
|
||||
if "gpt-5" in model:
|
||||
return OpenAILLMService(
|
||||
api_key=user_config.llm.api_key,
|
||||
model=model_value,
|
||||
model=model,
|
||||
params=OpenAILLMService.InputParams(
|
||||
reasoning_effort="minimal", verbosity="low"
|
||||
),
|
||||
|
|
@ -197,16 +182,16 @@ def create_llm_service(user_config):
|
|||
else:
|
||||
return OpenAILLMService(
|
||||
api_key=user_config.llm.api_key,
|
||||
model=model_value,
|
||||
model=model,
|
||||
params=OpenAILLMService.InputParams(temperature=0.1),
|
||||
)
|
||||
elif user_config.llm.provider == ServiceProviders.GROQ.value:
|
||||
print(
|
||||
f"Creating Groq LLM service with API key: {user_config.llm.api_key} and model: {model_value}"
|
||||
f"Creating Groq LLM service with API key: {user_config.llm.api_key} and model: {model}"
|
||||
)
|
||||
return GroqLLMService(
|
||||
api_key=user_config.llm.api_key,
|
||||
model=model_value,
|
||||
model=model,
|
||||
params=OpenAILLMService.InputParams(temperature=0.1),
|
||||
)
|
||||
elif user_config.llm.provider == ServiceProviders.GOOGLE.value:
|
||||
|
|
@ -214,21 +199,21 @@ def create_llm_service(user_config):
|
|||
# NOT_GIVEN sentinels that break Pydantic validation in GoogleLLMService.
|
||||
return GoogleLLMService(
|
||||
api_key=user_config.llm.api_key,
|
||||
model=model_value,
|
||||
model=model,
|
||||
params=GoogleLLMService.InputParams(temperature=0.1),
|
||||
)
|
||||
elif user_config.llm.provider == ServiceProviders.AZURE.value:
|
||||
return AzureLLMService(
|
||||
api_key=user_config.llm.api_key,
|
||||
endpoint=user_config.llm.endpoint,
|
||||
model=model_value, # Azure uses deployment name as model
|
||||
model=model, # Azure uses deployment name as model
|
||||
params=AzureLLMService.InputParams(temperature=0.1),
|
||||
)
|
||||
elif user_config.llm.provider == ServiceProviders.DOGRAH.value:
|
||||
return DograhLLMService(
|
||||
base_url=f"{MPS_API_URL}/api/v1/llm",
|
||||
api_key=user_config.llm.api_key,
|
||||
model=model_value,
|
||||
model=model,
|
||||
)
|
||||
else:
|
||||
raise HTTPException(status_code=400, detail="Invalid LLM provider")
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue