feat: add Azure AI multi-provider support (TTS, STT, Embeddings, Realtime)

Enables Azure AI services across all model layers so users with Azure
credits can consolidate billing on a single provider.

- Voice (TTS): AzureSpeechTTSConfiguration via azure_speech provider
- Transcriber (STT): AzureSpeechSTTConfiguration via azure_speech provider
- Embedding: AzureOpenAIEmbeddingsConfiguration via azure provider
- Realtime: AzureRealtimeLLMConfiguration via azure_realtime provider

New files:
- api/services/pipecat/realtime/azure_realtime.py
- api/services/gen_ai/embedding/azure_openai_service.py
- api/tests/test_azure_speech_service_factory.py

The UI picks up all four providers automatically from the schema —
no frontend changes required.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
Vishal Dhateria 2026-05-29 20:48:42 +05:30
parent e695436fb3
commit dbbf362315
12 changed files with 883 additions and 28 deletions

View file

@ -0,0 +1,242 @@
"""Dograh subclass of pipecat's Azure OpenAI Realtime LLM service.
Layers Dograh engine integration quirks (mute gating, TTSSpeakFrame greeting
trigger, LLMMessagesAppendFrame handling, deferred tool calls) onto pipecat's
AzureRealtimeLLMService, mirroring what DograhOpenAIRealtimeLLMService does
for the standard OpenAI Realtime endpoint.
"""
import json
from typing import Any
from loguru import logger
from pipecat.frames.frames import (
BotStartedSpeakingFrame,
BotStoppedSpeakingFrame,
Frame,
LLMFullResponseStartFrame,
LLMMessagesAppendFrame,
TranscriptionFrame,
TTSSpeakFrame,
UserMuteStartedFrame,
UserMuteStoppedFrame,
)
from pipecat.processors.aggregators.llm_context import LLMContext
from pipecat.processors.frame_processor import FrameDirection
from pipecat.services.azure.realtime.llm import AzureRealtimeLLMService
from pipecat.services.llm_service import FunctionCallFromLLM
from pipecat.services.openai.realtime import events
from pipecat.transcriptions.language import Language
from pipecat.utils.time import time_now_iso8601
class DograhAzureRealtimeLLMService(AzureRealtimeLLMService):
"""Azure OpenAI Realtime with Dograh engine integration quirks.
Extends AzureRealtimeLLMService with the same Dograh-specific behaviours
added to DograhOpenAIRealtimeLLMService:
- User-mute audio gating
- TTSSpeakFrame as initial-response trigger
- One-off LLMMessagesAppendFrame handling
- Deferred tool calls until bot finishes speaking
- finalized=True on TranscriptionFrame for consistency
"""
def __init__(self, **kwargs):
super().__init__(**kwargs)
self._user_is_muted: bool = False
self._handled_initial_context: bool = False
self._bot_is_speaking: bool = False
self._deferred_function_calls: list[FunctionCallFromLLM] = []
async def process_frame(self, frame: Frame, direction: FrameDirection):
if isinstance(frame, UserMuteStartedFrame):
self._user_is_muted = True
await self.push_frame(frame, direction)
return
if isinstance(frame, UserMuteStoppedFrame):
self._user_is_muted = False
await self.push_frame(frame, direction)
return
if isinstance(frame, TTSSpeakFrame):
if not self._handled_initial_context:
await self._handle_context(self._context)
else:
logger.warning(
f"{self}: TTSSpeakFrame after initial context already handled — "
"Azure Realtime owns audio generation, ignoring"
)
return
if isinstance(frame, LLMMessagesAppendFrame):
await self._handle_messages_append(frame)
return
if isinstance(frame, BotStartedSpeakingFrame):
self._bot_is_speaking = True
elif isinstance(frame, BotStoppedSpeakingFrame):
self._bot_is_speaking = False
await self._run_pending_function_calls()
await super().process_frame(frame, direction)
async def _handle_messages_append(self, frame: LLMMessagesAppendFrame):
if self._disconnecting:
return
if not self._api_session_ready:
if frame.run_llm:
logger.debug(
f"{self}: LLMMessagesAppendFrame received before session ready; "
"deferring response until the session is initialized"
)
self._run_llm_when_api_session_ready = True
return
appended_any = False
for message in frame.messages:
item = self._message_to_conversation_item(message)
if item is None:
continue
evt = events.ConversationItemCreateEvent(item=item)
self._messages_added_manually[evt.item.id] = True
await self.send_client_event(evt)
appended_any = True
if frame.run_llm and appended_any:
await self._send_manual_response_create()
async def _handle_context(self, context: LLMContext):
if not self._handled_initial_context:
if context is None:
logger.warning(
f"{self}: received initial context trigger before context was set"
)
return
self._handled_initial_context = True
self._context = context
await self._create_response()
else:
self._context = context
await self._process_completed_function_calls(send_new_results=True)
async def _send_user_audio(self, frame):
if self._user_is_muted:
return
await super()._send_user_audio(frame)
def _message_to_conversation_item(
self, message: dict[str, Any]
) -> events.ConversationItem | None:
if not isinstance(message, dict):
logger.warning(
f"{self}: skipping unsupported appended message payload {message!r}"
)
return None
role = message.get("role")
if role not in {"user", "system", "developer"}:
logger.warning(
f"{self}: skipping unsupported appended message role {role!r}"
)
return None
text = self._extract_text_content(message.get("content"))
if not text:
logger.warning(
f"{self}: skipping appended message with unsupported content {message!r}"
)
return None
item_role = "system" if role in {"system", "developer"} else "user"
return events.ConversationItem(
type="message",
role=item_role,
content=[events.ItemContent(type="input_text", text=text)],
)
@staticmethod
def _extract_text_content(content: Any) -> str | None:
if isinstance(content, str):
return content
if isinstance(content, list):
parts: list[str] = []
for part in content:
if not isinstance(part, dict):
return None
if part.get("type") != "text":
return None
text = part.get("text")
if not isinstance(text, str):
return None
parts.append(text)
return "\n".join(parts) if parts else None
return None
async def _send_manual_response_create(self):
await self.push_frame(LLMFullResponseStartFrame())
await self.start_processing_metrics()
await self.start_ttfb_metrics()
await self.send_client_event(
events.ResponseCreateEvent(
response=events.ResponseProperties(
output_modalities=self._get_enabled_modalities()
)
)
)
async def _run_pending_function_calls(self):
if not self._deferred_function_calls:
return
function_calls = self._deferred_function_calls
self._deferred_function_calls = []
logger.debug(
f"{self}: executing {len(function_calls)} deferred function call(s) "
"after bot turn ended"
)
await self.run_function_calls(function_calls)
async def _handle_evt_function_call_arguments_done(self, evt):
try:
args = json.loads(evt.arguments)
function_call_item = self._pending_function_calls.get(evt.call_id)
if function_call_item:
del self._pending_function_calls[evt.call_id]
function_calls = [
FunctionCallFromLLM(
context=self._context,
tool_call_id=evt.call_id,
function_name=function_call_item.name,
arguments=args,
)
]
if self._bot_is_speaking:
self._deferred_function_calls.extend(function_calls)
logger.debug(
f"{self}: deferring function call {function_call_item.name} "
"until bot stops speaking"
)
else:
await self.run_function_calls(function_calls)
logger.debug(f"Processed function call: {function_call_item.name}")
else:
logger.warning(
f"No tracked function call found for call_id: {evt.call_id}"
)
except Exception as e:
logger.error(f"Failed to process function call arguments: {e}")
async def handle_evt_input_audio_transcription_completed(self, evt):
await self._call_event_handler(
"on_conversation_item_updated", evt.item_id, None
)
await self.broadcast_frame(
TranscriptionFrame,
text=evt.transcript,
user_id="",
timestamp=time_now_iso8601(),
result=evt,
finalized=True,
)
await self._handle_user_transcription(evt.transcript, True, Language.EN)

