mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-25 08:48:13 +02:00
feat: centralise workflow run authorization
This commit is contained in:
parent
5bf7518829
commit
281656b960
21 changed files with 1036 additions and 252 deletions
|
|
@ -33,7 +33,9 @@ def test_create_cartesia_tts_service_passes_selected_model():
|
|||
transport_in_sample_rate=16000,
|
||||
)
|
||||
|
||||
with patch("api.services.pipecat.service_factory.CartesiaTTSService") as mock_service:
|
||||
with patch(
|
||||
"api.services.pipecat.service_factory.CartesiaTTSService"
|
||||
) as mock_service:
|
||||
create_tts_service(user_config, audio_config)
|
||||
|
||||
assert mock_service.call_count == 1
|
||||
|
|
|
|||
|
|
@ -270,6 +270,12 @@ class TestDispatcherThreadsTelephonyConfig:
|
|||
"api.services.campaign.campaign_call_dispatcher.get_backend_endpoints",
|
||||
AsyncMock(return_value=("https://example.com", None)),
|
||||
),
|
||||
patch(
|
||||
"api.services.campaign.campaign_call_dispatcher.authorize_workflow_run_start",
|
||||
AsyncMock(
|
||||
return_value=SimpleNamespace(has_quota=True, error_message="")
|
||||
),
|
||||
),
|
||||
):
|
||||
mock_db.get_workflow_by_id = AsyncMock(return_value=SimpleNamespace(id=1))
|
||||
mock_db.create_workflow_run = AsyncMock(return_value=workflow_run)
|
||||
|
|
|
|||
|
|
@ -175,6 +175,76 @@ async def test_get_billing_account_status_uses_hosted_org_auth(monkeypatch):
|
|||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_authorize_workflow_run_start_uses_hosted_org_auth(monkeypatch):
|
||||
calls = []
|
||||
|
||||
class FakeAsyncClient:
|
||||
def __init__(self, timeout):
|
||||
self.timeout = timeout
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
return None
|
||||
|
||||
async def post(self, url, json, headers):
|
||||
calls.append(("POST", url, json, headers))
|
||||
return _Response(
|
||||
200,
|
||||
{
|
||||
"allowed": True,
|
||||
"billing_mode": "v2",
|
||||
"remaining_credits": "25.0000",
|
||||
"correlation_id": "mps-corr-123",
|
||||
},
|
||||
)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"api.services.mps_service_key_client.httpx.AsyncClient", FakeAsyncClient
|
||||
)
|
||||
monkeypatch.setattr("api.services.mps_service_key_client.DEPLOYMENT_MODE", "saas")
|
||||
monkeypatch.setattr(
|
||||
"api.services.mps_service_key_client.DOGRAH_MPS_SECRET_KEY", "mps-secret"
|
||||
)
|
||||
|
||||
client = MPSServiceKeyClient()
|
||||
|
||||
assert await client.authorize_workflow_run_start(
|
||||
organization_id=42,
|
||||
workflow_run_id=88,
|
||||
service_key="mps_sk_paid",
|
||||
require_correlation_id=True,
|
||||
minimum_credits=0.1,
|
||||
metadata={"workflow_id": 7},
|
||||
created_by="provider-123",
|
||||
) == {
|
||||
"allowed": True,
|
||||
"billing_mode": "v2",
|
||||
"remaining_credits": "25.0000",
|
||||
"correlation_id": "mps-corr-123",
|
||||
}
|
||||
assert calls == [
|
||||
(
|
||||
"POST",
|
||||
f"{client.base_url}/api/v1/billing/accounts/42/run-authorization",
|
||||
{
|
||||
"workflow_run_id": 88,
|
||||
"service_key": "mps_sk_paid",
|
||||
"require_correlation_id": True,
|
||||
"minimum_credits": 0.1,
|
||||
"metadata": {"workflow_id": 7},
|
||||
},
|
||||
{
|
||||
"Content-Type": "application/json",
|
||||
"X-Secret-Key": "mps-secret",
|
||||
"X-Organization-Id": "42",
|
||||
},
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ensure_billing_account_v2_uses_balance_endpoint(monkeypatch):
|
||||
calls = []
|
||||
|
|
|
|||
|
|
@ -57,7 +57,7 @@ def test_trigger_route_executes_as_workflow_owner():
|
|||
with (
|
||||
patch("api.routes.public_agent.db_client") as mock_db,
|
||||
patch(
|
||||
"api.routes.public_agent.check_dograh_quota_by_user_id",
|
||||
"api.routes.public_agent.authorize_workflow_run_start",
|
||||
new=quota_mock,
|
||||
),
|
||||
patch(
|
||||
|
|
@ -92,7 +92,10 @@ def test_trigger_route_executes_as_workflow_owner():
|
|||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
quota_mock.assert_awaited_once_with(workflow.user_id, workflow_id=workflow.id)
|
||||
quota_mock.assert_awaited_once_with(
|
||||
workflow_id=workflow.id,
|
||||
workflow_run_id=501,
|
||||
)
|
||||
mock_db.get_workflow.assert_awaited_once_with(workflow.id, organization_id=11)
|
||||
|
||||
create_kwargs = mock_db.create_workflow_run.await_args.kwargs
|
||||
|
|
@ -124,7 +127,7 @@ def test_workflow_uuid_route_uses_scoped_lookup_and_shared_execution():
|
|||
with (
|
||||
patch("api.routes.public_agent.db_client") as mock_db,
|
||||
patch(
|
||||
"api.routes.public_agent.check_dograh_quota_by_user_id",
|
||||
"api.routes.public_agent.authorize_workflow_run_start",
|
||||
new=quota_mock,
|
||||
),
|
||||
patch(
|
||||
|
|
|
|||
369
api/tests/test_quota_service.py
Normal file
369
api/tests/test_quota_service.py
Normal file
|
|
@ -0,0 +1,369 @@
|
|||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
|
||||
from api.services import quota_service
|
||||
from api.services.configuration.registry import ServiceProviders
|
||||
from api.services.managed_model_services import MPS_CORRELATION_ID_CONTEXT_KEY
|
||||
|
||||
|
||||
def _dograh_config(
|
||||
api_key: str = "mps_sk_12345678",
|
||||
*,
|
||||
managed_service_version: int = 1,
|
||||
):
|
||||
return SimpleNamespace(
|
||||
managed_service_version=managed_service_version,
|
||||
llm=SimpleNamespace(provider=ServiceProviders.DOGRAH, api_key=api_key),
|
||||
stt=None,
|
||||
tts=None,
|
||||
embeddings=None,
|
||||
)
|
||||
|
||||
|
||||
def _byok_config():
|
||||
return SimpleNamespace(
|
||||
managed_service_version=2,
|
||||
llm=SimpleNamespace(provider="openai", api_key="sk-openai"),
|
||||
stt=None,
|
||||
tts=None,
|
||||
embeddings=None,
|
||||
)
|
||||
|
||||
|
||||
def _workflow():
|
||||
return SimpleNamespace(
|
||||
id=7,
|
||||
user_id=123,
|
||||
organization_id=42,
|
||||
workflow_configurations={"model_overrides": {}},
|
||||
)
|
||||
|
||||
|
||||
def _workflow_owner():
|
||||
return SimpleNamespace(
|
||||
id=123,
|
||||
provider_id="provider-123",
|
||||
)
|
||||
|
||||
|
||||
def _actor():
|
||||
return SimpleNamespace(
|
||||
id=456,
|
||||
provider_id="actor-456",
|
||||
selected_organization_id=42,
|
||||
)
|
||||
|
||||
|
||||
def _patch_workflow_context(monkeypatch, *, workflow=None, owner=None):
|
||||
monkeypatch.setattr(
|
||||
quota_service.db_client,
|
||||
"get_workflow_by_id",
|
||||
AsyncMock(return_value=workflow or _workflow()),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
quota_service.db_client,
|
||||
"get_user_by_id",
|
||||
AsyncMock(return_value=owner or _workflow_owner()),
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_authorize_workflow_run_uses_workflow_org_for_hosted_v2(
|
||||
monkeypatch,
|
||||
):
|
||||
get_config = AsyncMock(return_value=_dograh_config())
|
||||
authorize = AsyncMock(
|
||||
return_value={
|
||||
"allowed": True,
|
||||
"billing_mode": "v2",
|
||||
"remaining_credits": "25.0000",
|
||||
}
|
||||
)
|
||||
check_usage = AsyncMock()
|
||||
|
||||
monkeypatch.setattr(quota_service, "DEPLOYMENT_MODE", "saas")
|
||||
_patch_workflow_context(monkeypatch)
|
||||
monkeypatch.setattr(
|
||||
quota_service,
|
||||
"get_effective_ai_model_configuration_for_workflow",
|
||||
get_config,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
quota_service.mps_service_key_client,
|
||||
"authorize_workflow_run_start",
|
||||
authorize,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
quota_service.mps_service_key_client,
|
||||
"check_service_key_usage",
|
||||
check_usage,
|
||||
)
|
||||
|
||||
result = await quota_service.authorize_workflow_run_start(workflow_id=7)
|
||||
|
||||
assert result.has_quota is True
|
||||
get_config.assert_awaited_once_with(
|
||||
user_id=123,
|
||||
organization_id=42,
|
||||
workflow_configurations={"model_overrides": {}},
|
||||
)
|
||||
authorize.assert_awaited_once_with(
|
||||
organization_id=42,
|
||||
workflow_run_id=None,
|
||||
service_key=None,
|
||||
require_correlation_id=False,
|
||||
minimum_credits=quota_service.MINIMUM_DOGRAH_CREDITS_FOR_CALL,
|
||||
created_by="provider-123",
|
||||
metadata={"dograh_user_id": "123", "workflow_id": 7},
|
||||
)
|
||||
check_usage.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_authorize_workflow_run_v2_insufficient_credits_prompts_billing(
|
||||
monkeypatch,
|
||||
):
|
||||
get_config = AsyncMock(return_value=_byok_config())
|
||||
authorize = AsyncMock(
|
||||
return_value={
|
||||
"allowed": False,
|
||||
"billing_mode": "v2",
|
||||
"remaining_credits": "0.0000",
|
||||
"error": "insufficient_credits",
|
||||
}
|
||||
)
|
||||
check_usage = AsyncMock()
|
||||
|
||||
monkeypatch.setattr(quota_service, "DEPLOYMENT_MODE", "saas")
|
||||
_patch_workflow_context(monkeypatch)
|
||||
monkeypatch.setattr(
|
||||
quota_service,
|
||||
"get_effective_ai_model_configuration_for_workflow",
|
||||
get_config,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
quota_service.mps_service_key_client,
|
||||
"authorize_workflow_run_start",
|
||||
authorize,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
quota_service.mps_service_key_client,
|
||||
"check_service_key_usage",
|
||||
check_usage,
|
||||
)
|
||||
|
||||
result = await quota_service.authorize_workflow_run_start(workflow_id=7)
|
||||
|
||||
assert result.has_quota is False
|
||||
assert result.error_code == "insufficient_credits"
|
||||
assert "/billing" in result.error_message
|
||||
assert "founders@dograh.com" not in result.error_message
|
||||
authorize.assert_awaited_once()
|
||||
check_usage.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_authorize_workflow_run_v1_uses_legacy_key_usage(
|
||||
monkeypatch,
|
||||
):
|
||||
api_key = "mps_sk_12345678"
|
||||
get_config = AsyncMock(return_value=_dograh_config(api_key))
|
||||
authorize = AsyncMock(
|
||||
return_value={
|
||||
"allowed": True,
|
||||
"billing_mode": "v1",
|
||||
"remaining_credits": "0.0000",
|
||||
}
|
||||
)
|
||||
check_usage = AsyncMock(
|
||||
return_value={"total_credits_used": 500.0, "remaining_credits": 0.0}
|
||||
)
|
||||
|
||||
monkeypatch.setattr(quota_service, "DEPLOYMENT_MODE", "saas")
|
||||
_patch_workflow_context(monkeypatch)
|
||||
monkeypatch.setattr(
|
||||
quota_service,
|
||||
"get_effective_ai_model_configuration_for_workflow",
|
||||
get_config,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
quota_service.mps_service_key_client,
|
||||
"authorize_workflow_run_start",
|
||||
authorize,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
quota_service.mps_service_key_client,
|
||||
"check_service_key_usage",
|
||||
check_usage,
|
||||
)
|
||||
|
||||
result = await quota_service.authorize_workflow_run_start(workflow_id=7)
|
||||
|
||||
assert result.has_quota is False
|
||||
assert result.error_code == "quota_exceeded"
|
||||
assert "founders@dograh.com" in result.error_message
|
||||
assert "/billing" not in result.error_message
|
||||
authorize.assert_awaited_once()
|
||||
check_usage.assert_awaited_once_with(
|
||||
api_key,
|
||||
organization_id=42,
|
||||
created_by="provider-123",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_authorize_workflow_run_managed_v2_stores_hosted_correlation(
|
||||
monkeypatch,
|
||||
):
|
||||
api_key = "mps_sk_12345678"
|
||||
workflow_run = SimpleNamespace(initial_context={"existing": "value"})
|
||||
get_config = AsyncMock(
|
||||
return_value=_dograh_config(api_key, managed_service_version=2)
|
||||
)
|
||||
authorize = AsyncMock(
|
||||
return_value={
|
||||
"allowed": True,
|
||||
"billing_mode": "v2",
|
||||
"remaining_credits": "25.0000",
|
||||
"correlation_id": "mps-corr-123",
|
||||
}
|
||||
)
|
||||
update_workflow_run = AsyncMock()
|
||||
|
||||
monkeypatch.setattr(quota_service, "DEPLOYMENT_MODE", "saas")
|
||||
_patch_workflow_context(monkeypatch)
|
||||
monkeypatch.setattr(
|
||||
quota_service.db_client,
|
||||
"get_workflow_run_by_id",
|
||||
AsyncMock(return_value=workflow_run),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
quota_service.db_client,
|
||||
"update_workflow_run",
|
||||
update_workflow_run,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
quota_service,
|
||||
"get_effective_ai_model_configuration_for_workflow",
|
||||
get_config,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
quota_service.mps_service_key_client,
|
||||
"authorize_workflow_run_start",
|
||||
authorize,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
quota_service.mps_service_key_client,
|
||||
"check_service_key_usage",
|
||||
AsyncMock(),
|
||||
)
|
||||
|
||||
result = await quota_service.authorize_workflow_run_start(
|
||||
workflow_id=7,
|
||||
workflow_run_id=88,
|
||||
)
|
||||
|
||||
assert result.has_quota is True
|
||||
authorize.assert_awaited_once_with(
|
||||
organization_id=42,
|
||||
workflow_run_id=88,
|
||||
service_key=api_key,
|
||||
require_correlation_id=True,
|
||||
minimum_credits=quota_service.MINIMUM_DOGRAH_CREDITS_FOR_CALL,
|
||||
created_by="provider-123",
|
||||
metadata={"dograh_user_id": "123", "workflow_id": 7},
|
||||
)
|
||||
update_workflow_run.assert_awaited_once_with(
|
||||
88,
|
||||
initial_context={
|
||||
"existing": "value",
|
||||
MPS_CORRELATION_ID_CONTEXT_KEY: "mps-corr-123",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_authorize_workflow_run_oss_uses_key_paths_not_workflow_org(
|
||||
monkeypatch,
|
||||
):
|
||||
api_key = "mps_sk_12345678"
|
||||
workflow_run = SimpleNamespace(initial_context={})
|
||||
get_config = AsyncMock(
|
||||
return_value=_dograh_config(api_key, managed_service_version=2)
|
||||
)
|
||||
hosted_authorize = AsyncMock()
|
||||
check_usage = AsyncMock(
|
||||
return_value={"total_credits_used": 1.0, "remaining_credits": 499.0}
|
||||
)
|
||||
create_correlation = AsyncMock(return_value={"correlation_id": "oss-corr-123"})
|
||||
update_workflow_run = AsyncMock()
|
||||
|
||||
monkeypatch.setattr(quota_service, "DEPLOYMENT_MODE", "oss")
|
||||
_patch_workflow_context(monkeypatch)
|
||||
monkeypatch.setattr(
|
||||
quota_service.db_client,
|
||||
"get_workflow_run_by_id",
|
||||
AsyncMock(return_value=workflow_run),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
quota_service.db_client,
|
||||
"update_workflow_run",
|
||||
update_workflow_run,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
quota_service,
|
||||
"get_effective_ai_model_configuration_for_workflow",
|
||||
get_config,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
quota_service.mps_service_key_client,
|
||||
"authorize_workflow_run_start",
|
||||
hosted_authorize,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
quota_service.mps_service_key_client,
|
||||
"check_service_key_usage",
|
||||
check_usage,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
quota_service.mps_service_key_client,
|
||||
"create_correlation_id",
|
||||
create_correlation,
|
||||
)
|
||||
|
||||
result = await quota_service.authorize_workflow_run_start(
|
||||
workflow_id=7,
|
||||
workflow_run_id=88,
|
||||
)
|
||||
|
||||
assert result.has_quota is True
|
||||
hosted_authorize.assert_not_awaited()
|
||||
check_usage.assert_awaited_once_with(
|
||||
api_key,
|
||||
organization_id=None,
|
||||
created_by="provider-123",
|
||||
)
|
||||
create_correlation.assert_awaited_once_with(
|
||||
service_key=api_key,
|
||||
workflow_run_id=88,
|
||||
)
|
||||
update_workflow_run.assert_awaited_once_with(
|
||||
88,
|
||||
initial_context={MPS_CORRELATION_ID_CONTEXT_KEY: "oss-corr-123"},
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_authorize_workflow_run_rejects_actor_from_another_org(monkeypatch):
|
||||
monkeypatch.setattr(quota_service, "DEPLOYMENT_MODE", "saas")
|
||||
_patch_workflow_context(monkeypatch)
|
||||
|
||||
result = await quota_service.authorize_workflow_run_start(
|
||||
workflow_id=7,
|
||||
actor_user=SimpleNamespace(selected_organization_id=999),
|
||||
)
|
||||
|
||||
assert result.has_quota is False
|
||||
assert result.error_code == "workflow_not_found"
|
||||
|
|
@ -1,5 +1,5 @@
|
|||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
from unittest.mock import ANY, AsyncMock, Mock, patch
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
|
|
@ -54,7 +54,7 @@ def test_initiate_call_executes_as_workflow_owner_for_shared_org_workflow():
|
|||
with (
|
||||
patch("api.routes.telephony.db_client") as mock_db,
|
||||
patch(
|
||||
"api.routes.telephony.check_dograh_quota_by_user_id",
|
||||
"api.routes.telephony.authorize_workflow_run_start",
|
||||
new=quota_mock,
|
||||
),
|
||||
patch(
|
||||
|
|
@ -88,7 +88,11 @@ def test_initiate_call_executes_as_workflow_owner_for_shared_org_workflow():
|
|||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
quota_mock.assert_awaited_once_with(workflow.user_id, workflow_id=workflow.id)
|
||||
quota_mock.assert_awaited_once_with(
|
||||
workflow_id=workflow.id,
|
||||
workflow_run_id=501,
|
||||
actor_user=ANY,
|
||||
)
|
||||
mock_db.get_workflow.assert_awaited_once_with(workflow.id, organization_id=11)
|
||||
|
||||
create_call = mock_db.create_workflow_run.await_args
|
||||
|
|
@ -119,7 +123,7 @@ def test_initiate_call_uses_organization_preference_phone_number():
|
|||
with (
|
||||
patch("api.routes.telephony.db_client") as mock_db,
|
||||
patch(
|
||||
"api.routes.telephony.check_dograh_quota_by_user_id",
|
||||
"api.routes.telephony.authorize_workflow_run_start",
|
||||
new=quota_mock,
|
||||
),
|
||||
patch(
|
||||
|
|
@ -173,7 +177,7 @@ def test_initiate_call_rejects_existing_run_for_different_workflow():
|
|||
with (
|
||||
patch("api.routes.telephony.db_client") as mock_db,
|
||||
patch(
|
||||
"api.routes.telephony.check_dograh_quota_by_user_id",
|
||||
"api.routes.telephony.authorize_workflow_run_start",
|
||||
new=quota_mock,
|
||||
),
|
||||
patch(
|
||||
|
|
|
|||
|
|
@ -1105,7 +1105,7 @@ async def test_text_chat_session_creation_rejects_quota_before_creating_run(
|
|||
|
||||
async with test_client_factory(user) as client:
|
||||
with patch(
|
||||
"api.routes.workflow_text_chat.check_dograh_quota",
|
||||
"api.routes.workflow_text_chat.authorize_workflow_run_start",
|
||||
new=AsyncMock(
|
||||
return_value=SimpleNamespace(
|
||||
has_quota=False,
|
||||
|
|
@ -1120,11 +1120,16 @@ async def test_text_chat_session_creation_rejects_quota_before_creating_run(
|
|||
|
||||
assert create_response.status_code == 402
|
||||
assert create_response.json()["detail"] == "Quota exceeded"
|
||||
_, total_count = await db_session.get_workflow_runs_by_workflow_id(
|
||||
runs, total_count = await db_session.get_workflow_runs_by_workflow_id(
|
||||
workflow.id,
|
||||
organization_id=workflow.organization_id,
|
||||
)
|
||||
assert total_count == 0
|
||||
assert total_count == 1
|
||||
text_session = await db_session.get_workflow_run_text_session(
|
||||
runs[0].id,
|
||||
organization_id=workflow.organization_id,
|
||||
)
|
||||
assert text_session is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -1168,7 +1173,7 @@ async def test_text_chat_append_rejects_quota_without_mutating_session(
|
|||
async with test_client_factory(user) as client:
|
||||
with (
|
||||
patch(
|
||||
"api.routes.workflow_text_chat.check_dograh_quota",
|
||||
"api.routes.workflow_text_chat.authorize_workflow_run_start",
|
||||
new=AsyncMock(
|
||||
side_effect=[
|
||||
SimpleNamespace(has_quota=True, error_message=""),
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue