diff --git a/api/db/organization_usage_client.py b/api/db/organization_usage_client.py index f845fc75..826275e6 100644 --- a/api/db/organization_usage_client.py +++ b/api/db/organization_usage_client.py @@ -19,7 +19,7 @@ from api.db.models import ( WorkflowRunModel, ) from api.enums import OrganizationConfigurationKey -from api.schemas.user_configuration import EffectiveAIModelConfiguration +from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration class OrganizationUsageClient(BaseDBClient): @@ -473,11 +473,11 @@ class OrganizationUsageClient(BaseDBClient): ) config_obj = config_result.scalar_one_or_none() if config_obj and config_obj.configuration: - user_config = EffectiveAIModelConfiguration.model_validate( + effective_config = EffectiveAIModelConfiguration.model_validate( config_obj.configuration ) - if user_config.timezone and user_timezone == "UTC": - user_timezone = user_config.timezone + if effective_config.timezone and user_timezone == "UTC": + user_timezone = effective_config.timezone # Validate timezone string try: diff --git a/api/db/user_client.py b/api/db/user_client.py index 9c4476f2..4ea0bca9 100644 --- a/api/db/user_client.py +++ b/api/db/user_client.py @@ -8,7 +8,7 @@ from sqlalchemy.future import select from api.db.base_client import BaseDBClient from api.db.models import UserConfigurationModel, UserModel -from api.schemas.user_configuration import EffectiveAIModelConfiguration +from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration class UserClient(BaseDBClient): diff --git a/api/routes/knowledge_base.py b/api/routes/knowledge_base.py index d9156871..bd0ba046 100644 --- a/api/routes/knowledge_base.py +++ b/api/routes/knowledge_base.py @@ -384,7 +384,7 @@ async def search_chunks( user_id=user.id, organization_id=user.selected_organization_id, ) - user_config = resolved_config.effective + effective_config = resolved_config.effective embeddings_api_key = None embeddings_model = None embeddings_provider = None @@ -392,17 +392,17 @@ async def search_chunks( embeddings_endpoint = None embeddings_api_version = None - if user_config.embeddings: - embeddings_api_key = user_config.embeddings.api_key - embeddings_model = user_config.embeddings.model - embeddings_provider = getattr(user_config.embeddings, "provider", None) - embeddings_endpoint = getattr(user_config.embeddings, "endpoint", None) + if effective_config.embeddings: + embeddings_api_key = effective_config.embeddings.api_key + embeddings_model = effective_config.embeddings.model + embeddings_provider = getattr(effective_config.embeddings, "provider", None) + embeddings_endpoint = getattr(effective_config.embeddings, "endpoint", None) embeddings_base_url = apply_managed_embeddings_base_url( provider=embeddings_provider, - base_url=getattr(user_config.embeddings, "base_url", None), + base_url=getattr(effective_config.embeddings, "base_url", None), ) embeddings_api_version = getattr( - user_config.embeddings, "api_version", None + effective_config.embeddings, "api_version", None ) # Initialize embedding service based on provider diff --git a/api/routes/workflow.py b/api/routes/workflow.py index 9157c5cf..ab6f5da7 100644 --- a/api/routes/workflow.py +++ b/api/routes/workflow.py @@ -1053,13 +1053,15 @@ async def update_workflow( user_id=user.id, organization_id=user.selected_organization_id, ) - user_config = resolved_config.effective + effective_config = resolved_config.effective try: enriched_overrides = enrich_overrides_with_api_keys( workflow_configurations["model_overrides"], - user_config, + effective_config, + ) + effective = resolve_effective_config( + effective_config, enriched_overrides ) - effective = resolve_effective_config(user_config, enriched_overrides) if resolved_config.source == "organization_v2": v2_override = convert_legacy_ai_model_configuration_to_v2(effective) await UserConfigurationValidator().validate( diff --git a/api/schemas/ai_model_configuration.py b/api/schemas/ai_model_configuration.py index dcc3a6e7..c5403b04 100644 --- a/api/schemas/ai_model_configuration.py +++ b/api/schemas/ai_model_configuration.py @@ -1,10 +1,10 @@ from __future__ import annotations +from datetime import datetime from typing import Literal from pydantic import BaseModel, Field, model_validator -from api.schemas.user_configuration import EffectiveAIModelConfiguration from api.services.configuration.registry import ( DograhEmbeddingsConfiguration, DograhLLMService, @@ -23,6 +23,29 @@ DOGRAH_DEFAULT_VOICE = "default" DOGRAH_DEFAULT_LANGUAGE = "multi" +class EffectiveAIModelConfiguration(BaseModel): + llm: LLMConfig | None = None + stt: STTConfig | None = None + tts: TTSConfig | None = None + embeddings: EmbeddingsConfig | None = None + realtime: RealtimeConfig | None = None + is_realtime: bool = False + managed_service_version: int | None = None + test_phone_number: str | None = None + timezone: str | None = None + last_validated_at: datetime | None = None + + @model_validator(mode="before") + @classmethod + def strip_incomplete_realtime_when_disabled(cls, data): + """Skip realtime validation when is_realtime is False and api_key is missing.""" + if isinstance(data, dict) and not data.get("is_realtime", False): + realtime = data.get("realtime") + if isinstance(realtime, dict) and not realtime.get("api_key"): + data.pop("realtime", None) + return data + + class DograhManagedAIModelConfiguration(BaseModel): api_key: str voice: str = DOGRAH_DEFAULT_VOICE @@ -160,6 +183,7 @@ def _compile_dograh_configuration( model="default", ), is_realtime=False, + managed_service_version=2, ) diff --git a/api/schemas/user_configuration.py b/api/schemas/user_configuration.py deleted file mode 100644 index fc958a5b..00000000 --- a/api/schemas/user_configuration.py +++ /dev/null @@ -1,33 +0,0 @@ -from datetime import datetime - -from pydantic import BaseModel, model_validator - -from api.services.configuration.registry import ( - EmbeddingsConfig, - LLMConfig, - RealtimeConfig, - STTConfig, - TTSConfig, -) - - -class EffectiveAIModelConfiguration(BaseModel): - llm: LLMConfig | None = None - stt: STTConfig | None = None - tts: TTSConfig | None = None - embeddings: EmbeddingsConfig | None = None - realtime: RealtimeConfig | None = None - is_realtime: bool = False - test_phone_number: str | None = None - timezone: str | None = None - last_validated_at: datetime | None = None - - @model_validator(mode="before") - @classmethod - def strip_incomplete_realtime_when_disabled(cls, data): - """Skip realtime validation when is_realtime is False and api_key is missing.""" - if isinstance(data, dict) and not data.get("is_realtime", False): - realtime = data.get("realtime") - if isinstance(realtime, dict) and not realtime.get("api_key"): - data.pop("realtime", None) - return data diff --git a/api/services/auth/depends.py b/api/services/auth/depends.py index d9e24684..4bb4862a 100644 --- a/api/services/auth/depends.py +++ b/api/services/auth/depends.py @@ -9,7 +9,7 @@ from api.constants import AUTH_PROVIDER, DOGRAH_MPS_SECRET_KEY, MPS_API_URL from api.db import db_client from api.db.models import UserModel from api.enums import PostHogEvent -from api.schemas.user_configuration import EffectiveAIModelConfiguration +from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration from api.services.auth.stack_auth import stackauth from api.services.configuration.registry import ServiceProviders from api.services.posthog_client import capture_event @@ -285,8 +285,8 @@ async def create_user_configuration_with_mps_key( "model": "default", }, } - user_config = EffectiveAIModelConfiguration(**configuration) - return user_config + effective_config = EffectiveAIModelConfiguration(**configuration) + return effective_config else: logger.warning( f"Failed to get MPS service key: {response.status_code} - {response.text}" diff --git a/api/services/configuration/ai_model_configuration.py b/api/services/configuration/ai_model_configuration.py index 1b9a00f6..c5331515 100644 --- a/api/services/configuration/ai_model_configuration.py +++ b/api/services/configuration/ai_model_configuration.py @@ -21,10 +21,10 @@ from api.schemas.ai_model_configuration import ( BYOKPipelineAIModelConfiguration, BYOKRealtimeAIModelConfiguration, DograhManagedAIModelConfiguration, + EffectiveAIModelConfiguration, OrganizationAIModelConfigurationV2, compile_ai_model_configuration_v2, ) -from api.schemas.user_configuration import EffectiveAIModelConfiguration from api.services.configuration.masking import ( SERVICE_SECRET_FIELDS, contains_masked_key, diff --git a/api/services/configuration/check_validity.py b/api/services/configuration/check_validity.py index e8f5bfa7..cc17481f 100644 --- a/api/services/configuration/check_validity.py +++ b/api/services/configuration/check_validity.py @@ -8,7 +8,7 @@ from groq import Groq # from pyneuphonic import Neuphonic # except ImportError: # Neuphonic = None -from api.schemas.user_configuration import ( +from api.schemas.ai_model_configuration import ( EffectiveAIModelConfiguration, ) from api.services.configuration.registry import ServiceConfig, ServiceProviders diff --git a/api/services/configuration/masking.py b/api/services/configuration/masking.py index c3fa4bfc..a7e1af6a 100644 --- a/api/services/configuration/masking.py +++ b/api/services/configuration/masking.py @@ -12,7 +12,7 @@ The rules are simple: import copy from typing import Any, Dict, Optional -from api.schemas.user_configuration import EffectiveAIModelConfiguration +from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration from api.services.configuration.registry import ServiceConfig from api.services.integrations import get_node_secret_fields diff --git a/api/services/configuration/merge.py b/api/services/configuration/merge.py index 1b174ee8..3100fa45 100644 --- a/api/services/configuration/merge.py +++ b/api/services/configuration/merge.py @@ -7,7 +7,7 @@ stored, while honouring masked API keys. import copy from typing import Dict -from api.schemas.user_configuration import EffectiveAIModelConfiguration +from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration from api.services.configuration.masking import ( MODEL_OVERRIDE_FIELDS, SERVICE_SECRET_FIELDS, diff --git a/api/services/configuration/resolve.py b/api/services/configuration/resolve.py index a33f5c09..5cbf11ef 100644 --- a/api/services/configuration/resolve.py +++ b/api/services/configuration/resolve.py @@ -4,7 +4,7 @@ from __future__ import annotations import copy -from api.schemas.user_configuration import EffectiveAIModelConfiguration +from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration from api.services.configuration.registry import ( REGISTRY, ServiceType, diff --git a/api/services/gen_ai/embedding/openai_service.py b/api/services/gen_ai/embedding/openai_service.py index da5d3d4d..1081889e 100644 --- a/api/services/gen_ai/embedding/openai_service.py +++ b/api/services/gen_ai/embedding/openai_service.py @@ -38,6 +38,7 @@ class OpenAIEmbeddingService(BaseEmbeddingService): api_key: Optional[str] = None, model_id: str = DEFAULT_MODEL_ID, base_url: Optional[str] = None, + default_headers: Optional[Dict[str, str]] = None, ): """Initialize the OpenAI embedding service. @@ -60,6 +61,8 @@ class OpenAIEmbeddingService(BaseEmbeddingService): field_name="base_url", ) client_kwargs["base_url"] = base_url + if default_headers: + client_kwargs["default_headers"] = default_headers self.client = AsyncOpenAI(**client_kwargs) logger.info(f"OpenAI embedding service initialized with model: {model_id}") else: diff --git a/api/services/managed_model_services.py b/api/services/managed_model_services.py new file mode 100644 index 00000000..b6992aaf --- /dev/null +++ b/api/services/managed_model_services.py @@ -0,0 +1,98 @@ +from __future__ import annotations + +from typing import Any + +from loguru import logger + +from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration +from api.services.configuration.registry import ServiceProviders +from api.services.mps_service_key_client import mps_service_key_client + +MPS_CORRELATION_ID_CONTEXT_KEY = "mps_correlation_id" + + +def uses_managed_model_services_v2( + ai_model_config: EffectiveAIModelConfiguration | None, +) -> bool: + if ( + ai_model_config is None + or getattr(ai_model_config, "managed_service_version", None) != 2 + ): + return False + + return any( + _is_dograh_service(getattr(ai_model_config, section_name, None)) + for section_name in ("llm", "tts", "stt", "embeddings") + ) + + +def get_mps_correlation_id(initial_context: dict[str, Any] | None) -> str | None: + if not initial_context: + return None + correlation_id = initial_context.get(MPS_CORRELATION_ID_CONTEXT_KEY) + if correlation_id is None: + return None + return str(correlation_id) + + +async def ensure_mps_correlation_id( + *, + ai_model_config: EffectiveAIModelConfiguration, + workflow_run_id: int, + initial_context: dict[str, Any] | None, +) -> str | None: + existing = get_mps_correlation_id(initial_context) + if existing: + return existing + + if not uses_managed_model_services_v2(ai_model_config): + return None + + service_key = _get_dograh_service_api_key(ai_model_config) + if not service_key: + raise ValueError( + "Managed model services v2 requires a Dograh service key before the run starts." + ) + + response = await mps_service_key_client.create_correlation_id( + service_key=service_key, + workflow_run_id=workflow_run_id, + ) + correlation_id = response.get("correlation_id") + if not correlation_id: + raise ValueError("MPS correlation-id response did not include correlation_id") + + correlation_id = str(correlation_id) + logger.info( + "Minted MPS correlation id {} for workflow run {}", + correlation_id, + workflow_run_id, + ) + return correlation_id + + +def _is_dograh_service(service: Any) -> bool: + provider = getattr(service, "provider", None) + return ( + provider == ServiceProviders.DOGRAH or provider == ServiceProviders.DOGRAH.value + ) + + +def _get_dograh_service_api_key( + ai_model_config: EffectiveAIModelConfiguration, +) -> str | None: + for section_name in ("llm", "tts", "stt", "embeddings"): + service = getattr(ai_model_config, section_name, None) + if not _is_dograh_service(service): + continue + + if hasattr(service, "get_all_api_keys"): + keys = service.get_all_api_keys() + if keys: + return keys[0] + + api_key = getattr(service, "api_key", None) + if isinstance(api_key, str) and api_key: + return api_key + + return None diff --git a/api/services/mps_service_key_client.py b/api/services/mps_service_key_client.py index 2c7fc56b..f7ce749f 100644 --- a/api/services/mps_service_key_client.py +++ b/api/services/mps_service_key_client.py @@ -353,6 +353,40 @@ class MPSServiceKeyClient: response=response, ) + async def create_correlation_id( + self, + *, + service_key: str, + workflow_run_id: int | None = None, + ) -> dict: + """Mint a server-generated correlation ID for managed model services.""" + payload: dict[str, int] = {} + if workflow_run_id is not None: + payload["workflow_run_id"] = workflow_run_id + + async with httpx.AsyncClient(timeout=self.timeout) as client: + response = await client.post( + f"{self.base_url}/api/v1/service-keys/correlation-id/self", + json=payload, + headers={ + "Authorization": f"Bearer {service_key}", + "Content-Type": "application/json", + }, + ) + + if response.status_code == 200: + return response.json() + + logger.error( + "Failed to create correlation ID: " + f"{response.status_code} - {response.text}" + ) + raise httpx.HTTPStatusError( + f"Failed to create correlation ID: {response.text}", + request=response.request, + response=response, + ) + async def transcribe_audio( self, audio_data: bytes, diff --git a/api/services/pipecat/run_pipeline.py b/api/services/pipecat/run_pipeline.py index 63c11f53..a5f2d077 100644 --- a/api/services/pipecat/run_pipeline.py +++ b/api/services/pipecat/run_pipeline.py @@ -340,7 +340,7 @@ async def _run_pipeline( if workflow_run.is_completed: raise HTTPException(status_code=400, detail="Workflow run already completed") - merged_call_context_vars = workflow_run.initial_context + merged_call_context_vars = dict(workflow_run.initial_context or {}) # If there is some extra call_context_vars, fold them in. Persistence # happens once below, after runtime_configuration is also resolved. if call_context_vars: @@ -398,6 +398,19 @@ async def _run_pipeline( else: user_config = resolved_user_config + from api.services.managed_model_services import ( + MPS_CORRELATION_ID_CONTEXT_KEY, + ensure_mps_correlation_id, + ) + + mps_correlation_id = await ensure_mps_correlation_id( + ai_model_config=user_config, + workflow_run_id=workflow_run_id, + initial_context=merged_call_context_vars, + ) + if mps_correlation_id: + merged_call_context_vars[MPS_CORRELATION_ID_CONTEXT_KEY] = mps_correlation_id + # Detect realtime mode (speech-to-speech services like OpenAI Realtime, Gemini Live) is_realtime = user_config.is_realtime and user_config.realtime is not None @@ -409,11 +422,23 @@ async def _run_pipeline( # Realtime services don't implement run_inference, so create a # separate text LLM for variable extraction and other out-of-band # inference calls. - inference_llm = create_llm_service(user_config) + inference_llm = create_llm_service( + user_config, + correlation_id=mps_correlation_id, + ) else: - stt = create_stt_service(user_config, audio_config, keyterms=keyterms) - tts = create_tts_service(user_config, audio_config) - llm = create_llm_service(user_config) + stt = create_stt_service( + user_config, + audio_config, + keyterms=keyterms, + correlation_id=mps_correlation_id, + ) + tts = create_tts_service( + user_config, + audio_config, + correlation_id=mps_correlation_id, + ) + llm = create_llm_service(user_config, correlation_id=mps_correlation_id) inference_llm = None # Stamp the providers/models actually resolved for this run onto @@ -695,7 +720,10 @@ async def _run_pipeline( # Create a separate LLM instance for the voicemail sub-pipeline # (can't share with main pipeline as it would mess up frame linking) if voicemail_config.get("use_workflow_llm", True): - voicemail_llm = create_llm_service(user_config) + voicemail_llm = create_llm_service( + user_config, + correlation_id=mps_correlation_id, + ) else: voicemail_llm = create_llm_service_from_provider( provider=voicemail_config.get("provider", "openai"), diff --git a/api/services/pipecat/service_factory.py b/api/services/pipecat/service_factory.py index 8ed96e40..ec5e9911 100644 --- a/api/services/pipecat/service_factory.py +++ b/api/services/pipecat/service_factory.py @@ -78,7 +78,10 @@ def _validate_runtime_service_url(url: str, field_name: str) -> None: def create_stt_service( - user_config, audio_config: "AudioConfig", keyterms: list[str] | None = None + user_config, + audio_config: "AudioConfig", + keyterms: list[str] | None = None, + correlation_id: str | None = None, ): """Create and return appropriate STT service based on user configuration @@ -160,6 +163,7 @@ def create_stt_service( return DograhSTTService( base_url=base_url, api_key=user_config.stt.api_key, + correlation_id=correlation_id, settings=DograhSTTSettings( model=user_config.stt.model, language=language, @@ -286,7 +290,9 @@ def create_stt_service( ) -def create_tts_service(user_config, audio_config: "AudioConfig"): +def create_tts_service( + user_config, audio_config: "AudioConfig", correlation_id: str | None = None +): """Create and return appropriate TTS service based on user configuration Args: @@ -404,6 +410,7 @@ def create_tts_service(user_config, audio_config: "AudioConfig"): return DograhTTSService( base_url=base_url, api_key=user_config.tts.api_key, + correlation_id=correlation_id, settings=DograhTTSSettings( model=user_config.tts.model, voice=user_config.tts.voice, @@ -564,6 +571,7 @@ def create_llm_service_from_provider( model: str, api_key: str | None, *, + correlation_id: str | None = None, base_url: str | None = None, endpoint: str | None = None, aws_access_key: str | None = None, @@ -637,6 +645,7 @@ def create_llm_service_from_provider( return DograhLLMService( base_url=f"{MPS_API_URL}/api/v1/llm", api_key=api_key, + correlation_id=correlation_id, settings=OpenAILLMSettings(model=model), ) elif provider == ServiceProviders.AWS_BEDROCK.value: @@ -851,7 +860,7 @@ def create_realtime_llm_service(user_config, audio_config: "AudioConfig"): ) -def create_llm_service(user_config): +def create_llm_service(user_config, correlation_id: str | None = None): """Create and return appropriate LLM service based on user configuration.""" provider = user_config.llm.provider model = user_config.llm.model @@ -880,4 +889,10 @@ def create_llm_service(user_config): elif provider == ServiceProviders.SARVAM.value: kwargs["temperature"] = user_config.llm.temperature - return create_llm_service_from_provider(provider, model, api_key, **kwargs) + return create_llm_service_from_provider( + provider, + model, + api_key, + correlation_id=correlation_id, + **kwargs, + ) diff --git a/api/services/workflow/pipecat_engine.py b/api/services/workflow/pipecat_engine.py index cea1d21f..a0d67947 100644 --- a/api/services/workflow/pipecat_engine.py +++ b/api/services/workflow/pipecat_engine.py @@ -35,6 +35,7 @@ import asyncio from loguru import logger +from api.services.managed_model_services import MPS_CORRELATION_ID_CONTEXT_KEY from api.services.workflow import pipecat_engine_callbacks as engine_callbacks from api.services.workflow.mcp_tool_session import McpToolSession from api.services.workflow.pipecat_engine_context_composer import ( @@ -382,6 +383,9 @@ class PipecatEngine: embeddings_provider=self._embeddings_provider, embeddings_endpoint=self._embeddings_endpoint, embeddings_api_version=self._embeddings_api_version, + correlation_id=self._call_context_vars.get( + MPS_CORRELATION_ID_CONTEXT_KEY + ), tracing_context=self._get_otel_context(), ) diff --git a/api/services/workflow/text_chat_runner.py b/api/services/workflow/text_chat_runner.py index 59073c80..7f6c5a0b 100644 --- a/api/services/workflow/text_chat_runner.py +++ b/api/services/workflow/text_chat_runner.py @@ -421,7 +421,19 @@ async def execute_text_chat_pending_turn( if user_config.llm is None: raise ValueError("Text chat requires an LLM configuration") - llm = create_llm_service(user_config) + from api.services.managed_model_services import ( + MPS_CORRELATION_ID_CONTEXT_KEY, + ensure_mps_correlation_id, + ) + + base_initial_context = dict(workflow_run.initial_context or {}) + mps_correlation_id = await ensure_mps_correlation_id( + ai_model_config=user_config, + workflow_run_id=workflow_run_id, + initial_context=base_initial_context, + ) + + llm = create_llm_service(user_config, correlation_id=mps_correlation_id) inference_llm = llm runtime_configuration = { @@ -429,9 +441,15 @@ async def execute_text_chat_pending_turn( "llm_model": user_config.llm.model, } initial_context = { - **(workflow_run.initial_context or {}), + **base_initial_context, "runtime_configuration": runtime_configuration, } + if mps_correlation_id: + initial_context[MPS_CORRELATION_ID_CONTEXT_KEY] = mps_correlation_id + await db_client.update_workflow_run( + workflow_run_id, + initial_context=initial_context, + ) workflow_graph = WorkflowGraph( ReactFlowDTO.model_validate(run_definition.workflow_json) diff --git a/api/services/workflow/tools/knowledge_base.py b/api/services/workflow/tools/knowledge_base.py index 6ce8f8c7..7b93aea7 100644 --- a/api/services/workflow/tools/knowledge_base.py +++ b/api/services/workflow/tools/knowledge_base.py @@ -29,6 +29,7 @@ async def retrieve_from_knowledge_base( embeddings_provider: Optional[str] = None, embeddings_endpoint: Optional[str] = None, embeddings_api_version: Optional[str] = None, + correlation_id: Optional[str] = None, tracing_context=None, ) -> Dict[str, Any]: """Retrieve relevant information from the knowledge base using vector similarity search. @@ -75,6 +76,7 @@ async def retrieve_from_knowledge_base( embeddings_provider, embeddings_endpoint, embeddings_api_version, + correlation_id, ) # Create span with parent context @@ -115,6 +117,7 @@ async def retrieve_from_knowledge_base( embeddings_provider, embeddings_endpoint, embeddings_api_version, + correlation_id, ) # Add result metadata to span @@ -192,6 +195,7 @@ async def retrieve_from_knowledge_base( embeddings_provider, embeddings_endpoint, embeddings_api_version, + correlation_id, ) else: # Tracing is disabled - perform retrieval without tracing @@ -206,6 +210,7 @@ async def retrieve_from_knowledge_base( embeddings_provider, embeddings_endpoint, embeddings_api_version, + correlation_id, ) @@ -220,6 +225,7 @@ async def _perform_retrieval( embeddings_provider: Optional[str] = None, embeddings_endpoint: Optional[str] = None, embeddings_api_version: Optional[str] = None, + correlation_id: Optional[str] = None, ) -> Dict[str, Any]: """Internal function to perform the actual retrieval operation. @@ -272,11 +278,20 @@ async def _perform_retrieval( api_version=embeddings_api_version or "2024-02-15-preview", ) else: + default_headers = None + if ( + embeddings_provider == ServiceProviders.DOGRAH.value + and correlation_id + ): + default_headers = { + "X-Dograh-Correlation-Id": correlation_id, + } embedding_service = OpenAIEmbeddingService( db_client=db_client, api_key=embeddings_api_key, model_id=embeddings_model or "text-embedding-3-small", base_url=embeddings_base_url, + default_headers=default_headers, ) results = await embedding_service.search_similar_chunks( diff --git a/api/tasks/knowledge_base_processing.py b/api/tasks/knowledge_base_processing.py index f496ac0e..a6ca0d6d 100644 --- a/api/tasks/knowledge_base_processing.py +++ b/api/tasks/knowledge_base_processing.py @@ -166,18 +166,22 @@ async def process_knowledge_base_document( user_id=document.created_by, organization_id=document.organization_id, ) - user_config = resolved_config.effective - if user_config.embeddings: - embeddings_provider = getattr(user_config.embeddings, "provider", None) - embeddings_api_key = user_config.embeddings.api_key - embeddings_model = user_config.embeddings.model + effective_config = resolved_config.effective + if effective_config.embeddings: + embeddings_provider = getattr( + effective_config.embeddings, "provider", None + ) + embeddings_api_key = effective_config.embeddings.api_key + embeddings_model = effective_config.embeddings.model embeddings_base_url = apply_managed_embeddings_base_url( provider=embeddings_provider, - base_url=getattr(user_config.embeddings, "base_url", None), + base_url=getattr(effective_config.embeddings, "base_url", None), + ) + embeddings_endpoint = getattr( + effective_config.embeddings, "endpoint", None ) - embeddings_endpoint = getattr(user_config.embeddings, "endpoint", None) embeddings_api_version = getattr( - user_config.embeddings, "api_version", None + effective_config.embeddings, "api_version", None ) logger.info( f"Using user embeddings config: provider={embeddings_provider}, " diff --git a/api/tests/integrations/_run_pipeline_helpers.py b/api/tests/integrations/_run_pipeline_helpers.py index 1a3251a0..58b4ffd2 100644 --- a/api/tests/integrations/_run_pipeline_helpers.py +++ b/api/tests/integrations/_run_pipeline_helpers.py @@ -203,7 +203,7 @@ async def create_workflow_run_rows( Returns: Tuple of (workflow_run, user, workflow). """ - from api.schemas.user_configuration import EffectiveAIModelConfiguration + from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration org = OrganizationModel(provider_id=f"test-org-{provider_id_suffix}") async_session.add(org) diff --git a/api/tests/test_ai_model_configuration_v2.py b/api/tests/test_ai_model_configuration_v2.py index 98f431e8..71772b28 100644 --- a/api/tests/test_ai_model_configuration_v2.py +++ b/api/tests/test_ai_model_configuration_v2.py @@ -3,10 +3,10 @@ from pydantic import ValidationError from api.schemas.ai_model_configuration import ( DograhManagedAIModelConfiguration, + EffectiveAIModelConfiguration, OrganizationAIModelConfigurationV2, compile_ai_model_configuration_v2, ) -from api.schemas.user_configuration import EffectiveAIModelConfiguration from api.services.configuration.ai_model_configuration import ( WORKFLOW_MODEL_CONFIGURATION_V2_OVERRIDE_KEY, check_for_masked_keys_in_ai_model_configuration_v2, @@ -49,6 +49,7 @@ def test_dograh_v2_compiles_to_effective_managed_pipeline_with_embeddings(): assert effective.stt.language == "multi" assert effective.embeddings.provider == "dograh" assert effective.embeddings.model == "default" + assert effective.managed_service_version == 2 def test_dograh_v2_rejects_non_predefined_speed(): diff --git a/api/tests/test_dograh_managed_correlation.py b/api/tests/test_dograh_managed_correlation.py new file mode 100644 index 00000000..b0cb52c0 --- /dev/null +++ b/api/tests/test_dograh_managed_correlation.py @@ -0,0 +1,110 @@ +import json + +import pytest +from openai._types import NOT_GIVEN as OPENAI_NOT_GIVEN +from pipecat.frames.frames import TTSStartedFrame +from pipecat.services.dograh.llm import DograhLLMService +from pipecat.services.dograh.stt import DograhSTTService +from pipecat.services.dograh.tts import DograhTTSService +from pipecat.services.openai.base_llm import OpenAILLMSettings +from websockets.protocol import State + + +class _FakeWebSocket: + def __init__(self): + self.state = State.OPEN + self.messages: list[dict] = [] + + async def send(self, message: str) -> None: + self.messages.append(json.loads(message)) + + async def close(self, *args, **kwargs) -> None: + self.state = State.CLOSED + + +def test_dograh_llm_uses_explicit_mps_correlation_id(): + service = DograhLLMService( + api_key="mps-secret", + correlation_id="mps-corr-123", + settings=OpenAILLMSettings(model="default"), + ) + service._start_metadata = {"workflow_run_id": 99} + + params = service.build_chat_completion_params( + { + "messages": [], + "tools": OPENAI_NOT_GIVEN, + "tool_choice": OPENAI_NOT_GIVEN, + } + ) + + assert params["metadata"]["correlation_id"] == "mps-corr-123" + assert params["metadata"]["mps_billing_version"] == "2" + + +@pytest.mark.asyncio +async def test_dograh_stt_config_uses_explicit_mps_correlation_id(monkeypatch): + fake_ws = _FakeWebSocket() + + async def fake_connect(url, additional_headers): + return fake_ws + + monkeypatch.setattr( + "pipecat.services.dograh.stt.websocket_connect", + fake_connect, + ) + + service = DograhSTTService( + api_key="mps-secret", + correlation_id="mps-corr-123", + sample_rate=16000, + ) + service._start_metadata = {"workflow_run_id": 99} + + await service._connect_websocket() + + assert fake_ws.messages[0]["type"] == "config" + assert fake_ws.messages[0]["correlation_id"] == "mps-corr-123" + assert fake_ws.messages[0]["mps_billing_version"] == "2" + + +@pytest.mark.asyncio +async def test_dograh_tts_messages_use_explicit_mps_correlation_id(monkeypatch): + fake_ws = _FakeWebSocket() + + async def fake_connect(url, additional_headers): + return fake_ws + + monkeypatch.setattr( + "pipecat.services.dograh.tts.websocket_connect", + fake_connect, + ) + + service = DograhTTSService( + api_key="mps-secret", + correlation_id="mps-corr-123", + sample_rate=24000, + ) + service._start_metadata = {"workflow_run_id": 99} + + await service._connect_websocket() + assert fake_ws.messages[0]["type"] == "config" + assert fake_ws.messages[0]["correlation_id"] == "mps-corr-123" + assert fake_ws.messages[0]["mps_billing_version"] == "2" + + async def _noop(*args, **kwargs): + return None + + service.audio_context_available = lambda context_id: False + service.create_audio_context = _noop + service.start_ttfb_metrics = _noop + service.start_tts_usage_metrics = _noop + + frames = [] + async for frame in service.run_tts("hello", "ctx-1"): + frames.append(frame) + + assert isinstance(frames[0], TTSStartedFrame) + assert fake_ws.messages[1]["type"] == "create_context" + assert fake_ws.messages[1]["correlation_id"] == "mps-corr-123" + assert fake_ws.messages[1]["mps_billing_version"] == "2" diff --git a/api/tests/test_grok_realtime_wrapper.py b/api/tests/test_grok_realtime_wrapper.py index 7f7359dc..19cae657 100644 --- a/api/tests/test_grok_realtime_wrapper.py +++ b/api/tests/test_grok_realtime_wrapper.py @@ -7,7 +7,7 @@ from pipecat.processors.aggregators.llm_context import LLMContext from pipecat.processors.frame_processor import FrameDirection from pipecat.services.xai.realtime import events -from api.schemas.user_configuration import EffectiveAIModelConfiguration +from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration from api.services.configuration.registry import GrokRealtimeLLMConfiguration from api.services.pipecat.realtime.grok_realtime import ( DograhGrokRealtimeLLMService, @@ -120,7 +120,7 @@ async def test_completed_input_transcription_is_broadcast_as_finalized(): def test_factory_creates_dograh_grok_realtime_service(): - user_config = EffectiveAIModelConfiguration( + effective_config = EffectiveAIModelConfiguration( is_realtime=True, realtime=GrokRealtimeLLMConfiguration( provider="grok_realtime", @@ -131,7 +131,7 @@ def test_factory_creates_dograh_grok_realtime_service(): ) service = create_realtime_llm_service( - user_config, + effective_config, audio_config=SimpleNamespace(), ) diff --git a/api/tests/test_masked_key_rejection.py b/api/tests/test_masked_key_rejection.py index 2012c60b..45782335 100644 --- a/api/tests/test_masked_key_rejection.py +++ b/api/tests/test_masked_key_rejection.py @@ -5,7 +5,7 @@ from fastapi import FastAPI from fastapi.testclient import TestClient from api.routes.user import router -from api.schemas.user_configuration import EffectiveAIModelConfiguration +from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration from api.services.auth.depends import get_user from api.services.configuration.masking import mask_key from api.services.configuration.registry import ( diff --git a/api/tests/test_mps_service_key_client.py b/api/tests/test_mps_service_key_client.py index 9cd629e3..7f42f13d 100644 --- a/api/tests/test_mps_service_key_client.py +++ b/api/tests/test_mps_service_key_client.py @@ -87,3 +87,44 @@ async def test_check_service_key_usage_uses_bearer_self_usage(monkeypatch): "Content-Type": "application/json", }, ) + + +@pytest.mark.asyncio +async def test_create_correlation_id_uses_bearer_auth(monkeypatch): + calls = [] + + class FakeAsyncClient: + def __init__(self, timeout): + self.timeout = timeout + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + return None + + async def post(self, url, json, headers): + calls.append(("POST", url, json, headers)) + return _Response(200, {"correlation_id": "mps-corr-123"}) + + monkeypatch.setattr( + "api.services.mps_service_key_client.httpx.AsyncClient", FakeAsyncClient + ) + + client = MPSServiceKeyClient() + + assert await client.create_correlation_id( + service_key="mps_sk_paid", + workflow_run_id=42, + ) == {"correlation_id": "mps-corr-123"} + assert calls == [ + ( + "POST", + f"{client.base_url}/api/v1/service-keys/correlation-id/self", + {"workflow_run_id": 42}, + { + "Authorization": "Bearer mps_sk_paid", + "Content-Type": "application/json", + }, + ) + ] diff --git a/api/tests/test_resolve_effective_config.py b/api/tests/test_resolve_effective_config.py index 85afe30f..1b9ad8c6 100644 --- a/api/tests/test_resolve_effective_config.py +++ b/api/tests/test_resolve_effective_config.py @@ -9,7 +9,7 @@ Module under test: api.services.configuration.resolve import pytest -from api.schemas.user_configuration import EffectiveAIModelConfiguration +from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration from api.services.configuration.masking import ( contains_masked_key, mask_workflow_configurations, diff --git a/api/tests/test_ultravox_realtime_wrapper.py b/api/tests/test_ultravox_realtime_wrapper.py index 65b062b6..32888439 100644 --- a/api/tests/test_ultravox_realtime_wrapper.py +++ b/api/tests/test_ultravox_realtime_wrapper.py @@ -10,7 +10,7 @@ from pipecat.processors.frame_processor import FrameDirection from websockets.exceptions import ConnectionClosedError from websockets.frames import Close -from api.schemas.user_configuration import EffectiveAIModelConfiguration +from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration from api.services.configuration.registry import UltravoxRealtimeLLMConfiguration from api.services.pipecat.realtime.ultravox_realtime import ( _RESUMPTION_USER_MESSAGE, @@ -430,7 +430,7 @@ async def test_receive_messages_reports_unexpected_websocket_close(): def test_factory_creates_dograh_ultravox_realtime_service(): - user_config = EffectiveAIModelConfiguration( + effective_config = EffectiveAIModelConfiguration( is_realtime=True, realtime=UltravoxRealtimeLLMConfiguration( provider="ultravox_realtime", @@ -441,7 +441,7 @@ def test_factory_creates_dograh_ultravox_realtime_service(): ) service = create_realtime_llm_service( - user_config, + effective_config, audio_config=SimpleNamespace(), ) diff --git a/api/tests/test_workflow_text_chat.py b/api/tests/test_workflow_text_chat.py index e69e7c0a..b3fb0d86 100644 --- a/api/tests/test_workflow_text_chat.py +++ b/api/tests/test_workflow_text_chat.py @@ -4,7 +4,7 @@ from unittest.mock import AsyncMock, patch import pytest from api.db.models import OrganizationModel, UserModel -from api.schemas.user_configuration import EffectiveAIModelConfiguration +from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration from api.tests.integrations._run_pipeline_helpers import USER_CONFIGURATION from pipecat.tests import MockLLMService