feat: use mps generated correlation ID

This commit is contained in:
Abhishek Kumar 2026-06-09 18:24:40 +05:30
parent 91ac460799
commit 3336c6e794
30 changed files with 453 additions and 89 deletions

View file

@ -19,7 +19,7 @@ from api.db.models import (
WorkflowRunModel,
)
from api.enums import OrganizationConfigurationKey
from api.schemas.user_configuration import EffectiveAIModelConfiguration
from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration
class OrganizationUsageClient(BaseDBClient):
@ -473,11 +473,11 @@ class OrganizationUsageClient(BaseDBClient):
)
config_obj = config_result.scalar_one_or_none()
if config_obj and config_obj.configuration:
user_config = EffectiveAIModelConfiguration.model_validate(
effective_config = EffectiveAIModelConfiguration.model_validate(
config_obj.configuration
)
if user_config.timezone and user_timezone == "UTC":
user_timezone = user_config.timezone
if effective_config.timezone and user_timezone == "UTC":
user_timezone = effective_config.timezone
# Validate timezone string
try:

View file

@ -8,7 +8,7 @@ from sqlalchemy.future import select
from api.db.base_client import BaseDBClient
from api.db.models import UserConfigurationModel, UserModel
from api.schemas.user_configuration import EffectiveAIModelConfiguration
from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration
class UserClient(BaseDBClient):

View file

@ -384,7 +384,7 @@ async def search_chunks(
user_id=user.id,
organization_id=user.selected_organization_id,
)
user_config = resolved_config.effective
effective_config = resolved_config.effective
embeddings_api_key = None
embeddings_model = None
embeddings_provider = None
@ -392,17 +392,17 @@ async def search_chunks(
embeddings_endpoint = None
embeddings_api_version = None
if user_config.embeddings:
embeddings_api_key = user_config.embeddings.api_key
embeddings_model = user_config.embeddings.model
embeddings_provider = getattr(user_config.embeddings, "provider", None)
embeddings_endpoint = getattr(user_config.embeddings, "endpoint", None)
if effective_config.embeddings:
embeddings_api_key = effective_config.embeddings.api_key
embeddings_model = effective_config.embeddings.model
embeddings_provider = getattr(effective_config.embeddings, "provider", None)
embeddings_endpoint = getattr(effective_config.embeddings, "endpoint", None)
embeddings_base_url = apply_managed_embeddings_base_url(
provider=embeddings_provider,
base_url=getattr(user_config.embeddings, "base_url", None),
base_url=getattr(effective_config.embeddings, "base_url", None),
)
embeddings_api_version = getattr(
user_config.embeddings, "api_version", None
effective_config.embeddings, "api_version", None
)
# Initialize embedding service based on provider

View file

@ -1053,13 +1053,15 @@ async def update_workflow(
user_id=user.id,
organization_id=user.selected_organization_id,
)
user_config = resolved_config.effective
effective_config = resolved_config.effective
try:
enriched_overrides = enrich_overrides_with_api_keys(
workflow_configurations["model_overrides"],
user_config,
effective_config,
)
effective = resolve_effective_config(
effective_config, enriched_overrides
)
effective = resolve_effective_config(user_config, enriched_overrides)
if resolved_config.source == "organization_v2":
v2_override = convert_legacy_ai_model_configuration_to_v2(effective)
await UserConfigurationValidator().validate(

View file

@ -1,10 +1,10 @@
from __future__ import annotations
from datetime import datetime
from typing import Literal
from pydantic import BaseModel, Field, model_validator
from api.schemas.user_configuration import EffectiveAIModelConfiguration
from api.services.configuration.registry import (
DograhEmbeddingsConfiguration,
DograhLLMService,
@ -23,6 +23,29 @@ DOGRAH_DEFAULT_VOICE = "default"
DOGRAH_DEFAULT_LANGUAGE = "multi"
class EffectiveAIModelConfiguration(BaseModel):
llm: LLMConfig | None = None
stt: STTConfig | None = None
tts: TTSConfig | None = None
embeddings: EmbeddingsConfig | None = None
realtime: RealtimeConfig | None = None
is_realtime: bool = False
managed_service_version: int | None = None
test_phone_number: str | None = None
timezone: str | None = None
last_validated_at: datetime | None = None
@model_validator(mode="before")
@classmethod
def strip_incomplete_realtime_when_disabled(cls, data):
"""Skip realtime validation when is_realtime is False and api_key is missing."""
if isinstance(data, dict) and not data.get("is_realtime", False):
realtime = data.get("realtime")
if isinstance(realtime, dict) and not realtime.get("api_key"):
data.pop("realtime", None)
return data
class DograhManagedAIModelConfiguration(BaseModel):
api_key: str
voice: str = DOGRAH_DEFAULT_VOICE
@ -160,6 +183,7 @@ def _compile_dograh_configuration(
model="default",
),
is_realtime=False,
managed_service_version=2,
)

View file

@ -1,33 +0,0 @@
from datetime import datetime
from pydantic import BaseModel, model_validator
from api.services.configuration.registry import (
EmbeddingsConfig,
LLMConfig,
RealtimeConfig,
STTConfig,
TTSConfig,
)
class EffectiveAIModelConfiguration(BaseModel):
llm: LLMConfig | None = None
stt: STTConfig | None = None
tts: TTSConfig | None = None
embeddings: EmbeddingsConfig | None = None
realtime: RealtimeConfig | None = None
is_realtime: bool = False
test_phone_number: str | None = None
timezone: str | None = None
last_validated_at: datetime | None = None
@model_validator(mode="before")
@classmethod
def strip_incomplete_realtime_when_disabled(cls, data):
"""Skip realtime validation when is_realtime is False and api_key is missing."""
if isinstance(data, dict) and not data.get("is_realtime", False):
realtime = data.get("realtime")
if isinstance(realtime, dict) and not realtime.get("api_key"):
data.pop("realtime", None)
return data

View file

@ -9,7 +9,7 @@ from api.constants import AUTH_PROVIDER, DOGRAH_MPS_SECRET_KEY, MPS_API_URL
from api.db import db_client
from api.db.models import UserModel
from api.enums import PostHogEvent
from api.schemas.user_configuration import EffectiveAIModelConfiguration
from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration
from api.services.auth.stack_auth import stackauth
from api.services.configuration.registry import ServiceProviders
from api.services.posthog_client import capture_event
@ -285,8 +285,8 @@ async def create_user_configuration_with_mps_key(
"model": "default",
},
}
user_config = EffectiveAIModelConfiguration(**configuration)
return user_config
effective_config = EffectiveAIModelConfiguration(**configuration)
return effective_config
else:
logger.warning(
f"Failed to get MPS service key: {response.status_code} - {response.text}"

View file

@ -21,10 +21,10 @@ from api.schemas.ai_model_configuration import (
BYOKPipelineAIModelConfiguration,
BYOKRealtimeAIModelConfiguration,
DograhManagedAIModelConfiguration,
EffectiveAIModelConfiguration,
OrganizationAIModelConfigurationV2,
compile_ai_model_configuration_v2,
)
from api.schemas.user_configuration import EffectiveAIModelConfiguration
from api.services.configuration.masking import (
SERVICE_SECRET_FIELDS,
contains_masked_key,

View file

@ -8,7 +8,7 @@ from groq import Groq
# from pyneuphonic import Neuphonic
# except ImportError:
# Neuphonic = None
from api.schemas.user_configuration import (
from api.schemas.ai_model_configuration import (
EffectiveAIModelConfiguration,
)
from api.services.configuration.registry import ServiceConfig, ServiceProviders

View file

@ -12,7 +12,7 @@ The rules are simple:
import copy
from typing import Any, Dict, Optional
from api.schemas.user_configuration import EffectiveAIModelConfiguration
from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration
from api.services.configuration.registry import ServiceConfig
from api.services.integrations import get_node_secret_fields

View file

@ -7,7 +7,7 @@ stored, while honouring masked API keys.
import copy
from typing import Dict
from api.schemas.user_configuration import EffectiveAIModelConfiguration
from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration
from api.services.configuration.masking import (
MODEL_OVERRIDE_FIELDS,
SERVICE_SECRET_FIELDS,

View file

@ -4,7 +4,7 @@ from __future__ import annotations
import copy
from api.schemas.user_configuration import EffectiveAIModelConfiguration
from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration
from api.services.configuration.registry import (
REGISTRY,
ServiceType,

View file

@ -38,6 +38,7 @@ class OpenAIEmbeddingService(BaseEmbeddingService):
api_key: Optional[str] = None,
model_id: str = DEFAULT_MODEL_ID,
base_url: Optional[str] = None,
default_headers: Optional[Dict[str, str]] = None,
):
"""Initialize the OpenAI embedding service.
@ -60,6 +61,8 @@ class OpenAIEmbeddingService(BaseEmbeddingService):
field_name="base_url",
)
client_kwargs["base_url"] = base_url
if default_headers:
client_kwargs["default_headers"] = default_headers
self.client = AsyncOpenAI(**client_kwargs)
logger.info(f"OpenAI embedding service initialized with model: {model_id}")
else:

View file

@ -0,0 +1,98 @@
from __future__ import annotations
from typing import Any
from loguru import logger
from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration
from api.services.configuration.registry import ServiceProviders
from api.services.mps_service_key_client import mps_service_key_client
MPS_CORRELATION_ID_CONTEXT_KEY = "mps_correlation_id"
def uses_managed_model_services_v2(
ai_model_config: EffectiveAIModelConfiguration | None,
) -> bool:
if (
ai_model_config is None
or getattr(ai_model_config, "managed_service_version", None) != 2
):
return False
return any(
_is_dograh_service(getattr(ai_model_config, section_name, None))
for section_name in ("llm", "tts", "stt", "embeddings")
)
def get_mps_correlation_id(initial_context: dict[str, Any] | None) -> str | None:
if not initial_context:
return None
correlation_id = initial_context.get(MPS_CORRELATION_ID_CONTEXT_KEY)
if correlation_id is None:
return None
return str(correlation_id)
async def ensure_mps_correlation_id(
*,
ai_model_config: EffectiveAIModelConfiguration,
workflow_run_id: int,
initial_context: dict[str, Any] | None,
) -> str | None:
existing = get_mps_correlation_id(initial_context)
if existing:
return existing
if not uses_managed_model_services_v2(ai_model_config):
return None
service_key = _get_dograh_service_api_key(ai_model_config)
if not service_key:
raise ValueError(
"Managed model services v2 requires a Dograh service key before the run starts."
)
response = await mps_service_key_client.create_correlation_id(
service_key=service_key,
workflow_run_id=workflow_run_id,
)
correlation_id = response.get("correlation_id")
if not correlation_id:
raise ValueError("MPS correlation-id response did not include correlation_id")
correlation_id = str(correlation_id)
logger.info(
"Minted MPS correlation id {} for workflow run {}",
correlation_id,
workflow_run_id,
)
return correlation_id
def _is_dograh_service(service: Any) -> bool:
provider = getattr(service, "provider", None)
return (
provider == ServiceProviders.DOGRAH or provider == ServiceProviders.DOGRAH.value
)
def _get_dograh_service_api_key(
ai_model_config: EffectiveAIModelConfiguration,
) -> str | None:
for section_name in ("llm", "tts", "stt", "embeddings"):
service = getattr(ai_model_config, section_name, None)
if not _is_dograh_service(service):
continue
if hasattr(service, "get_all_api_keys"):
keys = service.get_all_api_keys()
if keys:
return keys[0]
api_key = getattr(service, "api_key", None)
if isinstance(api_key, str) and api_key:
return api_key
return None

View file

@ -353,6 +353,40 @@ class MPSServiceKeyClient:
response=response,
)
async def create_correlation_id(
self,
*,
service_key: str,
workflow_run_id: int | None = None,
) -> dict:
"""Mint a server-generated correlation ID for managed model services."""
payload: dict[str, int] = {}
if workflow_run_id is not None:
payload["workflow_run_id"] = workflow_run_id
async with httpx.AsyncClient(timeout=self.timeout) as client:
response = await client.post(
f"{self.base_url}/api/v1/service-keys/correlation-id/self",
json=payload,
headers={
"Authorization": f"Bearer {service_key}",
"Content-Type": "application/json",
},
)
if response.status_code == 200:
return response.json()
logger.error(
"Failed to create correlation ID: "
f"{response.status_code} - {response.text}"
)
raise httpx.HTTPStatusError(
f"Failed to create correlation ID: {response.text}",
request=response.request,
response=response,
)
async def transcribe_audio(
self,
audio_data: bytes,

View file

@ -340,7 +340,7 @@ async def _run_pipeline(
if workflow_run.is_completed:
raise HTTPException(status_code=400, detail="Workflow run already completed")
merged_call_context_vars = workflow_run.initial_context
merged_call_context_vars = dict(workflow_run.initial_context or {})
# If there is some extra call_context_vars, fold them in. Persistence
# happens once below, after runtime_configuration is also resolved.
if call_context_vars:
@ -398,6 +398,19 @@ async def _run_pipeline(
else:
user_config = resolved_user_config
from api.services.managed_model_services import (
MPS_CORRELATION_ID_CONTEXT_KEY,
ensure_mps_correlation_id,
)
mps_correlation_id = await ensure_mps_correlation_id(
ai_model_config=user_config,
workflow_run_id=workflow_run_id,
initial_context=merged_call_context_vars,
)
if mps_correlation_id:
merged_call_context_vars[MPS_CORRELATION_ID_CONTEXT_KEY] = mps_correlation_id
# Detect realtime mode (speech-to-speech services like OpenAI Realtime, Gemini Live)
is_realtime = user_config.is_realtime and user_config.realtime is not None
@ -409,11 +422,23 @@ async def _run_pipeline(
# Realtime services don't implement run_inference, so create a
# separate text LLM for variable extraction and other out-of-band
# inference calls.
inference_llm = create_llm_service(user_config)
inference_llm = create_llm_service(
user_config,
correlation_id=mps_correlation_id,
)
else:
stt = create_stt_service(user_config, audio_config, keyterms=keyterms)
tts = create_tts_service(user_config, audio_config)
llm = create_llm_service(user_config)
stt = create_stt_service(
user_config,
audio_config,
keyterms=keyterms,
correlation_id=mps_correlation_id,
)
tts = create_tts_service(
user_config,
audio_config,
correlation_id=mps_correlation_id,
)
llm = create_llm_service(user_config, correlation_id=mps_correlation_id)
inference_llm = None
# Stamp the providers/models actually resolved for this run onto
@ -695,7 +720,10 @@ async def _run_pipeline(
# Create a separate LLM instance for the voicemail sub-pipeline
# (can't share with main pipeline as it would mess up frame linking)
if voicemail_config.get("use_workflow_llm", True):
voicemail_llm = create_llm_service(user_config)
voicemail_llm = create_llm_service(
user_config,
correlation_id=mps_correlation_id,
)
else:
voicemail_llm = create_llm_service_from_provider(
provider=voicemail_config.get("provider", "openai"),

View file

@ -78,7 +78,10 @@ def _validate_runtime_service_url(url: str, field_name: str) -> None:
def create_stt_service(
user_config, audio_config: "AudioConfig", keyterms: list[str] | None = None
user_config,
audio_config: "AudioConfig",
keyterms: list[str] | None = None,
correlation_id: str | None = None,
):
"""Create and return appropriate STT service based on user configuration
@ -160,6 +163,7 @@ def create_stt_service(
return DograhSTTService(
base_url=base_url,
api_key=user_config.stt.api_key,
correlation_id=correlation_id,
settings=DograhSTTSettings(
model=user_config.stt.model,
language=language,
@ -286,7 +290,9 @@ def create_stt_service(
)
def create_tts_service(user_config, audio_config: "AudioConfig"):
def create_tts_service(
user_config, audio_config: "AudioConfig", correlation_id: str | None = None
):
"""Create and return appropriate TTS service based on user configuration
Args:
@ -404,6 +410,7 @@ def create_tts_service(user_config, audio_config: "AudioConfig"):
return DograhTTSService(
base_url=base_url,
api_key=user_config.tts.api_key,
correlation_id=correlation_id,
settings=DograhTTSSettings(
model=user_config.tts.model,
voice=user_config.tts.voice,
@ -564,6 +571,7 @@ def create_llm_service_from_provider(
model: str,
api_key: str | None,
*,
correlation_id: str | None = None,
base_url: str | None = None,
endpoint: str | None = None,
aws_access_key: str | None = None,
@ -637,6 +645,7 @@ def create_llm_service_from_provider(
return DograhLLMService(
base_url=f"{MPS_API_URL}/api/v1/llm",
api_key=api_key,
correlation_id=correlation_id,
settings=OpenAILLMSettings(model=model),
)
elif provider == ServiceProviders.AWS_BEDROCK.value:
@ -851,7 +860,7 @@ def create_realtime_llm_service(user_config, audio_config: "AudioConfig"):
)
def create_llm_service(user_config):
def create_llm_service(user_config, correlation_id: str | None = None):
"""Create and return appropriate LLM service based on user configuration."""
provider = user_config.llm.provider
model = user_config.llm.model
@ -880,4 +889,10 @@ def create_llm_service(user_config):
elif provider == ServiceProviders.SARVAM.value:
kwargs["temperature"] = user_config.llm.temperature
return create_llm_service_from_provider(provider, model, api_key, **kwargs)
return create_llm_service_from_provider(
provider,
model,
api_key,
correlation_id=correlation_id,
**kwargs,
)

View file

@ -35,6 +35,7 @@ import asyncio
from loguru import logger
from api.services.managed_model_services import MPS_CORRELATION_ID_CONTEXT_KEY
from api.services.workflow import pipecat_engine_callbacks as engine_callbacks
from api.services.workflow.mcp_tool_session import McpToolSession
from api.services.workflow.pipecat_engine_context_composer import (
@ -382,6 +383,9 @@ class PipecatEngine:
embeddings_provider=self._embeddings_provider,
embeddings_endpoint=self._embeddings_endpoint,
embeddings_api_version=self._embeddings_api_version,
correlation_id=self._call_context_vars.get(
MPS_CORRELATION_ID_CONTEXT_KEY
),
tracing_context=self._get_otel_context(),
)

View file

@ -421,7 +421,19 @@ async def execute_text_chat_pending_turn(
if user_config.llm is None:
raise ValueError("Text chat requires an LLM configuration")
llm = create_llm_service(user_config)
from api.services.managed_model_services import (
MPS_CORRELATION_ID_CONTEXT_KEY,
ensure_mps_correlation_id,
)
base_initial_context = dict(workflow_run.initial_context or {})
mps_correlation_id = await ensure_mps_correlation_id(
ai_model_config=user_config,
workflow_run_id=workflow_run_id,
initial_context=base_initial_context,
)
llm = create_llm_service(user_config, correlation_id=mps_correlation_id)
inference_llm = llm
runtime_configuration = {
@ -429,9 +441,15 @@ async def execute_text_chat_pending_turn(
"llm_model": user_config.llm.model,
}
initial_context = {
**(workflow_run.initial_context or {}),
**base_initial_context,
"runtime_configuration": runtime_configuration,
}
if mps_correlation_id:
initial_context[MPS_CORRELATION_ID_CONTEXT_KEY] = mps_correlation_id
await db_client.update_workflow_run(
workflow_run_id,
initial_context=initial_context,
)
workflow_graph = WorkflowGraph(
ReactFlowDTO.model_validate(run_definition.workflow_json)

View file

@ -29,6 +29,7 @@ async def retrieve_from_knowledge_base(
embeddings_provider: Optional[str] = None,
embeddings_endpoint: Optional[str] = None,
embeddings_api_version: Optional[str] = None,
correlation_id: Optional[str] = None,
tracing_context=None,
) -> Dict[str, Any]:
"""Retrieve relevant information from the knowledge base using vector similarity search.
@ -75,6 +76,7 @@ async def retrieve_from_knowledge_base(
embeddings_provider,
embeddings_endpoint,
embeddings_api_version,
correlation_id,
)
# Create span with parent context
@ -115,6 +117,7 @@ async def retrieve_from_knowledge_base(
embeddings_provider,
embeddings_endpoint,
embeddings_api_version,
correlation_id,
)
# Add result metadata to span
@ -192,6 +195,7 @@ async def retrieve_from_knowledge_base(
embeddings_provider,
embeddings_endpoint,
embeddings_api_version,
correlation_id,
)
else:
# Tracing is disabled - perform retrieval without tracing
@ -206,6 +210,7 @@ async def retrieve_from_knowledge_base(
embeddings_provider,
embeddings_endpoint,
embeddings_api_version,
correlation_id,
)
@ -220,6 +225,7 @@ async def _perform_retrieval(
embeddings_provider: Optional[str] = None,
embeddings_endpoint: Optional[str] = None,
embeddings_api_version: Optional[str] = None,
correlation_id: Optional[str] = None,
) -> Dict[str, Any]:
"""Internal function to perform the actual retrieval operation.
@ -272,11 +278,20 @@ async def _perform_retrieval(
api_version=embeddings_api_version or "2024-02-15-preview",
)
else:
default_headers = None
if (
embeddings_provider == ServiceProviders.DOGRAH.value
and correlation_id
):
default_headers = {
"X-Dograh-Correlation-Id": correlation_id,
}
embedding_service = OpenAIEmbeddingService(
db_client=db_client,
api_key=embeddings_api_key,
model_id=embeddings_model or "text-embedding-3-small",
base_url=embeddings_base_url,
default_headers=default_headers,
)
results = await embedding_service.search_similar_chunks(

View file

@ -166,18 +166,22 @@ async def process_knowledge_base_document(
user_id=document.created_by,
organization_id=document.organization_id,
)
user_config = resolved_config.effective
if user_config.embeddings:
embeddings_provider = getattr(user_config.embeddings, "provider", None)
embeddings_api_key = user_config.embeddings.api_key
embeddings_model = user_config.embeddings.model
effective_config = resolved_config.effective
if effective_config.embeddings:
embeddings_provider = getattr(
effective_config.embeddings, "provider", None
)
embeddings_api_key = effective_config.embeddings.api_key
embeddings_model = effective_config.embeddings.model
embeddings_base_url = apply_managed_embeddings_base_url(
provider=embeddings_provider,
base_url=getattr(user_config.embeddings, "base_url", None),
base_url=getattr(effective_config.embeddings, "base_url", None),
)
embeddings_endpoint = getattr(
effective_config.embeddings, "endpoint", None
)
embeddings_endpoint = getattr(user_config.embeddings, "endpoint", None)
embeddings_api_version = getattr(
user_config.embeddings, "api_version", None
effective_config.embeddings, "api_version", None
)
logger.info(
f"Using user embeddings config: provider={embeddings_provider}, "

View file

@ -203,7 +203,7 @@ async def create_workflow_run_rows(
Returns:
Tuple of (workflow_run, user, workflow).
"""
from api.schemas.user_configuration import EffectiveAIModelConfiguration
from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration
org = OrganizationModel(provider_id=f"test-org-{provider_id_suffix}")
async_session.add(org)

View file

@ -3,10 +3,10 @@ from pydantic import ValidationError
from api.schemas.ai_model_configuration import (
DograhManagedAIModelConfiguration,
EffectiveAIModelConfiguration,
OrganizationAIModelConfigurationV2,
compile_ai_model_configuration_v2,
)
from api.schemas.user_configuration import EffectiveAIModelConfiguration
from api.services.configuration.ai_model_configuration import (
WORKFLOW_MODEL_CONFIGURATION_V2_OVERRIDE_KEY,
check_for_masked_keys_in_ai_model_configuration_v2,
@ -49,6 +49,7 @@ def test_dograh_v2_compiles_to_effective_managed_pipeline_with_embeddings():
assert effective.stt.language == "multi"
assert effective.embeddings.provider == "dograh"
assert effective.embeddings.model == "default"
assert effective.managed_service_version == 2
def test_dograh_v2_rejects_non_predefined_speed():

View file

@ -0,0 +1,110 @@
import json
import pytest
from openai._types import NOT_GIVEN as OPENAI_NOT_GIVEN
from pipecat.frames.frames import TTSStartedFrame
from pipecat.services.dograh.llm import DograhLLMService
from pipecat.services.dograh.stt import DograhSTTService
from pipecat.services.dograh.tts import DograhTTSService
from pipecat.services.openai.base_llm import OpenAILLMSettings
from websockets.protocol import State
class _FakeWebSocket:
def __init__(self):
self.state = State.OPEN
self.messages: list[dict] = []
async def send(self, message: str) -> None:
self.messages.append(json.loads(message))
async def close(self, *args, **kwargs) -> None:
self.state = State.CLOSED
def test_dograh_llm_uses_explicit_mps_correlation_id():
service = DograhLLMService(
api_key="mps-secret",
correlation_id="mps-corr-123",
settings=OpenAILLMSettings(model="default"),
)
service._start_metadata = {"workflow_run_id": 99}
params = service.build_chat_completion_params(
{
"messages": [],
"tools": OPENAI_NOT_GIVEN,
"tool_choice": OPENAI_NOT_GIVEN,
}
)
assert params["metadata"]["correlation_id"] == "mps-corr-123"
assert params["metadata"]["mps_billing_version"] == "2"
@pytest.mark.asyncio
async def test_dograh_stt_config_uses_explicit_mps_correlation_id(monkeypatch):
fake_ws = _FakeWebSocket()
async def fake_connect(url, additional_headers):
return fake_ws
monkeypatch.setattr(
"pipecat.services.dograh.stt.websocket_connect",
fake_connect,
)
service = DograhSTTService(
api_key="mps-secret",
correlation_id="mps-corr-123",
sample_rate=16000,
)
service._start_metadata = {"workflow_run_id": 99}
await service._connect_websocket()
assert fake_ws.messages[0]["type"] == "config"
assert fake_ws.messages[0]["correlation_id"] == "mps-corr-123"
assert fake_ws.messages[0]["mps_billing_version"] == "2"
@pytest.mark.asyncio
async def test_dograh_tts_messages_use_explicit_mps_correlation_id(monkeypatch):
fake_ws = _FakeWebSocket()
async def fake_connect(url, additional_headers):
return fake_ws
monkeypatch.setattr(
"pipecat.services.dograh.tts.websocket_connect",
fake_connect,
)
service = DograhTTSService(
api_key="mps-secret",
correlation_id="mps-corr-123",
sample_rate=24000,
)
service._start_metadata = {"workflow_run_id": 99}
await service._connect_websocket()
assert fake_ws.messages[0]["type"] == "config"
assert fake_ws.messages[0]["correlation_id"] == "mps-corr-123"
assert fake_ws.messages[0]["mps_billing_version"] == "2"
async def _noop(*args, **kwargs):
return None
service.audio_context_available = lambda context_id: False
service.create_audio_context = _noop
service.start_ttfb_metrics = _noop
service.start_tts_usage_metrics = _noop
frames = []
async for frame in service.run_tts("hello", "ctx-1"):
frames.append(frame)
assert isinstance(frames[0], TTSStartedFrame)
assert fake_ws.messages[1]["type"] == "create_context"
assert fake_ws.messages[1]["correlation_id"] == "mps-corr-123"
assert fake_ws.messages[1]["mps_billing_version"] == "2"

View file

@ -7,7 +7,7 @@ from pipecat.processors.aggregators.llm_context import LLMContext
from pipecat.processors.frame_processor import FrameDirection
from pipecat.services.xai.realtime import events
from api.schemas.user_configuration import EffectiveAIModelConfiguration
from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration
from api.services.configuration.registry import GrokRealtimeLLMConfiguration
from api.services.pipecat.realtime.grok_realtime import (
DograhGrokRealtimeLLMService,
@ -120,7 +120,7 @@ async def test_completed_input_transcription_is_broadcast_as_finalized():
def test_factory_creates_dograh_grok_realtime_service():
user_config = EffectiveAIModelConfiguration(
effective_config = EffectiveAIModelConfiguration(
is_realtime=True,
realtime=GrokRealtimeLLMConfiguration(
provider="grok_realtime",
@ -131,7 +131,7 @@ def test_factory_creates_dograh_grok_realtime_service():
)
service = create_realtime_llm_service(
user_config,
effective_config,
audio_config=SimpleNamespace(),
)

View file

@ -5,7 +5,7 @@ from fastapi import FastAPI
from fastapi.testclient import TestClient
from api.routes.user import router
from api.schemas.user_configuration import EffectiveAIModelConfiguration
from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration
from api.services.auth.depends import get_user
from api.services.configuration.masking import mask_key
from api.services.configuration.registry import (

View file

@ -87,3 +87,44 @@ async def test_check_service_key_usage_uses_bearer_self_usage(monkeypatch):
"Content-Type": "application/json",
},
)
@pytest.mark.asyncio
async def test_create_correlation_id_uses_bearer_auth(monkeypatch):
calls = []
class FakeAsyncClient:
def __init__(self, timeout):
self.timeout = timeout
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc, tb):
return None
async def post(self, url, json, headers):
calls.append(("POST", url, json, headers))
return _Response(200, {"correlation_id": "mps-corr-123"})
monkeypatch.setattr(
"api.services.mps_service_key_client.httpx.AsyncClient", FakeAsyncClient
)
client = MPSServiceKeyClient()
assert await client.create_correlation_id(
service_key="mps_sk_paid",
workflow_run_id=42,
) == {"correlation_id": "mps-corr-123"}
assert calls == [
(
"POST",
f"{client.base_url}/api/v1/service-keys/correlation-id/self",
{"workflow_run_id": 42},
{
"Authorization": "Bearer mps_sk_paid",
"Content-Type": "application/json",
},
)
]

View file

@ -9,7 +9,7 @@ Module under test: api.services.configuration.resolve
import pytest
from api.schemas.user_configuration import EffectiveAIModelConfiguration
from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration
from api.services.configuration.masking import (
contains_masked_key,
mask_workflow_configurations,

View file

@ -10,7 +10,7 @@ from pipecat.processors.frame_processor import FrameDirection
from websockets.exceptions import ConnectionClosedError
from websockets.frames import Close
from api.schemas.user_configuration import EffectiveAIModelConfiguration
from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration
from api.services.configuration.registry import UltravoxRealtimeLLMConfiguration
from api.services.pipecat.realtime.ultravox_realtime import (
_RESUMPTION_USER_MESSAGE,
@ -430,7 +430,7 @@ async def test_receive_messages_reports_unexpected_websocket_close():
def test_factory_creates_dograh_ultravox_realtime_service():
user_config = EffectiveAIModelConfiguration(
effective_config = EffectiveAIModelConfiguration(
is_realtime=True,
realtime=UltravoxRealtimeLLMConfiguration(
provider="ultravox_realtime",
@ -441,7 +441,7 @@ def test_factory_creates_dograh_ultravox_realtime_service():
)
service = create_realtime_llm_service(
user_config,
effective_config,
audio_config=SimpleNamespace(),
)

View file

@ -4,7 +4,7 @@ from unittest.mock import AsyncMock, patch
import pytest
from api.db.models import OrganizationModel, UserModel
from api.schemas.user_configuration import EffectiveAIModelConfiguration
from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration
from api.tests.integrations._run_pipeline_helpers import USER_CONFIGURATION
from pipecat.tests import MockLLMService