mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-10 08:05:22 +02:00
fix: add text filter for tts and logs for filter
This commit is contained in:
parent
570168424a
commit
003039ca56
3 changed files with 21 additions and 3 deletions
|
|
@ -17,6 +17,7 @@ from pipecat.services.groq.llm import GroqLLMService
|
|||
from pipecat.services.openai.llm import OpenAILLMService
|
||||
from pipecat.services.openai.stt import OpenAISTTService
|
||||
from pipecat.services.openai.tts import OpenAITTSService
|
||||
from pipecat.utils.text.xml_function_tag_filter import XMLFunctionTagFilter
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from api.services.pipecat.audio_config import AudioConfig
|
||||
|
|
@ -65,13 +66,19 @@ def create_tts_service(user_config, audio_config: "AudioConfig"):
|
|||
user_config: User configuration containing TTS settings
|
||||
transport_type: Type of transport (e.g., 'stasis', 'twilio', 'webrtc')
|
||||
"""
|
||||
# Create function call filter to prevent TTS from speaking function call tags
|
||||
xml_function_tag_filter = XMLFunctionTagFilter()
|
||||
if user_config.tts.provider == ServiceProviders.DEEPGRAM.value:
|
||||
return DeepgramTTSService(
|
||||
api_key=user_config.tts.api_key, voice=user_config.tts.voice.value
|
||||
api_key=user_config.tts.api_key,
|
||||
voice=user_config.tts.voice.value,
|
||||
text_filters=[xml_function_tag_filter]
|
||||
)
|
||||
elif user_config.tts.provider == ServiceProviders.OPENAI.value:
|
||||
return OpenAITTSService(
|
||||
api_key=user_config.tts.api_key, model=user_config.tts.model.value
|
||||
api_key=user_config.tts.api_key,
|
||||
model=user_config.tts.model.value,
|
||||
text_filters=[xml_function_tag_filter]
|
||||
)
|
||||
elif user_config.tts.provider == ServiceProviders.ELEVENLABS.value:
|
||||
voice_id = user_config.tts.voice.split(" - ")[1]
|
||||
|
|
@ -83,6 +90,7 @@ def create_tts_service(user_config, audio_config: "AudioConfig"):
|
|||
params=ElevenLabsTTSService.InputParams(
|
||||
stability=0.8, speed=user_config.tts.speed, similarity_boost=0.75
|
||||
),
|
||||
text_filters=[xml_function_tag_filter]
|
||||
)
|
||||
elif user_config.tts.provider == ServiceProviders.DOGRAH.value:
|
||||
# Convert HTTP URL to WebSocket URL for TTS
|
||||
|
|
@ -93,6 +101,7 @@ def create_tts_service(user_config, audio_config: "AudioConfig"):
|
|||
api_key=user_config.tts.api_key,
|
||||
model=user_config.tts.model.value,
|
||||
voice=user_config.tts.voice.value,
|
||||
text_filters=[xml_function_tag_filter]
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
|
|
|
|||
|
|
@ -181,6 +181,9 @@ class PipecatEngine:
|
|||
async def _create_transition_func(self, name: str, transition_to_node: str):
|
||||
async def transition_func(function_call_params: FunctionCallParams) -> None:
|
||||
"""Inner function that handles the node change tool calls"""
|
||||
logger.info(f"LLM Function Call EXECUTED: {name}")
|
||||
logger.info(f"Function: {name} -> transitioning to node: {transition_to_node}")
|
||||
logger.info(f"Arguments: {function_call_params.arguments}")
|
||||
try:
|
||||
|
||||
async def on_context_updated() -> None:
|
||||
|
|
@ -240,6 +243,8 @@ class PipecatEngine:
|
|||
|
||||
# Register calculator function
|
||||
async def calculate_func(function_call_params: FunctionCallParams) -> None:
|
||||
logger.info(f"LLM Function Call EXECUTED: safe_calculator")
|
||||
logger.info(f"Arguments: {function_call_params.arguments}")
|
||||
try:
|
||||
expr = function_call_params.arguments.get("expression", "")
|
||||
result = safe_calculator(expr)
|
||||
|
|
@ -255,6 +260,8 @@ class PipecatEngine:
|
|||
async def get_current_time_func(
|
||||
function_call_params: FunctionCallParams,
|
||||
) -> None:
|
||||
logger.info(f"LLM Function Call EXECUTED: get_current_time")
|
||||
logger.info(f"Arguments: {function_call_params.arguments}")
|
||||
try:
|
||||
timezone = function_call_params.arguments.get("timezone", "UTC")
|
||||
result = get_current_time(timezone)
|
||||
|
|
@ -267,6 +274,8 @@ class PipecatEngine:
|
|||
)
|
||||
|
||||
async def convert_time_func(function_call_params: FunctionCallParams) -> None:
|
||||
logger.info(f"LLM Function Call EXECUTED: convert_time")
|
||||
logger.info(f"Arguments: {function_call_params.arguments}")
|
||||
try:
|
||||
result = convert_time(
|
||||
function_call_params.arguments.get("source_timezone"),
|
||||
|
|
|
|||
2
pipecat
2
pipecat
|
|
@ -1 +1 @@
|
|||
Subproject commit 30c6d1edb93c144a52adc6a3aa1aa618b1ee85fc
|
||||
Subproject commit 6e6cf412ad1d3251c3f4b0c06a85ae9a66b5719e
|
||||
Loading…
Add table
Add a link
Reference in a new issue