mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-10 08:05:22 +02:00
feat: add config v2 to simplify billing (#428)
* feat: add model config v2 * chore: centralize user org selection * chore: move preferences to platform settings * fix: decouple org preference and ai model preferences
This commit is contained in:
parent
49e68b49d5
commit
cdbd06c8d9
42 changed files with 5135 additions and 264 deletions
295
api/tests/test_ai_model_configuration_v2.py
Normal file
295
api/tests/test_ai_model_configuration_v2.py
Normal file
|
|
@ -0,0 +1,295 @@
|
|||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from api.schemas.ai_model_configuration import (
|
||||
DograhManagedAIModelConfiguration,
|
||||
OrganizationAIModelConfigurationV2,
|
||||
compile_ai_model_configuration_v2,
|
||||
)
|
||||
from api.schemas.user_configuration import UserConfiguration
|
||||
from api.services.configuration.ai_model_configuration import (
|
||||
WORKFLOW_MODEL_CONFIGURATION_V2_OVERRIDE_KEY,
|
||||
check_for_masked_keys_in_ai_model_configuration_v2,
|
||||
convert_legacy_ai_model_configuration_to_v2,
|
||||
mask_ai_model_configuration_v2,
|
||||
merge_ai_model_configuration_v2_secrets,
|
||||
migrate_workflow_configuration_model_override_to_v2,
|
||||
)
|
||||
from api.services.configuration.masking import mask_key
|
||||
from api.services.configuration.registry import (
|
||||
DeepgramSTTConfiguration,
|
||||
DograhLLMService,
|
||||
DograhSTTService,
|
||||
DograhTTSService,
|
||||
ElevenlabsTTSConfiguration,
|
||||
OpenAIEmbeddingsConfiguration,
|
||||
OpenAILLMService,
|
||||
)
|
||||
|
||||
|
||||
def test_dograh_v2_compiles_to_effective_managed_pipeline_with_embeddings():
|
||||
config = OrganizationAIModelConfigurationV2(
|
||||
mode="dograh",
|
||||
dograh=DograhManagedAIModelConfiguration(
|
||||
api_key="mps-secret",
|
||||
voice="default",
|
||||
speed=1.2,
|
||||
language="multi",
|
||||
),
|
||||
)
|
||||
|
||||
effective = compile_ai_model_configuration_v2(config)
|
||||
|
||||
assert effective.is_realtime is False
|
||||
assert effective.llm.provider == "dograh"
|
||||
assert effective.llm.model == "default"
|
||||
assert effective.tts.provider == "dograh"
|
||||
assert effective.tts.speed == 1.2
|
||||
assert effective.stt.provider == "dograh"
|
||||
assert effective.stt.language == "multi"
|
||||
assert effective.embeddings.provider == "dograh"
|
||||
assert effective.embeddings.model == "default"
|
||||
|
||||
|
||||
def test_dograh_v2_rejects_non_predefined_speed():
|
||||
with pytest.raises(ValidationError):
|
||||
OrganizationAIModelConfigurationV2(
|
||||
mode="dograh",
|
||||
dograh=DograhManagedAIModelConfiguration(
|
||||
api_key="mps-secret",
|
||||
speed=1.5,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def test_byok_v2_rejects_dograh_provider():
|
||||
with pytest.raises(ValidationError):
|
||||
OrganizationAIModelConfigurationV2.model_validate(
|
||||
{
|
||||
"mode": "byok",
|
||||
"byok": {
|
||||
"mode": "pipeline",
|
||||
"pipeline": {
|
||||
"llm": {
|
||||
"provider": "dograh",
|
||||
"api_key": "mps-secret",
|
||||
"model": "default",
|
||||
},
|
||||
"tts": {
|
||||
"provider": "dograh",
|
||||
"api_key": "mps-secret",
|
||||
"model": "default",
|
||||
"voice": "default",
|
||||
},
|
||||
"stt": {
|
||||
"provider": "dograh",
|
||||
"api_key": "mps-secret",
|
||||
"model": "default",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def test_masked_dograh_key_is_preserved_when_saving_same_mode():
|
||||
existing = OrganizationAIModelConfigurationV2(
|
||||
mode="dograh",
|
||||
dograh=DograhManagedAIModelConfiguration(api_key="mps-real-secret"),
|
||||
)
|
||||
incoming = OrganizationAIModelConfigurationV2(
|
||||
mode="dograh",
|
||||
dograh=DograhManagedAIModelConfiguration(api_key=mask_key("mps-real-secret")),
|
||||
)
|
||||
|
||||
merged = merge_ai_model_configuration_v2_secrets(incoming, existing)
|
||||
|
||||
assert merged.dograh.api_key == "mps-real-secret"
|
||||
check_for_masked_keys_in_ai_model_configuration_v2(merged)
|
||||
|
||||
|
||||
def test_masked_v2_configuration_masks_nested_service_keys():
|
||||
config = OrganizationAIModelConfigurationV2(
|
||||
mode="byok",
|
||||
byok={
|
||||
"mode": "pipeline",
|
||||
"pipeline": {
|
||||
"llm": {
|
||||
"provider": "openai",
|
||||
"api_key": "sk-real-secret",
|
||||
"model": "gpt-4.1",
|
||||
},
|
||||
"tts": {
|
||||
"provider": "elevenlabs",
|
||||
"api_key": "el-real-secret",
|
||||
"model": "eleven_flash_v2_5",
|
||||
"voice": "Rachel",
|
||||
},
|
||||
"stt": {
|
||||
"provider": "deepgram",
|
||||
"api_key": "dg-real-secret",
|
||||
"model": "nova-3-general",
|
||||
},
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
masked = mask_ai_model_configuration_v2(config)
|
||||
|
||||
assert masked["byok"]["pipeline"]["llm"]["api_key"] == mask_key("sk-real-secret")
|
||||
assert masked["byok"]["pipeline"]["tts"]["api_key"] == mask_key("el-real-secret")
|
||||
assert masked["byok"]["pipeline"]["stt"]["api_key"] == mask_key("dg-real-secret")
|
||||
|
||||
|
||||
def test_legacy_all_dograh_pipeline_converts_to_dograh_v2():
|
||||
legacy = UserConfiguration(
|
||||
llm=DograhLLMService(
|
||||
provider="dograh",
|
||||
api_key=["mps-secret"],
|
||||
model="default",
|
||||
),
|
||||
tts=DograhTTSService(
|
||||
provider="dograh",
|
||||
api_key=["mps-secret"],
|
||||
model="default",
|
||||
voice="default",
|
||||
speed=1.0,
|
||||
),
|
||||
stt=DograhSTTService(
|
||||
provider="dograh",
|
||||
api_key=["mps-secret"],
|
||||
model="default",
|
||||
language="multi",
|
||||
),
|
||||
)
|
||||
|
||||
config = convert_legacy_ai_model_configuration_to_v2(legacy)
|
||||
|
||||
assert config.mode == "dograh"
|
||||
assert config.dograh.api_key == "mps-secret"
|
||||
|
||||
|
||||
def test_legacy_mixed_dograh_pipeline_converts_to_dograh_v2():
|
||||
legacy = UserConfiguration(
|
||||
llm=OpenAILLMService(
|
||||
provider="openai",
|
||||
api_key="sk-llm",
|
||||
model="gpt-4.1",
|
||||
),
|
||||
tts=DograhTTSService(
|
||||
provider="dograh",
|
||||
api_key="mps-tts",
|
||||
model="default",
|
||||
voice="default",
|
||||
),
|
||||
stt=DograhSTTService(
|
||||
provider="dograh",
|
||||
api_key="mps-stt",
|
||||
model="default",
|
||||
),
|
||||
embeddings=OpenAIEmbeddingsConfiguration(
|
||||
provider="openai",
|
||||
api_key="sk-emb",
|
||||
model="text-embedding-3-small",
|
||||
),
|
||||
)
|
||||
|
||||
config = convert_legacy_ai_model_configuration_to_v2(legacy)
|
||||
|
||||
assert config.mode == "dograh"
|
||||
assert config.dograh.api_key == "mps-tts"
|
||||
assert config.dograh.voice == "default"
|
||||
|
||||
|
||||
def test_legacy_byok_pipeline_converts_to_byok_v2():
|
||||
legacy = UserConfiguration(
|
||||
llm=OpenAILLMService(
|
||||
provider="openai",
|
||||
api_key="sk-llm",
|
||||
model="gpt-4.1",
|
||||
),
|
||||
tts=ElevenlabsTTSConfiguration(
|
||||
provider="elevenlabs",
|
||||
api_key="el-tts",
|
||||
model="eleven_flash_v2_5",
|
||||
voice="Rachel",
|
||||
),
|
||||
stt=DeepgramSTTConfiguration(
|
||||
provider="deepgram",
|
||||
api_key="dg-stt",
|
||||
model="nova-3-general",
|
||||
),
|
||||
embeddings=OpenAIEmbeddingsConfiguration(
|
||||
provider="openai",
|
||||
api_key="sk-emb",
|
||||
model="text-embedding-3-small",
|
||||
),
|
||||
)
|
||||
|
||||
config = convert_legacy_ai_model_configuration_to_v2(legacy)
|
||||
|
||||
assert config.mode == "byok"
|
||||
assert config.byok.mode == "pipeline"
|
||||
assert config.byok.pipeline.llm.provider == "openai"
|
||||
assert config.byok.pipeline.tts.provider == "elevenlabs"
|
||||
|
||||
|
||||
def test_workflow_model_override_migration_removes_v1_override_and_sets_v2():
|
||||
base = UserConfiguration(
|
||||
llm=OpenAILLMService(
|
||||
provider="openai",
|
||||
api_key="sk-llm",
|
||||
model="gpt-4.1",
|
||||
),
|
||||
tts=ElevenlabsTTSConfiguration(
|
||||
provider="elevenlabs",
|
||||
api_key="el-tts",
|
||||
model="eleven_flash_v2_5",
|
||||
voice="Rachel",
|
||||
),
|
||||
stt=DeepgramSTTConfiguration(
|
||||
provider="deepgram",
|
||||
api_key="dg-stt",
|
||||
model="nova-3-general",
|
||||
),
|
||||
)
|
||||
workflow_configurations = {
|
||||
"ambient_noise_configuration": {"enabled": False},
|
||||
"model_overrides": {
|
||||
"tts": {
|
||||
"provider": "dograh",
|
||||
"api_key": "mps-workflow",
|
||||
"model": "default",
|
||||
"voice": "default",
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
migrated, changed = migrate_workflow_configuration_model_override_to_v2(
|
||||
workflow_configurations,
|
||||
base,
|
||||
)
|
||||
|
||||
assert changed is True
|
||||
assert "model_overrides" not in migrated
|
||||
assert migrated["ambient_noise_configuration"] == {"enabled": False}
|
||||
v2_override = migrated[WORKFLOW_MODEL_CONFIGURATION_V2_OVERRIDE_KEY]
|
||||
assert v2_override["mode"] == "dograh"
|
||||
assert v2_override["dograh"]["api_key"] == "mps-workflow"
|
||||
|
||||
|
||||
def test_workflow_model_override_migration_removes_invalid_v1_override_marker():
|
||||
base = UserConfiguration()
|
||||
workflow_configurations = {
|
||||
"ambient_noise_configuration": {"enabled": False},
|
||||
"model_overrides": None,
|
||||
}
|
||||
|
||||
migrated, changed = migrate_workflow_configuration_model_override_to_v2(
|
||||
workflow_configurations,
|
||||
base,
|
||||
)
|
||||
|
||||
assert changed is True
|
||||
assert "model_overrides" not in migrated
|
||||
assert migrated["ambient_noise_configuration"] == {"enabled": False}
|
||||
|
|
@ -1,3 +1,4 @@
|
|||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from fastapi import FastAPI
|
||||
|
|
@ -14,14 +15,14 @@ from api.services.configuration.registry import (
|
|||
)
|
||||
|
||||
|
||||
def _make_test_app():
|
||||
def _make_test_app(selected_organization_id=None):
|
||||
app = FastAPI()
|
||||
app.include_router(router)
|
||||
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = 1
|
||||
mock_user.is_superuser = False
|
||||
mock_user.selected_organization_id = None
|
||||
mock_user.selected_organization_id = selected_organization_id
|
||||
|
||||
app.dependency_overrides[get_user] = lambda: mock_user
|
||||
return app
|
||||
|
|
@ -210,3 +211,38 @@ class TestMaskedKeyRejection:
|
|||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
def test_preference_only_update_does_not_validate_or_save_model_config(self):
|
||||
"""Saving a test phone number through the legacy endpoint must not touch models."""
|
||||
app = _make_test_app(selected_organization_id=11)
|
||||
client = TestClient(app)
|
||||
preferences = SimpleNamespace(test_phone_number=None, timezone=None)
|
||||
|
||||
with (
|
||||
patch("api.routes.user.db_client") as mock_db,
|
||||
patch("api.routes.user.UserConfigurationValidator") as mock_validator,
|
||||
patch(
|
||||
"api.routes.user.get_organization_preferences",
|
||||
new=AsyncMock(return_value=preferences),
|
||||
),
|
||||
patch(
|
||||
"api.routes.user.upsert_organization_preferences",
|
||||
new=AsyncMock(return_value=preferences),
|
||||
) as upsert_preferences,
|
||||
):
|
||||
existing = _existing_openai_config()
|
||||
mock_db.get_user_configurations = AsyncMock(return_value=existing)
|
||||
mock_db.update_user_configuration = AsyncMock()
|
||||
mock_db.get_organization_by_id = AsyncMock(return_value=None)
|
||||
mock_validator.return_value.validate = AsyncMock()
|
||||
|
||||
response = client.put(
|
||||
"/user/configurations/user",
|
||||
json={"test_phone_number": "+15551234567"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["test_phone_number"] == "+15551234567"
|
||||
mock_db.update_user_configuration.assert_not_called()
|
||||
mock_validator.return_value.validate.assert_not_called()
|
||||
upsert_preferences.assert_awaited_once()
|
||||
|
|
|
|||
|
|
@ -103,6 +103,61 @@ def test_initiate_call_executes_as_workflow_owner_for_shared_org_workflow():
|
|||
assert initiate_kwargs["workflow_id"] == workflow.id
|
||||
assert initiate_kwargs["user_id"] == workflow.user_id
|
||||
assert "user_id=99" in initiate_kwargs["webhook_url"]
|
||||
mock_db.get_user_configurations.assert_not_called()
|
||||
|
||||
|
||||
def test_initiate_call_uses_organization_preference_phone_number():
|
||||
app = _make_test_app()
|
||||
client = TestClient(app)
|
||||
|
||||
workflow = _workflow()
|
||||
provider = _provider()
|
||||
quota_mock = AsyncMock(
|
||||
return_value=SimpleNamespace(has_quota=True, error_message="")
|
||||
)
|
||||
|
||||
with (
|
||||
patch("api.routes.telephony.db_client") as mock_db,
|
||||
patch(
|
||||
"api.routes.telephony.check_dograh_quota_by_user_id",
|
||||
new=quota_mock,
|
||||
),
|
||||
patch(
|
||||
"api.routes.telephony.get_default_telephony_provider",
|
||||
new=AsyncMock(return_value=provider),
|
||||
),
|
||||
patch(
|
||||
"api.routes.telephony.get_backend_endpoints",
|
||||
new=AsyncMock(return_value=("https://api.example.com", "wss://ignored")),
|
||||
),
|
||||
):
|
||||
mock_db.get_user_configurations = AsyncMock(
|
||||
return_value=SimpleNamespace(test_phone_number="+15550000000")
|
||||
)
|
||||
mock_db.get_configuration = Mock(
|
||||
return_value=SimpleNamespace(value={"test_phone_number": "+15557654321"})
|
||||
)
|
||||
mock_db.get_default_telephony_configuration = AsyncMock(
|
||||
return_value=SimpleNamespace(id=55)
|
||||
)
|
||||
mock_db.get_workflow = AsyncMock(return_value=workflow)
|
||||
mock_db.create_workflow_run = AsyncMock(
|
||||
return_value=SimpleNamespace(
|
||||
id=501,
|
||||
name="WR-TEL-OUT-00000001",
|
||||
initial_context={},
|
||||
)
|
||||
)
|
||||
mock_db.update_workflow_run = AsyncMock()
|
||||
|
||||
response = client.post(
|
||||
"/telephony/initiate-call",
|
||||
json={"workflow_id": workflow.id},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert provider.initiate_call.await_args.kwargs["to_number"] == "+15557654321"
|
||||
mock_db.get_user_configurations.assert_not_called()
|
||||
|
||||
|
||||
def test_initiate_call_rejects_existing_run_for_different_workflow():
|
||||
|
|
|
|||
|
|
@ -51,6 +51,38 @@ async def _create_user_and_workflow(
|
|||
return user, workflow
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_text_chat_session_creation_requires_selected_organization():
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
|
||||
from api.app import app
|
||||
from api.services.auth.depends import get_user
|
||||
|
||||
user = UserModel(provider_id="textchat-user-no-selected-org")
|
||||
|
||||
async def mock_get_user():
|
||||
return user
|
||||
|
||||
original_override = app.dependency_overrides.get(get_user)
|
||||
app.dependency_overrides[get_user] = mock_get_user
|
||||
|
||||
try:
|
||||
async with AsyncClient(
|
||||
transport=ASGITransport(app=app), base_url="http://test"
|
||||
) as client:
|
||||
response = await client.post(
|
||||
"/api/v1/workflow/123/text-chat/sessions", json={}
|
||||
)
|
||||
finally:
|
||||
if original_override:
|
||||
app.dependency_overrides[get_user] = original_override
|
||||
else:
|
||||
app.dependency_overrides.pop(get_user, None)
|
||||
|
||||
assert response.status_code == 400
|
||||
assert response.json() == {"detail": "No organization selected"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_text_chat_session_creation_executes_initial_assistant_turn(
|
||||
db_session,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue