mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-07 07:55:16 +02:00
feat: add Azure AI multi-provider support (TTS, STT, Embeddings, Realtime)
Enables Azure AI services across all model layers so users with Azure credits can consolidate billing on a single provider. - Voice (TTS): AzureSpeechTTSConfiguration via azure_speech provider - Transcriber (STT): AzureSpeechSTTConfiguration via azure_speech provider - Embedding: AzureOpenAIEmbeddingsConfiguration via azure provider - Realtime: AzureRealtimeLLMConfiguration via azure_realtime provider New files: - api/services/pipecat/realtime/azure_realtime.py - api/services/gen_ai/embedding/azure_openai_service.py - api/tests/test_azure_speech_service_factory.py The UI picks up all four providers automatically from the schema — no frontend changes required. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
e695436fb3
commit
dbbf362315
12 changed files with 883 additions and 28 deletions
|
|
@ -369,26 +369,42 @@ async def search_chunks(
|
|||
|
||||
try:
|
||||
# Import here to avoid circular dependency
|
||||
from api.services.gen_ai import OpenAIEmbeddingService
|
||||
from api.services.configuration.registry import ServiceProviders
|
||||
from api.services.gen_ai import AzureOpenAIEmbeddingService, OpenAIEmbeddingService
|
||||
|
||||
# Try to get user's embeddings configuration
|
||||
user_config = await db_client.get_user_configurations(user.id)
|
||||
embeddings_api_key = None
|
||||
embeddings_model = None
|
||||
embeddings_provider = None
|
||||
embeddings_endpoint = None
|
||||
embeddings_api_version = None
|
||||
|
||||
if user_config.embeddings:
|
||||
embeddings_api_key = user_config.embeddings.api_key
|
||||
embeddings_model = user_config.embeddings.model
|
||||
embeddings_provider = getattr(user_config.embeddings, "provider", None)
|
||||
embeddings_endpoint = getattr(user_config.embeddings, "endpoint", None)
|
||||
embeddings_api_version = getattr(user_config.embeddings, "api_version", None)
|
||||
|
||||
# Initialize embedding service with user config or fallback to env
|
||||
embedding_service = OpenAIEmbeddingService(
|
||||
db_client=db_client,
|
||||
api_key=embeddings_api_key,
|
||||
model_id=embeddings_model or "text-embedding-3-small",
|
||||
base_url=getattr(user_config.embeddings, "base_url", None)
|
||||
if user_config.embeddings
|
||||
else None,
|
||||
)
|
||||
# Initialize embedding service based on provider
|
||||
if embeddings_provider == ServiceProviders.AZURE.value and embeddings_endpoint:
|
||||
embedding_service = AzureOpenAIEmbeddingService(
|
||||
db_client=db_client,
|
||||
api_key=embeddings_api_key,
|
||||
endpoint=embeddings_endpoint,
|
||||
model_id=embeddings_model or "text-embedding-3-small",
|
||||
api_version=embeddings_api_version or "2024-02-15-preview",
|
||||
)
|
||||
else:
|
||||
embedding_service = OpenAIEmbeddingService(
|
||||
db_client=db_client,
|
||||
api_key=embeddings_api_key,
|
||||
model_id=embeddings_model or "text-embedding-3-small",
|
||||
base_url=getattr(user_config.embeddings, "base_url", None)
|
||||
if user_config.embeddings
|
||||
else None,
|
||||
)
|
||||
|
||||
# Perform search
|
||||
results = await embedding_service.search_similar_chunks(
|
||||
|
|
|
|||
|
|
@ -49,6 +49,7 @@ class ServiceProviders(str, Enum):
|
|||
ELEVENLABS = "elevenlabs"
|
||||
GOOGLE = "google"
|
||||
AZURE = "azure"
|
||||
AZURE_SPEECH = "azure_speech"
|
||||
DOGRAH = "dograh"
|
||||
SARVAM = "sarvam"
|
||||
SPEECHMATICS = "speechmatics"
|
||||
|
|
@ -65,6 +66,7 @@ class ServiceProviders(str, Enum):
|
|||
ULTRAVOX_REALTIME = "ultravox_realtime"
|
||||
GOOGLE_REALTIME = "google_realtime"
|
||||
GOOGLE_VERTEX_REALTIME = "google_vertex_realtime"
|
||||
AZURE_REALTIME = "azure_realtime"
|
||||
|
||||
|
||||
class BaseServiceConfiguration(BaseModel):
|
||||
|
|
@ -76,6 +78,7 @@ class BaseServiceConfiguration(BaseModel):
|
|||
ServiceProviders.ELEVENLABS,
|
||||
ServiceProviders.GOOGLE,
|
||||
ServiceProviders.AZURE,
|
||||
ServiceProviders.AZURE_SPEECH,
|
||||
ServiceProviders.DOGRAH,
|
||||
ServiceProviders.AWS_BEDROCK,
|
||||
ServiceProviders.SPEACHES,
|
||||
|
|
@ -89,6 +92,7 @@ class BaseServiceConfiguration(BaseModel):
|
|||
ServiceProviders.ULTRAVOX_REALTIME,
|
||||
ServiceProviders.GOOGLE_REALTIME,
|
||||
ServiceProviders.GOOGLE_VERTEX_REALTIME,
|
||||
ServiceProviders.AZURE_REALTIME,
|
||||
# ServiceProviders.SARVAM,
|
||||
]
|
||||
api_key: str | list[str]
|
||||
|
|
@ -239,6 +243,16 @@ SPEACHES_PROVIDER_MODEL_CONFIG = provider_model_config(
|
|||
),
|
||||
provider_docs_url="https://github.com/speaches-ai/speaches",
|
||||
)
|
||||
AZURE_SPEECH_PROVIDER_MODEL_CONFIG = provider_model_config(
|
||||
"Azure Speech Services",
|
||||
description="Azure Cognitive Services Speech — TTS and STT via the Azure Speech SDK.",
|
||||
provider_docs_url="https://learn.microsoft.com/en-us/azure/ai-services/speech-service/",
|
||||
)
|
||||
AZURE_REALTIME_PROVIDER_MODEL_CONFIG = provider_model_config(
|
||||
"Azure OpenAI Realtime",
|
||||
description="Azure OpenAI Realtime API — low-latency speech-to-speech conversations.",
|
||||
provider_docs_url="https://learn.microsoft.com/en-us/azure/ai-services/openai/how-to/realtime-audio-quickstart",
|
||||
)
|
||||
|
||||
OPENAI_MODELS = [
|
||||
"gpt-4.1",
|
||||
|
|
@ -640,12 +654,63 @@ class GoogleVertexRealtimeLLMConfiguration(BaseLLMConfiguration):
|
|||
)
|
||||
|
||||
|
||||
AZURE_REALTIME_MODELS = ["gpt-4o-realtime-preview"]
|
||||
AZURE_REALTIME_VOICES = [
|
||||
"alloy",
|
||||
"ash",
|
||||
"ballad",
|
||||
"coral",
|
||||
"echo",
|
||||
"sage",
|
||||
"shimmer",
|
||||
"verse",
|
||||
]
|
||||
AZURE_REALTIME_API_VERSIONS = [
|
||||
"2025-04-01-preview",
|
||||
"2024-10-01-preview",
|
||||
"2024-12-17",
|
||||
]
|
||||
|
||||
|
||||
@register_service(ServiceType.REALTIME)
|
||||
class AzureRealtimeLLMConfiguration(BaseLLMConfiguration):
|
||||
model_config = AZURE_REALTIME_PROVIDER_MODEL_CONFIG
|
||||
provider: Literal[ServiceProviders.AZURE_REALTIME] = ServiceProviders.AZURE_REALTIME
|
||||
model: str = Field(
|
||||
default="gpt-4o-realtime-preview",
|
||||
description="Azure OpenAI realtime deployment name.",
|
||||
json_schema_extra={
|
||||
"examples": AZURE_REALTIME_MODELS,
|
||||
"allow_custom_input": True,
|
||||
},
|
||||
)
|
||||
endpoint: str = Field(
|
||||
description="Azure OpenAI resource endpoint (e.g. https://<resource>.openai.azure.com).",
|
||||
)
|
||||
voice: str = Field(
|
||||
default="alloy",
|
||||
description="Voice the model speaks in.",
|
||||
json_schema_extra={
|
||||
"examples": AZURE_REALTIME_VOICES,
|
||||
"allow_custom_input": True,
|
||||
},
|
||||
)
|
||||
api_version: str = Field(
|
||||
default="2025-04-01-preview",
|
||||
description="Azure OpenAI API version.",
|
||||
json_schema_extra={
|
||||
"examples": AZURE_REALTIME_API_VERSIONS,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
REALTIME_PROVIDERS = {
|
||||
ServiceProviders.OPENAI_REALTIME.value,
|
||||
ServiceProviders.GROK_REALTIME.value,
|
||||
ServiceProviders.ULTRAVOX_REALTIME.value,
|
||||
ServiceProviders.GOOGLE_REALTIME.value,
|
||||
ServiceProviders.GOOGLE_VERTEX_REALTIME.value,
|
||||
ServiceProviders.AZURE_REALTIME.value,
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -672,6 +737,7 @@ RealtimeConfig = Annotated[
|
|||
UltravoxRealtimeLLMConfiguration,
|
||||
GoogleRealtimeLLMConfiguration,
|
||||
GoogleVertexRealtimeLLMConfiguration,
|
||||
AzureRealtimeLLMConfiguration,
|
||||
],
|
||||
Field(discriminator="provider"),
|
||||
]
|
||||
|
|
@ -993,6 +1059,116 @@ class MiniMaxTTSConfiguration(BaseTTSConfiguration):
|
|||
)
|
||||
|
||||
|
||||
AZURE_SPEECH_REGIONS = [
|
||||
"eastus",
|
||||
"eastus2",
|
||||
"westus",
|
||||
"westus2",
|
||||
"westus3",
|
||||
"centralus",
|
||||
"northcentralus",
|
||||
"southcentralus",
|
||||
"westcentralus",
|
||||
"westeurope",
|
||||
"northeurope",
|
||||
"uksouth",
|
||||
"ukwest",
|
||||
"francecentral",
|
||||
"switzerlandnorth",
|
||||
"germanywestcentral",
|
||||
"norwayeast",
|
||||
"australiaeast",
|
||||
"eastasia",
|
||||
"southeastasia",
|
||||
"japaneast",
|
||||
"japanwest",
|
||||
"koreacentral",
|
||||
"centralindia",
|
||||
"southindia",
|
||||
"brazilsouth",
|
||||
]
|
||||
|
||||
AZURE_SPEECH_TTS_LANGUAGES = [
|
||||
"en-US", "en-GB", "en-AU", "en-CA", "en-IN",
|
||||
"es-ES", "es-MX",
|
||||
"fr-FR", "fr-CA",
|
||||
"de-DE",
|
||||
"it-IT",
|
||||
"ja-JP",
|
||||
"ko-KR",
|
||||
"zh-CN", "zh-HK", "zh-TW",
|
||||
"pt-BR", "pt-PT",
|
||||
"ru-RU",
|
||||
"ar-SA",
|
||||
"nl-NL",
|
||||
"pl-PL",
|
||||
"sv-SE",
|
||||
"hi-IN",
|
||||
]
|
||||
|
||||
AZURE_SPEECH_TTS_VOICES = [
|
||||
"en-US-AriaNeural",
|
||||
"en-US-GuyNeural",
|
||||
"en-US-JennyNeural",
|
||||
"en-US-DavisNeural",
|
||||
"en-US-AmberNeural",
|
||||
"en-US-AnaNeural",
|
||||
"en-US-AshleyNeural",
|
||||
"en-US-BrandonNeural",
|
||||
"en-US-ChristopherNeural",
|
||||
"en-US-ElizabethNeural",
|
||||
"en-US-EricNeural",
|
||||
"en-US-JacobNeural",
|
||||
"en-US-MichelleNeural",
|
||||
"en-US-MonicaNeural",
|
||||
"en-US-NancyNeural",
|
||||
"en-US-RogerNeural",
|
||||
"en-US-SaraNeural",
|
||||
"en-US-SteffanNeural",
|
||||
"en-US-TonyNeural",
|
||||
]
|
||||
|
||||
|
||||
@register_tts
|
||||
class AzureSpeechTTSConfiguration(BaseTTSConfiguration):
|
||||
model_config = AZURE_SPEECH_PROVIDER_MODEL_CONFIG
|
||||
provider: Literal[ServiceProviders.AZURE_SPEECH] = ServiceProviders.AZURE_SPEECH
|
||||
model: str = Field(
|
||||
default="neural",
|
||||
description="Azure Speech synthesis engine (neural voices only).",
|
||||
json_schema_extra={"examples": ["neural"]},
|
||||
)
|
||||
region: str = Field(
|
||||
default="eastus",
|
||||
description="Azure region for Speech Services (e.g. 'eastus', 'westeurope').",
|
||||
json_schema_extra={
|
||||
"examples": AZURE_SPEECH_REGIONS,
|
||||
},
|
||||
)
|
||||
voice: str = Field(
|
||||
default="en-US-AriaNeural",
|
||||
description="Azure Neural voice name (e.g. 'en-US-AriaNeural').",
|
||||
json_schema_extra={
|
||||
"examples": AZURE_SPEECH_TTS_VOICES,
|
||||
"allow_custom_input": True,
|
||||
},
|
||||
)
|
||||
language: str = Field(
|
||||
default="en-US",
|
||||
description="BCP-47 language code for synthesis.",
|
||||
json_schema_extra={
|
||||
"examples": AZURE_SPEECH_TTS_LANGUAGES,
|
||||
"allow_custom_input": True,
|
||||
},
|
||||
)
|
||||
speed: float = Field(
|
||||
default=1.0,
|
||||
ge=0.5,
|
||||
le=2.0,
|
||||
description="Speech speed multiplier (0.5 to 2.0).",
|
||||
)
|
||||
|
||||
|
||||
TTSConfig = Annotated[
|
||||
Union[
|
||||
DeepgramTTSConfiguration,
|
||||
|
|
@ -1006,6 +1182,7 @@ TTSConfig = Annotated[
|
|||
RimeTTSConfiguration,
|
||||
SpeachesTTSConfiguration,
|
||||
MiniMaxTTSConfiguration,
|
||||
AzureSpeechTTSConfiguration,
|
||||
],
|
||||
Field(discriminator="provider"),
|
||||
]
|
||||
|
|
@ -1227,6 +1404,50 @@ class GladiaSTTConfiguration(BaseSTTConfiguration):
|
|||
)
|
||||
|
||||
|
||||
AZURE_SPEECH_STT_LANGUAGES = [
|
||||
"en-US", "en-GB", "en-AU", "en-CA", "en-IN",
|
||||
"es-ES", "es-MX",
|
||||
"fr-FR", "fr-CA",
|
||||
"de-DE",
|
||||
"it-IT",
|
||||
"ja-JP",
|
||||
"ko-KR",
|
||||
"zh-CN",
|
||||
"pt-BR", "pt-PT",
|
||||
"ru-RU",
|
||||
"ar-SA",
|
||||
"nl-NL",
|
||||
"pl-PL",
|
||||
"hi-IN",
|
||||
]
|
||||
|
||||
|
||||
@register_stt
|
||||
class AzureSpeechSTTConfiguration(BaseSTTConfiguration):
|
||||
model_config = AZURE_SPEECH_PROVIDER_MODEL_CONFIG
|
||||
provider: Literal[ServiceProviders.AZURE_SPEECH] = ServiceProviders.AZURE_SPEECH
|
||||
model: str = Field(
|
||||
default="latest_long",
|
||||
description="Azure Speech recognition model (use 'latest_long' for continuous recognition).",
|
||||
json_schema_extra={"examples": ["latest_long", "latest_short"]},
|
||||
)
|
||||
region: str = Field(
|
||||
default="eastus",
|
||||
description="Azure region for Speech Services (e.g. 'eastus', 'westeurope').",
|
||||
json_schema_extra={
|
||||
"examples": AZURE_SPEECH_REGIONS,
|
||||
},
|
||||
)
|
||||
language: str = Field(
|
||||
default="en-US",
|
||||
description="BCP-47 language code for recognition.",
|
||||
json_schema_extra={
|
||||
"examples": AZURE_SPEECH_STT_LANGUAGES,
|
||||
"allow_custom_input": True,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
STTConfig = Annotated[
|
||||
Union[
|
||||
DeepgramSTTConfiguration,
|
||||
|
|
@ -1239,6 +1460,7 @@ STTConfig = Annotated[
|
|||
SpeachesSTTConfiguration,
|
||||
AssemblyAISTTConfiguration,
|
||||
GladiaSTTConfiguration,
|
||||
AzureSpeechSTTConfiguration,
|
||||
],
|
||||
Field(discriminator="provider"),
|
||||
]
|
||||
|
|
@ -1278,8 +1500,33 @@ class OpenRouterEmbeddingsConfiguration(BaseEmbeddingsConfiguration):
|
|||
)
|
||||
|
||||
|
||||
AZURE_EMBEDDING_MODELS = ["text-embedding-3-small", "text-embedding-3-large", "text-embedding-ada-002"]
|
||||
|
||||
|
||||
@register_embeddings
|
||||
class AzureOpenAIEmbeddingsConfiguration(BaseEmbeddingsConfiguration):
|
||||
model_config = AZURE_OPENAI_PROVIDER_MODEL_CONFIG
|
||||
provider: Literal[ServiceProviders.AZURE] = ServiceProviders.AZURE
|
||||
model: str = Field(
|
||||
default="text-embedding-3-small",
|
||||
description="Azure OpenAI embedding deployment name (must match the deployed model).",
|
||||
json_schema_extra={"examples": AZURE_EMBEDDING_MODELS, "allow_custom_input": True},
|
||||
)
|
||||
endpoint: str = Field(
|
||||
description="Azure OpenAI resource endpoint (e.g. https://<resource>.openai.azure.com).",
|
||||
)
|
||||
api_version: str = Field(
|
||||
default="2024-02-15-preview",
|
||||
description="Azure OpenAI API version for embeddings.",
|
||||
)
|
||||
|
||||
|
||||
EmbeddingsConfig = Annotated[
|
||||
Union[OpenAIEmbeddingsConfiguration, OpenRouterEmbeddingsConfiguration],
|
||||
Union[
|
||||
OpenAIEmbeddingsConfiguration,
|
||||
OpenRouterEmbeddingsConfiguration,
|
||||
AzureOpenAIEmbeddingsConfiguration,
|
||||
],
|
||||
Field(discriminator="provider"),
|
||||
]
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,8 @@
|
|||
"""Generative AI services for embeddings and document processing."""
|
||||
|
||||
from .embedding import (
|
||||
AzureEmbeddingAPIKeyNotConfiguredError,
|
||||
AzureOpenAIEmbeddingService,
|
||||
BaseEmbeddingService,
|
||||
EmbeddingAPIKeyNotConfiguredError,
|
||||
OpenAIEmbeddingService,
|
||||
|
|
@ -8,6 +10,8 @@ from .embedding import (
|
|||
from .json_parser import parse_llm_json
|
||||
|
||||
__all__ = [
|
||||
"AzureEmbeddingAPIKeyNotConfiguredError",
|
||||
"AzureOpenAIEmbeddingService",
|
||||
"BaseEmbeddingService",
|
||||
"EmbeddingAPIKeyNotConfiguredError",
|
||||
"OpenAIEmbeddingService",
|
||||
|
|
|
|||
|
|
@ -1,9 +1,12 @@
|
|||
"""Embedding services for document processing and retrieval."""
|
||||
|
||||
from .azure_openai_service import AzureEmbeddingAPIKeyNotConfiguredError, AzureOpenAIEmbeddingService
|
||||
from .base import BaseEmbeddingService
|
||||
from .openai_service import EmbeddingAPIKeyNotConfiguredError, OpenAIEmbeddingService
|
||||
|
||||
__all__ = [
|
||||
"AzureEmbeddingAPIKeyNotConfiguredError",
|
||||
"AzureOpenAIEmbeddingService",
|
||||
"BaseEmbeddingService",
|
||||
"EmbeddingAPIKeyNotConfiguredError",
|
||||
"OpenAIEmbeddingService",
|
||||
|
|
|
|||
119
api/services/gen_ai/embedding/azure_openai_service.py
Normal file
119
api/services/gen_ai/embedding/azure_openai_service.py
Normal file
|
|
@ -0,0 +1,119 @@
|
|||
"""Azure OpenAI embedding service.
|
||||
|
||||
Uses the Azure OpenAI REST API for text embeddings, compatible with
|
||||
text-embedding-3-small, text-embedding-3-large, and text-embedding-ada-002
|
||||
deployments.
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from loguru import logger
|
||||
from openai import AsyncAzureOpenAI
|
||||
|
||||
from api.db.db_client import DBClient
|
||||
from api.utils.url_security import validate_user_configured_service_url
|
||||
|
||||
from .base import BaseEmbeddingService
|
||||
|
||||
DEFAULT_MODEL_ID = "text-embedding-3-small"
|
||||
EMBEDDING_DIMENSION = 1536
|
||||
|
||||
|
||||
class AzureEmbeddingAPIKeyNotConfiguredError(Exception):
|
||||
"""Raised when Azure OpenAI credentials are not configured for embeddings."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
"Azure OpenAI endpoint or API key not configured. Please set your "
|
||||
"endpoint and API key in Model Configurations > Embedding to use "
|
||||
"document processing."
|
||||
)
|
||||
|
||||
|
||||
class AzureOpenAIEmbeddingService(BaseEmbeddingService):
|
||||
"""Embedding service using Azure OpenAI text-embedding deployments."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
db_client: DBClient,
|
||||
api_key: Optional[str] = None,
|
||||
endpoint: Optional[str] = None,
|
||||
model_id: str = DEFAULT_MODEL_ID,
|
||||
api_version: str = "2024-02-15-preview",
|
||||
):
|
||||
"""Initialize the Azure OpenAI embedding service.
|
||||
|
||||
Args:
|
||||
db_client: Database client for vector similarity search.
|
||||
api_key: Azure OpenAI API key.
|
||||
endpoint: Azure OpenAI resource endpoint (e.g. https://<resource>.openai.azure.com).
|
||||
model_id: Deployment name, used as both the deployment and model identifier.
|
||||
api_version: Azure OpenAI API version.
|
||||
"""
|
||||
self.db = db_client
|
||||
self.model_id = model_id
|
||||
|
||||
self._configured = bool(api_key and endpoint)
|
||||
if self._configured:
|
||||
validate_user_configured_service_url(endpoint, field_name="endpoint")
|
||||
self.client = AsyncAzureOpenAI(
|
||||
api_key=api_key,
|
||||
azure_endpoint=endpoint,
|
||||
api_version=api_version,
|
||||
)
|
||||
logger.info(
|
||||
f"Azure OpenAI embedding service initialized with deployment: {model_id}"
|
||||
)
|
||||
else:
|
||||
self.client = None
|
||||
logger.warning(
|
||||
"Azure OpenAI embedding service initialized without credentials. "
|
||||
"Operations will fail until endpoint and API key are configured."
|
||||
)
|
||||
|
||||
def get_model_id(self) -> str:
|
||||
return self.model_id
|
||||
|
||||
def get_embedding_dimension(self) -> int:
|
||||
return EMBEDDING_DIMENSION
|
||||
|
||||
def _ensure_configured(self):
|
||||
if not self._configured or self.client is None:
|
||||
raise AzureEmbeddingAPIKeyNotConfiguredError()
|
||||
|
||||
async def embed_texts(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Embed a batch of texts using Azure OpenAI API."""
|
||||
self._ensure_configured()
|
||||
try:
|
||||
response = await self.client.embeddings.create(
|
||||
input=texts,
|
||||
model=self.model_id,
|
||||
)
|
||||
return [item.embedding for item in response.data]
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating Azure OpenAI embeddings: {e}")
|
||||
raise
|
||||
|
||||
async def embed_query(self, query: str) -> List[float]:
|
||||
"""Embed a single query text using Azure OpenAI API."""
|
||||
self._ensure_configured()
|
||||
embeddings = await self.embed_texts([query])
|
||||
return embeddings[0]
|
||||
|
||||
async def search_similar_chunks(
|
||||
self,
|
||||
query: str,
|
||||
organization_id: int,
|
||||
limit: int = 5,
|
||||
document_uuids: Optional[List[str]] = None,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Search for similar chunks using vector similarity."""
|
||||
self._ensure_configured()
|
||||
query_embedding = await self.embed_query(query)
|
||||
return await self.db.search_similar_chunks(
|
||||
query_embedding=query_embedding,
|
||||
organization_id=organization_id,
|
||||
limit=limit,
|
||||
document_uuids=document_uuids,
|
||||
embedding_model=self.model_id,
|
||||
)
|
||||
242
api/services/pipecat/realtime/azure_realtime.py
Normal file
242
api/services/pipecat/realtime/azure_realtime.py
Normal file
|
|
@ -0,0 +1,242 @@
|
|||
"""Dograh subclass of pipecat's Azure OpenAI Realtime LLM service.
|
||||
|
||||
Layers Dograh engine integration quirks (mute gating, TTSSpeakFrame greeting
|
||||
trigger, LLMMessagesAppendFrame handling, deferred tool calls) onto pipecat's
|
||||
AzureRealtimeLLMService, mirroring what DograhOpenAIRealtimeLLMService does
|
||||
for the standard OpenAI Realtime endpoint.
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
BotStartedSpeakingFrame,
|
||||
BotStoppedSpeakingFrame,
|
||||
Frame,
|
||||
LLMFullResponseStartFrame,
|
||||
LLMMessagesAppendFrame,
|
||||
TranscriptionFrame,
|
||||
TTSSpeakFrame,
|
||||
UserMuteStartedFrame,
|
||||
UserMuteStoppedFrame,
|
||||
)
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.azure.realtime.llm import AzureRealtimeLLMService
|
||||
from pipecat.services.llm_service import FunctionCallFromLLM
|
||||
from pipecat.services.openai.realtime import events
|
||||
from pipecat.transcriptions.language import Language
|
||||
from pipecat.utils.time import time_now_iso8601
|
||||
|
||||
|
||||
class DograhAzureRealtimeLLMService(AzureRealtimeLLMService):
|
||||
"""Azure OpenAI Realtime with Dograh engine integration quirks.
|
||||
|
||||
Extends AzureRealtimeLLMService with the same Dograh-specific behaviours
|
||||
added to DograhOpenAIRealtimeLLMService:
|
||||
- User-mute audio gating
|
||||
- TTSSpeakFrame as initial-response trigger
|
||||
- One-off LLMMessagesAppendFrame handling
|
||||
- Deferred tool calls until bot finishes speaking
|
||||
- finalized=True on TranscriptionFrame for consistency
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._user_is_muted: bool = False
|
||||
self._handled_initial_context: bool = False
|
||||
self._bot_is_speaking: bool = False
|
||||
self._deferred_function_calls: list[FunctionCallFromLLM] = []
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
if isinstance(frame, UserMuteStartedFrame):
|
||||
self._user_is_muted = True
|
||||
await self.push_frame(frame, direction)
|
||||
return
|
||||
if isinstance(frame, UserMuteStoppedFrame):
|
||||
self._user_is_muted = False
|
||||
await self.push_frame(frame, direction)
|
||||
return
|
||||
if isinstance(frame, TTSSpeakFrame):
|
||||
if not self._handled_initial_context:
|
||||
await self._handle_context(self._context)
|
||||
else:
|
||||
logger.warning(
|
||||
f"{self}: TTSSpeakFrame after initial context already handled — "
|
||||
"Azure Realtime owns audio generation, ignoring"
|
||||
)
|
||||
return
|
||||
if isinstance(frame, LLMMessagesAppendFrame):
|
||||
await self._handle_messages_append(frame)
|
||||
return
|
||||
if isinstance(frame, BotStartedSpeakingFrame):
|
||||
self._bot_is_speaking = True
|
||||
elif isinstance(frame, BotStoppedSpeakingFrame):
|
||||
self._bot_is_speaking = False
|
||||
await self._run_pending_function_calls()
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
async def _handle_messages_append(self, frame: LLMMessagesAppendFrame):
|
||||
if self._disconnecting:
|
||||
return
|
||||
|
||||
if not self._api_session_ready:
|
||||
if frame.run_llm:
|
||||
logger.debug(
|
||||
f"{self}: LLMMessagesAppendFrame received before session ready; "
|
||||
"deferring response until the session is initialized"
|
||||
)
|
||||
self._run_llm_when_api_session_ready = True
|
||||
return
|
||||
|
||||
appended_any = False
|
||||
for message in frame.messages:
|
||||
item = self._message_to_conversation_item(message)
|
||||
if item is None:
|
||||
continue
|
||||
evt = events.ConversationItemCreateEvent(item=item)
|
||||
self._messages_added_manually[evt.item.id] = True
|
||||
await self.send_client_event(evt)
|
||||
appended_any = True
|
||||
|
||||
if frame.run_llm and appended_any:
|
||||
await self._send_manual_response_create()
|
||||
|
||||
async def _handle_context(self, context: LLMContext):
|
||||
if not self._handled_initial_context:
|
||||
if context is None:
|
||||
logger.warning(
|
||||
f"{self}: received initial context trigger before context was set"
|
||||
)
|
||||
return
|
||||
self._handled_initial_context = True
|
||||
self._context = context
|
||||
await self._create_response()
|
||||
else:
|
||||
self._context = context
|
||||
await self._process_completed_function_calls(send_new_results=True)
|
||||
|
||||
async def _send_user_audio(self, frame):
|
||||
if self._user_is_muted:
|
||||
return
|
||||
await super()._send_user_audio(frame)
|
||||
|
||||
def _message_to_conversation_item(
|
||||
self, message: dict[str, Any]
|
||||
) -> events.ConversationItem | None:
|
||||
if not isinstance(message, dict):
|
||||
logger.warning(
|
||||
f"{self}: skipping unsupported appended message payload {message!r}"
|
||||
)
|
||||
return None
|
||||
|
||||
role = message.get("role")
|
||||
if role not in {"user", "system", "developer"}:
|
||||
logger.warning(
|
||||
f"{self}: skipping unsupported appended message role {role!r}"
|
||||
)
|
||||
return None
|
||||
|
||||
text = self._extract_text_content(message.get("content"))
|
||||
if not text:
|
||||
logger.warning(
|
||||
f"{self}: skipping appended message with unsupported content {message!r}"
|
||||
)
|
||||
return None
|
||||
|
||||
item_role = "system" if role in {"system", "developer"} else "user"
|
||||
return events.ConversationItem(
|
||||
type="message",
|
||||
role=item_role,
|
||||
content=[events.ItemContent(type="input_text", text=text)],
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _extract_text_content(content: Any) -> str | None:
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
if isinstance(content, list):
|
||||
parts: list[str] = []
|
||||
for part in content:
|
||||
if not isinstance(part, dict):
|
||||
return None
|
||||
if part.get("type") != "text":
|
||||
return None
|
||||
text = part.get("text")
|
||||
if not isinstance(text, str):
|
||||
return None
|
||||
parts.append(text)
|
||||
return "\n".join(parts) if parts else None
|
||||
return None
|
||||
|
||||
async def _send_manual_response_create(self):
|
||||
await self.push_frame(LLMFullResponseStartFrame())
|
||||
await self.start_processing_metrics()
|
||||
await self.start_ttfb_metrics()
|
||||
await self.send_client_event(
|
||||
events.ResponseCreateEvent(
|
||||
response=events.ResponseProperties(
|
||||
output_modalities=self._get_enabled_modalities()
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
async def _run_pending_function_calls(self):
|
||||
if not self._deferred_function_calls:
|
||||
return
|
||||
function_calls = self._deferred_function_calls
|
||||
self._deferred_function_calls = []
|
||||
logger.debug(
|
||||
f"{self}: executing {len(function_calls)} deferred function call(s) "
|
||||
"after bot turn ended"
|
||||
)
|
||||
await self.run_function_calls(function_calls)
|
||||
|
||||
async def _handle_evt_function_call_arguments_done(self, evt):
|
||||
try:
|
||||
args = json.loads(evt.arguments)
|
||||
|
||||
function_call_item = self._pending_function_calls.get(evt.call_id)
|
||||
if function_call_item:
|
||||
del self._pending_function_calls[evt.call_id]
|
||||
|
||||
function_calls = [
|
||||
FunctionCallFromLLM(
|
||||
context=self._context,
|
||||
tool_call_id=evt.call_id,
|
||||
function_name=function_call_item.name,
|
||||
arguments=args,
|
||||
)
|
||||
]
|
||||
|
||||
if self._bot_is_speaking:
|
||||
self._deferred_function_calls.extend(function_calls)
|
||||
logger.debug(
|
||||
f"{self}: deferring function call {function_call_item.name} "
|
||||
"until bot stops speaking"
|
||||
)
|
||||
else:
|
||||
await self.run_function_calls(function_calls)
|
||||
logger.debug(f"Processed function call: {function_call_item.name}")
|
||||
else:
|
||||
logger.warning(
|
||||
f"No tracked function call found for call_id: {evt.call_id}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to process function call arguments: {e}")
|
||||
|
||||
async def handle_evt_input_audio_transcription_completed(self, evt):
|
||||
await self._call_event_handler(
|
||||
"on_conversation_item_updated", evt.item_id, None
|
||||
)
|
||||
await self.broadcast_frame(
|
||||
TranscriptionFrame,
|
||||
text=evt.transcript,
|
||||
user_id="",
|
||||
timestamp=time_now_iso8601(),
|
||||
result=evt,
|
||||
finalized=True,
|
||||
)
|
||||
await self._handle_user_transcription(evt.transcript, True, Language.EN)
|
||||
|
|
@ -504,10 +504,16 @@ async def _run_pipeline(
|
|||
embeddings_api_key = None
|
||||
embeddings_model = None
|
||||
embeddings_base_url = None
|
||||
embeddings_provider = None
|
||||
embeddings_endpoint = None
|
||||
embeddings_api_version = None
|
||||
if user_config and user_config.embeddings:
|
||||
embeddings_api_key = user_config.embeddings.api_key
|
||||
embeddings_model = user_config.embeddings.model
|
||||
embeddings_provider = getattr(user_config.embeddings, "provider", None)
|
||||
embeddings_base_url = getattr(user_config.embeddings, "base_url", None)
|
||||
embeddings_endpoint = getattr(user_config.embeddings, "endpoint", None)
|
||||
embeddings_api_version = getattr(user_config.embeddings, "api_version", None)
|
||||
|
||||
# Check if the workflow has any active recordings so the engine can
|
||||
# include recording response mode instructions in all node prompts.
|
||||
|
|
@ -532,6 +538,9 @@ async def _run_pipeline(
|
|||
embeddings_api_key=embeddings_api_key,
|
||||
embeddings_model=embeddings_model,
|
||||
embeddings_base_url=embeddings_base_url,
|
||||
embeddings_provider=embeddings_provider,
|
||||
embeddings_endpoint=embeddings_endpoint,
|
||||
embeddings_api_version=embeddings_api_version,
|
||||
has_recordings=has_recordings,
|
||||
context_compaction_enabled=context_compaction_enabled,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -11,6 +11,8 @@ from api.utils.url_security import validate_user_configured_service_url
|
|||
from pipecat.services.assemblyai.stt import AssemblyAISTTService, AssemblyAISTTSettings
|
||||
from pipecat.services.aws.llm import AWSBedrockLLMService, AWSBedrockLLMSettings
|
||||
from pipecat.services.azure.llm import AzureLLMService, AzureLLMSettings
|
||||
from pipecat.services.azure.stt import AzureSTTService, AzureSTTSettings
|
||||
from pipecat.services.azure.tts import AzureTTSService, AzureTTSSettings
|
||||
from pipecat.services.cartesia.stt import CartesiaSTTService
|
||||
from pipecat.services.cartesia.tts import (
|
||||
CartesiaTTSService,
|
||||
|
|
@ -246,6 +248,22 @@ def create_stt_service(
|
|||
),
|
||||
sample_rate=audio_config.transport_in_sample_rate,
|
||||
)
|
||||
elif user_config.stt.provider == ServiceProviders.AZURE_SPEECH.value:
|
||||
from pipecat.transcriptions.language import Language as PipecatLanguage
|
||||
|
||||
language_code = getattr(user_config.stt, "language", None) or "en-US"
|
||||
region = getattr(user_config.stt, "region", None) or "eastus"
|
||||
# Try to map BCP-47 string to pipecat Language enum; fall back to string
|
||||
try:
|
||||
pipecat_language = PipecatLanguage(language_code)
|
||||
except ValueError:
|
||||
pipecat_language = PipecatLanguage.EN_US
|
||||
return AzureSTTService(
|
||||
api_key=user_config.stt.api_key,
|
||||
region=region,
|
||||
settings=AzureSTTSettings(language=pipecat_language),
|
||||
sample_rate=audio_config.transport_in_sample_rate,
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=400, detail=f"Invalid STT provider {user_config.stt.provider}"
|
||||
|
|
@ -492,6 +510,27 @@ def create_tts_service(user_config, audio_config: "AudioConfig"):
|
|||
skip_aggregator_types=["recording_router", "recording"],
|
||||
silence_time_s=1.0,
|
||||
)
|
||||
elif user_config.tts.provider == ServiceProviders.AZURE_SPEECH.value:
|
||||
region = getattr(user_config.tts, "region", None) or "eastus"
|
||||
voice = getattr(user_config.tts, "voice", None) or "en-US-AriaNeural"
|
||||
language = getattr(user_config.tts, "language", None) or "en-US"
|
||||
speed = getattr(user_config.tts, "speed", None) or 1.0
|
||||
# Map speed multiplier (0.5–2.0) to Azure SSML rate string (e.g. "1.25")
|
||||
rate = str(speed) if speed != 1.0 else None
|
||||
settings_kwargs: dict = {
|
||||
"voice": voice,
|
||||
"language": language,
|
||||
}
|
||||
if rate:
|
||||
settings_kwargs["rate"] = rate
|
||||
return AzureTTSService(
|
||||
api_key=user_config.tts.api_key,
|
||||
region=region,
|
||||
settings=AzureTTSSettings(**settings_kwargs),
|
||||
text_filters=[xml_function_tag_filter],
|
||||
skip_aggregator_types=["recording_router", "recording"],
|
||||
silence_time_s=1.0,
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=400, detail=f"Invalid TTS provider {user_config.tts.provider}"
|
||||
|
|
@ -724,6 +763,44 @@ def create_realtime_llm_service(user_config, audio_config: "AudioConfig"):
|
|||
location=location,
|
||||
settings=DograhGeminiLiveVertexLLMService.Settings(**settings_kwargs),
|
||||
)
|
||||
elif provider == ServiceProviders.AZURE_REALTIME.value:
|
||||
from api.services.pipecat.realtime.azure_realtime import (
|
||||
DograhAzureRealtimeLLMService,
|
||||
)
|
||||
from pipecat.services.openai.realtime.events import (
|
||||
AudioConfiguration,
|
||||
AudioInput,
|
||||
AudioOutput,
|
||||
InputAudioTranscription,
|
||||
SessionProperties,
|
||||
)
|
||||
|
||||
endpoint = getattr(realtime_config, "endpoint", None) or ""
|
||||
api_version = getattr(realtime_config, "api_version", None) or "2025-04-01-preview"
|
||||
# Construct the Azure Realtime WebSocket URL
|
||||
# https://<resource>.openai.azure.com/openai/realtime?api-version=<ver>&deployment=<model>
|
||||
base_host = endpoint.rstrip("/").replace("https://", "").replace("http://", "")
|
||||
wss_url = (
|
||||
f"wss://{base_host}/openai/realtime"
|
||||
f"?api-version={api_version}&deployment={model}"
|
||||
)
|
||||
return DograhAzureRealtimeLLMService(
|
||||
api_key=api_key,
|
||||
base_url=wss_url,
|
||||
settings=DograhAzureRealtimeLLMService.Settings(
|
||||
model=model,
|
||||
session_properties=SessionProperties(
|
||||
audio=AudioConfiguration(
|
||||
input=AudioInput(
|
||||
transcription=InputAudioTranscription(),
|
||||
),
|
||||
output=AudioOutput(
|
||||
voice=voice or "alloy",
|
||||
),
|
||||
),
|
||||
),
|
||||
),
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=400, detail=f"Invalid realtime LLM provider {provider}"
|
||||
|
|
|
|||
|
|
@ -73,6 +73,9 @@ class PipecatEngine:
|
|||
embeddings_api_key: Optional[str] = None,
|
||||
embeddings_model: Optional[str] = None,
|
||||
embeddings_base_url: Optional[str] = None,
|
||||
embeddings_provider: Optional[str] = None,
|
||||
embeddings_endpoint: Optional[str] = None,
|
||||
embeddings_api_version: Optional[str] = None,
|
||||
has_recordings: bool = False,
|
||||
context_compaction_enabled: bool = False,
|
||||
):
|
||||
|
|
@ -126,6 +129,9 @@ class PipecatEngine:
|
|||
self._embeddings_api_key: Optional[str] = embeddings_api_key
|
||||
self._embeddings_model: Optional[str] = embeddings_model
|
||||
self._embeddings_base_url: Optional[str] = embeddings_base_url
|
||||
self._embeddings_provider: Optional[str] = embeddings_provider
|
||||
self._embeddings_endpoint: Optional[str] = embeddings_endpoint
|
||||
self._embeddings_api_version: Optional[str] = embeddings_api_version
|
||||
|
||||
# Audio configuration (set via set_audio_config from _run_pipeline)
|
||||
self._audio_config = None
|
||||
|
|
@ -373,6 +379,9 @@ class PipecatEngine:
|
|||
embeddings_api_key=self._embeddings_api_key,
|
||||
embeddings_model=self._embeddings_model,
|
||||
embeddings_base_url=self._embeddings_base_url,
|
||||
embeddings_provider=self._embeddings_provider,
|
||||
embeddings_endpoint=self._embeddings_endpoint,
|
||||
embeddings_api_version=self._embeddings_api_version,
|
||||
tracing_context=self._get_otel_context(),
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -13,7 +13,8 @@ from loguru import logger
|
|||
from opentelemetry import trace
|
||||
|
||||
from api.db import db_client
|
||||
from api.services.gen_ai import OpenAIEmbeddingService
|
||||
from api.services.configuration.registry import ServiceProviders
|
||||
from api.services.gen_ai import AzureOpenAIEmbeddingService, OpenAIEmbeddingService
|
||||
from api.services.pipecat.tracing_config import ensure_tracing
|
||||
|
||||
|
||||
|
|
@ -25,6 +26,9 @@ async def retrieve_from_knowledge_base(
|
|||
embeddings_api_key: Optional[str] = None,
|
||||
embeddings_model: Optional[str] = None,
|
||||
embeddings_base_url: Optional[str] = None,
|
||||
embeddings_provider: Optional[str] = None,
|
||||
embeddings_endpoint: Optional[str] = None,
|
||||
embeddings_api_version: Optional[str] = None,
|
||||
tracing_context=None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Retrieve relevant information from the knowledge base using vector similarity search.
|
||||
|
|
@ -68,6 +72,9 @@ async def retrieve_from_knowledge_base(
|
|||
embeddings_api_key,
|
||||
embeddings_model,
|
||||
embeddings_base_url,
|
||||
embeddings_provider,
|
||||
embeddings_endpoint,
|
||||
embeddings_api_version,
|
||||
)
|
||||
|
||||
# Create span with parent context
|
||||
|
|
@ -105,6 +112,9 @@ async def retrieve_from_knowledge_base(
|
|||
embeddings_api_key,
|
||||
embeddings_model,
|
||||
embeddings_base_url,
|
||||
embeddings_provider,
|
||||
embeddings_endpoint,
|
||||
embeddings_api_version,
|
||||
)
|
||||
|
||||
# Add result metadata to span
|
||||
|
|
@ -179,6 +189,9 @@ async def retrieve_from_knowledge_base(
|
|||
embeddings_api_key,
|
||||
embeddings_model,
|
||||
embeddings_base_url,
|
||||
embeddings_provider,
|
||||
embeddings_endpoint,
|
||||
embeddings_api_version,
|
||||
)
|
||||
else:
|
||||
# Tracing is disabled - perform retrieval without tracing
|
||||
|
|
@ -189,6 +202,10 @@ async def retrieve_from_knowledge_base(
|
|||
limit,
|
||||
embeddings_api_key,
|
||||
embeddings_model,
|
||||
embeddings_base_url,
|
||||
embeddings_provider,
|
||||
embeddings_endpoint,
|
||||
embeddings_api_version,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -200,6 +217,9 @@ async def _perform_retrieval(
|
|||
embeddings_api_key: Optional[str] = None,
|
||||
embeddings_model: Optional[str] = None,
|
||||
embeddings_base_url: Optional[str] = None,
|
||||
embeddings_provider: Optional[str] = None,
|
||||
embeddings_endpoint: Optional[str] = None,
|
||||
embeddings_api_version: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Internal function to perform the actual retrieval operation.
|
||||
|
||||
|
|
@ -240,12 +260,21 @@ async def _perform_retrieval(
|
|||
"Model Configurations > Embedding."
|
||||
)
|
||||
|
||||
embedding_service = OpenAIEmbeddingService(
|
||||
db_client=db_client,
|
||||
api_key=embeddings_api_key,
|
||||
model_id=embeddings_model or "text-embedding-3-small",
|
||||
base_url=embeddings_base_url,
|
||||
)
|
||||
if embeddings_provider == ServiceProviders.AZURE.value and embeddings_endpoint:
|
||||
embedding_service = AzureOpenAIEmbeddingService(
|
||||
db_client=db_client,
|
||||
api_key=embeddings_api_key,
|
||||
endpoint=embeddings_endpoint,
|
||||
model_id=embeddings_model or "text-embedding-3-small",
|
||||
api_version=embeddings_api_version or "2024-02-15-preview",
|
||||
)
|
||||
else:
|
||||
embedding_service = OpenAIEmbeddingService(
|
||||
db_client=db_client,
|
||||
api_key=embeddings_api_key,
|
||||
model_id=embeddings_model or "text-embedding-3-small",
|
||||
base_url=embeddings_base_url,
|
||||
)
|
||||
|
||||
results = await embedding_service.search_similar_chunks(
|
||||
query=query,
|
||||
|
|
|
|||
|
|
@ -12,7 +12,8 @@ from loguru import logger
|
|||
|
||||
from api.db import db_client
|
||||
from api.db.models import KnowledgeBaseChunkModel
|
||||
from api.services.gen_ai import OpenAIEmbeddingService
|
||||
from api.services.configuration.registry import ServiceProviders
|
||||
from api.services.gen_ai import AzureOpenAIEmbeddingService, OpenAIEmbeddingService
|
||||
from api.services.mps_service_key_client import mps_service_key_client
|
||||
from api.services.storage import storage_fs
|
||||
|
||||
|
|
@ -148,21 +149,30 @@ async def process_knowledge_base_document(
|
|||
)
|
||||
return
|
||||
|
||||
# Chunked mode: fetch user embedding config, embed via OpenAI, persist chunks.
|
||||
# Chunked mode: fetch user embedding config, embed, and persist chunks.
|
||||
embeddings_provider = None
|
||||
embeddings_api_key = None
|
||||
embeddings_model = None
|
||||
embeddings_base_url = None
|
||||
embeddings_endpoint = None
|
||||
embeddings_api_version = None
|
||||
if document.created_by:
|
||||
user_config = await db_client.get_user_configurations(document.created_by)
|
||||
if user_config.embeddings:
|
||||
embeddings_provider = getattr(user_config.embeddings, "provider", None)
|
||||
embeddings_api_key = user_config.embeddings.api_key
|
||||
embeddings_model = user_config.embeddings.model
|
||||
embeddings_base_url = getattr(user_config.embeddings, "base_url", None)
|
||||
logger.info(f"Using user embeddings config: model={embeddings_model}")
|
||||
embeddings_endpoint = getattr(user_config.embeddings, "endpoint", None)
|
||||
embeddings_api_version = getattr(user_config.embeddings, "api_version", None)
|
||||
logger.info(
|
||||
f"Using user embeddings config: provider={embeddings_provider}, "
|
||||
f"model={embeddings_model}"
|
||||
)
|
||||
|
||||
if not embeddings_api_key:
|
||||
error_message = (
|
||||
"OpenAI API key not configured. Please set your API key in "
|
||||
"API key not configured. Please set your API key in "
|
||||
"Model Configurations > Embedding to process documents."
|
||||
)
|
||||
logger.warning(f"Document {document_id}: {error_message}")
|
||||
|
|
@ -171,12 +181,21 @@ async def process_knowledge_base_document(
|
|||
)
|
||||
return
|
||||
|
||||
embedding_service = OpenAIEmbeddingService(
|
||||
db_client=db_client,
|
||||
api_key=embeddings_api_key,
|
||||
model_id=embeddings_model or "text-embedding-3-small",
|
||||
base_url=embeddings_base_url,
|
||||
)
|
||||
if embeddings_provider == ServiceProviders.AZURE.value and embeddings_endpoint:
|
||||
embedding_service = AzureOpenAIEmbeddingService(
|
||||
db_client=db_client,
|
||||
api_key=embeddings_api_key,
|
||||
endpoint=embeddings_endpoint,
|
||||
model_id=embeddings_model or "text-embedding-3-small",
|
||||
api_version=embeddings_api_version or "2024-02-15-preview",
|
||||
)
|
||||
else:
|
||||
embedding_service = OpenAIEmbeddingService(
|
||||
db_client=db_client,
|
||||
api_key=embeddings_api_key,
|
||||
model_id=embeddings_model or "text-embedding-3-small",
|
||||
base_url=embeddings_base_url,
|
||||
)
|
||||
|
||||
mps_chunks = mps_response.get("chunks", [])
|
||||
if not mps_chunks:
|
||||
|
|
|
|||
81
api/tests/test_azure_speech_service_factory.py
Normal file
81
api/tests/test_azure_speech_service_factory.py
Normal file
|
|
@ -0,0 +1,81 @@
|
|||
"""Tests for Azure Speech TTS/STT service factory dispatch."""
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import patch
|
||||
|
||||
from api.services.configuration.registry import ServiceProviders
|
||||
from api.services.pipecat.service_factory import create_stt_service, create_tts_service
|
||||
|
||||
|
||||
def _audio_config():
|
||||
return SimpleNamespace(
|
||||
transport_out_sample_rate=24000,
|
||||
transport_in_sample_rate=16000,
|
||||
)
|
||||
|
||||
|
||||
def test_create_azure_speech_tts_service():
|
||||
user_config = SimpleNamespace(
|
||||
tts=SimpleNamespace(
|
||||
provider=ServiceProviders.AZURE_SPEECH.value,
|
||||
api_key="test-subscription-key",
|
||||
region="eastus",
|
||||
voice="en-US-AriaNeural",
|
||||
language="en-US",
|
||||
speed=1.0,
|
||||
model="neural",
|
||||
)
|
||||
)
|
||||
|
||||
with patch("api.services.pipecat.service_factory.AzureTTSService") as mock_service:
|
||||
create_tts_service(user_config, _audio_config())
|
||||
|
||||
assert mock_service.call_count == 1
|
||||
kwargs = mock_service.call_args.kwargs
|
||||
assert kwargs["api_key"] == "test-subscription-key"
|
||||
assert kwargs["region"] == "eastus"
|
||||
assert kwargs["settings"].voice == "en-US-AriaNeural"
|
||||
assert kwargs["settings"].language == "en-US"
|
||||
|
||||
|
||||
def test_create_azure_speech_tts_service_with_speed():
|
||||
user_config = SimpleNamespace(
|
||||
tts=SimpleNamespace(
|
||||
provider=ServiceProviders.AZURE_SPEECH.value,
|
||||
api_key="test-key",
|
||||
region="westeurope",
|
||||
voice="en-GB-SoniaNeural",
|
||||
language="en-GB",
|
||||
speed=1.5,
|
||||
model="neural",
|
||||
)
|
||||
)
|
||||
|
||||
with patch("api.services.pipecat.service_factory.AzureTTSService") as mock_service:
|
||||
create_tts_service(user_config, _audio_config())
|
||||
|
||||
assert mock_service.call_count == 1
|
||||
kwargs = mock_service.call_args.kwargs
|
||||
assert kwargs["region"] == "westeurope"
|
||||
assert kwargs["settings"].rate == "1.5"
|
||||
|
||||
|
||||
def test_create_azure_speech_stt_service():
|
||||
user_config = SimpleNamespace(
|
||||
stt=SimpleNamespace(
|
||||
provider=ServiceProviders.AZURE_SPEECH.value,
|
||||
api_key="test-subscription-key",
|
||||
region="eastus",
|
||||
language="en-US",
|
||||
model="latest_long",
|
||||
)
|
||||
)
|
||||
|
||||
with patch("api.services.pipecat.service_factory.AzureSTTService") as mock_service:
|
||||
create_stt_service(user_config, _audio_config())
|
||||
|
||||
assert mock_service.call_count == 1
|
||||
kwargs = mock_service.call_args.kwargs
|
||||
assert kwargs["api_key"] == "test-subscription-key"
|
||||
assert kwargs["region"] == "eastus"
|
||||
assert kwargs["sample_rate"] == 16000
|
||||
Loading…
Add table
Add a link
Reference in a new issue