From 7ba95c0fbe16cbe7610329a13ed1fc835712be8b Mon Sep 17 00:00:00 2001 From: Vishal Dhateria Date: Tue, 2 Jun 2026 12:50:00 +0530 Subject: [PATCH] feat: add Azure AI multi-provider support (TTS, STT, Embeddings, Realtime) (#381) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * feat: add Azure AI multi-provider support (TTS, STT, Embeddings, Realtime) Enables Azure AI services across all model layers so users with Azure credits can consolidate billing on a single provider. - Voice (TTS): AzureSpeechTTSConfiguration via azure_speech provider - Transcriber (STT): AzureSpeechSTTConfiguration via azure_speech provider - Embedding: AzureOpenAIEmbeddingsConfiguration via azure provider - Realtime: AzureRealtimeLLMConfiguration via azure_realtime provider New files: - api/services/pipecat/realtime/azure_realtime.py - api/services/gen_ai/embedding/azure_openai_service.py - api/tests/test_azure_speech_service_factory.py The UI picks up all four providers automatically from the schema — no frontend changes required. Co-Authored-By: Claude Sonnet 4.6 * fix: add validation for URL and params --------- Co-authored-by: Vishal Dhateria Co-authored-by: Claude Sonnet 4.6 Co-authored-by: Abhishek Kumar --- api/routes/knowledge_base.py | 41 ++- api/services/configuration/check_validity.py | 8 + .../configuration/options/__init__.py | 20 ++ api/services/configuration/options/azure.py | 125 +++++++++ api/services/configuration/registry.py | 156 ++++++++++- api/services/gen_ai/__init__.py | 4 + api/services/gen_ai/embedding/__init__.py | 6 + .../gen_ai/embedding/azure_openai_service.py | 131 ++++++++++ .../pipecat/realtime/azure_realtime.py | 242 ++++++++++++++++++ api/services/pipecat/run_pipeline.py | 9 + api/services/pipecat/service_factory.py | 91 +++++++ api/services/workflow/pipecat_engine.py | 9 + api/services/workflow/tools/knowledge_base.py | 46 +++- api/tasks/knowledge_base_processing.py | 41 ++- .../test_azure_speech_service_factory.py | 182 +++++++++++++ 15 files changed, 1082 insertions(+), 29 deletions(-) create mode 100644 api/services/configuration/options/azure.py create mode 100644 api/services/gen_ai/embedding/azure_openai_service.py create mode 100644 api/services/pipecat/realtime/azure_realtime.py create mode 100644 api/tests/test_azure_speech_service_factory.py diff --git a/api/routes/knowledge_base.py b/api/routes/knowledge_base.py index 95f64b8..5bf4b0a 100644 --- a/api/routes/knowledge_base.py +++ b/api/routes/knowledge_base.py @@ -369,26 +369,47 @@ async def search_chunks( try: # Import here to avoid circular dependency - from api.services.gen_ai import OpenAIEmbeddingService + from api.services.configuration.registry import ServiceProviders + from api.services.gen_ai import ( + AzureOpenAIEmbeddingService, + OpenAIEmbeddingService, + ) # Try to get user's embeddings configuration user_config = await db_client.get_user_configurations(user.id) embeddings_api_key = None embeddings_model = None + embeddings_provider = None + embeddings_endpoint = None + embeddings_api_version = None if user_config.embeddings: embeddings_api_key = user_config.embeddings.api_key embeddings_model = user_config.embeddings.model + embeddings_provider = getattr(user_config.embeddings, "provider", None) + embeddings_endpoint = getattr(user_config.embeddings, "endpoint", None) + embeddings_api_version = getattr( + user_config.embeddings, "api_version", None + ) - # Initialize embedding service with user config or fallback to env - embedding_service = OpenAIEmbeddingService( - db_client=db_client, - api_key=embeddings_api_key, - model_id=embeddings_model or "text-embedding-3-small", - base_url=getattr(user_config.embeddings, "base_url", None) - if user_config.embeddings - else None, - ) + # Initialize embedding service based on provider + if embeddings_provider == ServiceProviders.AZURE.value and embeddings_endpoint: + embedding_service = AzureOpenAIEmbeddingService( + db_client=db_client, + api_key=embeddings_api_key, + endpoint=embeddings_endpoint, + model_id=embeddings_model or "text-embedding-3-small", + api_version=embeddings_api_version or "2024-02-15-preview", + ) + else: + embedding_service = OpenAIEmbeddingService( + db_client=db_client, + api_key=embeddings_api_key, + model_id=embeddings_model or "text-embedding-3-small", + base_url=getattr(user_config.embeddings, "base_url", None) + if user_config.embeddings + else None, + ) # Perform search results = await embedding_service.search_similar_chunks( diff --git a/api/services/configuration/check_validity.py b/api/services/configuration/check_validity.py index 3a76147..0e4da86 100644 --- a/api/services/configuration/check_validity.py +++ b/api/services/configuration/check_validity.py @@ -41,6 +41,7 @@ class UserConfigurationValidator: ServiceProviders.ELEVENLABS.value: self._validate_elevenlabs_api_key, ServiceProviders.GOOGLE.value: self._check_google_api_key, ServiceProviders.AZURE.value: self._check_azure_api_key, + ServiceProviders.AZURE_SPEECH.value: self._check_azure_speech_api_key, ServiceProviders.CARTESIA.value: self._check_cartesia_api_key, ServiceProviders.DOGRAH.value: self._check_dograh_api_key, ServiceProviders.SARVAM.value: self._check_sarvam_api_key, @@ -54,6 +55,7 @@ class UserConfigurationValidator: ServiceProviders.ULTRAVOX_REALTIME.value: self._check_ultravox_realtime_api_key, ServiceProviders.GOOGLE_REALTIME.value: self._check_google_api_key, ServiceProviders.GOOGLE_VERTEX_REALTIME.value: self._check_google_vertex_realtime_api_key, + ServiceProviders.AZURE_REALTIME.value: self._check_azure_realtime_api_key, ServiceProviders.ASSEMBLYAI.value: self._check_assemblyai_api_key, ServiceProviders.GLADIA.value: self._check_gladia_api_key, ServiceProviders.RIME.value: self._check_rime_api_key, @@ -313,6 +315,12 @@ class UserConfigurationValidator: def _check_azure_api_key(self, model: str, api_key: str) -> bool: return True + def _check_azure_speech_api_key(self, model: str, api_key: str) -> bool: + return True + + def _check_azure_realtime_api_key(self, model: str, api_key: str) -> bool: + return True + def _check_cartesia_api_key(self, model: str, api_key: str) -> bool: return True diff --git a/api/services/configuration/options/__init__.py b/api/services/configuration/options/__init__.py index acc088c..1e3294a 100644 --- a/api/services/configuration/options/__init__.py +++ b/api/services/configuration/options/__init__.py @@ -1,3 +1,14 @@ +from .azure import ( + AZURE_EMBEDDING_MODELS, + AZURE_MODELS, + AZURE_REALTIME_API_VERSIONS, + AZURE_REALTIME_MODELS, + AZURE_REALTIME_VOICES, + AZURE_SPEECH_REGIONS, + AZURE_SPEECH_STT_LANGUAGES, + AZURE_SPEECH_TTS_LANGUAGES, + AZURE_SPEECH_TTS_VOICES, +) from .deepgram import DEEPGRAM_LANGUAGES, DEEPGRAM_STT_MODELS from .gladia import GLADIA_STT_LANGUAGES, GLADIA_STT_MODELS from .google import ( @@ -27,6 +38,15 @@ from .sarvam import ( from .speechmatics import SPEECHMATICS_STT_LANGUAGES __all__ = [ + "AZURE_EMBEDDING_MODELS", + "AZURE_MODELS", + "AZURE_REALTIME_API_VERSIONS", + "AZURE_REALTIME_MODELS", + "AZURE_REALTIME_VOICES", + "AZURE_SPEECH_REGIONS", + "AZURE_SPEECH_STT_LANGUAGES", + "AZURE_SPEECH_TTS_LANGUAGES", + "AZURE_SPEECH_TTS_VOICES", "DEEPGRAM_LANGUAGES", "DEEPGRAM_STT_MODELS", "GLADIA_STT_LANGUAGES", diff --git a/api/services/configuration/options/azure.py b/api/services/configuration/options/azure.py new file mode 100644 index 0000000..d80282b --- /dev/null +++ b/api/services/configuration/options/azure.py @@ -0,0 +1,125 @@ +AZURE_MODELS = ["gpt-4.1-mini"] + +AZURE_REALTIME_MODELS = ["gpt-4o-realtime-preview"] +AZURE_REALTIME_VOICES = [ + "alloy", + "ash", + "ballad", + "coral", + "echo", + "sage", + "shimmer", + "verse", +] +AZURE_REALTIME_API_VERSIONS = [ + "2025-04-01-preview", + "2024-10-01-preview", + "2024-12-17", +] + +AZURE_SPEECH_REGIONS = [ + "eastus", + "eastus2", + "westus", + "westus2", + "westus3", + "centralus", + "northcentralus", + "southcentralus", + "westcentralus", + "westeurope", + "northeurope", + "uksouth", + "ukwest", + "francecentral", + "switzerlandnorth", + "germanywestcentral", + "norwayeast", + "australiaeast", + "eastasia", + "southeastasia", + "japaneast", + "japanwest", + "koreacentral", + "centralindia", + "southindia", + "brazilsouth", +] + +AZURE_SPEECH_TTS_LANGUAGES = [ + "en-US", + "en-GB", + "en-AU", + "en-CA", + "en-IN", + "es-ES", + "es-MX", + "fr-FR", + "fr-CA", + "de-DE", + "it-IT", + "ja-JP", + "ko-KR", + "zh-CN", + "zh-HK", + "zh-TW", + "pt-BR", + "pt-PT", + "ru-RU", + "ar-SA", + "nl-NL", + "pl-PL", + "sv-SE", + "hi-IN", +] + +AZURE_SPEECH_TTS_VOICES = [ + "en-US-AriaNeural", + "en-US-GuyNeural", + "en-US-JennyNeural", + "en-US-DavisNeural", + "en-US-AmberNeural", + "en-US-AnaNeural", + "en-US-AshleyNeural", + "en-US-BrandonNeural", + "en-US-ChristopherNeural", + "en-US-ElizabethNeural", + "en-US-EricNeural", + "en-US-JacobNeural", + "en-US-MichelleNeural", + "en-US-MonicaNeural", + "en-US-NancyNeural", + "en-US-RogerNeural", + "en-US-SaraNeural", + "en-US-SteffanNeural", + "en-US-TonyNeural", +] + +AZURE_SPEECH_STT_LANGUAGES = [ + "en-US", + "en-GB", + "en-AU", + "en-CA", + "en-IN", + "es-ES", + "es-MX", + "fr-FR", + "fr-CA", + "de-DE", + "it-IT", + "ja-JP", + "ko-KR", + "zh-CN", + "pt-BR", + "pt-PT", + "ru-RU", + "ar-SA", + "nl-NL", + "pl-PL", + "hi-IN", +] + +AZURE_EMBEDDING_MODELS = [ + "text-embedding-3-small", + "text-embedding-ada-002", +] diff --git a/api/services/configuration/registry.py b/api/services/configuration/registry.py index a497e3f..f05c5f7 100644 --- a/api/services/configuration/registry.py +++ b/api/services/configuration/registry.py @@ -5,6 +5,15 @@ from typing import Annotated, Dict, Literal, Type, TypeVar, Union from pydantic import BaseModel, ConfigDict, Field, computed_field, field_validator from api.services.configuration.options import ( + AZURE_EMBEDDING_MODELS, + AZURE_MODELS, + AZURE_REALTIME_API_VERSIONS, + AZURE_REALTIME_MODELS, + AZURE_REALTIME_VOICES, + AZURE_SPEECH_REGIONS, + AZURE_SPEECH_STT_LANGUAGES, + AZURE_SPEECH_TTS_LANGUAGES, + AZURE_SPEECH_TTS_VOICES, DEEPGRAM_LANGUAGES, DEEPGRAM_STT_MODELS, GLADIA_STT_LANGUAGES, @@ -52,6 +61,7 @@ class ServiceProviders(str, Enum): ELEVENLABS = "elevenlabs" GOOGLE = "google" AZURE = "azure" + AZURE_SPEECH = "azure_speech" DOGRAH = "dograh" SARVAM = "sarvam" SPEECHMATICS = "speechmatics" @@ -68,6 +78,7 @@ class ServiceProviders(str, Enum): ULTRAVOX_REALTIME = "ultravox_realtime" GOOGLE_REALTIME = "google_realtime" GOOGLE_VERTEX_REALTIME = "google_vertex_realtime" + AZURE_REALTIME = "azure_realtime" class BaseServiceConfiguration(BaseModel): @@ -79,6 +90,7 @@ class BaseServiceConfiguration(BaseModel): ServiceProviders.ELEVENLABS, ServiceProviders.GOOGLE, ServiceProviders.AZURE, + ServiceProviders.AZURE_SPEECH, ServiceProviders.DOGRAH, ServiceProviders.AWS_BEDROCK, ServiceProviders.SPEACHES, @@ -92,6 +104,7 @@ class BaseServiceConfiguration(BaseModel): ServiceProviders.ULTRAVOX_REALTIME, ServiceProviders.GOOGLE_REALTIME, ServiceProviders.GOOGLE_VERTEX_REALTIME, + ServiceProviders.AZURE_REALTIME, ServiceProviders.SARVAM, ] api_key: str | list[str] @@ -242,6 +255,16 @@ SPEACHES_PROVIDER_MODEL_CONFIG = provider_model_config( ), provider_docs_url="https://github.com/speaches-ai/speaches", ) +AZURE_SPEECH_PROVIDER_MODEL_CONFIG = provider_model_config( + "Azure Speech Services", + description="Azure Cognitive Services Speech — TTS and STT via the Azure Speech SDK.", + provider_docs_url="https://learn.microsoft.com/en-us/azure/ai-services/speech-service/", +) +AZURE_REALTIME_PROVIDER_MODEL_CONFIG = provider_model_config( + "Azure OpenAI Realtime", + description="Azure OpenAI Realtime API — low-latency speech-to-speech conversations.", + provider_docs_url="https://learn.microsoft.com/en-us/azure/ai-services/openai/how-to/realtime-audio-quickstart", +) OPENAI_MODELS = [ "gpt-4.1", @@ -272,7 +295,6 @@ OPENROUTER_MODELS = [ "meta-llama/llama-3.3-70b-instruct", "deepseek/deepseek-chat-v3-0324", ] -AZURE_MODELS = ["gpt-4.1-mini"] DOGRAH_LLM_MODELS = ["default", "accurate", "fast", "lite", "zen"] AWS_BEDROCK_MODELS = [ "us.amazon.nova-pro-v1:0", @@ -666,12 +688,45 @@ class GoogleVertexRealtimeLLMConfiguration(BaseLLMConfiguration): ) +@register_service(ServiceType.REALTIME) +class AzureRealtimeLLMConfiguration(BaseLLMConfiguration): + model_config = AZURE_REALTIME_PROVIDER_MODEL_CONFIG + provider: Literal[ServiceProviders.AZURE_REALTIME] = ServiceProviders.AZURE_REALTIME + model: str = Field( + default="gpt-4o-realtime-preview", + description="Azure OpenAI realtime deployment name.", + json_schema_extra={ + "examples": AZURE_REALTIME_MODELS, + "allow_custom_input": True, + }, + ) + endpoint: str = Field( + description="Azure OpenAI resource endpoint (e.g. https://.openai.azure.com).", + ) + voice: str = Field( + default="alloy", + description="Voice the model speaks in.", + json_schema_extra={ + "examples": AZURE_REALTIME_VOICES, + "allow_custom_input": True, + }, + ) + api_version: str = Field( + default="2025-04-01-preview", + description="Azure OpenAI API version.", + json_schema_extra={ + "examples": AZURE_REALTIME_API_VERSIONS, + }, + ) + + REALTIME_PROVIDERS = { ServiceProviders.OPENAI_REALTIME.value, ServiceProviders.GROK_REALTIME.value, ServiceProviders.ULTRAVOX_REALTIME.value, ServiceProviders.GOOGLE_REALTIME.value, ServiceProviders.GOOGLE_VERTEX_REALTIME.value, + ServiceProviders.AZURE_REALTIME.value, } @@ -699,6 +754,7 @@ RealtimeConfig = Annotated[ UltravoxRealtimeLLMConfiguration, GoogleRealtimeLLMConfiguration, GoogleVertexRealtimeLLMConfiguration, + AzureRealtimeLLMConfiguration, ], Field(discriminator="provider"), ] @@ -1024,6 +1080,46 @@ class MiniMaxTTSConfiguration(BaseTTSConfiguration): ) +@register_tts +class AzureSpeechTTSConfiguration(BaseTTSConfiguration): + model_config = AZURE_SPEECH_PROVIDER_MODEL_CONFIG + provider: Literal[ServiceProviders.AZURE_SPEECH] = ServiceProviders.AZURE_SPEECH + model: str = Field( + default="neural", + description="Azure Speech synthesis engine (neural voices only).", + json_schema_extra={"examples": ["neural"]}, + ) + region: str = Field( + default="eastus", + description="Azure region for Speech Services (e.g. 'eastus', 'westeurope').", + json_schema_extra={ + "examples": AZURE_SPEECH_REGIONS, + }, + ) + voice: str = Field( + default="en-US-AriaNeural", + description="Azure Neural voice name (e.g. 'en-US-AriaNeural').", + json_schema_extra={ + "examples": AZURE_SPEECH_TTS_VOICES, + "allow_custom_input": True, + }, + ) + language: str = Field( + default="en-US", + description="BCP-47 language code for synthesis.", + json_schema_extra={ + "examples": AZURE_SPEECH_TTS_LANGUAGES, + "allow_custom_input": True, + }, + ) + speed: float = Field( + default=1.0, + ge=0.5, + le=2.0, + description="Speech speed multiplier (0.5 to 2.0).", + ) + + TTSConfig = Annotated[ Union[ DeepgramTTSConfiguration, @@ -1037,6 +1133,7 @@ TTSConfig = Annotated[ RimeTTSConfiguration, SpeachesTTSConfiguration, MiniMaxTTSConfiguration, + AzureSpeechTTSConfiguration, ], Field(discriminator="provider"), ] @@ -1273,6 +1370,32 @@ class GladiaSTTConfiguration(BaseSTTConfiguration): ) +@register_stt +class AzureSpeechSTTConfiguration(BaseSTTConfiguration): + model_config = AZURE_SPEECH_PROVIDER_MODEL_CONFIG + provider: Literal[ServiceProviders.AZURE_SPEECH] = ServiceProviders.AZURE_SPEECH + model: str = Field( + default="latest_long", + description="Azure Speech recognition model (use 'latest_long' for continuous recognition).", + json_schema_extra={"examples": ["latest_long", "latest_short"]}, + ) + region: str = Field( + default="eastus", + description="Azure region for Speech Services (e.g. 'eastus', 'westeurope').", + json_schema_extra={ + "examples": AZURE_SPEECH_REGIONS, + }, + ) + language: str = Field( + default="en-US", + description="BCP-47 language code for recognition.", + json_schema_extra={ + "examples": AZURE_SPEECH_STT_LANGUAGES, + "allow_custom_input": True, + }, + ) + + STTConfig = Annotated[ Union[ DeepgramSTTConfiguration, @@ -1285,6 +1408,7 @@ STTConfig = Annotated[ SpeachesSTTConfiguration, AssemblyAISTTConfiguration, GladiaSTTConfiguration, + AzureSpeechSTTConfiguration, ], Field(discriminator="provider"), ] @@ -1324,8 +1448,36 @@ class OpenRouterEmbeddingsConfiguration(BaseEmbeddingsConfiguration): ) +@register_embeddings +class AzureOpenAIEmbeddingsConfiguration(BaseEmbeddingsConfiguration): + model_config = AZURE_OPENAI_PROVIDER_MODEL_CONFIG + provider: Literal[ServiceProviders.AZURE] = ServiceProviders.AZURE + model: str = Field( + default="text-embedding-3-small", + description=( + "Azure OpenAI embedding deployment name. The deployment must return " + "1536-dimensional embeddings." + ), + json_schema_extra={ + "examples": AZURE_EMBEDDING_MODELS, + "allow_custom_input": True, + }, + ) + endpoint: str = Field( + description="Azure OpenAI resource endpoint (e.g. https://.openai.azure.com).", + ) + api_version: str = Field( + default="2024-02-15-preview", + description="Azure OpenAI API version for embeddings.", + ) + + EmbeddingsConfig = Annotated[ - Union[OpenAIEmbeddingsConfiguration, OpenRouterEmbeddingsConfiguration], + Union[ + OpenAIEmbeddingsConfiguration, + OpenRouterEmbeddingsConfiguration, + AzureOpenAIEmbeddingsConfiguration, + ], Field(discriminator="provider"), ] diff --git a/api/services/gen_ai/__init__.py b/api/services/gen_ai/__init__.py index 4d5b8fe..ec9ba17 100644 --- a/api/services/gen_ai/__init__.py +++ b/api/services/gen_ai/__init__.py @@ -1,6 +1,8 @@ """Generative AI services for embeddings and document processing.""" from .embedding import ( + AzureEmbeddingAPIKeyNotConfiguredError, + AzureOpenAIEmbeddingService, BaseEmbeddingService, EmbeddingAPIKeyNotConfiguredError, OpenAIEmbeddingService, @@ -8,6 +10,8 @@ from .embedding import ( from .json_parser import parse_llm_json __all__ = [ + "AzureEmbeddingAPIKeyNotConfiguredError", + "AzureOpenAIEmbeddingService", "BaseEmbeddingService", "EmbeddingAPIKeyNotConfiguredError", "OpenAIEmbeddingService", diff --git a/api/services/gen_ai/embedding/__init__.py b/api/services/gen_ai/embedding/__init__.py index f6a4f18..40a04bd 100644 --- a/api/services/gen_ai/embedding/__init__.py +++ b/api/services/gen_ai/embedding/__init__.py @@ -1,9 +1,15 @@ """Embedding services for document processing and retrieval.""" +from .azure_openai_service import ( + AzureEmbeddingAPIKeyNotConfiguredError, + AzureOpenAIEmbeddingService, +) from .base import BaseEmbeddingService from .openai_service import EmbeddingAPIKeyNotConfiguredError, OpenAIEmbeddingService __all__ = [ + "AzureEmbeddingAPIKeyNotConfiguredError", + "AzureOpenAIEmbeddingService", "BaseEmbeddingService", "EmbeddingAPIKeyNotConfiguredError", "OpenAIEmbeddingService", diff --git a/api/services/gen_ai/embedding/azure_openai_service.py b/api/services/gen_ai/embedding/azure_openai_service.py new file mode 100644 index 0000000..dddb785 --- /dev/null +++ b/api/services/gen_ai/embedding/azure_openai_service.py @@ -0,0 +1,131 @@ +"""Azure OpenAI embedding service. + +Uses the Azure OpenAI REST API for text embeddings, compatible with +1536-dimensional embedding deployments such as text-embedding-3-small and +text-embedding-ada-002. +""" + +from typing import Any, Dict, List, Optional + +from loguru import logger +from openai import AsyncAzureOpenAI + +from api.db.db_client import DBClient +from api.utils.url_security import validate_user_configured_service_url + +from .base import BaseEmbeddingService + +DEFAULT_MODEL_ID = "text-embedding-3-small" +EMBEDDING_DIMENSION = 1536 + + +class AzureEmbeddingAPIKeyNotConfiguredError(Exception): + """Raised when Azure OpenAI credentials are not configured for embeddings.""" + + def __init__(self): + super().__init__( + "Azure OpenAI endpoint or API key not configured. Please set your " + "endpoint and API key in Model Configurations > Embedding to use " + "document processing." + ) + + +class AzureOpenAIEmbeddingService(BaseEmbeddingService): + """Embedding service using Azure OpenAI text-embedding deployments.""" + + def __init__( + self, + db_client: DBClient, + api_key: Optional[str] = None, + endpoint: Optional[str] = None, + model_id: str = DEFAULT_MODEL_ID, + api_version: str = "2024-02-15-preview", + ): + """Initialize the Azure OpenAI embedding service. + + Args: + db_client: Database client for vector similarity search. + api_key: Azure OpenAI API key. + endpoint: Azure OpenAI resource endpoint (e.g. https://.openai.azure.com). + model_id: Deployment name, used as both the deployment and model identifier. + api_version: Azure OpenAI API version. + """ + self.db = db_client + self.model_id = model_id + + self._configured = bool(api_key and endpoint) + if self._configured: + validate_user_configured_service_url(endpoint, field_name="endpoint") + self.client = AsyncAzureOpenAI( + api_key=api_key, + azure_endpoint=endpoint, + api_version=api_version, + ) + logger.info( + f"Azure OpenAI embedding service initialized with deployment: {model_id}" + ) + else: + self.client = None + logger.warning( + "Azure OpenAI embedding service initialized without credentials. " + "Operations will fail until endpoint and API key are configured." + ) + + def get_model_id(self) -> str: + return self.model_id + + def get_embedding_dimension(self) -> int: + return EMBEDDING_DIMENSION + + def _ensure_configured(self): + if not self._configured or self.client is None: + raise AzureEmbeddingAPIKeyNotConfiguredError() + + async def embed_texts(self, texts: List[str]) -> List[List[float]]: + """Embed a batch of texts using Azure OpenAI API.""" + self._ensure_configured() + try: + response = await self.client.embeddings.create( + input=texts, + model=self.model_id, + ) + embeddings = [item.embedding for item in response.data] + self._validate_embedding_dimensions(embeddings) + return embeddings + except Exception as e: + logger.error(f"Error generating Azure OpenAI embeddings: {e}") + raise + + def _validate_embedding_dimensions(self, embeddings: List[List[float]]) -> None: + for embedding in embeddings: + if len(embedding) != EMBEDDING_DIMENSION: + raise ValueError( + "Azure OpenAI embedding deployment " + f"{self.model_id!r} returned {len(embedding)} dimensions; " + "Dograh knowledge base storage currently supports " + f"{EMBEDDING_DIMENSION}-dimensional embeddings." + ) + + async def embed_query(self, query: str) -> List[float]: + """Embed a single query text using Azure OpenAI API.""" + self._ensure_configured() + embeddings = await self.embed_texts([query]) + return embeddings[0] + + async def search_similar_chunks( + self, + query: str, + organization_id: int, + limit: int = 5, + document_uuids: Optional[List[str]] = None, + ) -> List[Dict[str, Any]]: + """Search for similar chunks using vector similarity.""" + self._ensure_configured() + query_embedding = await self.embed_query(query) + return await self.db.search_similar_chunks( + query_embedding=query_embedding, + organization_id=organization_id, + limit=limit, + document_uuids=document_uuids, + embedding_model=self.model_id, + ) diff --git a/api/services/pipecat/realtime/azure_realtime.py b/api/services/pipecat/realtime/azure_realtime.py new file mode 100644 index 0000000..0cf025b --- /dev/null +++ b/api/services/pipecat/realtime/azure_realtime.py @@ -0,0 +1,242 @@ +"""Dograh subclass of pipecat's Azure OpenAI Realtime LLM service. + +Layers Dograh engine integration quirks (mute gating, TTSSpeakFrame greeting +trigger, LLMMessagesAppendFrame handling, deferred tool calls) onto pipecat's +AzureRealtimeLLMService, mirroring what DograhOpenAIRealtimeLLMService does +for the standard OpenAI Realtime endpoint. +""" + +import json +from typing import Any + +from loguru import logger + +from pipecat.frames.frames import ( + BotStartedSpeakingFrame, + BotStoppedSpeakingFrame, + Frame, + LLMFullResponseStartFrame, + LLMMessagesAppendFrame, + TranscriptionFrame, + TTSSpeakFrame, + UserMuteStartedFrame, + UserMuteStoppedFrame, +) +from pipecat.processors.aggregators.llm_context import LLMContext +from pipecat.processors.frame_processor import FrameDirection +from pipecat.services.azure.realtime.llm import AzureRealtimeLLMService +from pipecat.services.llm_service import FunctionCallFromLLM +from pipecat.services.openai.realtime import events +from pipecat.transcriptions.language import Language +from pipecat.utils.time import time_now_iso8601 + + +class DograhAzureRealtimeLLMService(AzureRealtimeLLMService): + """Azure OpenAI Realtime with Dograh engine integration quirks. + + Extends AzureRealtimeLLMService with the same Dograh-specific behaviours + added to DograhOpenAIRealtimeLLMService: + - User-mute audio gating + - TTSSpeakFrame as initial-response trigger + - One-off LLMMessagesAppendFrame handling + - Deferred tool calls until bot finishes speaking + - finalized=True on TranscriptionFrame for consistency + """ + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self._user_is_muted: bool = False + self._handled_initial_context: bool = False + self._bot_is_speaking: bool = False + self._deferred_function_calls: list[FunctionCallFromLLM] = [] + + async def process_frame(self, frame: Frame, direction: FrameDirection): + if isinstance(frame, UserMuteStartedFrame): + self._user_is_muted = True + await self.push_frame(frame, direction) + return + if isinstance(frame, UserMuteStoppedFrame): + self._user_is_muted = False + await self.push_frame(frame, direction) + return + if isinstance(frame, TTSSpeakFrame): + if not self._handled_initial_context: + await self._handle_context(self._context) + else: + logger.warning( + f"{self}: TTSSpeakFrame after initial context already handled — " + "Azure Realtime owns audio generation, ignoring" + ) + return + if isinstance(frame, LLMMessagesAppendFrame): + await self._handle_messages_append(frame) + return + if isinstance(frame, BotStartedSpeakingFrame): + self._bot_is_speaking = True + elif isinstance(frame, BotStoppedSpeakingFrame): + self._bot_is_speaking = False + await self._run_pending_function_calls() + await super().process_frame(frame, direction) + + async def _handle_messages_append(self, frame: LLMMessagesAppendFrame): + if self._disconnecting: + return + + if not self._api_session_ready: + if frame.run_llm: + logger.debug( + f"{self}: LLMMessagesAppendFrame received before session ready; " + "deferring response until the session is initialized" + ) + self._run_llm_when_api_session_ready = True + return + + appended_any = False + for message in frame.messages: + item = self._message_to_conversation_item(message) + if item is None: + continue + evt = events.ConversationItemCreateEvent(item=item) + self._messages_added_manually[evt.item.id] = True + await self.send_client_event(evt) + appended_any = True + + if frame.run_llm and appended_any: + await self._send_manual_response_create() + + async def _handle_context(self, context: LLMContext): + if not self._handled_initial_context: + if context is None: + logger.warning( + f"{self}: received initial context trigger before context was set" + ) + return + self._handled_initial_context = True + self._context = context + await self._create_response() + else: + self._context = context + await self._process_completed_function_calls(send_new_results=True) + + async def _send_user_audio(self, frame): + if self._user_is_muted: + return + await super()._send_user_audio(frame) + + def _message_to_conversation_item( + self, message: dict[str, Any] + ) -> events.ConversationItem | None: + if not isinstance(message, dict): + logger.warning( + f"{self}: skipping unsupported appended message payload {message!r}" + ) + return None + + role = message.get("role") + if role not in {"user", "system", "developer"}: + logger.warning( + f"{self}: skipping unsupported appended message role {role!r}" + ) + return None + + text = self._extract_text_content(message.get("content")) + if not text: + logger.warning( + f"{self}: skipping appended message with unsupported content {message!r}" + ) + return None + + item_role = "system" if role in {"system", "developer"} else "user" + return events.ConversationItem( + type="message", + role=item_role, + content=[events.ItemContent(type="input_text", text=text)], + ) + + @staticmethod + def _extract_text_content(content: Any) -> str | None: + if isinstance(content, str): + return content + if isinstance(content, list): + parts: list[str] = [] + for part in content: + if not isinstance(part, dict): + return None + if part.get("type") != "text": + return None + text = part.get("text") + if not isinstance(text, str): + return None + parts.append(text) + return "\n".join(parts) if parts else None + return None + + async def _send_manual_response_create(self): + await self.push_frame(LLMFullResponseStartFrame()) + await self.start_processing_metrics() + await self.start_ttfb_metrics() + await self.send_client_event( + events.ResponseCreateEvent( + response=events.ResponseProperties( + output_modalities=self._get_enabled_modalities() + ) + ) + ) + + async def _run_pending_function_calls(self): + if not self._deferred_function_calls: + return + function_calls = self._deferred_function_calls + self._deferred_function_calls = [] + logger.debug( + f"{self}: executing {len(function_calls)} deferred function call(s) " + "after bot turn ended" + ) + await self.run_function_calls(function_calls) + + async def _handle_evt_function_call_arguments_done(self, evt): + try: + args = json.loads(evt.arguments) + + function_call_item = self._pending_function_calls.get(evt.call_id) + if function_call_item: + del self._pending_function_calls[evt.call_id] + + function_calls = [ + FunctionCallFromLLM( + context=self._context, + tool_call_id=evt.call_id, + function_name=function_call_item.name, + arguments=args, + ) + ] + + if self._bot_is_speaking: + self._deferred_function_calls.extend(function_calls) + logger.debug( + f"{self}: deferring function call {function_call_item.name} " + "until bot stops speaking" + ) + else: + await self.run_function_calls(function_calls) + logger.debug(f"Processed function call: {function_call_item.name}") + else: + logger.warning( + f"No tracked function call found for call_id: {evt.call_id}" + ) + except Exception as e: + logger.error(f"Failed to process function call arguments: {e}") + + async def handle_evt_input_audio_transcription_completed(self, evt): + await self._call_event_handler( + "on_conversation_item_updated", evt.item_id, None + ) + await self.broadcast_frame( + TranscriptionFrame, + text=evt.transcript, + user_id="", + timestamp=time_now_iso8601(), + result=evt, + finalized=True, + ) + await self._handle_user_transcription(evt.transcript, True, Language.EN) diff --git a/api/services/pipecat/run_pipeline.py b/api/services/pipecat/run_pipeline.py index be71b86..7ce41d8 100644 --- a/api/services/pipecat/run_pipeline.py +++ b/api/services/pipecat/run_pipeline.py @@ -504,10 +504,16 @@ async def _run_pipeline( embeddings_api_key = None embeddings_model = None embeddings_base_url = None + embeddings_provider = None + embeddings_endpoint = None + embeddings_api_version = None if user_config and user_config.embeddings: embeddings_api_key = user_config.embeddings.api_key embeddings_model = user_config.embeddings.model + embeddings_provider = getattr(user_config.embeddings, "provider", None) embeddings_base_url = getattr(user_config.embeddings, "base_url", None) + embeddings_endpoint = getattr(user_config.embeddings, "endpoint", None) + embeddings_api_version = getattr(user_config.embeddings, "api_version", None) # Check if the workflow has any active recordings so the engine can # include recording response mode instructions in all node prompts. @@ -532,6 +538,9 @@ async def _run_pipeline( embeddings_api_key=embeddings_api_key, embeddings_model=embeddings_model, embeddings_base_url=embeddings_base_url, + embeddings_provider=embeddings_provider, + embeddings_endpoint=embeddings_endpoint, + embeddings_api_version=embeddings_api_version, has_recordings=has_recordings, context_compaction_enabled=context_compaction_enabled, ) diff --git a/api/services/pipecat/service_factory.py b/api/services/pipecat/service_factory.py index 217a593..8ed96e4 100644 --- a/api/services/pipecat/service_factory.py +++ b/api/services/pipecat/service_factory.py @@ -1,4 +1,5 @@ from typing import TYPE_CHECKING +from urllib.parse import urlencode, urlparse, urlunparse import aiohttp from fastapi import HTTPException @@ -11,6 +12,8 @@ from api.utils.url_security import validate_user_configured_service_url from pipecat.services.assemblyai.stt import AssemblyAISTTService, AssemblyAISTTSettings from pipecat.services.aws.llm import AWSBedrockLLMService, AWSBedrockLLMSettings from pipecat.services.azure.llm import AzureLLMService, AzureLLMSettings +from pipecat.services.azure.stt import AzureSTTService, AzureSTTSettings +from pipecat.services.azure.tts import AzureTTSService, AzureTTSSettings from pipecat.services.cartesia.stt import CartesiaSTTService from pipecat.services.cartesia.tts import ( CartesiaTTSService, @@ -262,6 +265,21 @@ def create_stt_service( ), sample_rate=audio_config.transport_in_sample_rate, ) + elif user_config.stt.provider == ServiceProviders.AZURE_SPEECH.value: + from pipecat.transcriptions.language import Language as PipecatLanguage + + language_code = getattr(user_config.stt, "language", None) or "en-US" + region = getattr(user_config.stt, "region", None) or "eastus" + try: + pipecat_language = PipecatLanguage(language_code) + except ValueError: + pipecat_language = language_code + return AzureSTTService( + api_key=user_config.stt.api_key, + region=region, + settings=AzureSTTSettings(language=pipecat_language), + sample_rate=audio_config.transport_in_sample_rate, + ) else: raise HTTPException( status_code=400, detail=f"Invalid STT provider {user_config.stt.provider}" @@ -514,6 +532,27 @@ def create_tts_service(user_config, audio_config: "AudioConfig"): skip_aggregator_types=["recording_router", "recording"], silence_time_s=1.0, ) + elif user_config.tts.provider == ServiceProviders.AZURE_SPEECH.value: + region = getattr(user_config.tts, "region", None) or "eastus" + voice = getattr(user_config.tts, "voice", None) or "en-US-AriaNeural" + language = getattr(user_config.tts, "language", None) or "en-US" + speed = getattr(user_config.tts, "speed", None) or 1.0 + # Map speed multiplier (0.5–2.0) to Azure SSML rate string (e.g. "1.25") + rate = str(speed) if speed != 1.0 else None + settings_kwargs: dict = { + "voice": voice, + "language": language, + } + if rate: + settings_kwargs["rate"] = rate + return AzureTTSService( + api_key=user_config.tts.api_key, + region=region, + settings=AzureTTSSettings(**settings_kwargs), + text_filters=[xml_function_tag_filter], + skip_aggregator_types=["recording_router", "recording"], + silence_time_s=1.0, + ) else: raise HTTPException( status_code=400, detail=f"Invalid TTS provider {user_config.tts.provider}" @@ -754,6 +793,58 @@ def create_realtime_llm_service(user_config, audio_config: "AudioConfig"): location=location, settings=DograhGeminiLiveVertexLLMService.Settings(**settings_kwargs), ) + elif provider == ServiceProviders.AZURE_REALTIME.value: + from api.services.pipecat.realtime.azure_realtime import ( + DograhAzureRealtimeLLMService, + ) + from pipecat.services.openai.realtime.events import ( + AudioConfiguration, + AudioInput, + AudioOutput, + InputAudioTranscription, + SessionProperties, + ) + + endpoint = getattr(realtime_config, "endpoint", None) or "" + if not endpoint: + raise HTTPException( + status_code=400, + detail="Azure Realtime requires an endpoint.", + ) + _validate_runtime_service_url(endpoint, "endpoint") + api_version = ( + getattr(realtime_config, "api_version", None) or "2025-04-01-preview" + ) + # Construct the Azure Realtime WebSocket URL + # https://.openai.azure.com/openai/realtime?api-version=&deployment= + parsed_endpoint = urlparse(endpoint) + wss_url = urlunparse( + ( + "wss", + parsed_endpoint.netloc, + "/openai/realtime", + "", + urlencode({"api-version": api_version, "deployment": model}), + "", + ) + ) + return DograhAzureRealtimeLLMService( + api_key=api_key, + base_url=wss_url, + settings=DograhAzureRealtimeLLMService.Settings( + model=model, + session_properties=SessionProperties( + audio=AudioConfiguration( + input=AudioInput( + transcription=InputAudioTranscription(), + ), + output=AudioOutput( + voice=voice or "alloy", + ), + ), + ), + ), + ) else: raise HTTPException( status_code=400, detail=f"Invalid realtime LLM provider {provider}" diff --git a/api/services/workflow/pipecat_engine.py b/api/services/workflow/pipecat_engine.py index d72c3f4..cea1d21 100644 --- a/api/services/workflow/pipecat_engine.py +++ b/api/services/workflow/pipecat_engine.py @@ -73,6 +73,9 @@ class PipecatEngine: embeddings_api_key: Optional[str] = None, embeddings_model: Optional[str] = None, embeddings_base_url: Optional[str] = None, + embeddings_provider: Optional[str] = None, + embeddings_endpoint: Optional[str] = None, + embeddings_api_version: Optional[str] = None, has_recordings: bool = False, context_compaction_enabled: bool = False, ): @@ -126,6 +129,9 @@ class PipecatEngine: self._embeddings_api_key: Optional[str] = embeddings_api_key self._embeddings_model: Optional[str] = embeddings_model self._embeddings_base_url: Optional[str] = embeddings_base_url + self._embeddings_provider: Optional[str] = embeddings_provider + self._embeddings_endpoint: Optional[str] = embeddings_endpoint + self._embeddings_api_version: Optional[str] = embeddings_api_version # Audio configuration (set via set_audio_config from _run_pipeline) self._audio_config = None @@ -373,6 +379,9 @@ class PipecatEngine: embeddings_api_key=self._embeddings_api_key, embeddings_model=self._embeddings_model, embeddings_base_url=self._embeddings_base_url, + embeddings_provider=self._embeddings_provider, + embeddings_endpoint=self._embeddings_endpoint, + embeddings_api_version=self._embeddings_api_version, tracing_context=self._get_otel_context(), ) diff --git a/api/services/workflow/tools/knowledge_base.py b/api/services/workflow/tools/knowledge_base.py index 6821583..6ce8f8c 100644 --- a/api/services/workflow/tools/knowledge_base.py +++ b/api/services/workflow/tools/knowledge_base.py @@ -13,7 +13,8 @@ from loguru import logger from opentelemetry import trace from api.db import db_client -from api.services.gen_ai import OpenAIEmbeddingService +from api.services.configuration.registry import ServiceProviders +from api.services.gen_ai import AzureOpenAIEmbeddingService, OpenAIEmbeddingService from api.services.pipecat.tracing_config import ensure_tracing @@ -25,6 +26,9 @@ async def retrieve_from_knowledge_base( embeddings_api_key: Optional[str] = None, embeddings_model: Optional[str] = None, embeddings_base_url: Optional[str] = None, + embeddings_provider: Optional[str] = None, + embeddings_endpoint: Optional[str] = None, + embeddings_api_version: Optional[str] = None, tracing_context=None, ) -> Dict[str, Any]: """Retrieve relevant information from the knowledge base using vector similarity search. @@ -68,6 +72,9 @@ async def retrieve_from_knowledge_base( embeddings_api_key, embeddings_model, embeddings_base_url, + embeddings_provider, + embeddings_endpoint, + embeddings_api_version, ) # Create span with parent context @@ -105,6 +112,9 @@ async def retrieve_from_knowledge_base( embeddings_api_key, embeddings_model, embeddings_base_url, + embeddings_provider, + embeddings_endpoint, + embeddings_api_version, ) # Add result metadata to span @@ -179,6 +189,9 @@ async def retrieve_from_knowledge_base( embeddings_api_key, embeddings_model, embeddings_base_url, + embeddings_provider, + embeddings_endpoint, + embeddings_api_version, ) else: # Tracing is disabled - perform retrieval without tracing @@ -189,6 +202,10 @@ async def retrieve_from_knowledge_base( limit, embeddings_api_key, embeddings_model, + embeddings_base_url, + embeddings_provider, + embeddings_endpoint, + embeddings_api_version, ) @@ -200,6 +217,9 @@ async def _perform_retrieval( embeddings_api_key: Optional[str] = None, embeddings_model: Optional[str] = None, embeddings_base_url: Optional[str] = None, + embeddings_provider: Optional[str] = None, + embeddings_endpoint: Optional[str] = None, + embeddings_api_version: Optional[str] = None, ) -> Dict[str, Any]: """Internal function to perform the actual retrieval operation. @@ -240,12 +260,24 @@ async def _perform_retrieval( "Model Configurations > Embedding." ) - embedding_service = OpenAIEmbeddingService( - db_client=db_client, - api_key=embeddings_api_key, - model_id=embeddings_model or "text-embedding-3-small", - base_url=embeddings_base_url, - ) + if ( + embeddings_provider == ServiceProviders.AZURE.value + and embeddings_endpoint + ): + embedding_service = AzureOpenAIEmbeddingService( + db_client=db_client, + api_key=embeddings_api_key, + endpoint=embeddings_endpoint, + model_id=embeddings_model or "text-embedding-3-small", + api_version=embeddings_api_version or "2024-02-15-preview", + ) + else: + embedding_service = OpenAIEmbeddingService( + db_client=db_client, + api_key=embeddings_api_key, + model_id=embeddings_model or "text-embedding-3-small", + base_url=embeddings_base_url, + ) results = await embedding_service.search_similar_chunks( query=query, diff --git a/api/tasks/knowledge_base_processing.py b/api/tasks/knowledge_base_processing.py index 1c891e2..4e94329 100644 --- a/api/tasks/knowledge_base_processing.py +++ b/api/tasks/knowledge_base_processing.py @@ -12,7 +12,8 @@ from loguru import logger from api.db import db_client from api.db.models import KnowledgeBaseChunkModel -from api.services.gen_ai import OpenAIEmbeddingService +from api.services.configuration.registry import ServiceProviders +from api.services.gen_ai import AzureOpenAIEmbeddingService, OpenAIEmbeddingService from api.services.mps_service_key_client import mps_service_key_client from api.services.storage import storage_fs @@ -148,21 +149,32 @@ async def process_knowledge_base_document( ) return - # Chunked mode: fetch user embedding config, embed via OpenAI, persist chunks. + # Chunked mode: fetch user embedding config, embed, and persist chunks. + embeddings_provider = None embeddings_api_key = None embeddings_model = None embeddings_base_url = None + embeddings_endpoint = None + embeddings_api_version = None if document.created_by: user_config = await db_client.get_user_configurations(document.created_by) if user_config.embeddings: + embeddings_provider = getattr(user_config.embeddings, "provider", None) embeddings_api_key = user_config.embeddings.api_key embeddings_model = user_config.embeddings.model embeddings_base_url = getattr(user_config.embeddings, "base_url", None) - logger.info(f"Using user embeddings config: model={embeddings_model}") + embeddings_endpoint = getattr(user_config.embeddings, "endpoint", None) + embeddings_api_version = getattr( + user_config.embeddings, "api_version", None + ) + logger.info( + f"Using user embeddings config: provider={embeddings_provider}, " + f"model={embeddings_model}" + ) if not embeddings_api_key: error_message = ( - "OpenAI API key not configured. Please set your API key in " + "API key not configured. Please set your API key in " "Model Configurations > Embedding to process documents." ) logger.warning(f"Document {document_id}: {error_message}") @@ -171,12 +183,21 @@ async def process_knowledge_base_document( ) return - embedding_service = OpenAIEmbeddingService( - db_client=db_client, - api_key=embeddings_api_key, - model_id=embeddings_model or "text-embedding-3-small", - base_url=embeddings_base_url, - ) + if embeddings_provider == ServiceProviders.AZURE.value and embeddings_endpoint: + embedding_service = AzureOpenAIEmbeddingService( + db_client=db_client, + api_key=embeddings_api_key, + endpoint=embeddings_endpoint, + model_id=embeddings_model or "text-embedding-3-small", + api_version=embeddings_api_version or "2024-02-15-preview", + ) + else: + embedding_service = OpenAIEmbeddingService( + db_client=db_client, + api_key=embeddings_api_key, + model_id=embeddings_model or "text-embedding-3-small", + base_url=embeddings_base_url, + ) mps_chunks = mps_response.get("chunks", []) if not mps_chunks: diff --git a/api/tests/test_azure_speech_service_factory.py b/api/tests/test_azure_speech_service_factory.py new file mode 100644 index 0000000..26739a7 --- /dev/null +++ b/api/tests/test_azure_speech_service_factory.py @@ -0,0 +1,182 @@ +"""Tests for Azure Speech TTS/STT service factory dispatch.""" + +from types import SimpleNamespace +from unittest.mock import patch + +import pytest +from fastapi import HTTPException + +from api.services.configuration.check_validity import UserConfigurationValidator +from api.services.configuration.registry import ( + AzureRealtimeLLMConfiguration, + AzureSpeechSTTConfiguration, + AzureSpeechTTSConfiguration, + ServiceProviders, +) +from api.services.gen_ai.embedding.azure_openai_service import ( + AzureOpenAIEmbeddingService, +) +from api.services.pipecat.service_factory import ( + create_realtime_llm_service, + create_stt_service, + create_tts_service, +) + + +def _audio_config(): + return SimpleNamespace( + transport_out_sample_rate=24000, + transport_in_sample_rate=16000, + ) + + +def test_create_azure_speech_tts_service(): + user_config = SimpleNamespace( + tts=SimpleNamespace( + provider=ServiceProviders.AZURE_SPEECH.value, + api_key="test-subscription-key", + region="eastus", + voice="en-US-AriaNeural", + language="en-US", + speed=1.0, + model="neural", + ) + ) + + with patch("api.services.pipecat.service_factory.AzureTTSService") as mock_service: + create_tts_service(user_config, _audio_config()) + + assert mock_service.call_count == 1 + kwargs = mock_service.call_args.kwargs + assert kwargs["api_key"] == "test-subscription-key" + assert kwargs["region"] == "eastus" + assert kwargs["settings"].voice == "en-US-AriaNeural" + assert kwargs["settings"].language == "en-US" + + +def test_create_azure_speech_tts_service_with_speed(): + user_config = SimpleNamespace( + tts=SimpleNamespace( + provider=ServiceProviders.AZURE_SPEECH.value, + api_key="test-key", + region="westeurope", + voice="en-GB-SoniaNeural", + language="en-GB", + speed=1.5, + model="neural", + ) + ) + + with patch("api.services.pipecat.service_factory.AzureTTSService") as mock_service: + create_tts_service(user_config, _audio_config()) + + assert mock_service.call_count == 1 + kwargs = mock_service.call_args.kwargs + assert kwargs["region"] == "westeurope" + assert kwargs["settings"].rate == "1.5" + + +def test_create_azure_speech_stt_service(): + user_config = SimpleNamespace( + stt=SimpleNamespace( + provider=ServiceProviders.AZURE_SPEECH.value, + api_key="test-subscription-key", + region="eastus", + language="en-US", + model="latest_long", + ) + ) + + with patch("api.services.pipecat.service_factory.AzureSTTService") as mock_service: + create_stt_service(user_config, _audio_config()) + + assert mock_service.call_count == 1 + kwargs = mock_service.call_args.kwargs + assert kwargs["api_key"] == "test-subscription-key" + assert kwargs["region"] == "eastus" + assert kwargs["sample_rate"] == 16000 + + +def test_create_azure_speech_stt_service_preserves_custom_language(): + user_config = SimpleNamespace( + stt=SimpleNamespace( + provider=ServiceProviders.AZURE_SPEECH.value, + api_key="test-subscription-key", + region="eastus", + language="custom-locale", + model="latest_long", + ) + ) + + with patch("api.services.pipecat.service_factory.AzureSTTService") as mock_service: + create_stt_service(user_config, _audio_config()) + + kwargs = mock_service.call_args.kwargs + assert kwargs["settings"].language == "custom-locale" + + +def test_validator_accepts_azure_speech_services(): + validator = UserConfigurationValidator() + + assert ( + validator._validate_service( + AzureSpeechTTSConfiguration(api_key="test-key"), + "tts", + ) + == [] + ) + assert ( + validator._validate_service( + AzureSpeechSTTConfiguration(api_key="test-key"), + "stt", + ) + == [] + ) + + +def test_validator_accepts_azure_realtime_service(monkeypatch): + monkeypatch.setattr("api.utils.url_security.DEPLOYMENT_MODE", "oss") + validator = UserConfigurationValidator() + + assert ( + validator._validate_service( + AzureRealtimeLLMConfiguration( + api_key="test-key", + endpoint="https://example.openai.azure.com", + ), + "realtime", + ) + == [] + ) + + +def test_create_azure_realtime_blocks_private_endpoint_in_saas(monkeypatch): + monkeypatch.setattr("api.utils.url_security.DEPLOYMENT_MODE", "saas") + user_config = SimpleNamespace( + realtime=SimpleNamespace( + provider=ServiceProviders.AZURE_REALTIME.value, + api_key="test-key", + endpoint="http://10.0.0.10", + api_version="2025-04-01-preview", + model="gpt-4o-realtime-preview", + voice="alloy", + ) + ) + + with pytest.raises(HTTPException) as exc_info: + create_realtime_llm_service(user_config, _audio_config()) + + assert exc_info.value.status_code == 400 + assert "public IP" in exc_info.value.detail + + +def test_azure_embedding_service_rejects_wrong_dimension(): + service = AzureOpenAIEmbeddingService( + db_client=SimpleNamespace(), + api_key=None, + endpoint=None, + model_id="text-embedding-3-large", + ) + + with pytest.raises(ValueError, match="1536-dimensional"): + service._validate_embedding_dimensions([[0.0] * 3072])