feat: use mps generated correlation ID

This commit is contained in:
Abhishek Kumar 2026-06-09 18:24:40 +05:30
parent 91ac460799
commit 3336c6e794
30 changed files with 453 additions and 89 deletions

View file

@ -203,7 +203,7 @@ async def create_workflow_run_rows(
Returns:
Tuple of (workflow_run, user, workflow).
"""
from api.schemas.user_configuration import EffectiveAIModelConfiguration
from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration
org = OrganizationModel(provider_id=f"test-org-{provider_id_suffix}")
async_session.add(org)

View file

@ -3,10 +3,10 @@ from pydantic import ValidationError
from api.schemas.ai_model_configuration import (
DograhManagedAIModelConfiguration,
EffectiveAIModelConfiguration,
OrganizationAIModelConfigurationV2,
compile_ai_model_configuration_v2,
)
from api.schemas.user_configuration import EffectiveAIModelConfiguration
from api.services.configuration.ai_model_configuration import (
WORKFLOW_MODEL_CONFIGURATION_V2_OVERRIDE_KEY,
check_for_masked_keys_in_ai_model_configuration_v2,
@ -49,6 +49,7 @@ def test_dograh_v2_compiles_to_effective_managed_pipeline_with_embeddings():
assert effective.stt.language == "multi"
assert effective.embeddings.provider == "dograh"
assert effective.embeddings.model == "default"
assert effective.managed_service_version == 2
def test_dograh_v2_rejects_non_predefined_speed():

View file

@ -0,0 +1,110 @@
import json
import pytest
from openai._types import NOT_GIVEN as OPENAI_NOT_GIVEN
from pipecat.frames.frames import TTSStartedFrame
from pipecat.services.dograh.llm import DograhLLMService
from pipecat.services.dograh.stt import DograhSTTService
from pipecat.services.dograh.tts import DograhTTSService
from pipecat.services.openai.base_llm import OpenAILLMSettings
from websockets.protocol import State
class _FakeWebSocket:
def __init__(self):
self.state = State.OPEN
self.messages: list[dict] = []
async def send(self, message: str) -> None:
self.messages.append(json.loads(message))
async def close(self, *args, **kwargs) -> None:
self.state = State.CLOSED
def test_dograh_llm_uses_explicit_mps_correlation_id():
service = DograhLLMService(
api_key="mps-secret",
correlation_id="mps-corr-123",
settings=OpenAILLMSettings(model="default"),
)
service._start_metadata = {"workflow_run_id": 99}
params = service.build_chat_completion_params(
{
"messages": [],
"tools": OPENAI_NOT_GIVEN,
"tool_choice": OPENAI_NOT_GIVEN,
}
)
assert params["metadata"]["correlation_id"] == "mps-corr-123"
assert params["metadata"]["mps_billing_version"] == "2"
@pytest.mark.asyncio
async def test_dograh_stt_config_uses_explicit_mps_correlation_id(monkeypatch):
fake_ws = _FakeWebSocket()
async def fake_connect(url, additional_headers):
return fake_ws
monkeypatch.setattr(
"pipecat.services.dograh.stt.websocket_connect",
fake_connect,
)
service = DograhSTTService(
api_key="mps-secret",
correlation_id="mps-corr-123",
sample_rate=16000,
)
service._start_metadata = {"workflow_run_id": 99}
await service._connect_websocket()
assert fake_ws.messages[0]["type"] == "config"
assert fake_ws.messages[0]["correlation_id"] == "mps-corr-123"
assert fake_ws.messages[0]["mps_billing_version"] == "2"
@pytest.mark.asyncio
async def test_dograh_tts_messages_use_explicit_mps_correlation_id(monkeypatch):
fake_ws = _FakeWebSocket()
async def fake_connect(url, additional_headers):
return fake_ws
monkeypatch.setattr(
"pipecat.services.dograh.tts.websocket_connect",
fake_connect,
)
service = DograhTTSService(
api_key="mps-secret",
correlation_id="mps-corr-123",
sample_rate=24000,
)
service._start_metadata = {"workflow_run_id": 99}
await service._connect_websocket()
assert fake_ws.messages[0]["type"] == "config"
assert fake_ws.messages[0]["correlation_id"] == "mps-corr-123"
assert fake_ws.messages[0]["mps_billing_version"] == "2"
async def _noop(*args, **kwargs):
return None
service.audio_context_available = lambda context_id: False
service.create_audio_context = _noop
service.start_ttfb_metrics = _noop
service.start_tts_usage_metrics = _noop
frames = []
async for frame in service.run_tts("hello", "ctx-1"):
frames.append(frame)
assert isinstance(frames[0], TTSStartedFrame)
assert fake_ws.messages[1]["type"] == "create_context"
assert fake_ws.messages[1]["correlation_id"] == "mps-corr-123"
assert fake_ws.messages[1]["mps_billing_version"] == "2"

View file

@ -7,7 +7,7 @@ from pipecat.processors.aggregators.llm_context import LLMContext
from pipecat.processors.frame_processor import FrameDirection
from pipecat.services.xai.realtime import events
from api.schemas.user_configuration import EffectiveAIModelConfiguration
from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration
from api.services.configuration.registry import GrokRealtimeLLMConfiguration
from api.services.pipecat.realtime.grok_realtime import (
DograhGrokRealtimeLLMService,
@ -120,7 +120,7 @@ async def test_completed_input_transcription_is_broadcast_as_finalized():
def test_factory_creates_dograh_grok_realtime_service():
user_config = EffectiveAIModelConfiguration(
effective_config = EffectiveAIModelConfiguration(
is_realtime=True,
realtime=GrokRealtimeLLMConfiguration(
provider="grok_realtime",
@ -131,7 +131,7 @@ def test_factory_creates_dograh_grok_realtime_service():
)
service = create_realtime_llm_service(
user_config,
effective_config,
audio_config=SimpleNamespace(),
)

View file

@ -5,7 +5,7 @@ from fastapi import FastAPI
from fastapi.testclient import TestClient
from api.routes.user import router
from api.schemas.user_configuration import EffectiveAIModelConfiguration
from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration
from api.services.auth.depends import get_user
from api.services.configuration.masking import mask_key
from api.services.configuration.registry import (

View file

@ -87,3 +87,44 @@ async def test_check_service_key_usage_uses_bearer_self_usage(monkeypatch):
"Content-Type": "application/json",
},
)
@pytest.mark.asyncio
async def test_create_correlation_id_uses_bearer_auth(monkeypatch):
calls = []
class FakeAsyncClient:
def __init__(self, timeout):
self.timeout = timeout
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc, tb):
return None
async def post(self, url, json, headers):
calls.append(("POST", url, json, headers))
return _Response(200, {"correlation_id": "mps-corr-123"})
monkeypatch.setattr(
"api.services.mps_service_key_client.httpx.AsyncClient", FakeAsyncClient
)
client = MPSServiceKeyClient()
assert await client.create_correlation_id(
service_key="mps_sk_paid",
workflow_run_id=42,
) == {"correlation_id": "mps-corr-123"}
assert calls == [
(
"POST",
f"{client.base_url}/api/v1/service-keys/correlation-id/self",
{"workflow_run_id": 42},
{
"Authorization": "Bearer mps_sk_paid",
"Content-Type": "application/json",
},
)
]

View file

@ -9,7 +9,7 @@ Module under test: api.services.configuration.resolve
import pytest
from api.schemas.user_configuration import EffectiveAIModelConfiguration
from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration
from api.services.configuration.masking import (
contains_masked_key,
mask_workflow_configurations,

View file

@ -10,7 +10,7 @@ from pipecat.processors.frame_processor import FrameDirection
from websockets.exceptions import ConnectionClosedError
from websockets.frames import Close
from api.schemas.user_configuration import EffectiveAIModelConfiguration
from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration
from api.services.configuration.registry import UltravoxRealtimeLLMConfiguration
from api.services.pipecat.realtime.ultravox_realtime import (
_RESUMPTION_USER_MESSAGE,
@ -430,7 +430,7 @@ async def test_receive_messages_reports_unexpected_websocket_close():
def test_factory_creates_dograh_ultravox_realtime_service():
user_config = EffectiveAIModelConfiguration(
effective_config = EffectiveAIModelConfiguration(
is_realtime=True,
realtime=UltravoxRealtimeLLMConfiguration(
provider="ultravox_realtime",
@ -441,7 +441,7 @@ def test_factory_creates_dograh_ultravox_realtime_service():
)
service = create_realtime_llm_service(
user_config,
effective_config,
audio_config=SimpleNamespace(),
)

View file

@ -4,7 +4,7 @@ from unittest.mock import AsyncMock, patch
import pytest
from api.db.models import OrganizationModel, UserModel
from api.schemas.user_configuration import EffectiveAIModelConfiguration
from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration
from api.tests.integrations._run_pipeline_helpers import USER_CONFIGURATION
from pipecat.tests import MockLLMService