feat: centralise workflow run authorization

This commit is contained in:
Abhishek Kumar 2026-06-12 18:16:30 +05:30
parent 5bf7518829
commit 281656b960
21 changed files with 1036 additions and 252 deletions

View file

@ -33,7 +33,9 @@ def test_create_cartesia_tts_service_passes_selected_model():
transport_in_sample_rate=16000,
)
with patch("api.services.pipecat.service_factory.CartesiaTTSService") as mock_service:
with patch(
"api.services.pipecat.service_factory.CartesiaTTSService"
) as mock_service:
create_tts_service(user_config, audio_config)
assert mock_service.call_count == 1

View file

@ -270,6 +270,12 @@ class TestDispatcherThreadsTelephonyConfig:
"api.services.campaign.campaign_call_dispatcher.get_backend_endpoints",
AsyncMock(return_value=("https://example.com", None)),
),
patch(
"api.services.campaign.campaign_call_dispatcher.authorize_workflow_run_start",
AsyncMock(
return_value=SimpleNamespace(has_quota=True, error_message="")
),
),
):
mock_db.get_workflow_by_id = AsyncMock(return_value=SimpleNamespace(id=1))
mock_db.create_workflow_run = AsyncMock(return_value=workflow_run)

View file

@ -175,6 +175,76 @@ async def test_get_billing_account_status_uses_hosted_org_auth(monkeypatch):
]
@pytest.mark.asyncio
async def test_authorize_workflow_run_start_uses_hosted_org_auth(monkeypatch):
calls = []
class FakeAsyncClient:
def __init__(self, timeout):
self.timeout = timeout
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc, tb):
return None
async def post(self, url, json, headers):
calls.append(("POST", url, json, headers))
return _Response(
200,
{
"allowed": True,
"billing_mode": "v2",
"remaining_credits": "25.0000",
"correlation_id": "mps-corr-123",
},
)
monkeypatch.setattr(
"api.services.mps_service_key_client.httpx.AsyncClient", FakeAsyncClient
)
monkeypatch.setattr("api.services.mps_service_key_client.DEPLOYMENT_MODE", "saas")
monkeypatch.setattr(
"api.services.mps_service_key_client.DOGRAH_MPS_SECRET_KEY", "mps-secret"
)
client = MPSServiceKeyClient()
assert await client.authorize_workflow_run_start(
organization_id=42,
workflow_run_id=88,
service_key="mps_sk_paid",
require_correlation_id=True,
minimum_credits=0.1,
metadata={"workflow_id": 7},
created_by="provider-123",
) == {
"allowed": True,
"billing_mode": "v2",
"remaining_credits": "25.0000",
"correlation_id": "mps-corr-123",
}
assert calls == [
(
"POST",
f"{client.base_url}/api/v1/billing/accounts/42/run-authorization",
{
"workflow_run_id": 88,
"service_key": "mps_sk_paid",
"require_correlation_id": True,
"minimum_credits": 0.1,
"metadata": {"workflow_id": 7},
},
{
"Content-Type": "application/json",
"X-Secret-Key": "mps-secret",
"X-Organization-Id": "42",
},
)
]
@pytest.mark.asyncio
async def test_ensure_billing_account_v2_uses_balance_endpoint(monkeypatch):
calls = []

View file

