from types import SimpleNamespace from unittest.mock import AsyncMock import pytest from api.services import workflow_run_billing as workflow_run_billing_mod from api.services.workflow_run_billing import ( report_completed_workflow_run_platform_usage, report_workflow_run_platform_usage, ) def _make_workflow_run(): return SimpleNamespace( id=123, workflow_id=456, is_completed=True, initial_context={"mps_correlation_id": "mps-corr-123"}, usage_info={"call_duration_seconds": 87}, workflow=SimpleNamespace( organization_id=42, user=SimpleNamespace(selected_organization_id=42), ), ) @pytest.mark.asyncio async def test_report_workflow_run_platform_usage_reports_hosted_completion( monkeypatch, ): workflow_run = _make_workflow_run() get_status = AsyncMock(return_value={"billing_mode": "v2"}) report_usage = AsyncMock(return_value={"metered": True}) monkeypatch.setattr(workflow_run_billing_mod, "DEPLOYMENT_MODE", "saas") monkeypatch.setattr( workflow_run_billing_mod.mps_service_key_client, "get_billing_account_status", get_status, ) monkeypatch.setattr( workflow_run_billing_mod.mps_service_key_client, "report_platform_usage", report_usage, ) await report_workflow_run_platform_usage(workflow_run) report_usage.assert_awaited_once_with( organization_id=42, correlation_id="mps-corr-123", duration_seconds=None, workflow_run_id=workflow_run.id, metadata={ "source": "workflow_run_completion", "workflow_id": workflow_run.workflow_id, "duration_source": "mps_correlation", }, ) @pytest.mark.asyncio async def test_report_workflow_run_platform_usage_reports_duration_without_correlation( monkeypatch, ): workflow_run = _make_workflow_run() workflow_run.initial_context = {} get_status = AsyncMock(return_value={"billing_mode": "v2"}) report_usage = AsyncMock(return_value={"metered": True}) monkeypatch.setattr(workflow_run_billing_mod, "DEPLOYMENT_MODE", "saas") monkeypatch.setattr( workflow_run_billing_mod.mps_service_key_client, "get_billing_account_status", get_status, ) monkeypatch.setattr( workflow_run_billing_mod.mps_service_key_client, "report_platform_usage", report_usage, ) await report_workflow_run_platform_usage(workflow_run) report_usage.assert_awaited_once_with( organization_id=42, correlation_id=None, duration_seconds=87.0, workflow_run_id=workflow_run.id, metadata={ "source": "workflow_run_completion", "workflow_id": workflow_run.workflow_id, "duration_source": "dograh_usage_info", }, ) @pytest.mark.asyncio async def test_report_workflow_run_platform_usage_skips_non_v2_account(monkeypatch): workflow_run = _make_workflow_run() get_status = AsyncMock(return_value={"billing_mode": "v1"}) report_usage = AsyncMock() monkeypatch.setattr(workflow_run_billing_mod, "DEPLOYMENT_MODE", "saas") monkeypatch.setattr( workflow_run_billing_mod.mps_service_key_client, "get_billing_account_status", get_status, ) monkeypatch.setattr( workflow_run_billing_mod.mps_service_key_client, "report_platform_usage", report_usage, ) await report_workflow_run_platform_usage(workflow_run) get_status.assert_awaited_once_with(organization_id=42) report_usage.assert_not_awaited() @pytest.mark.asyncio async def test_report_workflow_run_platform_usage_skips_missing_duration_without_correlation( monkeypatch, ): workflow_run = _make_workflow_run() workflow_run.initial_context = {} workflow_run.usage_info = {} get_status = AsyncMock(return_value={"billing_mode": "v2"}) report_usage = AsyncMock() monkeypatch.setattr(workflow_run_billing_mod, "DEPLOYMENT_MODE", "saas") monkeypatch.setattr( workflow_run_billing_mod.mps_service_key_client, "get_billing_account_status", get_status, ) monkeypatch.setattr( workflow_run_billing_mod.mps_service_key_client, "report_platform_usage", report_usage, ) await report_workflow_run_platform_usage(workflow_run) get_status.assert_not_awaited() report_usage.assert_not_awaited() @pytest.mark.asyncio async def test_report_workflow_run_platform_usage_skips_oss(monkeypatch): workflow_run = _make_workflow_run() report_usage = AsyncMock() monkeypatch.setattr(workflow_run_billing_mod, "DEPLOYMENT_MODE", "oss") monkeypatch.setattr( workflow_run_billing_mod.mps_service_key_client, "report_platform_usage", report_usage, ) await report_workflow_run_platform_usage(workflow_run) report_usage.assert_not_awaited() @pytest.mark.asyncio async def test_report_workflow_run_platform_usage_skips_incomplete(monkeypatch): workflow_run = _make_workflow_run() workflow_run.is_completed = False report_usage = AsyncMock() monkeypatch.setattr(workflow_run_billing_mod, "DEPLOYMENT_MODE", "saas") monkeypatch.setattr( workflow_run_billing_mod.mps_service_key_client, "report_platform_usage", report_usage, ) await report_workflow_run_platform_usage(workflow_run) report_usage.assert_not_awaited() @pytest.mark.asyncio async def test_report_completed_workflow_run_platform_usage_loads_run(monkeypatch): workflow_run = _make_workflow_run() get_run = AsyncMock(return_value=workflow_run) get_status = AsyncMock(return_value={"billing_mode": "v2"}) report_usage = AsyncMock(return_value={"metered": True}) monkeypatch.setattr(workflow_run_billing_mod, "DEPLOYMENT_MODE", "saas") monkeypatch.setattr( workflow_run_billing_mod.db_client, "get_workflow_run_by_id", get_run, ) monkeypatch.setattr( workflow_run_billing_mod.mps_service_key_client, "get_billing_account_status", get_status, ) monkeypatch.setattr( workflow_run_billing_mod.mps_service_key_client, "report_platform_usage", report_usage, ) await report_completed_workflow_run_platform_usage(workflow_run.id) get_run.assert_awaited_once_with(workflow_run.id) report_usage.assert_awaited_once()