fix: change type definition from enum to str for consistency

This commit is contained in:
Abhishek Kumar 2025-12-26 16:00:02 +05:30
parent 74b069354b
commit e83f3a36d2
4 changed files with 147 additions and 150 deletions

View file

@ -30,12 +30,9 @@ def create_stt_service(user_config):
"""Create and return appropriate STT service based on user configuration"""
if user_config.stt.provider == ServiceProviders.DEEPGRAM.value:
# Use language from user config, defaulting to "multi" for multilingual support
language = getattr(user_config.stt, "language", None)
language_value = (
language.value if hasattr(language, "value") else (language or "multi")
)
language = getattr(user_config.stt, "language", None) or "multi"
live_options = LiveOptions(
language=language_value, profanity_filter=False, endpointing=100
language=language, profanity_filter=False, endpointing=100
)
return DeepgramSTTService(
live_options=live_options,
@ -45,7 +42,7 @@ def create_stt_service(user_config):
elif user_config.stt.provider == ServiceProviders.OPENAI.value:
return OpenAISTTService(
api_key=user_config.stt.api_key,
model=user_config.stt.model.value,
model=user_config.stt.model,
audio_passthrough=False, # Disable passthrough since audio is buffered separately
)
elif user_config.stt.provider == ServiceProviders.CARTESIA.value:
@ -58,7 +55,7 @@ def create_stt_service(user_config):
return DograhSTTService(
base_url=base_url,
api_key=user_config.stt.api_key,
model=user_config.stt.model.value,
model=user_config.stt.model,
audio_passthrough=False, # Disable passthrough since audio is buffered separately
)
elif user_config.stt.provider == ServiceProviders.SARVAM.value:
@ -78,12 +75,10 @@ def create_stt_service(user_config):
"as-IN": Language.AS_IN,
}
language = getattr(user_config.stt, "language", None)
language_value = language.value if hasattr(language, "value") else language
pipecat_language = language_mapping.get(language_value, Language.HI_IN)
pipecat_language = language_mapping.get(language, Language.HI_IN)
return SarvamSTTService(
api_key=user_config.stt.api_key,
model=user_config.stt.model.value,
model=user_config.stt.model,
params=SarvamSTTService.InputParams(language=pipecat_language),
audio_passthrough=False,
)
@ -105,13 +100,13 @@ def create_tts_service(user_config, audio_config: "AudioConfig"):
if user_config.tts.provider == ServiceProviders.DEEPGRAM.value:
return DeepgramTTSService(
api_key=user_config.tts.api_key,
voice=user_config.tts.voice.value,
voice=user_config.tts.voice,
text_filters=[xml_function_tag_filter],
)
elif user_config.tts.provider == ServiceProviders.OPENAI.value:
return OpenAITTSService(
api_key=user_config.tts.api_key,
model=user_config.tts.model.value,
model=user_config.tts.model,
text_filters=[xml_function_tag_filter],
)
elif user_config.tts.provider == ServiceProviders.ELEVENLABS.value:
@ -120,12 +115,11 @@ def create_tts_service(user_config, audio_config: "AudioConfig"):
voice_id = user_config.tts.voice.split(" - ")[1]
except IndexError:
voice_id = user_config.tts.voice
return ElevenLabsTTSService(
reconnect_on_error=False,
api_key=user_config.tts.api_key,
voice_id=voice_id,
model=user_config.tts.model.value,
model=user_config.tts.model,
params=ElevenLabsTTSService.InputParams(
stability=0.8, speed=user_config.tts.speed, similarity_boost=0.75
),
@ -134,12 +128,11 @@ def create_tts_service(user_config, audio_config: "AudioConfig"):
elif user_config.tts.provider == ServiceProviders.DOGRAH.value:
# Convert HTTP URL to WebSocket URL for TTS
base_url = MPS_API_URL.replace("http://", "ws://").replace("https://", "wss://")
# Handle both enum and string values for model and voice
return DograhTTSService(
base_url=base_url,
api_key=user_config.tts.api_key,
model=user_config.tts.model.value,
voice=user_config.tts.voice.value,
model=user_config.tts.model,
voice=user_config.tts.voice,
text_filters=[xml_function_tag_filter],
)
elif user_config.tts.provider == ServiceProviders.SARVAM.value:
@ -158,16 +151,13 @@ def create_tts_service(user_config, audio_config: "AudioConfig"):
"te-IN": Language.TE,
}
language = getattr(user_config.tts, "language", None)
language_value = language.value if hasattr(language, "value") else language
pipecat_language = language_mapping.get(language_value, Language.HI)
voice = getattr(user_config.tts, "voice", None)
voice_value = voice.value if hasattr(voice, "value") else (voice or "anushka")
pipecat_language = language_mapping.get(language, Language.HI)
voice = getattr(user_config.tts, "voice", None) or "anushka"
return SarvamTTSService(
api_key=user_config.tts.api_key,
model=user_config.tts.model.value,
voice_id=voice_value,
model=user_config.tts.model,
voice_id=voice,
params=SarvamTTSService.InputParams(language=pipecat_language),
text_filters=[xml_function_tag_filter],
)
@ -179,17 +169,12 @@ def create_tts_service(user_config, audio_config: "AudioConfig"):
def create_llm_service(user_config):
"""Create and return appropriate LLM service based on user configuration"""
# Handle both enum and string values for model
model_value = (
user_config.llm.model.value
if hasattr(user_config.llm.model, "value")
else user_config.llm.model
)
model = user_config.llm.model
if user_config.llm.provider == ServiceProviders.OPENAI.value:
if "gpt-5" in model_value:
if "gpt-5" in model:
return OpenAILLMService(
api_key=user_config.llm.api_key,
model=model_value,
model=model,
params=OpenAILLMService.InputParams(
reasoning_effort="minimal", verbosity="low"
),
@ -197,16 +182,16 @@ def create_llm_service(user_config):
else:
return OpenAILLMService(
api_key=user_config.llm.api_key,
model=model_value,
model=model,
params=OpenAILLMService.InputParams(temperature=0.1),
)
elif user_config.llm.provider == ServiceProviders.GROQ.value:
print(
f"Creating Groq LLM service with API key: {user_config.llm.api_key} and model: {model_value}"
f"Creating Groq LLM service with API key: {user_config.llm.api_key} and model: {model}"
)
return GroqLLMService(
api_key=user_config.llm.api_key,
model=model_value,
model=model,
params=OpenAILLMService.InputParams(temperature=0.1),
)
elif user_config.llm.provider == ServiceProviders.GOOGLE.value:
@ -214,21 +199,21 @@ def create_llm_service(user_config):
# NOT_GIVEN sentinels that break Pydantic validation in GoogleLLMService.
return GoogleLLMService(
api_key=user_config.llm.api_key,
model=model_value,
model=model,
params=GoogleLLMService.InputParams(temperature=0.1),
)
elif user_config.llm.provider == ServiceProviders.AZURE.value:
return AzureLLMService(
api_key=user_config.llm.api_key,
endpoint=user_config.llm.endpoint,
model=model_value, # Azure uses deployment name as model
model=model, # Azure uses deployment name as model
params=AzureLLMService.InputParams(temperature=0.1),
)
elif user_config.llm.provider == ServiceProviders.DOGRAH.value:
return DograhLLMService(
base_url=f"{MPS_API_URL}/api/v1/llm",
api_key=user_config.llm.api_key,
model=model_value,
model=model,
)
else:
raise HTTPException(status_code=400, detail="Invalid LLM provider")