@ -57,7 +57,7 @@ def test_trigger_route_executes_as_workflow_owner():
with (
patch("api.routes.public_agent.db_client") as mock_db,
patch(
"api.routes.public_agent.check_dograh_quota_by_user_id",
"api.routes.public_agent.authorize_workflow_run_start",
new=quota_mock,
),
patch(
@ -92,7 +92,10 @@ def test_trigger_route_executes_as_workflow_owner():
)
assert response.status_code == 200
quota_mock.assert_awaited_once_with(workflow.user_id, workflow_id=workflow.id)
quota_mock.assert_awaited_once_with(
workflow_id=workflow.id,
workflow_run_id=501,
)
mock_db.get_workflow.assert_awaited_once_with(workflow.id, organization_id=11)
create_kwargs = mock_db.create_workflow_run.await_args.kwargs
@ -124,7 +127,7 @@ def test_workflow_uuid_route_uses_scoped_lookup_and_shared_execution():
with (
patch("api.routes.public_agent.db_client") as mock_db,
patch(
"api.routes.public_agent.check_dograh_quota_by_user_id",
"api.routes.public_agent.authorize_workflow_run_start",
new=quota_mock,
),
patch(

View file

@ -0,0 +1,369 @@
from types import SimpleNamespace
from unittest.mock import AsyncMock
import pytest
from api.services import quota_service
from api.services.configuration.registry import ServiceProviders
from api.services.managed_model_services import MPS_CORRELATION_ID_CONTEXT_KEY
def _dograh_config(
api_key: str = "mps_sk_12345678",
*,
managed_service_version: int = 1,
):
return SimpleNamespace(
managed_service_version=managed_service_version,
llm=SimpleNamespace(provider=ServiceProviders.DOGRAH, api_key=api_key),
stt=None,
tts=None,
embeddings=None,
)
def _byok_config():
return SimpleNamespace(
managed_service_version=2,
llm=SimpleNamespace(provider="openai", api_key="sk-openai"),
stt=None,
tts=None,
embeddings=None,
)
def _workflow():
return SimpleNamespace(
id=7,
user_id=123,
organization_id=42,
workflow_configurations={"model_overrides": {}},
)
def _workflow_owner():
return SimpleNamespace(
id=123,
provider_id="provider-123",
)
def _actor():
return SimpleNamespace(
id=456,
provider_id="actor-456",
selected_organization_id=42,
)
def _patch_workflow_context(monkeypatch, *, workflow=None, owner=None):
monkeypatch.setattr(
quota_service.db_client,
"get_workflow_by_id",
AsyncMock(return_value=workflow or _workflow()),
)
monkeypatch.setattr(
quota_service.db_client,
"get_user_by_id",
AsyncMock(return_value=owner or _workflow_owner()),
)
@pytest.mark.asyncio
async def test_authorize_workflow_run_uses_workflow_org_for_hosted_v2(
monkeypatch,
):
get_config = AsyncMock(return_value=_dograh_config())
authorize = AsyncMock(
return_value={
"allowed": True,
"billing_mode": "v2",
"remaining_credits": "25.0000",
}
)
check_usage = AsyncMock()
monkeypatch.setattr(quota_service, "DEPLOYMENT_MODE", "saas")
_patch_workflow_context(monkeypatch)
monkeypatch.setattr(
quota_service,
"get_effective_ai_model_configuration_for_workflow",
get_config,
)
monkeypatch.setattr(
quota_service.mps_service_key_client,
"authorize_workflow_run_start",
authorize,
)
monkeypatch.setattr(
quota_service.mps_service_key_client,
"check_service_key_usage",
check_usage,
)
result = await quota_service.authorize_workflow_run_start(workflow_id=7)
assert result.has_quota is True
get_config.assert_awaited_once_with(
user_id=123,
organization_id=42,
workflow_configurations={"model_overrides": {}},
)
authorize.assert_awaited_once_with(
organization_id=42,
workflow_run_id=None,
service_key=None,
require_correlation_id=False,
minimum_credits=quota_service.MINIMUM_DOGRAH_CREDITS_FOR_CALL,
created_by="provider-123",
metadata={"dograh_user_id": "123", "workflow_id": 7},
)
check_usage.assert_not_awaited()
@pytest.mark.asyncio
async def test_authorize_workflow_run_v2_insufficient_credits_prompts_billing(
monkeypatch,
):
get_config = AsyncMock(return_value=_byok_config())
authorize = AsyncMock(
return_value={
"allowed": False,
"billing_mode": "v2",
"remaining_credits": "0.0000",
"error": "insufficient_credits",
}
)
check_usage = AsyncMock()
monkeypatch.setattr(quota_service, "DEPLOYMENT_MODE", "saas")
_patch_workflow_context(monkeypatch)
monkeypatch.setattr(
quota_service,
"get_effective_ai_model_configuration_for_workflow",
get_config,
)
monkeypatch.setattr(
quota_service.mps_service_key_client,
"authorize_workflow_run_start",
authorize,
)
monkeypatch.setattr(
quota_service.mps_service_key_client,
"check_service_key_usage",
check_usage,
)
result = await quota_service.authorize_workflow_run_start(workflow_id=7)
assert result.has_quota is False
assert result.error_code == "insufficient_credits"
assert "/billing" in result.error_message
assert "founders@dograh.com" not in result.error_message
authorize.assert_awaited_once()
check_usage.assert_not_awaited()
@pytest.mark.asyncio
async def test_authorize_workflow_run_v1_uses_legacy_key_usage(
monkeypatch,
):
api_key = "mps_sk_12345678"
get_config = AsyncMock(return_value=_dograh_config(api_key))
authorize = AsyncMock(
return_value={
"allowed": True,
"billing_mode": "v1",
"remaining_credits": "0.0000",
}
)
check_usage = AsyncMock(
return_value={"total_credits_used": 500.0, "remaining_credits": 0.0}
)
monkeypatch.setattr(quota_service, "DEPLOYMENT_MODE", "saas")
_patch_workflow_context(monkeypatch)
monkeypatch.setattr(
quota_service,
"get_effective_ai_model_configuration_for_workflow",
get_config,
)
monkeypatch.setattr(
quota_service.mps_service_key_client,
"authorize_workflow_run_start",
authorize,
)
monkeypatch.setattr(
quota_service.mps_service_key_client,
"check_service_key_usage",
check_usage,
)
result = await quota_service.authorize_workflow_run_start(workflow_id=7)
assert result.has_quota is False
assert result.error_code == "quota_exceeded"
assert "founders@dograh.com" in result.error_message
assert "/billing" not in result.error_message
authorize.assert_awaited_once()
check_usage.assert_awaited_once_with(
api_key,
organization_id=42,
created_by="provider-123",
)
@pytest.mark.asyncio
async def test_authorize_workflow_run_managed_v2_stores_hosted_correlation(
monkeypatch,
):
api_key = "mps_sk_12345678"
workflow_run = SimpleNamespace(initial_context={"existing": "value"})
get_config = AsyncMock(
return_value=_dograh_config(api_key, managed_service_version=2)
)
authorize = AsyncMock(
return_value={
"allowed": True,
"billing_mode": "v2",
"remaining_credits": "25.0000",
"correlation_id": "mps-corr-123",
}
)
update_workflow_run = AsyncMock()
monkeypatch.setattr(quota_service, "DEPLOYMENT_MODE", "saas")
_patch_workflow_context(monkeypatch)
monkeypatch.setattr(
quota_service.db_client,
"get_workflow_run_by_id",
AsyncMock(return_value=workflow_run),
)
monkeypatch.setattr(
quota_service.db_client,
"update_workflow_run",
update_workflow_run,
)
monkeypatch.setattr(
quota_service,
"get_effective_ai_model_configuration_for_workflow",
get_config,
)
monkeypatch.setattr(
quota_service.mps_service_key_client,
"authorize_workflow_run_start",
authorize,
)
monkeypatch.setattr(
quota_service.mps_service_key_client,
"check_service_key_usage",
AsyncMock(),
)
result = await quota_service.authorize_workflow_run_start(
workflow_id=7,
workflow_run_id=88,
)
assert result.has_quota is True
authorize.assert_awaited_once_with(
organization_id=42,
workflow_run_id=88,
service_key=api_key,
require_correlation_id=True,
minimum_credits=quota_service.MINIMUM_DOGRAH_CREDITS_FOR_CALL,
created_by="provider-123",
metadata={"dograh_user_id": "123", "workflow_id": 7},
)
update_workflow_run.assert_awaited_once_with(
88,
initial_context={
"existing": "value",
MPS_CORRELATION_ID_CONTEXT_KEY: "mps-corr-123",
},
)
@pytest.mark.asyncio
async def test_authorize_workflow_run_oss_uses_key_paths_not_workflow_org(
monkeypatch,
):
api_key = "mps_sk_12345678"
workflow_run = SimpleNamespace(initial_context={})
get_config = AsyncMock(
return_value=_dograh_config(api_key, managed_service_version=2)
)
hosted_authorize = AsyncMock()
check_usage = AsyncMock(
return_value={"total_credits_used": 1.0, "remaining_credits": 499.0}
)
create_correlation = AsyncMock(return_value={"correlation_id": "oss-corr-123"})
update_workflow_run = AsyncMock()
monkeypatch.setattr(quota_service, "DEPLOYMENT_MODE", "oss")
_patch_workflow_context(monkeypatch)
monkeypatch.setattr(
quota_service.db_client,
"get_workflow_run_by_id",
AsyncMock(return_value=workflow_run),
)
monkeypatch.setattr(
quota_service.db_client,
"update_workflow_run",
update_workflow_run,
)
monkeypatch.setattr(
quota_service,
"get_effective_ai_model_configuration_for_workflow",
get_config,
)
monkeypatch.setattr(
quota_service.mps_service_key_client,
"authorize_workflow_run_start",
hosted_authorize,
)
monkeypatch.setattr(
quota_service.mps_service_key_client,
"check_service_key_usage",
check_usage,
)
monkeypatch.setattr(
quota_service.mps_service_key_client,
"create_correlation_id",
create_correlation,
)
result = await quota_service.authorize_workflow_run_start(
workflow_id=7,
workflow_run_id=88,
)
assert result.has_quota is True
hosted_authorize.assert_not_awaited()
check_usage.assert_awaited_once_with(
api_key,
organization_id=None,
created_by="provider-123",
)
create_correlation.assert_awaited_once_with(
service_key=api_key,
workflow_run_id=88,
)
update_workflow_run.assert_awaited_once_with(
88,
initial_context={MPS_CORRELATION_ID_CONTEXT_KEY: "oss-corr-123"},
)
@pytest.mark.asyncio
async def test_authorize_workflow_run_rejects_actor_from_another_org(monkeypatch):
monkeypatch.setattr(quota_service, "DEPLOYMENT_MODE", "saas")
_patch_workflow_context(monkeypatch)
result = await quota_service.authorize_workflow_run_start(
workflow_id=7,
actor_user=SimpleNamespace(selected_organization_id=999),
)
assert result.has_quota is False
assert result.error_code == "workflow_not_found"

View file

@ -1,5 +1,5 @@
from types import SimpleNamespace
from unittest.mock import AsyncMock, Mock, patch
from unittest.mock import ANY, AsyncMock, Mock, patch
from fastapi import FastAPI
from fastapi.testclient import TestClient
@ -54,7 +54,7 @@ def test_initiate_call_executes_as_workflow_owner_for_shared_org_workflow():
with (
patch("api.routes.telephony.db_client") as mock_db,
patch(
"api.routes.telephony.check_dograh_quota_by_user_id",
"api.routes.telephony.authorize_workflow_run_start",
new=quota_mock,
),
patch(
@ -88,7 +88,11 @@ def test_initiate_call_executes_as_workflow_owner_for_shared_org_workflow():
)
assert response.status_code == 200
quota_mock.assert_awaited_once_with(workflow.user_id, workflow_id=workflow.id)
quota_mock.assert_awaited_once_with(
workflow_id=workflow.id,
workflow_run_id=501,
actor_user=ANY,
)
mock_db.get_workflow.assert_awaited_once_with(workflow.id, organization_id=11)
create_call = mock_db.create_workflow_run.await_args
@ -119,7 +123,7 @@ def test_initiate_call_uses_organization_preference_phone_number():
with (
patch("api.routes.telephony.db_client") as mock_db,
patch(
"api.routes.telephony.check_dograh_quota_by_user_id",
"api.routes.telephony.authorize_workflow_run_start",
new=quota_mock,
),
patch(
@ -173,7 +177,7 @@ def test_initiate_call_rejects_existing_run_for_different_workflow():
with (
patch("api.routes.telephony.db_client") as mock_db,
patch(
"api.routes.telephony.check_dograh_quota_by_user_id",
"api.routes.telephony.authorize_workflow_run_start",
new=quota_mock,
),
patch(

View file

@ -1105,7 +1105,7 @@ async def test_text_chat_session_creation_rejects_quota_before_creating_run(
async with test_client_factory(user) as client:
with patch(
"api.routes.workflow_text_chat.check_dograh_quota",
"api.routes.workflow_text_chat.authorize_workflow_run_start",
new=AsyncMock(
return_value=SimpleNamespace(
has_quota=False,
@ -1120,11 +1120,16 @@ async def test_text_chat_session_creation_rejects_quota_before_creating_run(
assert create_response.status_code == 402
assert create_response.json()["detail"] == "Quota exceeded"
_, total_count = await db_session.get_workflow_runs_by_workflow_id(
runs, total_count = await db_session.get_workflow_runs_by_workflow_id(
workflow.id,
organization_id=workflow.organization_id,
)
assert total_count == 0
assert total_count == 1
text_session = await db_session.get_workflow_run_text_session(
runs[0].id,
organization_id=workflow.organization_id,
)
assert text_session is None
@pytest.mark.asyncio
@ -1168,7 +1173,7 @@ async def test_text_chat_append_rejects_quota_without_mutating_session(
async with test_client_factory(user) as client:
with (
patch(
"api.routes.workflow_text_chat.check_dograh_quota",
"api.routes.workflow_text_chat.authorize_workflow_run_start",
new=AsyncMock(
side_effect=[
SimpleNamespace(has_quota=True, error_message=""),