dograh/api/tests/test_workflow_run_billing.py
2026-06-12 13:26:33 +05:30

212 lines
6.4 KiB
Python

from types import SimpleNamespace
from unittest.mock import AsyncMock
import pytest
from api.services import workflow_run_billing as workflow_run_billing_mod
from api.services.workflow_run_billing import (
report_completed_workflow_run_platform_usage,
report_workflow_run_platform_usage,
)
def _make_workflow_run():
return SimpleNamespace(
id=123,
workflow_id=456,
is_completed=True,
initial_context={"mps_correlation_id": "mps-corr-123"},
usage_info={"call_duration_seconds": 87},
workflow=SimpleNamespace(
organization_id=42,
user=SimpleNamespace(selected_organization_id=42),
),
)
@pytest.mark.asyncio
async def test_report_workflow_run_platform_usage_reports_hosted_completion(
monkeypatch,
):
workflow_run = _make_workflow_run()
get_status = AsyncMock(return_value={"billing_mode": "v2"})
report_usage = AsyncMock(return_value={"metered": True})
monkeypatch.setattr(workflow_run_billing_mod, "DEPLOYMENT_MODE", "saas")
monkeypatch.setattr(
workflow_run_billing_mod.mps_service_key_client,
"get_billing_account_status",
get_status,
)
monkeypatch.setattr(
workflow_run_billing_mod.mps_service_key_client,
"report_platform_usage",
report_usage,
)
await report_workflow_run_platform_usage(workflow_run)
report_usage.assert_awaited_once_with(
organization_id=42,
correlation_id="mps-corr-123",
duration_seconds=None,
workflow_run_id=workflow_run.id,
metadata={
"source": "workflow_run_completion",
"workflow_id": workflow_run.workflow_id,
"duration_source": "mps_correlation",
},
)
@pytest.mark.asyncio
async def test_report_workflow_run_platform_usage_reports_duration_without_correlation(
monkeypatch,
):
workflow_run = _make_workflow_run()
workflow_run.initial_context = {}
get_status = AsyncMock(return_value={"billing_mode": "v2"})
report_usage = AsyncMock(return_value={"metered": True})
monkeypatch.setattr(workflow_run_billing_mod, "DEPLOYMENT_MODE", "saas")
monkeypatch.setattr(
workflow_run_billing_mod.mps_service_key_client,
"get_billing_account_status",
get_status,
)
monkeypatch.setattr(
workflow_run_billing_mod.mps_service_key_client,
"report_platform_usage",
report_usage,
)
await report_workflow_run_platform_usage(workflow_run)
report_usage.assert_awaited_once_with(
organization_id=42,
correlation_id=None,
duration_seconds=87.0,
workflow_run_id=workflow_run.id,
metadata={
"source": "workflow_run_completion",
"workflow_id": workflow_run.workflow_id,
"duration_source": "dograh_usage_info",
},
)
@pytest.mark.asyncio
async def test_report_workflow_run_platform_usage_skips_non_v2_account(monkeypatch):
workflow_run = _make_workflow_run()
get_status = AsyncMock(return_value={"billing_mode": "v1"})
report_usage = AsyncMock()
monkeypatch.setattr(workflow_run_billing_mod, "DEPLOYMENT_MODE", "saas")
monkeypatch.setattr(
workflow_run_billing_mod.mps_service_key_client,
"get_billing_account_status",
get_status,
)
monkeypatch.setattr(
workflow_run_billing_mod.mps_service_key_client,
"report_platform_usage",
report_usage,
)
await report_workflow_run_platform_usage(workflow_run)
get_status.assert_awaited_once_with(organization_id=42)
report_usage.assert_not_awaited()
@pytest.mark.asyncio
async def test_report_workflow_run_platform_usage_skips_missing_duration_without_correlation(
monkeypatch,
):
workflow_run = _make_workflow_run()
workflow_run.initial_context = {}
workflow_run.usage_info = {}
get_status = AsyncMock(return_value={"billing_mode": "v2"})
report_usage = AsyncMock()
monkeypatch.setattr(workflow_run_billing_mod, "DEPLOYMENT_MODE", "saas")
monkeypatch.setattr(
workflow_run_billing_mod.mps_service_key_client,
"get_billing_account_status",
get_status,
)
monkeypatch.setattr(
workflow_run_billing_mod.mps_service_key_client,
"report_platform_usage",
report_usage,
)
await report_workflow_run_platform_usage(workflow_run)
get_status.assert_not_awaited()
report_usage.assert_not_awaited()
@pytest.mark.asyncio
async def test_report_workflow_run_platform_usage_skips_oss(monkeypatch):
workflow_run = _make_workflow_run()
report_usage = AsyncMock()
monkeypatch.setattr(workflow_run_billing_mod, "DEPLOYMENT_MODE", "oss")
monkeypatch.setattr(
workflow_run_billing_mod.mps_service_key_client,
"report_platform_usage",
report_usage,
)
await report_workflow_run_platform_usage(workflow_run)
report_usage.assert_not_awaited()
@pytest.mark.asyncio
async def test_report_workflow_run_platform_usage_skips_incomplete(monkeypatch):
workflow_run = _make_workflow_run()
workflow_run.is_completed = False
report_usage = AsyncMock()
monkeypatch.setattr(workflow_run_billing_mod, "DEPLOYMENT_MODE", "saas")
monkeypatch.setattr(
workflow_run_billing_mod.mps_service_key_client,
"report_platform_usage",
report_usage,
)
await report_workflow_run_platform_usage(workflow_run)
report_usage.assert_not_awaited()
@pytest.mark.asyncio
async def test_report_completed_workflow_run_platform_usage_loads_run(monkeypatch):
workflow_run = _make_workflow_run()
get_run = AsyncMock(return_value=workflow_run)
get_status = AsyncMock(return_value={"billing_mode": "v2"})
report_usage = AsyncMock(return_value={"metered": True})
monkeypatch.setattr(workflow_run_billing_mod, "DEPLOYMENT_MODE", "saas")
monkeypatch.setattr(
workflow_run_billing_mod.db_client,
"get_workflow_run_by_id",
get_run,
)
monkeypatch.setattr(
workflow_run_billing_mod.mps_service_key_client,
"get_billing_account_status",
get_status,
)
monkeypatch.setattr(
workflow_run_billing_mod.mps_service_key_client,
"report_platform_usage",
report_usage,
)
await report_completed_workflow_run_platform_usage(workflow_run.id)
get_run.assert_awaited_once_with(workflow_run.id)
report_usage.assert_awaited_once()