dograh/api/services/pipecat/service_factory.py

324 lines
13 KiB
Python
Raw Permalink Normal View History

2025-09-09 14:37:32 +05:30
from typing import TYPE_CHECKING
from fastapi import HTTPException
from loguru import logger
2025-09-09 14:37:32 +05:30
from api.constants import MPS_API_URL
from api.services.configuration.registry import ServiceProviders
from pipecat.services.azure.llm import AzureLLMService, AzureLLMSettings
2025-09-09 14:37:32 +05:30
from pipecat.services.cartesia.stt import CartesiaSTTService
from pipecat.services.cartesia.tts import CartesiaTTSService, CartesiaTTSSettings
from pipecat.services.deepgram.flux.stt import (
DeepgramFluxSTTService,
DeepgramFluxSTTSettings,
)
from pipecat.services.deepgram.stt import DeepgramSTTService, DeepgramSTTSettings
from pipecat.services.deepgram.tts import DeepgramTTSService, DeepgramTTSSettings
2025-09-09 14:37:32 +05:30
from pipecat.services.dograh.llm import DograhLLMService
from pipecat.services.dograh.stt import DograhSTTService, DograhSTTSettings
from pipecat.services.dograh.tts import DograhTTSService, DograhTTSSettings
from pipecat.services.elevenlabs.tts import ElevenLabsTTSService, ElevenLabsTTSSettings
from pipecat.services.google.llm import GoogleLLMService, GoogleLLMSettings
from pipecat.services.groq.llm import GroqLLMService, GroqLLMSettings
from pipecat.services.openai.base_llm import OpenAILLMSettings
2025-09-09 14:37:32 +05:30
from pipecat.services.openai.llm import OpenAILLMService
from pipecat.services.openai.stt import OpenAISTTService, OpenAISTTSettings
from pipecat.services.openai.tts import OpenAITTSService, OpenAITTSSettings
from pipecat.services.openrouter.llm import OpenRouterLLMService, OpenRouterLLMSettings
from pipecat.services.sarvam.stt import SarvamSTTService, SarvamSTTSettings
from pipecat.services.sarvam.tts import SarvamTTSService, SarvamTTSSettings
from pipecat.services.speechmatics.stt import (
SpeechmaticsSTTService,
SpeechmaticsSTTSettings,
)
from pipecat.transcriptions.language import Language
from pipecat.utils.text.xml_function_tag_filter import XMLFunctionTagFilter
2025-09-09 14:37:32 +05:30
if TYPE_CHECKING:
from api.services.pipecat.audio_config import AudioConfig
2026-02-11 14:15:19 +05:30
def create_stt_service(
user_config, audio_config: "AudioConfig", keyterms: list[str] | None = None
):
"""Create and return appropriate STT service based on user configuration
Args:
user_config: User configuration containing STT settings
keyterms: Optional list of keyterms for speech recognition boosting (Deepgram only)
"""
2026-01-13 14:55:48 +05:30
logger.info(
f"Creating STT service: provider={user_config.stt.provider}, model={user_config.stt.model}"
)
2025-09-09 14:37:32 +05:30
if user_config.stt.provider == ServiceProviders.DEEPGRAM.value:
# Check if using Flux model (English-only, no language selection)
if user_config.stt.model == "flux-general-en":
logger.debug("Using DeepGram Flux Model")
return DeepgramFluxSTTService(
api_key=user_config.stt.api_key,
settings=DeepgramFluxSTTSettings(
model=user_config.stt.model,
eot_timeout_ms=3000,
eot_threshold=0.7,
2026-03-07 12:41:24 +05:30
eager_eot_threshold=0.5,
keyterm=keyterms or [],
),
should_interrupt=False, # Let UserAggregator take care of sending InterruptionFrame
2026-02-11 14:15:19 +05:30
sample_rate=audio_config.transport_in_sample_rate,
)
# Other models than flux
# Use language from user config, defaulting to "multi" for multilingual support
language = getattr(user_config.stt, "language", None) or "multi"
logger.debug(f"Using DeepGram Model - {user_config.stt.model}")
2025-09-09 14:37:32 +05:30
return DeepgramSTTService(
api_key=user_config.stt.api_key,
settings=DeepgramSTTSettings(
language=language,
profanity_filter=False,
endpointing=100,
model=user_config.stt.model,
keyterm=keyterms or [],
),
should_interrupt=False, # Let UserAggregator take care of sending InterruptionFrame
2026-02-11 14:15:19 +05:30
sample_rate=audio_config.transport_in_sample_rate,
2025-09-09 14:37:32 +05:30
)
elif user_config.stt.provider == ServiceProviders.OPENAI.value:
return OpenAISTTService(
api_key=user_config.stt.api_key,
settings=OpenAISTTSettings(model=user_config.stt.model),
2025-09-09 14:37:32 +05:30
)
elif user_config.stt.provider == ServiceProviders.CARTESIA.value:
2026-02-11 14:15:19 +05:30
return CartesiaSTTService(
api_key=user_config.stt.api_key,
sample_rate=audio_config.transport_in_sample_rate,
)
2025-09-09 14:37:32 +05:30
elif user_config.stt.provider == ServiceProviders.DOGRAH.value:
base_url = MPS_API_URL.replace("http://", "ws://").replace("https://", "wss://")
language = getattr(user_config.stt, "language", None) or "multi"
2025-09-09 14:37:32 +05:30
return DograhSTTService(
base_url=base_url,
api_key=user_config.stt.api_key,
settings=DograhSTTSettings(
model=user_config.stt.model,
language=language,
),
keyterms=keyterms,
2026-02-11 14:15:19 +05:30
sample_rate=audio_config.transport_in_sample_rate,
2025-09-09 14:37:32 +05:30
)
elif user_config.stt.provider == ServiceProviders.SARVAM.value:
# Map Sarvam language code to pipecat Language enum
language_mapping = {
"bn-IN": Language.BN_IN,
"gu-IN": Language.GU_IN,
"hi-IN": Language.HI_IN,
"kn-IN": Language.KN_IN,
"ml-IN": Language.ML_IN,
"mr-IN": Language.MR_IN,
"ta-IN": Language.TA_IN,
"te-IN": Language.TE_IN,
"pa-IN": Language.PA_IN,
"od-IN": Language.OR_IN,
"en-IN": Language.EN_IN,
"as-IN": Language.AS_IN,
}
language = getattr(user_config.stt, "language", None)
pipecat_language = language_mapping.get(language, Language.HI_IN)
return SarvamSTTService(
api_key=user_config.stt.api_key,
settings=SarvamSTTSettings(
model=user_config.stt.model,
language=pipecat_language,
),
2026-02-11 14:15:19 +05:30
sample_rate=audio_config.transport_in_sample_rate,
)
elif user_config.stt.provider == ServiceProviders.SPEECHMATICS.value:
from pipecat.services.speechmatics.stt import (
AdditionalVocabEntry,
OperatingPoint,
)
language = getattr(user_config.stt, "language", None) or "en"
# Map model field to operating point (standard or enhanced)
operating_point = (
OperatingPoint.ENHANCED
if user_config.stt.model == "enhanced"
else OperatingPoint.STANDARD
)
# Convert keyterms to AdditionalVocabEntry objects for Speechmatics
additional_vocab = []
if keyterms:
additional_vocab = [AdditionalVocabEntry(content=term) for term in keyterms]
return SpeechmaticsSTTService(
api_key=user_config.stt.api_key,
settings=SpeechmaticsSTTSettings(
language=language,
operating_point=operating_point,
additional_vocab=additional_vocab,
),
2026-02-11 14:15:19 +05:30
sample_rate=audio_config.transport_in_sample_rate,
)
2025-09-09 14:37:32 +05:30
else:
raise HTTPException(
status_code=400, detail=f"Invalid STT provider {user_config.stt.provider}"
)
def create_tts_service(user_config, audio_config: "AudioConfig"):
"""Create and return appropriate TTS service based on user configuration
Args:
user_config: User configuration containing TTS settings
transport_type: Type of transport (e.g., 'twilio', 'webrtc')
2025-09-09 14:37:32 +05:30
"""
2026-01-13 14:55:48 +05:30
logger.info(
f"Creating TTS service: provider={user_config.tts.provider}, model={user_config.tts.model}"
)
# Create function call filter to prevent TTS from speaking function call tags
xml_function_tag_filter = XMLFunctionTagFilter()
2025-09-09 14:37:32 +05:30
if user_config.tts.provider == ServiceProviders.DEEPGRAM.value:
return DeepgramTTSService(
api_key=user_config.tts.api_key,
settings=DeepgramTTSSettings(voice=user_config.tts.voice),
text_filters=[xml_function_tag_filter],
silence_time_s=1.0,
2025-09-09 14:37:32 +05:30
)
elif user_config.tts.provider == ServiceProviders.OPENAI.value:
return OpenAITTSService(
api_key=user_config.tts.api_key,
settings=OpenAITTSSettings(model=user_config.tts.model),
text_filters=[xml_function_tag_filter],
silence_time_s=1.0,
2025-09-09 14:37:32 +05:30
)
elif user_config.tts.provider == ServiceProviders.ELEVENLABS.value:
# Backward compatible with older configuration "Name - voice_id"
try:
voice_id = user_config.tts.voice.split(" - ")[1]
except IndexError:
voice_id = user_config.tts.voice
2025-09-09 14:37:32 +05:30
return ElevenLabsTTSService(
reconnect_on_error=False,
api_key=user_config.tts.api_key,
settings=ElevenLabsTTSSettings(
voice=voice_id,
model=user_config.tts.model,
stability=0.8,
speed=user_config.tts.speed,
similarity_boost=0.75,
2025-09-09 14:37:32 +05:30
),
text_filters=[xml_function_tag_filter],
silence_time_s=1.0,
2025-09-09 14:37:32 +05:30
)
2026-02-20 20:41:11 +05:30
elif user_config.tts.provider == ServiceProviders.CARTESIA.value:
return CartesiaTTSService(
api_key=user_config.tts.api_key,
settings=CartesiaTTSSettings(
voice=user_config.tts.voice,
model=user_config.tts.model,
),
2026-02-20 20:41:11 +05:30
text_filters=[xml_function_tag_filter],
silence_time_s=1.0,
2026-02-20 20:41:11 +05:30
)
2025-09-09 14:37:32 +05:30
elif user_config.tts.provider == ServiceProviders.DOGRAH.value:
# Convert HTTP URL to WebSocket URL for TTS
base_url = MPS_API_URL.replace("http://", "ws://").replace("https://", "wss://")
return DograhTTSService(
base_url=base_url,
api_key=user_config.tts.api_key,
settings=DograhTTSSettings(
model=user_config.tts.model,
voice=user_config.tts.voice,
speed=user_config.tts.speed,
),
text_filters=[xml_function_tag_filter],
silence_time_s=1.0,
2025-09-09 14:37:32 +05:30
)
elif user_config.tts.provider == ServiceProviders.SARVAM.value:
# Map Sarvam language code to pipecat Language enum for TTS
language_mapping = {
"bn-IN": Language.BN,
"en-IN": Language.EN,
"gu-IN": Language.GU,
"hi-IN": Language.HI,
"kn-IN": Language.KN,
"ml-IN": Language.ML,
"mr-IN": Language.MR,
"od-IN": Language.OR,
"pa-IN": Language.PA,
"ta-IN": Language.TA,
"te-IN": Language.TE,
}
language = getattr(user_config.tts, "language", None)
pipecat_language = language_mapping.get(language, Language.HI)
voice = getattr(user_config.tts, "voice", None) or "anushka"
return SarvamTTSService(
api_key=user_config.tts.api_key,
settings=SarvamTTSSettings(
model=user_config.tts.model,
voice=voice,
language=pipecat_language,
),
text_filters=[xml_function_tag_filter],
silence_time_s=1.0,
)
2025-09-09 14:37:32 +05:30
else:
raise HTTPException(
status_code=400, detail=f"Invalid TTS provider {user_config.tts.provider}"
)
def create_llm_service(user_config):
"""Create and return appropriate LLM service based on user configuration"""
model = user_config.llm.model
2026-01-13 14:55:48 +05:30
logger.info(
f"Creating LLM service: provider={user_config.llm.provider}, model={model}"
)
2025-09-09 14:37:32 +05:30
if user_config.llm.provider == ServiceProviders.OPENAI.value:
if "gpt-5" in model:
2025-09-09 14:37:32 +05:30
return OpenAILLMService(
api_key=user_config.llm.api_key,
settings=OpenAILLMSettings(
model=model,
extra={"reasoning_effort": "minimal", "verbosity": "low"},
2025-09-09 14:37:32 +05:30
),
)
else:
return OpenAILLMService(
api_key=user_config.llm.api_key,
settings=OpenAILLMSettings(model=model, temperature=0.1),
2025-09-09 14:37:32 +05:30
)
elif user_config.llm.provider == ServiceProviders.GROQ.value:
print(
f"Creating Groq LLM service with API key: {user_config.llm.api_key} and model: {model}"
2025-09-09 14:37:32 +05:30
)
return GroqLLMService(
api_key=user_config.llm.api_key,
settings=GroqLLMSettings(model=model, temperature=0.1),
2025-09-09 14:37:32 +05:30
)
2026-02-09 13:31:32 +05:30
elif user_config.llm.provider == ServiceProviders.OPENROUTER.value:
return OpenRouterLLMService(
api_key=user_config.llm.api_key,
base_url=user_config.llm.base_url,
settings=OpenRouterLLMSettings(model=model, temperature=0.1),
2026-02-09 13:31:32 +05:30
)
2025-09-09 14:37:32 +05:30
elif user_config.llm.provider == ServiceProviders.GOOGLE.value:
return GoogleLLMService(
api_key=user_config.llm.api_key,
settings=GoogleLLMSettings(model=model, temperature=0.1),
2025-09-09 14:37:32 +05:30
)
elif user_config.llm.provider == ServiceProviders.AZURE.value:
return AzureLLMService(
api_key=user_config.llm.api_key,
endpoint=user_config.llm.endpoint,
settings=AzureLLMSettings(model=model, temperature=0.1),
2025-09-09 14:37:32 +05:30
)
elif user_config.llm.provider == ServiceProviders.DOGRAH.value:
return DograhLLMService(
base_url=f"{MPS_API_URL}/api/v1/llm",
api_key=user_config.llm.api_key,
settings=OpenAILLMSettings(model=model),
2025-09-09 14:37:32 +05:30
)
else:
raise HTTPException(status_code=400, detail="Invalid LLM provider")