mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-10 08:05:22 +02:00
fix: send sample rate to STT services
This commit is contained in:
parent
4c936ae57d
commit
7a102026fb
6 changed files with 16 additions and 11 deletions
|
|
@ -91,7 +91,7 @@ class LoopTalkPipelineBuilder:
|
|||
logger.info(f"Using {len(keyterms)} keyterms for STT: {keyterms}")
|
||||
|
||||
# Create services
|
||||
stt = create_stt_service(user_config, keyterms=keyterms)
|
||||
stt = create_stt_service(user_config, audio_config, keyterms=keyterms)
|
||||
llm = create_llm_service(user_config)
|
||||
tts = create_tts_service(user_config, audio_config)
|
||||
|
||||
|
|
|
|||
|
|
@ -104,7 +104,7 @@ def create_audio_config(transport_type: str) -> AudioConfig:
|
|||
transport_out_sample_rate=8000,
|
||||
vad_sample_rate=8000, # Use matching VAD rate
|
||||
pipeline_sample_rate=8000, # Keep at 8kHz to avoid resampling
|
||||
buffer_size_seconds=1.0,
|
||||
buffer_size_seconds=5.0,
|
||||
)
|
||||
elif transport_type == WorkflowRunMode.VONAGE.value:
|
||||
# Vonage uses 16kHz Linear PCM
|
||||
|
|
@ -113,7 +113,7 @@ def create_audio_config(transport_type: str) -> AudioConfig:
|
|||
transport_out_sample_rate=16000,
|
||||
vad_sample_rate=16000, # Use matching VAD rate
|
||||
pipeline_sample_rate=16000, # Keep at 16kHz to avoid resampling
|
||||
buffer_size_seconds=1.0,
|
||||
buffer_size_seconds=5.0,
|
||||
)
|
||||
elif transport_type in [
|
||||
WorkflowRunMode.WEBRTC.value,
|
||||
|
|
@ -126,7 +126,7 @@ def create_audio_config(transport_type: str) -> AudioConfig:
|
|||
transport_out_sample_rate=16000, # Transport will resample to 24kHz
|
||||
vad_sample_rate=16000, # VAD native rate
|
||||
pipeline_sample_rate=16000, # Keep pipeline at 16kHz
|
||||
buffer_size_seconds=1.0,
|
||||
buffer_size_seconds=5.0,
|
||||
)
|
||||
else:
|
||||
# Default configuration
|
||||
|
|
@ -138,5 +138,5 @@ def create_audio_config(transport_type: str) -> AudioConfig:
|
|||
transport_out_sample_rate=16000,
|
||||
vad_sample_rate=16000,
|
||||
pipeline_sample_rate=16000,
|
||||
buffer_size_seconds=1.0,
|
||||
buffer_size_seconds=5.0,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -494,7 +494,7 @@ async def _run_pipeline(
|
|||
]
|
||||
|
||||
# Create services based on user configuration
|
||||
stt = create_stt_service(user_config, keyterms=keyterms)
|
||||
stt = create_stt_service(user_config, audio_config, keyterms=keyterms)
|
||||
tts = create_tts_service(user_config, audio_config)
|
||||
llm = create_llm_service(user_config)
|
||||
|
||||
|
|
|
|||
|
|
@ -30,7 +30,7 @@ if TYPE_CHECKING:
|
|||
from api.services.pipecat.audio_config import AudioConfig
|
||||
|
||||
|
||||
def create_stt_service(user_config, keyterms: list[str] | None = None):
|
||||
def create_stt_service(user_config, audio_config: "AudioConfig", keyterms: list[str] | None = None):
|
||||
"""Create and return appropriate STT service based on user configuration
|
||||
|
||||
Args:
|
||||
|
|
@ -53,6 +53,7 @@ def create_stt_service(user_config, keyterms: list[str] | None = None):
|
|||
keyterm=keyterms or [],
|
||||
),
|
||||
should_interrupt=False, # Let UserAggregator take care of sending InterruptionFrame
|
||||
sample_rate=audio_config.transport_in_sample_rate
|
||||
)
|
||||
|
||||
# Other models than flux
|
||||
|
|
@ -63,20 +64,21 @@ def create_stt_service(user_config, keyterms: list[str] | None = None):
|
|||
profanity_filter=False,
|
||||
endpointing=100,
|
||||
model=user_config.stt.model,
|
||||
keyterm=keyterms or [],
|
||||
keyterm=keyterms or []
|
||||
)
|
||||
logger.debug(f"Using DeepGram Model - {user_config.stt.model}")
|
||||
return DeepgramSTTService(
|
||||
live_options=live_options,
|
||||
api_key=user_config.stt.api_key,
|
||||
should_interrupt=False, # Let UserAggregator take care of sending InterruptionFrame
|
||||
sample_rate=audio_config.transport_in_sample_rate
|
||||
)
|
||||
elif user_config.stt.provider == ServiceProviders.OPENAI.value:
|
||||
return OpenAISTTService(
|
||||
api_key=user_config.stt.api_key, model=user_config.stt.model
|
||||
)
|
||||
elif user_config.stt.provider == ServiceProviders.CARTESIA.value:
|
||||
return CartesiaSTTService(api_key=user_config.stt.api_key)
|
||||
return CartesiaSTTService(api_key=user_config.stt.api_key, sample_rate=audio_config.transport_in_sample_rate)
|
||||
elif user_config.stt.provider == ServiceProviders.DOGRAH.value:
|
||||
base_url = MPS_API_URL.replace("http://", "ws://").replace("https://", "wss://")
|
||||
language = getattr(user_config.stt, "language", None) or "multi"
|
||||
|
|
@ -86,6 +88,7 @@ def create_stt_service(user_config, keyterms: list[str] | None = None):
|
|||
model=user_config.stt.model,
|
||||
language=language,
|
||||
keyterms=keyterms,
|
||||
sample_rate=audio_config.transport_in_sample_rate
|
||||
)
|
||||
elif user_config.stt.provider == ServiceProviders.SARVAM.value:
|
||||
# Map Sarvam language code to pipecat Language enum
|
||||
|
|
@ -109,6 +112,7 @@ def create_stt_service(user_config, keyterms: list[str] | None = None):
|
|||
api_key=user_config.stt.api_key,
|
||||
model=user_config.stt.model,
|
||||
params=SarvamSTTService.InputParams(language=pipecat_language),
|
||||
sample_rate=audio_config.transport_in_sample_rate
|
||||
)
|
||||
elif user_config.stt.provider == ServiceProviders.SPEECHMATICS.value:
|
||||
from pipecat.services.speechmatics.stt import (
|
||||
|
|
@ -134,6 +138,7 @@ def create_stt_service(user_config, keyterms: list[str] | None = None):
|
|||
operating_point=operating_point,
|
||||
additional_vocab=additional_vocab,
|
||||
),
|
||||
sample_rate=audio_config.transport_in_sample_rate
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
|
|
|
|||
|
|
@ -368,7 +368,7 @@ def create_stasis_transport(
|
|||
audio_out_enabled=True,
|
||||
audio_out_sample_rate=audio_config.transport_out_sample_rate,
|
||||
audio_in_sample_rate=audio_config.transport_in_sample_rate,
|
||||
audio_out_10ms_chunks=2, # Send 20ms packets
|
||||
# audio_out_10ms_chunks=2, # ToDo: Check if we cant support 40 ms packets?
|
||||
audio_out_mixer=(
|
||||
SoundfileMixer(
|
||||
sound_files={
|
||||
|
|
|
|||
2
pipecat
2
pipecat
|
|
@ -1 +1 @@
|
|||
Subproject commit 5313e8cd94443f220cc56c10cc2fc2aa98d8b6ba
|
||||
Subproject commit d67983b3b165f945a93e5ce594f47781a96bff9b
|
||||
Loading…
Add table
Add a link
Reference in a new issue