fix: add text filter for tts and logs for filter

This commit is contained in:
Sabiha Khan 2025-12-03 16:55:42 +05:30
parent 570168424a
commit 003039ca56
3 changed files with 21 additions and 3 deletions

View file

@ -17,6 +17,7 @@ from pipecat.services.groq.llm import GroqLLMService
from pipecat.services.openai.llm import OpenAILLMService
from pipecat.services.openai.stt import OpenAISTTService
from pipecat.services.openai.tts import OpenAITTSService
from pipecat.utils.text.xml_function_tag_filter import XMLFunctionTagFilter
if TYPE_CHECKING:
from api.services.pipecat.audio_config import AudioConfig
@ -65,13 +66,19 @@ def create_tts_service(user_config, audio_config: "AudioConfig"):
user_config: User configuration containing TTS settings
transport_type: Type of transport (e.g., 'stasis', 'twilio', 'webrtc')
"""
# Create function call filter to prevent TTS from speaking function call tags
xml_function_tag_filter = XMLFunctionTagFilter()
if user_config.tts.provider == ServiceProviders.DEEPGRAM.value:
return DeepgramTTSService(
api_key=user_config.tts.api_key, voice=user_config.tts.voice.value
api_key=user_config.tts.api_key,
voice=user_config.tts.voice.value,
text_filters=[xml_function_tag_filter]
)
elif user_config.tts.provider == ServiceProviders.OPENAI.value:
return OpenAITTSService(
api_key=user_config.tts.api_key, model=user_config.tts.model.value
api_key=user_config.tts.api_key,
model=user_config.tts.model.value,
text_filters=[xml_function_tag_filter]
)
elif user_config.tts.provider == ServiceProviders.ELEVENLABS.value:
voice_id = user_config.tts.voice.split(" - ")[1]
@ -83,6 +90,7 @@ def create_tts_service(user_config, audio_config: "AudioConfig"):
params=ElevenLabsTTSService.InputParams(
stability=0.8, speed=user_config.tts.speed, similarity_boost=0.75
),
text_filters=[xml_function_tag_filter]
)
elif user_config.tts.provider == ServiceProviders.DOGRAH.value:
# Convert HTTP URL to WebSocket URL for TTS
@ -93,6 +101,7 @@ def create_tts_service(user_config, audio_config: "AudioConfig"):
api_key=user_config.tts.api_key,
model=user_config.tts.model.value,
voice=user_config.tts.voice.value,
text_filters=[xml_function_tag_filter]
)
else:
raise HTTPException(

View file

@ -181,6 +181,9 @@ class PipecatEngine:
async def _create_transition_func(self, name: str, transition_to_node: str):
async def transition_func(function_call_params: FunctionCallParams) -> None:
"""Inner function that handles the node change tool calls"""
logger.info(f"LLM Function Call EXECUTED: {name}")
logger.info(f"Function: {name} -> transitioning to node: {transition_to_node}")
logger.info(f"Arguments: {function_call_params.arguments}")
try:
async def on_context_updated() -> None:
@ -240,6 +243,8 @@ class PipecatEngine:
# Register calculator function
async def calculate_func(function_call_params: FunctionCallParams) -> None:
logger.info(f"LLM Function Call EXECUTED: safe_calculator")
logger.info(f"Arguments: {function_call_params.arguments}")
try:
expr = function_call_params.arguments.get("expression", "")
result = safe_calculator(expr)
@ -255,6 +260,8 @@ class PipecatEngine:
async def get_current_time_func(
function_call_params: FunctionCallParams,
) -> None:
logger.info(f"LLM Function Call EXECUTED: get_current_time")
logger.info(f"Arguments: {function_call_params.arguments}")
try:
timezone = function_call_params.arguments.get("timezone", "UTC")
result = get_current_time(timezone)
@ -267,6 +274,8 @@ class PipecatEngine:
)
async def convert_time_func(function_call_params: FunctionCallParams) -> None:
logger.info(f"LLM Function Call EXECUTED: convert_time")
logger.info(f"Arguments: {function_call_params.arguments}")
try:
result = convert_time(
function_call_params.arguments.get("source_timezone"),

@ -1 +1 @@
Subproject commit 30c6d1edb93c144a52adc6a3aa1aa618b1ee85fc
Subproject commit 6e6cf412ad1d3251c3f4b0c06a85ae9a66b5719e