View file

@ -504,10 +504,16 @@ async def _run_pipeline(
embeddings_api_key = None
embeddings_model = None
embeddings_base_url = None
embeddings_provider = None
embeddings_endpoint = None
embeddings_api_version = None
if user_config and user_config.embeddings:
embeddings_api_key = user_config.embeddings.api_key
embeddings_model = user_config.embeddings.model
embeddings_provider = getattr(user_config.embeddings, "provider", None)
embeddings_base_url = getattr(user_config.embeddings, "base_url", None)
embeddings_endpoint = getattr(user_config.embeddings, "endpoint", None)
embeddings_api_version = getattr(user_config.embeddings, "api_version", None)
# Check if the workflow has any active recordings so the engine can
# include recording response mode instructions in all node prompts.
@ -532,6 +538,9 @@ async def _run_pipeline(
embeddings_api_key=embeddings_api_key,
embeddings_model=embeddings_model,
embeddings_base_url=embeddings_base_url,
embeddings_provider=embeddings_provider,
embeddings_endpoint=embeddings_endpoint,
embeddings_api_version=embeddings_api_version,
has_recordings=has_recordings,
context_compaction_enabled=context_compaction_enabled,
)

View file

@ -11,6 +11,8 @@ from api.utils.url_security import validate_user_configured_service_url
from pipecat.services.assemblyai.stt import AssemblyAISTTService, AssemblyAISTTSettings
from pipecat.services.aws.llm import AWSBedrockLLMService, AWSBedrockLLMSettings
from pipecat.services.azure.llm import AzureLLMService, AzureLLMSettings
from pipecat.services.azure.stt import AzureSTTService, AzureSTTSettings
from pipecat.services.azure.tts import AzureTTSService, AzureTTSSettings
from pipecat.services.cartesia.stt import CartesiaSTTService
from pipecat.services.cartesia.tts import (
CartesiaTTSService,
@ -246,6 +248,22 @@ def create_stt_service(
),
sample_rate=audio_config.transport_in_sample_rate,
)
elif user_config.stt.provider == ServiceProviders.AZURE_SPEECH.value:
from pipecat.transcriptions.language import Language as PipecatLanguage
language_code = getattr(user_config.stt, "language", None) or "en-US"
region = getattr(user_config.stt, "region", None) or "eastus"
# Try to map BCP-47 string to pipecat Language enum; fall back to string
try:
pipecat_language = PipecatLanguage(language_code)
except ValueError:
pipecat_language = PipecatLanguage.EN_US
return AzureSTTService(
api_key=user_config.stt.api_key,
region=region,
settings=AzureSTTSettings(language=pipecat_language),
sample_rate=audio_config.transport_in_sample_rate,
)
else:
raise HTTPException(
status_code=400, detail=f"Invalid STT provider {user_config.stt.provider}"
@ -492,6 +510,27 @@ def create_tts_service(user_config, audio_config: "AudioConfig"):
skip_aggregator_types=["recording_router", "recording"],
silence_time_s=1.0,
)
elif user_config.tts.provider == ServiceProviders.AZURE_SPEECH.value:
region = getattr(user_config.tts, "region", None) or "eastus"
voice = getattr(user_config.tts, "voice", None) or "en-US-AriaNeural"
language = getattr(user_config.tts, "language", None) or "en-US"
speed = getattr(user_config.tts, "speed", None) or 1.0
# Map speed multiplier (0.52.0) to Azure SSML rate string (e.g. "1.25")
rate = str(speed) if speed != 1.0 else None
settings_kwargs: dict = {
"voice": voice,
"language": language,
}
if rate:
settings_kwargs["rate"] = rate
return AzureTTSService(
api_key=user_config.tts.api_key,
region=region,
settings=AzureTTSSettings(**settings_kwargs),
text_filters=[xml_function_tag_filter],
skip_aggregator_types=["recording_router", "recording"],
silence_time_s=1.0,
)
else:
raise HTTPException(
status_code=400, detail=f"Invalid TTS provider {user_config.tts.provider}"
@ -724,6 +763,44 @@ def create_realtime_llm_service(user_config, audio_config: "AudioConfig"):
location=location,
settings=DograhGeminiLiveVertexLLMService.Settings(**settings_kwargs),
)
elif provider == ServiceProviders.AZURE_REALTIME.value:
from api.services.pipecat.realtime.azure_realtime import (
DograhAzureRealtimeLLMService,
)
from pipecat.services.openai.realtime.events import (
AudioConfiguration,
AudioInput,
AudioOutput,
InputAudioTranscription,
SessionProperties,
)
endpoint = getattr(realtime_config, "endpoint", None) or ""
api_version = getattr(realtime_config, "api_version", None) or "2025-04-01-preview"
# Construct the Azure Realtime WebSocket URL
# https://<resource>.openai.azure.com/openai/realtime?api-version=<ver>&deployment=<model>
base_host = endpoint.rstrip("/").replace("https://", "").replace("http://", "")
wss_url = (
f"wss://{base_host}/openai/realtime"
f"?api-version={api_version}&deployment={model}"
)
return DograhAzureRealtimeLLMService(
api_key=api_key,
base_url=wss_url,
settings=DograhAzureRealtimeLLMService.Settings(
model=model,
session_properties=SessionProperties(
audio=AudioConfiguration(
input=AudioInput(
transcription=InputAudioTranscription(),
),
output=AudioOutput(
voice=voice or "alloy",
),
),
),
),
)
else:
raise HTTPException(
status_code=400, detail=f"Invalid realtime LLM provider {provider}"