feat: add dictionary support for STT boosting in voice agents (#136)

* feat: add dictionary support for voice agents

Also fixes #132

* chore: add keyterms in evals
This commit is contained in:
Abhishek 2026-01-29 11:20:07 +05:30 committed by GitHub
parent e3a1e0bf07
commit db75d90535
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
23 changed files with 9666 additions and 53 deletions

View file

@ -76,8 +76,22 @@ class LoopTalkPipelineBuilder:
pipeline_sample_rate=16000,
)
# Extract keyterms from workflow configurations
keyterms = None
if (
workflow.workflow_configurations
and "dictionary" in workflow.workflow_configurations
):
dictionary = workflow.workflow_configurations["dictionary"]
if dictionary and isinstance(dictionary, str):
keyterms = [
term.strip() for term in dictionary.split(",") if term.strip()
]
if keyterms:
logger.info(f"Using {len(keyterms)} keyterms for STT: {keyterms}")
# Create services
stt = create_stt_service(user_config)
stt = create_stt_service(user_config, keyterms=keyterms)
llm = create_llm_service(user_config)
tts = create_tts_service(user_config, audio_config)

View file

@ -443,12 +443,7 @@ async def _run_pipeline(
# Get user configuration
user_config = await db_client.get_user_configurations(user_id)
# Create services based on user configuration
stt = create_stt_service(user_config)
tts = create_tts_service(user_config, audio_config)
llm = create_llm_service(user_config)
# Get workflow first so we can create engine before pipeline components
# Get workflow first so we can extract configurations before creating services
workflow = await db_client.get_workflow(workflow_id, user_id)
if not workflow:
raise HTTPException(status_code=404, detail="Workflow not found")
@ -456,6 +451,7 @@ async def _run_pipeline(
# Extract configurations from workflow configurations
max_call_duration_seconds = 300 # Default 5 minutes
max_user_idle_timeout = 10.0 # Default 10 seconds
keyterms = None # Dictionary words for STT boosting
if workflow.workflow_configurations:
# Use workflow-specific max call duration if provided
@ -470,6 +466,20 @@ async def _run_pipeline(
"max_user_idle_timeout"
]
# Extract dictionary words and convert to keyterms list
if "dictionary" in workflow.workflow_configurations:
dictionary = workflow.workflow_configurations["dictionary"]
if dictionary and isinstance(dictionary, str):
# Split by comma and strip whitespace from each term
keyterms = [
term.strip() for term in dictionary.split(",") if term.strip()
]
# Create services based on user configuration
stt = create_stt_service(user_config, keyterms=keyterms)
tts = create_tts_service(user_config, audio_config)
llm = create_llm_service(user_config)
workflow_graph = WorkflowGraph(
ReactFlowDTO.model_validate(workflow.workflow_definition_with_fallback)
)

View file

@ -29,8 +29,13 @@ if TYPE_CHECKING:
from api.services.pipecat.audio_config import AudioConfig
def create_stt_service(user_config):
"""Create and return appropriate STT service based on user configuration"""
def create_stt_service(user_config, keyterms: list[str] | None = None):
"""Create and return appropriate STT service based on user configuration
Args:
user_config: User configuration containing STT settings
keyterms: Optional list of keyterms for speech recognition boosting (Deepgram only)
"""
logger.info(
f"Creating STT service: provider={user_config.stt.provider}, model={user_config.stt.model}"
)
@ -44,6 +49,7 @@ def create_stt_service(user_config):
params=DeepgramFluxSTTService.InputParams(
eot_timeout_ms=3000,
eot_threshold=0.7,
keyterm=keyterms or [],
),
should_interrupt=False, # Let UserAggregator take care of sending InterruptionFrame
)
@ -56,6 +62,7 @@ def create_stt_service(user_config):
profanity_filter=False,
endpointing=100,
model=user_config.stt.model,
keyterm=keyterms or [],
)
logger.debug(f"Using DeepGram Model - {user_config.stt.model}")
return DeepgramSTTService(
@ -77,6 +84,7 @@ def create_stt_service(user_config):
api_key=user_config.stt.api_key,
model=user_config.stt.model,
language=language,
keyterms=keyterms,
)
elif user_config.stt.provider == ServiceProviders.SARVAM.value:
# Map Sarvam language code to pipecat Language enum
@ -102,7 +110,10 @@ def create_stt_service(user_config):
params=SarvamSTTService.InputParams(language=pipecat_language),
)
elif user_config.stt.provider == ServiceProviders.SPEECHMATICS.value:
from pipecat.services.speechmatics.stt import OperatingPoint
from pipecat.services.speechmatics.stt import (
AdditionalVocabEntry,
OperatingPoint,
)
language = getattr(user_config.stt, "language", None) or "en"
# Map model field to operating point (standard or enhanced)
@ -111,11 +122,16 @@ def create_stt_service(user_config):
if user_config.stt.model == "enhanced"
else OperatingPoint.STANDARD
)
# Convert keyterms to AdditionalVocabEntry objects for Speechmatics
additional_vocab = []
if keyterms:
additional_vocab = [AdditionalVocabEntry(content=term) for term in keyterms]
return SpeechmaticsSTTService(
api_key=user_config.stt.api_key,
params=SpeechmaticsSTTService.InputParams(
language=language,
operating_point=operating_point,
additional_vocab=additional_vocab,
),
)
else: