dograh/api/tests/test_ai_model_configuration_v2.py
Abhishek 1f1149f4d5
feat: billing and credit management v2 (#429)
* feat: use mps generated correlation ID

* chore: update pipecat submodule

* feat: add credit purchase URL

* feat: carve out billing page and show credit ledger

* feat: deprecate dograh based quota tracking

* fix: remove cost calculation from dograh codebase

* fix: create mps account on migrate to v2

* chore: update pipecat
2026-06-12 14:55:30 +05:30

459 lines
14 KiB
Python

from types import SimpleNamespace
from unittest.mock import AsyncMock
import pytest
from pydantic import ValidationError
from api.schemas.ai_model_configuration import (
DograhManagedAIModelConfiguration,
EffectiveAIModelConfiguration,
OrganizationAIModelConfigurationResponse,
OrganizationAIModelConfigurationV2,
compile_ai_model_configuration_v2,
)
from api.services.configuration.ai_model_configuration import (
WORKFLOW_MODEL_CONFIGURATION_V2_OVERRIDE_KEY,
check_for_masked_keys_in_ai_model_configuration_v2,
convert_legacy_ai_model_configuration_to_v2,
mask_ai_model_configuration_v2,
merge_ai_model_configuration_v2_secrets,
migrate_workflow_configuration_model_override_to_v2,
)
from api.services.configuration.check_validity import UserConfigurationValidator
from api.services.configuration.masking import mask_key
from api.services.configuration.registry import (
DeepgramSTTConfiguration,
DograhLLMService,
DograhSTTService,
DograhTTSService,
ElevenlabsTTSConfiguration,
GoogleLLMService,
GoogleRealtimeLLMConfiguration,
OpenAIEmbeddingsConfiguration,
OpenAILLMService,
)
def test_dograh_v2_compiles_to_effective_managed_pipeline_with_embeddings():
config = OrganizationAIModelConfigurationV2(
mode="dograh",
dograh=DograhManagedAIModelConfiguration(
api_key="mps-secret",
voice="default",
speed=1.2,
language="multi",
),
)
effective = compile_ai_model_configuration_v2(config)
assert effective.is_realtime is False
assert effective.llm.provider == "dograh"
assert effective.llm.model == "default"
assert effective.tts.provider == "dograh"
assert effective.tts.speed == 1.2
assert effective.stt.provider == "dograh"
assert effective.stt.language == "multi"
assert effective.embeddings.provider == "dograh"
assert effective.embeddings.model == "default"
assert effective.managed_service_version == 2
def test_dograh_v2_rejects_non_predefined_speed():
with pytest.raises(ValidationError):
OrganizationAIModelConfigurationV2(
mode="dograh",
dograh=DograhManagedAIModelConfiguration(
api_key="mps-secret",
speed=1.5,
),
)
def test_byok_v2_rejects_dograh_provider():
with pytest.raises(ValidationError):
OrganizationAIModelConfigurationV2.model_validate(
{
"mode": "byok",
"byok": {
"mode": "pipeline",
"pipeline": {
"llm": {
"provider": "dograh",
"api_key": "mps-secret",
"model": "default",
},
"tts": {
"provider": "dograh",
"api_key": "mps-secret",
"model": "default",
"voice": "default",
},
"stt": {
"provider": "dograh",
"api_key": "mps-secret",
"model": "default",
},
},
},
}
)
@pytest.mark.asyncio
async def test_byok_realtime_validator_does_not_require_stt_or_tts():
config = OrganizationAIModelConfigurationV2.model_validate(
{
"mode": "byok",
"byok": {
"mode": "realtime",
"realtime": {
"realtime": {
"provider": "google_realtime",
"api_key": "google-realtime-key",
"model": "gemini-3.1-flash-live-preview",
"voice": "Puck",
"language": "en",
},
"llm": {
"provider": "google",
"api_key": "google-llm-key",
"model": "gemini-2.0-flash",
},
},
},
}
)
effective = compile_ai_model_configuration_v2(config)
assert effective.is_realtime is True
assert effective.stt is None
assert effective.tts is None
assert await UserConfigurationValidator().validate(effective) == {
"status": [{"model": "all", "message": "ok"}]
}
@pytest.mark.asyncio
async def test_pipeline_validator_requires_stt_and_tts_when_not_realtime():
effective = EffectiveAIModelConfiguration(
llm=GoogleLLMService(
provider="google",
api_key="google-llm-key",
model="gemini-2.0-flash",
),
realtime=GoogleRealtimeLLMConfiguration(
provider="google_realtime",
api_key="google-realtime-key",
model="gemini-3.1-flash-live-preview",
voice="Puck",
language="en",
),
is_realtime=False,
)
with pytest.raises(ValueError) as exc_info:
await UserConfigurationValidator().validate(effective)
assert exc_info.value.args[0] == [
{"model": "stt", "message": "API key is missing"},
{"model": "tts", "message": "API key is missing"},
]
def test_masked_dograh_key_is_preserved_when_saving_same_mode():
existing = OrganizationAIModelConfigurationV2(
mode="dograh",
dograh=DograhManagedAIModelConfiguration(api_key="mps-real-secret"),
)
incoming = OrganizationAIModelConfigurationV2(
mode="dograh",
dograh=DograhManagedAIModelConfiguration(api_key=mask_key("mps-real-secret")),
)
merged = merge_ai_model_configuration_v2_secrets(incoming, existing)
assert merged.dograh.api_key == "mps-real-secret"
check_for_masked_keys_in_ai_model_configuration_v2(merged)
def test_masked_v2_configuration_masks_nested_service_keys():
config = OrganizationAIModelConfigurationV2(
mode="byok",
byok={
"mode": "pipeline",
"pipeline": {
"llm": {
"provider": "openai",
"api_key": "sk-real-secret",
"model": "gpt-4.1",
},
"tts": {
"provider": "elevenlabs",
"api_key": "el-real-secret",
"model": "eleven_flash_v2_5",
"voice": "Rachel",
},
"stt": {
"provider": "deepgram",
"api_key": "dg-real-secret",
"model": "nova-3-general",
},
},
},
)
masked = mask_ai_model_configuration_v2(config)
assert masked["byok"]["pipeline"]["llm"]["api_key"] == mask_key("sk-real-secret")
assert masked["byok"]["pipeline"]["tts"]["api_key"] == mask_key("el-real-secret")
assert masked["byok"]["pipeline"]["stt"]["api_key"] == mask_key("dg-real-secret")
def test_legacy_all_dograh_pipeline_converts_to_dograh_v2():
legacy = EffectiveAIModelConfiguration(
llm=DograhLLMService(
provider="dograh",
api_key=["mps-secret"],
model="default",
),
tts=DograhTTSService(
provider="dograh",
api_key=["mps-secret"],
model="default",
voice="default",
speed=1.0,
),
stt=DograhSTTService(
provider="dograh",
api_key=["mps-secret"],
model="default",
language="multi",
),
)
config = convert_legacy_ai_model_configuration_to_v2(legacy)
assert config.mode == "dograh"
assert config.dograh.api_key == "mps-secret"
def test_legacy_mixed_dograh_pipeline_converts_to_dograh_v2():
legacy = EffectiveAIModelConfiguration(
llm=OpenAILLMService(
provider="openai",
api_key="sk-llm",
model="gpt-4.1",
),
tts=DograhTTSService(
provider="dograh",
api_key="mps-tts",
model="default",
voice="default",
),
stt=DograhSTTService(
provider="dograh",
api_key="mps-stt",
model="default",
),
embeddings=OpenAIEmbeddingsConfiguration(
provider="openai",
api_key="sk-emb",
model="text-embedding-3-small",
),
)
config = convert_legacy_ai_model_configuration_to_v2(legacy)
assert config.mode == "dograh"
assert config.dograh.api_key == "mps-tts"
assert config.dograh.voice == "default"
def test_legacy_byok_pipeline_converts_to_byok_v2():
legacy = EffectiveAIModelConfiguration(
llm=OpenAILLMService(
provider="openai",
api_key="sk-llm",
model="gpt-4.1",
),
tts=ElevenlabsTTSConfiguration(
provider="elevenlabs",
api_key="el-tts",
model="eleven_flash_v2_5",
voice="Rachel",
),
stt=DeepgramSTTConfiguration(
provider="deepgram",
api_key="dg-stt",
model="nova-3-general",
),
embeddings=OpenAIEmbeddingsConfiguration(
provider="openai",
api_key="sk-emb",
model="text-embedding-3-small",
),
)
config = convert_legacy_ai_model_configuration_to_v2(legacy)
assert config.mode == "byok"
assert config.byok.mode == "pipeline"
assert config.byok.pipeline.llm.provider == "openai"
assert config.byok.pipeline.tts.provider == "elevenlabs"
def test_workflow_model_override_migration_removes_v1_override_and_sets_v2():
base = EffectiveAIModelConfiguration(
llm=OpenAILLMService(
provider="openai",
api_key="sk-llm",
model="gpt-4.1",
),
tts=ElevenlabsTTSConfiguration(
provider="elevenlabs",
api_key="el-tts",
model="eleven_flash_v2_5",
voice="Rachel",
),
stt=DeepgramSTTConfiguration(
provider="deepgram",
api_key="dg-stt",
model="nova-3-general",
),
)
workflow_configurations = {
"ambient_noise_configuration": {"enabled": False},
"model_overrides": {
"tts": {
"provider": "dograh",
"api_key": "mps-workflow",
"model": "default",
"voice": "default",
}
},
}
migrated, changed = migrate_workflow_configuration_model_override_to_v2(
workflow_configurations,
base,
)
assert changed is True
assert "model_overrides" not in migrated
assert migrated["ambient_noise_configuration"] == {"enabled": False}
v2_override = migrated[WORKFLOW_MODEL_CONFIGURATION_V2_OVERRIDE_KEY]
assert v2_override["mode"] == "dograh"
assert v2_override["dograh"]["api_key"] == "mps-workflow"
def test_workflow_model_override_migration_removes_invalid_v1_override_marker():
base = EffectiveAIModelConfiguration()
workflow_configurations = {
"ambient_noise_configuration": {"enabled": False},
"model_overrides": None,
}
migrated, changed = migrate_workflow_configuration_model_override_to_v2(
workflow_configurations,
base,
)
assert changed is True
assert "model_overrides" not in migrated
assert migrated["ambient_noise_configuration"] == {"enabled": False}
@pytest.mark.asyncio
async def test_migrate_model_configuration_v2_initializes_hosted_mps_billing(
monkeypatch,
):
from api.routes import organization as organization_routes
legacy = EffectiveAIModelConfiguration(
llm=DograhLLMService(
provider="dograh",
api_key=["mps-secret"],
model="default",
),
tts=DograhTTSService(
provider="dograh",
api_key=["mps-secret"],
model="default",
voice="default",
),
stt=DograhSTTService(
provider="dograh",
api_key=["mps-secret"],
model="default",
),
)
expected_response = OrganizationAIModelConfigurationResponse(
configuration={"version": 2, "mode": "dograh"},
effective_configuration={},
source="organization_v2",
)
class FakeValidator:
async def validate(self, *args, **kwargs):
return {"status": [{"model": "all", "message": "ok"}]}
ensure_billing = AsyncMock(return_value={"billing_mode": "v2"})
upsert = AsyncMock()
migrate_workflows = AsyncMock()
monkeypatch.setattr(organization_routes, "DEPLOYMENT_MODE", "saas")
monkeypatch.setattr(
organization_routes,
"get_organization_ai_model_configuration_v2",
AsyncMock(return_value=None),
)
monkeypatch.setattr(
organization_routes.db_client,
"get_user_configurations",
AsyncMock(return_value=legacy),
)
monkeypatch.setattr(
organization_routes,
"UserConfigurationValidator",
lambda: FakeValidator(),
)
monkeypatch.setattr(
organization_routes,
"ensure_hosted_mps_billing_account_v2",
ensure_billing,
)
monkeypatch.setattr(
organization_routes,
"upsert_organization_ai_model_configuration_v2",
upsert,
)
monkeypatch.setattr(
organization_routes,
"migrate_workflow_model_configurations_to_v2",
migrate_workflows,
)
monkeypatch.setattr(
organization_routes,
"_model_configuration_v2_response",
AsyncMock(return_value=expected_response),
)
user = SimpleNamespace(
id=7,
provider_id="provider-123",
selected_organization_id=42,
)
response = await organization_routes.migrate_model_configuration_v2(
force=False,
user=user,
)
ensure_billing.assert_awaited_once_with(42, created_by="provider-123")
upsert.assert_awaited_once()
migrate_workflows.assert_awaited_once_with(
organization_id=42,
fallback_user_config=legacy,
)
assert response == expected_response