feat: use mps generated correlation ID

This commit is contained in:
Abhishek Kumar 2026-06-09 18:24:40 +05:30
parent 91ac460799
commit 3336c6e794
30 changed files with 453 additions and 89 deletions

View file

@ -9,7 +9,7 @@ from api.constants import AUTH_PROVIDER, DOGRAH_MPS_SECRET_KEY, MPS_API_URL
from api.db import db_client
from api.db.models import UserModel
from api.enums import PostHogEvent
from api.schemas.user_configuration import EffectiveAIModelConfiguration
from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration
from api.services.auth.stack_auth import stackauth
from api.services.configuration.registry import ServiceProviders
from api.services.posthog_client import capture_event
@ -285,8 +285,8 @@ async def create_user_configuration_with_mps_key(
"model": "default",
},
}
user_config = EffectiveAIModelConfiguration(**configuration)
return user_config
effective_config = EffectiveAIModelConfiguration(**configuration)
return effective_config
else:
logger.warning(
f"Failed to get MPS service key: {response.status_code} - {response.text}"

View file

@ -21,10 +21,10 @@ from api.schemas.ai_model_configuration import (
BYOKPipelineAIModelConfiguration,
BYOKRealtimeAIModelConfiguration,
DograhManagedAIModelConfiguration,
EffectiveAIModelConfiguration,
OrganizationAIModelConfigurationV2,
compile_ai_model_configuration_v2,
)
from api.schemas.user_configuration import EffectiveAIModelConfiguration
from api.services.configuration.masking import (
SERVICE_SECRET_FIELDS,
contains_masked_key,

View file

@ -8,7 +8,7 @@ from groq import Groq
# from pyneuphonic import Neuphonic
# except ImportError:
# Neuphonic = None
from api.schemas.user_configuration import (
from api.schemas.ai_model_configuration import (
EffectiveAIModelConfiguration,
)
from api.services.configuration.registry import ServiceConfig, ServiceProviders

View file

@ -12,7 +12,7 @@ The rules are simple:
import copy
from typing import Any, Dict, Optional
from api.schemas.user_configuration import EffectiveAIModelConfiguration
from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration
from api.services.configuration.registry import ServiceConfig
from api.services.integrations import get_node_secret_fields

View file

@ -7,7 +7,7 @@ stored, while honouring masked API keys.
import copy
from typing import Dict
from api.schemas.user_configuration import EffectiveAIModelConfiguration
from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration
from api.services.configuration.masking import (
MODEL_OVERRIDE_FIELDS,
SERVICE_SECRET_FIELDS,

View file

@ -4,7 +4,7 @@ from __future__ import annotations
import copy
from api.schemas.user_configuration import EffectiveAIModelConfiguration
from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration
from api.services.configuration.registry import (
REGISTRY,
ServiceType,

View file

@ -38,6 +38,7 @@ class OpenAIEmbeddingService(BaseEmbeddingService):
api_key: Optional[str] = None,
model_id: str = DEFAULT_MODEL_ID,
base_url: Optional[str] = None,
default_headers: Optional[Dict[str, str]] = None,
):
"""Initialize the OpenAI embedding service.
@ -60,6 +61,8 @@ class OpenAIEmbeddingService(BaseEmbeddingService):
field_name="base_url",
)
client_kwargs["base_url"] = base_url
if default_headers:
client_kwargs["default_headers"] = default_headers
self.client = AsyncOpenAI(**client_kwargs)
logger.info(f"OpenAI embedding service initialized with model: {model_id}")
else:

View file

@ -0,0 +1,98 @@
from __future__ import annotations
from typing import Any
from loguru import logger
from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration
from api.services.configuration.registry import ServiceProviders
from api.services.mps_service_key_client import mps_service_key_client
MPS_CORRELATION_ID_CONTEXT_KEY = "mps_correlation_id"
def uses_managed_model_services_v2(
ai_model_config: EffectiveAIModelConfiguration | None,
) -> bool:
if (
ai_model_config is None
or getattr(ai_model_config, "managed_service_version", None) != 2
):
return False
return any(
_is_dograh_service(getattr(ai_model_config, section_name, None))
for section_name in ("llm", "tts", "stt", "embeddings")
)
def get_mps_correlation_id(initial_context: dict[str, Any] | None) -> str | None:
if not initial_context:
return None
correlation_id = initial_context.get(MPS_CORRELATION_ID_CONTEXT_KEY)
if correlation_id is None:
return None
return str(correlation_id)
async def ensure_mps_correlation_id(
*,
ai_model_config: EffectiveAIModelConfiguration,
workflow_run_id: int,
initial_context: dict[str, Any] | None,
) -> str | None:
existing = get_mps_correlation_id(initial_context)
if existing:
return existing
if not uses_managed_model_services_v2(ai_model_config):
return None
service_key = _get_dograh_service_api_key(ai_model_config)
if not service_key:
raise ValueError(
"Managed model services v2 requires a Dograh service key before the run starts."
)
response = await mps_service_key_client.create_correlation_id(
service_key=service_key,
workflow_run_id=workflow_run_id,
)
correlation_id = response.get("correlation_id")
if not correlation_id:
raise ValueError("MPS correlation-id response did not include correlation_id")
correlation_id = str(correlation_id)
logger.info(
"Minted MPS correlation id {} for workflow run {}",
correlation_id,
workflow_run_id,
)
return correlation_id
def _is_dograh_service(service: Any) -> bool:
provider = getattr(service, "provider", None)
return (
provider == ServiceProviders.DOGRAH or provider == ServiceProviders.DOGRAH.value
)
def _get_dograh_service_api_key(
ai_model_config: EffectiveAIModelConfiguration,
) -> str | None:
for section_name in ("llm", "tts", "stt", "embeddings"):
service = getattr(ai_model_config, section_name, None)
if not _is_dograh_service(service):
continue
if hasattr(service, "get_all_api_keys"):
keys = service.get_all_api_keys()
if keys:
return keys[0]
api_key = getattr(service, "api_key", None)
if isinstance(api_key, str) and api_key:
return api_key
return None

View file

@ -353,6 +353,40 @@ class MPSServiceKeyClient:
response=response,
)
async def create_correlation_id(
self,
*,
service_key: str,
workflow_run_id: int | None = None,
) -> dict:
"""Mint a server-generated correlation ID for managed model services."""
payload: dict[str, int] = {}
if workflow_run_id is not None:
payload["workflow_run_id"] = workflow_run_id
async with httpx.AsyncClient(timeout=self.timeout) as client:
response = await client.post(
f"{self.base_url}/api/v1/service-keys/correlation-id/self",
json=payload,
headers={
"Authorization": f"Bearer {service_key}",
"Content-Type": "application/json",
},
)
if response.status_code == 200:
return response.json()
logger.error(
"Failed to create correlation ID: "
f"{response.status_code} - {response.text}"
)
raise httpx.HTTPStatusError(
f"Failed to create correlation ID: {response.text}",
request=response.request,
response=response,
)
async def transcribe_audio(
self,
audio_data: bytes,

View file

@ -340,7 +340,7 @@ async def _run_pipeline(
if workflow_run.is_completed:
raise HTTPException(status_code=400, detail="Workflow run already completed")
merged_call_context_vars = workflow_run.initial_context
merged_call_context_vars = dict(workflow_run.initial_context or {})
# If there is some extra call_context_vars, fold them in. Persistence
# happens once below, after runtime_configuration is also resolved.
if call_context_vars:
@ -398,6 +398,19 @@ async def _run_pipeline(
else:
user_config = resolved_user_config
from api.services.managed_model_services import (
MPS_CORRELATION_ID_CONTEXT_KEY,
ensure_mps_correlation_id,
)
mps_correlation_id = await ensure_mps_correlation_id(
ai_model_config=user_config,
workflow_run_id=workflow_run_id,
initial_context=merged_call_context_vars,
)
if mps_correlation_id:
merged_call_context_vars[MPS_CORRELATION_ID_CONTEXT_KEY] = mps_correlation_id
# Detect realtime mode (speech-to-speech services like OpenAI Realtime, Gemini Live)
is_realtime = user_config.is_realtime and user_config.realtime is not None
@ -409,11 +422,23 @@ async def _run_pipeline(
# Realtime services don't implement run_inference, so create a
# separate text LLM for variable extraction and other out-of-band
# inference calls.
inference_llm = create_llm_service(user_config)
inference_llm = create_llm_service(
user_config,
correlation_id=mps_correlation_id,
)
else:
stt = create_stt_service(user_config, audio_config, keyterms=keyterms)
tts = create_tts_service(user_config, audio_config)
llm = create_llm_service(user_config)
stt = create_stt_service(
user_config,
audio_config,
keyterms=keyterms,
correlation_id=mps_correlation_id,
)
tts = create_tts_service(
user_config,
audio_config,
correlation_id=mps_correlation_id,
)
llm = create_llm_service(user_config, correlation_id=mps_correlation_id)
inference_llm = None
# Stamp the providers/models actually resolved for this run onto
@ -695,7 +720,10 @@ async def _run_pipeline(
# Create a separate LLM instance for the voicemail sub-pipeline
# (can't share with main pipeline as it would mess up frame linking)
if voicemail_config.get("use_workflow_llm", True):
voicemail_llm = create_llm_service(user_config)
voicemail_llm = create_llm_service(
user_config,
correlation_id=mps_correlation_id,
)
else:
voicemail_llm = create_llm_service_from_provider(
provider=voicemail_config.get("provider", "openai"),

View file

@ -78,7 +78,10 @@ def _validate_runtime_service_url(url: str, field_name: str) -> None:
def create_stt_service(
user_config, audio_config: "AudioConfig", keyterms: list[str] | None = None
user_config,
audio_config: "AudioConfig",
keyterms: list[str] | None = None,
correlation_id: str | None = None,
):
"""Create and return appropriate STT service based on user configuration
@ -160,6 +163,7 @@ def create_stt_service(
return DograhSTTService(
base_url=base_url,
api_key=user_config.stt.api_key,
correlation_id=correlation_id,
settings=DograhSTTSettings(
model=user_config.stt.model,
language=language,
@ -286,7 +290,9 @@ def create_stt_service(
)
def create_tts_service(user_config, audio_config: "AudioConfig"):
def create_tts_service(
user_config, audio_config: "AudioConfig", correlation_id: str | None = None
):
"""Create and return appropriate TTS service based on user configuration
Args:
@ -404,6 +410,7 @@ def create_tts_service(user_config, audio_config: "AudioConfig"):
return DograhTTSService(
base_url=base_url,
api_key=user_config.tts.api_key,
correlation_id=correlation_id,
settings=DograhTTSSettings(
model=user_config.tts.model,
voice=user_config.tts.voice,
@ -564,6 +571,7 @@ def create_llm_service_from_provider(
model: str,
api_key: str | None,
*,
correlation_id: str | None = None,
base_url: str | None = None,
endpoint: str | None = None,
aws_access_key: str | None = None,
@ -637,6 +645,7 @@ def create_llm_service_from_provider(
return DograhLLMService(
base_url=f"{MPS_API_URL}/api/v1/llm",
api_key=api_key,
correlation_id=correlation_id,
settings=OpenAILLMSettings(model=model),
)
elif provider == ServiceProviders.AWS_BEDROCK.value:
@ -851,7 +860,7 @@ def create_realtime_llm_service(user_config, audio_config: "AudioConfig"):
)
def create_llm_service(user_config):
def create_llm_service(user_config, correlation_id: str | None = None):
"""Create and return appropriate LLM service based on user configuration."""
provider = user_config.llm.provider
model = user_config.llm.model
@ -880,4 +889,10 @@ def create_llm_service(user_config):
elif provider == ServiceProviders.SARVAM.value:
kwargs["temperature"] = user_config.llm.temperature
return create_llm_service_from_provider(provider, model, api_key, **kwargs)
return create_llm_service_from_provider(
provider,
model,
api_key,
correlation_id=correlation_id,
**kwargs,
)

View file

@ -35,6 +35,7 @@ import asyncio
from loguru import logger
from api.services.managed_model_services import MPS_CORRELATION_ID_CONTEXT_KEY
from api.services.workflow import pipecat_engine_callbacks as engine_callbacks
from api.services.workflow.mcp_tool_session import McpToolSession
from api.services.workflow.pipecat_engine_context_composer import (
@ -382,6 +383,9 @@ class PipecatEngine:
embeddings_provider=self._embeddings_provider,
embeddings_endpoint=self._embeddings_endpoint,
embeddings_api_version=self._embeddings_api_version,
correlation_id=self._call_context_vars.get(
MPS_CORRELATION_ID_CONTEXT_KEY
),
tracing_context=self._get_otel_context(),
)

View file

@ -421,7 +421,19 @@ async def execute_text_chat_pending_turn(
if user_config.llm is None:
raise ValueError("Text chat requires an LLM configuration")
llm = create_llm_service(user_config)
from api.services.managed_model_services import (
MPS_CORRELATION_ID_CONTEXT_KEY,
ensure_mps_correlation_id,
)
base_initial_context = dict(workflow_run.initial_context or {})
mps_correlation_id = await ensure_mps_correlation_id(
ai_model_config=user_config,
workflow_run_id=workflow_run_id,
initial_context=base_initial_context,
)
llm = create_llm_service(user_config, correlation_id=mps_correlation_id)
inference_llm = llm
runtime_configuration = {
@ -429,9 +441,15 @@ async def execute_text_chat_pending_turn(
"llm_model": user_config.llm.model,
}
initial_context = {
**(workflow_run.initial_context or {}),
**base_initial_context,
"runtime_configuration": runtime_configuration,
}
if mps_correlation_id:
initial_context[MPS_CORRELATION_ID_CONTEXT_KEY] = mps_correlation_id
await db_client.update_workflow_run(
workflow_run_id,
initial_context=initial_context,
)
workflow_graph = WorkflowGraph(
ReactFlowDTO.model_validate(run_definition.workflow_json)

View file

@ -29,6 +29,7 @@ async def retrieve_from_knowledge_base(
embeddings_provider: Optional[str] = None,
embeddings_endpoint: Optional[str] = None,
embeddings_api_version: Optional[str] = None,
correlation_id: Optional[str] = None,
tracing_context=None,
) -> Dict[str, Any]:
"""Retrieve relevant information from the knowledge base using vector similarity search.
@ -75,6 +76,7 @@ async def retrieve_from_knowledge_base(
embeddings_provider,
embeddings_endpoint,
embeddings_api_version,
correlation_id,
)
# Create span with parent context
@ -115,6 +117,7 @@ async def retrieve_from_knowledge_base(
embeddings_provider,
embeddings_endpoint,
embeddings_api_version,
correlation_id,
)
# Add result metadata to span
@ -192,6 +195,7 @@ async def retrieve_from_knowledge_base(
embeddings_provider,
embeddings_endpoint,
embeddings_api_version,
correlation_id,
)
else:
# Tracing is disabled - perform retrieval without tracing
@ -206,6 +210,7 @@ async def retrieve_from_knowledge_base(
embeddings_provider,
embeddings_endpoint,
embeddings_api_version,
correlation_id,
)
@ -220,6 +225,7 @@ async def _perform_retrieval(
embeddings_provider: Optional[str] = None,
embeddings_endpoint: Optional[str] = None,
embeddings_api_version: Optional[str] = None,
correlation_id: Optional[str] = None,
) -> Dict[str, Any]:
"""Internal function to perform the actual retrieval operation.
@ -272,11 +278,20 @@ async def _perform_retrieval(
api_version=embeddings_api_version or "2024-02-15-preview",
)
else:
default_headers = None
if (
embeddings_provider == ServiceProviders.DOGRAH.value
and correlation_id
):
default_headers = {
"X-Dograh-Correlation-Id": correlation_id,
}
embedding_service = OpenAIEmbeddingService(
db_client=db_client,
api_key=embeddings_api_key,
model_id=embeddings_model or "text-embedding-3-small",
base_url=embeddings_base_url,
default_headers=default_headers,
)
results = await embedding_service.search_similar_chunks(