diff --git a/api/routes/knowledge_base.py b/api/routes/knowledge_base.py index 95f64b8..bd47c4c 100644 --- a/api/routes/knowledge_base.py +++ b/api/routes/knowledge_base.py @@ -369,26 +369,42 @@ async def search_chunks( try: # Import here to avoid circular dependency - from api.services.gen_ai import OpenAIEmbeddingService + from api.services.configuration.registry import ServiceProviders + from api.services.gen_ai import AzureOpenAIEmbeddingService, OpenAIEmbeddingService # Try to get user's embeddings configuration user_config = await db_client.get_user_configurations(user.id) embeddings_api_key = None embeddings_model = None + embeddings_provider = None + embeddings_endpoint = None + embeddings_api_version = None if user_config.embeddings: embeddings_api_key = user_config.embeddings.api_key embeddings_model = user_config.embeddings.model + embeddings_provider = getattr(user_config.embeddings, "provider", None) + embeddings_endpoint = getattr(user_config.embeddings, "endpoint", None) + embeddings_api_version = getattr(user_config.embeddings, "api_version", None) - # Initialize embedding service with user config or fallback to env - embedding_service = OpenAIEmbeddingService( - db_client=db_client, - api_key=embeddings_api_key, - model_id=embeddings_model or "text-embedding-3-small", - base_url=getattr(user_config.embeddings, "base_url", None) - if user_config.embeddings - else None, - ) + # Initialize embedding service based on provider + if embeddings_provider == ServiceProviders.AZURE.value and embeddings_endpoint: + embedding_service = AzureOpenAIEmbeddingService( + db_client=db_client, + api_key=embeddings_api_key, + endpoint=embeddings_endpoint, + model_id=embeddings_model or "text-embedding-3-small", + api_version=embeddings_api_version or "2024-02-15-preview", + ) + else: + embedding_service = OpenAIEmbeddingService( + db_client=db_client, + api_key=embeddings_api_key, + model_id=embeddings_model or "text-embedding-3-small", + base_url=getattr(user_config.embeddings, "base_url", None) + if user_config.embeddings + else None, + ) # Perform search results = await embedding_service.search_similar_chunks( diff --git a/api/services/configuration/registry.py b/api/services/configuration/registry.py index 498a6fc..c652d9c 100644 --- a/api/services/configuration/registry.py +++ b/api/services/configuration/registry.py @@ -49,6 +49,7 @@ class ServiceProviders(str, Enum): ELEVENLABS = "elevenlabs" GOOGLE = "google" AZURE = "azure" + AZURE_SPEECH = "azure_speech" DOGRAH = "dograh" SARVAM = "sarvam" SPEECHMATICS = "speechmatics" @@ -65,6 +66,7 @@ class ServiceProviders(str, Enum): ULTRAVOX_REALTIME = "ultravox_realtime" GOOGLE_REALTIME = "google_realtime" GOOGLE_VERTEX_REALTIME = "google_vertex_realtime" + AZURE_REALTIME = "azure_realtime" class BaseServiceConfiguration(BaseModel): @@ -76,6 +78,7 @@ class BaseServiceConfiguration(BaseModel): ServiceProviders.ELEVENLABS, ServiceProviders.GOOGLE, ServiceProviders.AZURE, + ServiceProviders.AZURE_SPEECH, ServiceProviders.DOGRAH, ServiceProviders.AWS_BEDROCK, ServiceProviders.SPEACHES, @@ -89,6 +92,7 @@ class BaseServiceConfiguration(BaseModel): ServiceProviders.ULTRAVOX_REALTIME, ServiceProviders.GOOGLE_REALTIME, ServiceProviders.GOOGLE_VERTEX_REALTIME, + ServiceProviders.AZURE_REALTIME, # ServiceProviders.SARVAM, ] api_key: str | list[str] @@ -239,6 +243,16 @@ SPEACHES_PROVIDER_MODEL_CONFIG = provider_model_config( ), provider_docs_url="https://github.com/speaches-ai/speaches", ) +AZURE_SPEECH_PROVIDER_MODEL_CONFIG = provider_model_config( + "Azure Speech Services", + description="Azure Cognitive Services Speech — TTS and STT via the Azure Speech SDK.", + provider_docs_url="https://learn.microsoft.com/en-us/azure/ai-services/speech-service/", +) +AZURE_REALTIME_PROVIDER_MODEL_CONFIG = provider_model_config( + "Azure OpenAI Realtime", + description="Azure OpenAI Realtime API — low-latency speech-to-speech conversations.", + provider_docs_url="https://learn.microsoft.com/en-us/azure/ai-services/openai/how-to/realtime-audio-quickstart", +) OPENAI_MODELS = [ "gpt-4.1", @@ -640,12 +654,63 @@ class GoogleVertexRealtimeLLMConfiguration(BaseLLMConfiguration): ) +AZURE_REALTIME_MODELS = ["gpt-4o-realtime-preview"] +AZURE_REALTIME_VOICES = [ + "alloy", + "ash", + "ballad", + "coral", + "echo", + "sage", + "shimmer", + "verse", +] +AZURE_REALTIME_API_VERSIONS = [ + "2025-04-01-preview", + "2024-10-01-preview", + "2024-12-17", +] + + +@register_service(ServiceType.REALTIME) +class AzureRealtimeLLMConfiguration(BaseLLMConfiguration): + model_config = AZURE_REALTIME_PROVIDER_MODEL_CONFIG + provider: Literal[ServiceProviders.AZURE_REALTIME] = ServiceProviders.AZURE_REALTIME + model: str = Field( + default="gpt-4o-realtime-preview", + description="Azure OpenAI realtime deployment name.", + json_schema_extra={ + "examples": AZURE_REALTIME_MODELS, + "allow_custom_input": True, + }, + ) + endpoint: str = Field( + description="Azure OpenAI resource endpoint (e.g. https://.openai.azure.com).", + ) + voice: str = Field( + default="alloy", + description="Voice the model speaks in.", + json_schema_extra={ + "examples": AZURE_REALTIME_VOICES, + "allow_custom_input": True, + }, + ) + api_version: str = Field( + default="2025-04-01-preview", + description="Azure OpenAI API version.", + json_schema_extra={ + "examples": AZURE_REALTIME_API_VERSIONS, + }, + ) + + REALTIME_PROVIDERS = { ServiceProviders.OPENAI_REALTIME.value, ServiceProviders.GROK_REALTIME.value, ServiceProviders.ULTRAVOX_REALTIME.value, ServiceProviders.GOOGLE_REALTIME.value, ServiceProviders.GOOGLE_VERTEX_REALTIME.value, + ServiceProviders.AZURE_REALTIME.value, } @@ -672,6 +737,7 @@ RealtimeConfig = Annotated[ UltravoxRealtimeLLMConfiguration, GoogleRealtimeLLMConfiguration, GoogleVertexRealtimeLLMConfiguration, + AzureRealtimeLLMConfiguration, ], Field(discriminator="provider"), ] @@ -993,6 +1059,116 @@ class MiniMaxTTSConfiguration(BaseTTSConfiguration): ) +AZURE_SPEECH_REGIONS = [ + "eastus", + "eastus2", + "westus", + "westus2", + "westus3", + "centralus", + "northcentralus", + "southcentralus", + "westcentralus", + "westeurope", + "northeurope", + "uksouth", + "ukwest", + "francecentral", + "switzerlandnorth", + "germanywestcentral", + "norwayeast", + "australiaeast", + "eastasia", + "southeastasia", + "japaneast", + "japanwest", + "koreacentral", + "centralindia", + "southindia", + "brazilsouth", +] + +AZURE_SPEECH_TTS_LANGUAGES = [ + "en-US", "en-GB", "en-AU", "en-CA", "en-IN", + "es-ES", "es-MX", + "fr-FR", "fr-CA", + "de-DE", + "it-IT", + "ja-JP", + "ko-KR", + "zh-CN", "zh-HK", "zh-TW", + "pt-BR", "pt-PT", + "ru-RU", + "ar-SA", + "nl-NL", + "pl-PL", + "sv-SE", + "hi-IN", +] + +AZURE_SPEECH_TTS_VOICES = [ + "en-US-AriaNeural", + "en-US-GuyNeural", + "en-US-JennyNeural", + "en-US-DavisNeural", + "en-US-AmberNeural", + "en-US-AnaNeural", + "en-US-AshleyNeural", + "en-US-BrandonNeural", + "en-US-ChristopherNeural", + "en-US-ElizabethNeural", + "en-US-EricNeural", + "en-US-JacobNeural", + "en-US-MichelleNeural", + "en-US-MonicaNeural", + "en-US-NancyNeural", + "en-US-RogerNeural", + "en-US-SaraNeural", + "en-US-SteffanNeural", + "en-US-TonyNeural", +] + + +@register_tts +class AzureSpeechTTSConfiguration(BaseTTSConfiguration): + model_config = AZURE_SPEECH_PROVIDER_MODEL_CONFIG + provider: Literal[ServiceProviders.AZURE_SPEECH] = ServiceProviders.AZURE_SPEECH + model: str = Field( + default="neural", + description="Azure Speech synthesis engine (neural voices only).", + json_schema_extra={"examples": ["neural"]}, + ) + region: str = Field( + default="eastus", + description="Azure region for Speech Services (e.g. 'eastus', 'westeurope').", + json_schema_extra={ + "examples": AZURE_SPEECH_REGIONS, + }, + ) + voice: str = Field( + default="en-US-AriaNeural", + description="Azure Neural voice name (e.g. 'en-US-AriaNeural').", + json_schema_extra={ + "examples": AZURE_SPEECH_TTS_VOICES, + "allow_custom_input": True, + }, + ) + language: str = Field( + default="en-US", + description="BCP-47 language code for synthesis.", + json_schema_extra={ + "examples": AZURE_SPEECH_TTS_LANGUAGES, + "allow_custom_input": True, + }, + ) + speed: float = Field( + default=1.0, + ge=0.5, + le=2.0, + description="Speech speed multiplier (0.5 to 2.0).", + ) + + TTSConfig = Annotated[ Union[ DeepgramTTSConfiguration, @@ -1006,6 +1182,7 @@ TTSConfig = Annotated[ RimeTTSConfiguration, SpeachesTTSConfiguration, MiniMaxTTSConfiguration, + AzureSpeechTTSConfiguration, ], Field(discriminator="provider"), ] @@ -1227,6 +1404,50 @@ class GladiaSTTConfiguration(BaseSTTConfiguration): ) +AZURE_SPEECH_STT_LANGUAGES = [ + "en-US", "en-GB", "en-AU", "en-CA", "en-IN", + "es-ES", "es-MX", + "fr-FR", "fr-CA", + "de-DE", + "it-IT", + "ja-JP", + "ko-KR", + "zh-CN", + "pt-BR", "pt-PT", + "ru-RU", + "ar-SA", + "nl-NL", + "pl-PL", + "hi-IN", +] + + +@register_stt +class AzureSpeechSTTConfiguration(BaseSTTConfiguration): + model_config = AZURE_SPEECH_PROVIDER_MODEL_CONFIG + provider: Literal[ServiceProviders.AZURE_SPEECH] = ServiceProviders.AZURE_SPEECH + model: str = Field( + default="latest_long", + description="Azure Speech recognition model (use 'latest_long' for continuous recognition).", + json_schema_extra={"examples": ["latest_long", "latest_short"]}, + ) + region: str = Field( + default="eastus", + description="Azure region for Speech Services (e.g. 'eastus', 'westeurope').", + json_schema_extra={ + "examples": AZURE_SPEECH_REGIONS, + }, + ) + language: str = Field( + default="en-US", + description="BCP-47 language code for recognition.", + json_schema_extra={ + "examples": AZURE_SPEECH_STT_LANGUAGES, + "allow_custom_input": True, + }, + ) + + STTConfig = Annotated[ Union[ DeepgramSTTConfiguration, @@ -1239,6 +1460,7 @@ STTConfig = Annotated[ SpeachesSTTConfiguration, AssemblyAISTTConfiguration, GladiaSTTConfiguration, + AzureSpeechSTTConfiguration, ], Field(discriminator="provider"), ] @@ -1278,8 +1500,33 @@ class OpenRouterEmbeddingsConfiguration(BaseEmbeddingsConfiguration): ) +AZURE_EMBEDDING_MODELS = ["text-embedding-3-small", "text-embedding-3-large", "text-embedding-ada-002"] + + +@register_embeddings +class AzureOpenAIEmbeddingsConfiguration(BaseEmbeddingsConfiguration): + model_config = AZURE_OPENAI_PROVIDER_MODEL_CONFIG + provider: Literal[ServiceProviders.AZURE] = ServiceProviders.AZURE + model: str = Field( + default="text-embedding-3-small", + description="Azure OpenAI embedding deployment name (must match the deployed model).", + json_schema_extra={"examples": AZURE_EMBEDDING_MODELS, "allow_custom_input": True}, + ) + endpoint: str = Field( + description="Azure OpenAI resource endpoint (e.g. https://.openai.azure.com).", + ) + api_version: str = Field( + default="2024-02-15-preview", + description="Azure OpenAI API version for embeddings.", + ) + + EmbeddingsConfig = Annotated[ - Union[OpenAIEmbeddingsConfiguration, OpenRouterEmbeddingsConfiguration], + Union[ + OpenAIEmbeddingsConfiguration, + OpenRouterEmbeddingsConfiguration, + AzureOpenAIEmbeddingsConfiguration, + ], Field(discriminator="provider"), ] diff --git a/api/services/gen_ai/__init__.py b/api/services/gen_ai/__init__.py index 4d5b8fe..ec9ba17 100644 --- a/api/services/gen_ai/__init__.py +++ b/api/services/gen_ai/__init__.py @@ -1,6 +1,8 @@ """Generative AI services for embeddings and document processing.""" from .embedding import ( + AzureEmbeddingAPIKeyNotConfiguredError, + AzureOpenAIEmbeddingService, BaseEmbeddingService, EmbeddingAPIKeyNotConfiguredError, OpenAIEmbeddingService, @@ -8,6 +10,8 @@ from .embedding import ( from .json_parser import parse_llm_json __all__ = [ + "AzureEmbeddingAPIKeyNotConfiguredError", + "AzureOpenAIEmbeddingService", "BaseEmbeddingService", "EmbeddingAPIKeyNotConfiguredError", "OpenAIEmbeddingService", diff --git a/api/services/gen_ai/embedding/__init__.py b/api/services/gen_ai/embedding/__init__.py index f6a4f18..0632ad1 100644 --- a/api/services/gen_ai/embedding/__init__.py +++ b/api/services/gen_ai/embedding/__init__.py @@ -1,9 +1,12 @@ """Embedding services for document processing and retrieval.""" +from .azure_openai_service import AzureEmbeddingAPIKeyNotConfiguredError, AzureOpenAIEmbeddingService from .base import BaseEmbeddingService from .openai_service import EmbeddingAPIKeyNotConfiguredError, OpenAIEmbeddingService __all__ = [ + "AzureEmbeddingAPIKeyNotConfiguredError", + "AzureOpenAIEmbeddingService", "BaseEmbeddingService", "EmbeddingAPIKeyNotConfiguredError", "OpenAIEmbeddingService", diff --git a/api/services/gen_ai/embedding/azure_openai_service.py b/api/services/gen_ai/embedding/azure_openai_service.py new file mode 100644 index 0000000..ca2c759 --- /dev/null +++ b/api/services/gen_ai/embedding/azure_openai_service.py @@ -0,0 +1,119 @@ +"""Azure OpenAI embedding service. + +Uses the Azure OpenAI REST API for text embeddings, compatible with +text-embedding-3-small, text-embedding-3-large, and text-embedding-ada-002 +deployments. +""" + +from typing import Any, Dict, List, Optional + +from loguru import logger +from openai import AsyncAzureOpenAI + +from api.db.db_client import DBClient +from api.utils.url_security import validate_user_configured_service_url + +from .base import BaseEmbeddingService + +DEFAULT_MODEL_ID = "text-embedding-3-small" +EMBEDDING_DIMENSION = 1536 + + +class AzureEmbeddingAPIKeyNotConfiguredError(Exception): + """Raised when Azure OpenAI credentials are not configured for embeddings.""" + + def __init__(self): + super().__init__( + "Azure OpenAI endpoint or API key not configured. Please set your " + "endpoint and API key in Model Configurations > Embedding to use " + "document processing." + ) + + +class AzureOpenAIEmbeddingService(BaseEmbeddingService): + """Embedding service using Azure OpenAI text-embedding deployments.""" + + def __init__( + self, + db_client: DBClient, + api_key: Optional[str] = None, + endpoint: Optional[str] = None, + model_id: str = DEFAULT_MODEL_ID, + api_version: str = "2024-02-15-preview", + ): + """Initialize the Azure OpenAI embedding service. + + Args: + db_client: Database client for vector similarity search. + api_key: Azure OpenAI API key. + endpoint: Azure OpenAI resource endpoint (e.g. https://.openai.azure.com). + model_id: Deployment name, used as both the deployment and model identifier. + api_version: Azure OpenAI API version. + """ + self.db = db_client + self.model_id = model_id + + self._configured = bool(api_key and endpoint) + if self._configured: + validate_user_configured_service_url(endpoint, field_name="endpoint") + self.client = AsyncAzureOpenAI( + api_key=api_key, + azure_endpoint=endpoint, + api_version=api_version, + ) + logger.info( + f"Azure OpenAI embedding service initialized with deployment: {model_id}" + ) + else: + self.client = None + logger.warning( + "Azure OpenAI embedding service initialized without credentials. " + "Operations will fail until endpoint and API key are configured." + ) + + def get_model_id(self) -> str: + return self.model_id + + def get_embedding_dimension(self) -> int: + return EMBEDDING_DIMENSION + + def _ensure_configured(self): + if not self._configured or self.client is None: + raise AzureEmbeddingAPIKeyNotConfiguredError() + + async def embed_texts(self, texts: List[str]) -> List[List[float]]: + """Embed a batch of texts using Azure OpenAI API.""" + self._ensure_configured() + try: + response = await self.client.embeddings.create( + input=texts, + model=self.model_id, + ) + return [item.embedding for item in response.data] + except Exception as e: + logger.error(f"Error generating Azure OpenAI embeddings: {e}") + raise + + async def embed_query(self, query: str) -> List[float]: + """Embed a single query text using Azure OpenAI API.""" + self._ensure_configured() + embeddings = await self.embed_texts([query]) + return embeddings[0] + + async def search_similar_chunks( + self, + query: str, + organization_id: int, + limit: int = 5, + document_uuids: Optional[List[str]] = None, + ) -> List[Dict[str, Any]]: + """Search for similar chunks using vector similarity.""" + self._ensure_configured() + query_embedding = await self.embed_query(query) + return await self.db.search_similar_chunks( + query_embedding=query_embedding, + organization_id=organization_id, + limit=limit, + document_uuids=document_uuids, + embedding_model=self.model_id, + ) diff --git a/api/services/pipecat/realtime/azure_realtime.py b/api/services/pipecat/realtime/azure_realtime.py new file mode 100644 index 0000000..0cf025b --- /dev/null +++ b/api/services/pipecat/realtime/azure_realtime.py @@ -0,0 +1,242 @@ +"""Dograh subclass of pipecat's Azure OpenAI Realtime LLM service. + +Layers Dograh engine integration quirks (mute gating, TTSSpeakFrame greeting +trigger, LLMMessagesAppendFrame handling, deferred tool calls) onto pipecat's +AzureRealtimeLLMService, mirroring what DograhOpenAIRealtimeLLMService does +for the standard OpenAI Realtime endpoint. +""" + +import json +from typing import Any + +from loguru import logger + +from pipecat.frames.frames import ( + BotStartedSpeakingFrame, + BotStoppedSpeakingFrame, + Frame, + LLMFullResponseStartFrame, + LLMMessagesAppendFrame, + TranscriptionFrame, + TTSSpeakFrame, + UserMuteStartedFrame, + UserMuteStoppedFrame, +) +from pipecat.processors.aggregators.llm_context import LLMContext +from pipecat.processors.frame_processor import FrameDirection +from pipecat.services.azure.realtime.llm import AzureRealtimeLLMService +from pipecat.services.llm_service import FunctionCallFromLLM +from pipecat.services.openai.realtime import events +from pipecat.transcriptions.language import Language +from pipecat.utils.time import time_now_iso8601 + + +class DograhAzureRealtimeLLMService(AzureRealtimeLLMService): + """Azure OpenAI Realtime with Dograh engine integration quirks. + + Extends AzureRealtimeLLMService with the same Dograh-specific behaviours + added to DograhOpenAIRealtimeLLMService: + - User-mute audio gating + - TTSSpeakFrame as initial-response trigger + - One-off LLMMessagesAppendFrame handling + - Deferred tool calls until bot finishes speaking + - finalized=True on TranscriptionFrame for consistency + """ + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self._user_is_muted: bool = False + self._handled_initial_context: bool = False + self._bot_is_speaking: bool = False + self._deferred_function_calls: list[FunctionCallFromLLM] = [] + + async def process_frame(self, frame: Frame, direction: FrameDirection): + if isinstance(frame, UserMuteStartedFrame): + self._user_is_muted = True + await self.push_frame(frame, direction) + return + if isinstance(frame, UserMuteStoppedFrame): + self._user_is_muted = False + await self.push_frame(frame, direction) + return + if isinstance(frame, TTSSpeakFrame): + if not self._handled_initial_context: + await self._handle_context(self._context) + else: + logger.warning( + f"{self}: TTSSpeakFrame after initial context already handled — " + "Azure Realtime owns audio generation, ignoring" + ) + return + if isinstance(frame, LLMMessagesAppendFrame): + await self._handle_messages_append(frame) + return + if isinstance(frame, BotStartedSpeakingFrame): + self._bot_is_speaking = True + elif isinstance(frame, BotStoppedSpeakingFrame): + self._bot_is_speaking = False + await self._run_pending_function_calls() + await super().process_frame(frame, direction) + + async def _handle_messages_append(self, frame: LLMMessagesAppendFrame): + if self._disconnecting: + return + + if not self._api_session_ready: + if frame.run_llm: + logger.debug( + f"{self}: LLMMessagesAppendFrame received before session ready; " + "deferring response until the session is initialized" + ) + self._run_llm_when_api_session_ready = True + return + + appended_any = False + for message in frame.messages: + item = self._message_to_conversation_item(message) + if item is None: + continue + evt = events.ConversationItemCreateEvent(item=item) + self._messages_added_manually[evt.item.id] = True + await self.send_client_event(evt) + appended_any = True + + if frame.run_llm and appended_any: + await self._send_manual_response_create() + + async def _handle_context(self, context: LLMContext): + if not self._handled_initial_context: + if context is None: + logger.warning( + f"{self}: received initial context trigger before context was set" + ) + return + self._handled_initial_context = True + self._context = context + await self._create_response() + else: + self._context = context + await self._process_completed_function_calls(send_new_results=True) + + async def _send_user_audio(self, frame): + if self._user_is_muted: + return + await super()._send_user_audio(frame) + + def _message_to_conversation_item( + self, message: dict[str, Any] + ) -> events.ConversationItem | None: + if not isinstance(message, dict): + logger.warning( + f"{self}: skipping unsupported appended message payload {message!r}" + ) + return None + + role = message.get("role") + if role not in {"user", "system", "developer"}: + logger.warning( + f"{self}: skipping unsupported appended message role {role!r}" + ) + return None + + text = self._extract_text_content(message.get("content")) + if not text: + logger.warning( + f"{self}: skipping appended message with unsupported content {message!r}" + ) + return None + + item_role = "system" if role in {"system", "developer"} else "user" + return events.ConversationItem( + type="message", + role=item_role, + content=[events.ItemContent(type="input_text", text=text)], + ) + + @staticmethod + def _extract_text_content(content: Any) -> str | None: + if isinstance(content, str): + return content + if isinstance(content, list): + parts: list[str] = [] + for part in content: + if not isinstance(part, dict): + return None + if part.get("type") != "text": + return None + text = part.get("text") + if not isinstance(text, str): + return None + parts.append(text) + return "\n".join(parts) if parts else None + return None + + async def _send_manual_response_create(self): + await self.push_frame(LLMFullResponseStartFrame()) + await self.start_processing_metrics() + await self.start_ttfb_metrics() + await self.send_client_event( + events.ResponseCreateEvent( + response=events.ResponseProperties( + output_modalities=self._get_enabled_modalities() + ) + ) + ) + + async def _run_pending_function_calls(self): + if not self._deferred_function_calls: + return + function_calls = self._deferred_function_calls + self._deferred_function_calls = [] + logger.debug( + f"{self}: executing {len(function_calls)} deferred function call(s) " + "after bot turn ended" + ) + await self.run_function_calls(function_calls) + + async def _handle_evt_function_call_arguments_done(self, evt): + try: + args = json.loads(evt.arguments) + + function_call_item = self._pending_function_calls.get(evt.call_id) + if function_call_item: + del self._pending_function_calls[evt.call_id] + + function_calls = [ + FunctionCallFromLLM( + context=self._context, + tool_call_id=evt.call_id, + function_name=function_call_item.name, + arguments=args, + ) + ] + + if self._bot_is_speaking: + self._deferred_function_calls.extend(function_calls) + logger.debug( + f"{self}: deferring function call {function_call_item.name} " + "until bot stops speaking" + ) + else: + await self.run_function_calls(function_calls) + logger.debug(f"Processed function call: {function_call_item.name}") + else: + logger.warning( + f"No tracked function call found for call_id: {evt.call_id}" + ) + except Exception as e: + logger.error(f"Failed to process function call arguments: {e}") + + async def handle_evt_input_audio_transcription_completed(self, evt): + await self._call_event_handler( + "on_conversation_item_updated", evt.item_id, None + ) + await self.broadcast_frame( + TranscriptionFrame, + text=evt.transcript, + user_id="", + timestamp=time_now_iso8601(), + result=evt, + finalized=True, + ) + await self._handle_user_transcription(evt.transcript, True, Language.EN) diff --git a/api/services/pipecat/run_pipeline.py b/api/services/pipecat/run_pipeline.py index 6cae498..79cb2c9 100644 --- a/api/services/pipecat/run_pipeline.py +++ b/api/services/pipecat/run_pipeline.py @@ -504,10 +504,16 @@ async def _run_pipeline( embeddings_api_key = None embeddings_model = None embeddings_base_url = None + embeddings_provider = None + embeddings_endpoint = None + embeddings_api_version = None if user_config and user_config.embeddings: embeddings_api_key = user_config.embeddings.api_key embeddings_model = user_config.embeddings.model + embeddings_provider = getattr(user_config.embeddings, "provider", None) embeddings_base_url = getattr(user_config.embeddings, "base_url", None) + embeddings_endpoint = getattr(user_config.embeddings, "endpoint", None) + embeddings_api_version = getattr(user_config.embeddings, "api_version", None) # Check if the workflow has any active recordings so the engine can # include recording response mode instructions in all node prompts. @@ -532,6 +538,9 @@ async def _run_pipeline( embeddings_api_key=embeddings_api_key, embeddings_model=embeddings_model, embeddings_base_url=embeddings_base_url, + embeddings_provider=embeddings_provider, + embeddings_endpoint=embeddings_endpoint, + embeddings_api_version=embeddings_api_version, has_recordings=has_recordings, context_compaction_enabled=context_compaction_enabled, ) diff --git a/api/services/pipecat/service_factory.py b/api/services/pipecat/service_factory.py index 1c796e4..39b087e 100644 --- a/api/services/pipecat/service_factory.py +++ b/api/services/pipecat/service_factory.py @@ -11,6 +11,8 @@ from api.utils.url_security import validate_user_configured_service_url from pipecat.services.assemblyai.stt import AssemblyAISTTService, AssemblyAISTTSettings from pipecat.services.aws.llm import AWSBedrockLLMService, AWSBedrockLLMSettings from pipecat.services.azure.llm import AzureLLMService, AzureLLMSettings +from pipecat.services.azure.stt import AzureSTTService, AzureSTTSettings +from pipecat.services.azure.tts import AzureTTSService, AzureTTSSettings from pipecat.services.cartesia.stt import CartesiaSTTService from pipecat.services.cartesia.tts import ( CartesiaTTSService, @@ -246,6 +248,22 @@ def create_stt_service( ), sample_rate=audio_config.transport_in_sample_rate, ) + elif user_config.stt.provider == ServiceProviders.AZURE_SPEECH.value: + from pipecat.transcriptions.language import Language as PipecatLanguage + + language_code = getattr(user_config.stt, "language", None) or "en-US" + region = getattr(user_config.stt, "region", None) or "eastus" + # Try to map BCP-47 string to pipecat Language enum; fall back to string + try: + pipecat_language = PipecatLanguage(language_code) + except ValueError: + pipecat_language = PipecatLanguage.EN_US + return AzureSTTService( + api_key=user_config.stt.api_key, + region=region, + settings=AzureSTTSettings(language=pipecat_language), + sample_rate=audio_config.transport_in_sample_rate, + ) else: raise HTTPException( status_code=400, detail=f"Invalid STT provider {user_config.stt.provider}" @@ -492,6 +510,27 @@ def create_tts_service(user_config, audio_config: "AudioConfig"): skip_aggregator_types=["recording_router", "recording"], silence_time_s=1.0, ) + elif user_config.tts.provider == ServiceProviders.AZURE_SPEECH.value: + region = getattr(user_config.tts, "region", None) or "eastus" + voice = getattr(user_config.tts, "voice", None) or "en-US-AriaNeural" + language = getattr(user_config.tts, "language", None) or "en-US" + speed = getattr(user_config.tts, "speed", None) or 1.0 + # Map speed multiplier (0.5–2.0) to Azure SSML rate string (e.g. "1.25") + rate = str(speed) if speed != 1.0 else None + settings_kwargs: dict = { + "voice": voice, + "language": language, + } + if rate: + settings_kwargs["rate"] = rate + return AzureTTSService( + api_key=user_config.tts.api_key, + region=region, + settings=AzureTTSSettings(**settings_kwargs), + text_filters=[xml_function_tag_filter], + skip_aggregator_types=["recording_router", "recording"], + silence_time_s=1.0, + ) else: raise HTTPException( status_code=400, detail=f"Invalid TTS provider {user_config.tts.provider}" @@ -724,6 +763,44 @@ def create_realtime_llm_service(user_config, audio_config: "AudioConfig"): location=location, settings=DograhGeminiLiveVertexLLMService.Settings(**settings_kwargs), ) + elif provider == ServiceProviders.AZURE_REALTIME.value: + from api.services.pipecat.realtime.azure_realtime import ( + DograhAzureRealtimeLLMService, + ) + from pipecat.services.openai.realtime.events import ( + AudioConfiguration, + AudioInput, + AudioOutput, + InputAudioTranscription, + SessionProperties, + ) + + endpoint = getattr(realtime_config, "endpoint", None) or "" + api_version = getattr(realtime_config, "api_version", None) or "2025-04-01-preview" + # Construct the Azure Realtime WebSocket URL + # https://.openai.azure.com/openai/realtime?api-version=&deployment= + base_host = endpoint.rstrip("/").replace("https://", "").replace("http://", "") + wss_url = ( + f"wss://{base_host}/openai/realtime" + f"?api-version={api_version}&deployment={model}" + ) + return DograhAzureRealtimeLLMService( + api_key=api_key, + base_url=wss_url, + settings=DograhAzureRealtimeLLMService.Settings( + model=model, + session_properties=SessionProperties( + audio=AudioConfiguration( + input=AudioInput( + transcription=InputAudioTranscription(), + ), + output=AudioOutput( + voice=voice or "alloy", + ), + ), + ), + ), + ) else: raise HTTPException( status_code=400, detail=f"Invalid realtime LLM provider {provider}" diff --git a/api/services/workflow/pipecat_engine.py b/api/services/workflow/pipecat_engine.py index f056725..a36ecb8 100644 --- a/api/services/workflow/pipecat_engine.py +++ b/api/services/workflow/pipecat_engine.py @@ -73,6 +73,9 @@ class PipecatEngine: embeddings_api_key: Optional[str] = None, embeddings_model: Optional[str] = None, embeddings_base_url: Optional[str] = None, + embeddings_provider: Optional[str] = None, + embeddings_endpoint: Optional[str] = None, + embeddings_api_version: Optional[str] = None, has_recordings: bool = False, context_compaction_enabled: bool = False, ): @@ -126,6 +129,9 @@ class PipecatEngine: self._embeddings_api_key: Optional[str] = embeddings_api_key self._embeddings_model: Optional[str] = embeddings_model self._embeddings_base_url: Optional[str] = embeddings_base_url + self._embeddings_provider: Optional[str] = embeddings_provider + self._embeddings_endpoint: Optional[str] = embeddings_endpoint + self._embeddings_api_version: Optional[str] = embeddings_api_version # Audio configuration (set via set_audio_config from _run_pipeline) self._audio_config = None @@ -373,6 +379,9 @@ class PipecatEngine: embeddings_api_key=self._embeddings_api_key, embeddings_model=self._embeddings_model, embeddings_base_url=self._embeddings_base_url, + embeddings_provider=self._embeddings_provider, + embeddings_endpoint=self._embeddings_endpoint, + embeddings_api_version=self._embeddings_api_version, tracing_context=self._get_otel_context(), ) diff --git a/api/services/workflow/tools/knowledge_base.py b/api/services/workflow/tools/knowledge_base.py index 6821583..b3fbda9 100644 --- a/api/services/workflow/tools/knowledge_base.py +++ b/api/services/workflow/tools/knowledge_base.py @@ -13,7 +13,8 @@ from loguru import logger from opentelemetry import trace from api.db import db_client -from api.services.gen_ai import OpenAIEmbeddingService +from api.services.configuration.registry import ServiceProviders +from api.services.gen_ai import AzureOpenAIEmbeddingService, OpenAIEmbeddingService from api.services.pipecat.tracing_config import ensure_tracing @@ -25,6 +26,9 @@ async def retrieve_from_knowledge_base( embeddings_api_key: Optional[str] = None, embeddings_model: Optional[str] = None, embeddings_base_url: Optional[str] = None, + embeddings_provider: Optional[str] = None, + embeddings_endpoint: Optional[str] = None, + embeddings_api_version: Optional[str] = None, tracing_context=None, ) -> Dict[str, Any]: """Retrieve relevant information from the knowledge base using vector similarity search. @@ -68,6 +72,9 @@ async def retrieve_from_knowledge_base( embeddings_api_key, embeddings_model, embeddings_base_url, + embeddings_provider, + embeddings_endpoint, + embeddings_api_version, ) # Create span with parent context @@ -105,6 +112,9 @@ async def retrieve_from_knowledge_base( embeddings_api_key, embeddings_model, embeddings_base_url, + embeddings_provider, + embeddings_endpoint, + embeddings_api_version, ) # Add result metadata to span @@ -179,6 +189,9 @@ async def retrieve_from_knowledge_base( embeddings_api_key, embeddings_model, embeddings_base_url, + embeddings_provider, + embeddings_endpoint, + embeddings_api_version, ) else: # Tracing is disabled - perform retrieval without tracing @@ -189,6 +202,10 @@ async def retrieve_from_knowledge_base( limit, embeddings_api_key, embeddings_model, + embeddings_base_url, + embeddings_provider, + embeddings_endpoint, + embeddings_api_version, ) @@ -200,6 +217,9 @@ async def _perform_retrieval( embeddings_api_key: Optional[str] = None, embeddings_model: Optional[str] = None, embeddings_base_url: Optional[str] = None, + embeddings_provider: Optional[str] = None, + embeddings_endpoint: Optional[str] = None, + embeddings_api_version: Optional[str] = None, ) -> Dict[str, Any]: """Internal function to perform the actual retrieval operation. @@ -240,12 +260,21 @@ async def _perform_retrieval( "Model Configurations > Embedding." ) - embedding_service = OpenAIEmbeddingService( - db_client=db_client, - api_key=embeddings_api_key, - model_id=embeddings_model or "text-embedding-3-small", - base_url=embeddings_base_url, - ) + if embeddings_provider == ServiceProviders.AZURE.value and embeddings_endpoint: + embedding_service = AzureOpenAIEmbeddingService( + db_client=db_client, + api_key=embeddings_api_key, + endpoint=embeddings_endpoint, + model_id=embeddings_model or "text-embedding-3-small", + api_version=embeddings_api_version or "2024-02-15-preview", + ) + else: + embedding_service = OpenAIEmbeddingService( + db_client=db_client, + api_key=embeddings_api_key, + model_id=embeddings_model or "text-embedding-3-small", + base_url=embeddings_base_url, + ) results = await embedding_service.search_similar_chunks( query=query, diff --git a/api/tasks/knowledge_base_processing.py b/api/tasks/knowledge_base_processing.py index 1c891e2..2066f1d 100644 --- a/api/tasks/knowledge_base_processing.py +++ b/api/tasks/knowledge_base_processing.py @@ -12,7 +12,8 @@ from loguru import logger from api.db import db_client from api.db.models import KnowledgeBaseChunkModel -from api.services.gen_ai import OpenAIEmbeddingService +from api.services.configuration.registry import ServiceProviders +from api.services.gen_ai import AzureOpenAIEmbeddingService, OpenAIEmbeddingService from api.services.mps_service_key_client import mps_service_key_client from api.services.storage import storage_fs @@ -148,21 +149,30 @@ async def process_knowledge_base_document( ) return - # Chunked mode: fetch user embedding config, embed via OpenAI, persist chunks. + # Chunked mode: fetch user embedding config, embed, and persist chunks. + embeddings_provider = None embeddings_api_key = None embeddings_model = None embeddings_base_url = None + embeddings_endpoint = None + embeddings_api_version = None if document.created_by: user_config = await db_client.get_user_configurations(document.created_by) if user_config.embeddings: + embeddings_provider = getattr(user_config.embeddings, "provider", None) embeddings_api_key = user_config.embeddings.api_key embeddings_model = user_config.embeddings.model embeddings_base_url = getattr(user_config.embeddings, "base_url", None) - logger.info(f"Using user embeddings config: model={embeddings_model}") + embeddings_endpoint = getattr(user_config.embeddings, "endpoint", None) + embeddings_api_version = getattr(user_config.embeddings, "api_version", None) + logger.info( + f"Using user embeddings config: provider={embeddings_provider}, " + f"model={embeddings_model}" + ) if not embeddings_api_key: error_message = ( - "OpenAI API key not configured. Please set your API key in " + "API key not configured. Please set your API key in " "Model Configurations > Embedding to process documents." ) logger.warning(f"Document {document_id}: {error_message}") @@ -171,12 +181,21 @@ async def process_knowledge_base_document( ) return - embedding_service = OpenAIEmbeddingService( - db_client=db_client, - api_key=embeddings_api_key, - model_id=embeddings_model or "text-embedding-3-small", - base_url=embeddings_base_url, - ) + if embeddings_provider == ServiceProviders.AZURE.value and embeddings_endpoint: + embedding_service = AzureOpenAIEmbeddingService( + db_client=db_client, + api_key=embeddings_api_key, + endpoint=embeddings_endpoint, + model_id=embeddings_model or "text-embedding-3-small", + api_version=embeddings_api_version or "2024-02-15-preview", + ) + else: + embedding_service = OpenAIEmbeddingService( + db_client=db_client, + api_key=embeddings_api_key, + model_id=embeddings_model or "text-embedding-3-small", + base_url=embeddings_base_url, + ) mps_chunks = mps_response.get("chunks", []) if not mps_chunks: diff --git a/api/tests/test_azure_speech_service_factory.py b/api/tests/test_azure_speech_service_factory.py new file mode 100644 index 0000000..fe958e5 --- /dev/null +++ b/api/tests/test_azure_speech_service_factory.py @@ -0,0 +1,81 @@ +"""Tests for Azure Speech TTS/STT service factory dispatch.""" + +from types import SimpleNamespace +from unittest.mock import patch + +from api.services.configuration.registry import ServiceProviders +from api.services.pipecat.service_factory import create_stt_service, create_tts_service + + +def _audio_config(): + return SimpleNamespace( + transport_out_sample_rate=24000, + transport_in_sample_rate=16000, + ) + + +def test_create_azure_speech_tts_service(): + user_config = SimpleNamespace( + tts=SimpleNamespace( + provider=ServiceProviders.AZURE_SPEECH.value, + api_key="test-subscription-key", + region="eastus", + voice="en-US-AriaNeural", + language="en-US", + speed=1.0, + model="neural", + ) + ) + + with patch("api.services.pipecat.service_factory.AzureTTSService") as mock_service: + create_tts_service(user_config, _audio_config()) + + assert mock_service.call_count == 1 + kwargs = mock_service.call_args.kwargs + assert kwargs["api_key"] == "test-subscription-key" + assert kwargs["region"] == "eastus" + assert kwargs["settings"].voice == "en-US-AriaNeural" + assert kwargs["settings"].language == "en-US" + + +def test_create_azure_speech_tts_service_with_speed(): + user_config = SimpleNamespace( + tts=SimpleNamespace( + provider=ServiceProviders.AZURE_SPEECH.value, + api_key="test-key", + region="westeurope", + voice="en-GB-SoniaNeural", + language="en-GB", + speed=1.5, + model="neural", + ) + ) + + with patch("api.services.pipecat.service_factory.AzureTTSService") as mock_service: + create_tts_service(user_config, _audio_config()) + + assert mock_service.call_count == 1 + kwargs = mock_service.call_args.kwargs + assert kwargs["region"] == "westeurope" + assert kwargs["settings"].rate == "1.5" + + +def test_create_azure_speech_stt_service(): + user_config = SimpleNamespace( + stt=SimpleNamespace( + provider=ServiceProviders.AZURE_SPEECH.value, + api_key="test-subscription-key", + region="eastus", + language="en-US", + model="latest_long", + ) + ) + + with patch("api.services.pipecat.service_factory.AzureSTTService") as mock_service: + create_stt_service(user_config, _audio_config()) + + assert mock_service.call_count == 1 + kwargs = mock_service.call_args.kwargs + assert kwargs["api_key"] == "test-subscription-key" + assert kwargs["region"] == "eastus" + assert kwargs["sample_rate"] == 16000