dograh/api/tests/test_workflow_run_billing.py
Abhishek d2cda85b78
fix: validate workflow status filter to prevent 500 on invalid enum value (#450)
* Validate workflow status filter to prevent 500 on invalid enum value

The /workflow/fetch and /workflow/summary endpoints accepted a free-form
status query param and passed it straight into a query that casts to the
workflow_status PG enum (active/archived). Any other value — e.g. an
external caller passing 'published' (a workflow_definitions version state,
not a workflow status) — failed deep in Postgres as
InvalidTextRepresentationError, surfacing as an unhandled HTTP 500.

Add _validate_status_filter() to reject values outside WorkflowStatus with
a clean 422 before any DB query, for both the single and comma-separated
paths. Add route tests covering invalid, valid-single, comma-separated, and
mixed valid/invalid cases.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>

* chore: add tests

---------

Co-authored-by: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
2026-06-18 08:39:59 +05:30

223 lines
6.7 KiB
Python

from types import SimpleNamespace
from unittest.mock import AsyncMock
import pytest
from api.services import workflow_run_billing as workflow_run_billing_mod
from api.services.workflow_run_billing import (
_is_usage_not_ready_error,
report_completed_workflow_run_platform_usage,
report_workflow_run_platform_usage,
)
def _make_workflow_run():
return SimpleNamespace(
id=123,
workflow_id=456,
is_completed=True,
initial_context={"mps_correlation_id": "mps-corr-123"},
usage_info={"call_duration_seconds": 87},
workflow=SimpleNamespace(
organization_id=42,
user=SimpleNamespace(selected_organization_id=42),
),
)
def test_is_usage_not_ready_error_detects_mps_409():
exc = Exception("Failed to report platform usage")
exc.response = SimpleNamespace(
status_code=409,
text='{"detail":"usage_not_ready"}',
)
assert _is_usage_not_ready_error(exc) is True
@pytest.mark.asyncio
async def test_report_workflow_run_platform_usage_reports_hosted_completion(
monkeypatch,
):
workflow_run = _make_workflow_run()
get_status = AsyncMock(return_value={"billing_mode": "v2"})
report_usage = AsyncMock(return_value={"metered": True})
monkeypatch.setattr(workflow_run_billing_mod, "DEPLOYMENT_MODE", "saas")
monkeypatch.setattr(
workflow_run_billing_mod.mps_service_key_client,
"get_billing_account_status",
get_status,
)
monkeypatch.setattr(
workflow_run_billing_mod.mps_service_key_client,
"report_platform_usage",
report_usage,
)
await report_workflow_run_platform_usage(workflow_run)
report_usage.assert_awaited_once_with(
organization_id=42,
correlation_id="mps-corr-123",
duration_seconds=None,
workflow_run_id=workflow_run.id,
metadata={
"source": "workflow_run_completion",
"workflow_id": workflow_run.workflow_id,
"duration_source": "mps_correlation",
},
)
@pytest.mark.asyncio
async def test_report_workflow_run_platform_usage_reports_duration_without_correlation(
monkeypatch,
):
workflow_run = _make_workflow_run()
workflow_run.initial_context = {}
get_status = AsyncMock(return_value={"billing_mode": "v2"})
report_usage = AsyncMock(return_value={"metered": True})
monkeypatch.setattr(workflow_run_billing_mod, "DEPLOYMENT_MODE", "saas")
monkeypatch.setattr(
workflow_run_billing_mod.mps_service_key_client,
"get_billing_account_status",
get_status,
)
monkeypatch.setattr(
workflow_run_billing_mod.mps_service_key_client,
"report_platform_usage",
report_usage,
)
await report_workflow_run_platform_usage(workflow_run)
report_usage.assert_awaited_once_with(
organization_id=42,
correlation_id=None,
duration_seconds=87.0,
workflow_run_id=workflow_run.id,
metadata={
"source": "workflow_run_completion",
"workflow_id": workflow_run.workflow_id,
"duration_source": "dograh_usage_info",
},
)
@pytest.mark.asyncio
async def test_report_workflow_run_platform_usage_skips_non_v2_account(monkeypatch):
workflow_run = _make_workflow_run()
get_status = AsyncMock(return_value={"billing_mode": "v1"})
report_usage = AsyncMock()
monkeypatch.setattr(workflow_run_billing_mod, "DEPLOYMENT_MODE", "saas")
monkeypatch.setattr(
workflow_run_billing_mod.mps_service_key_client,
"get_billing_account_status",
get_status,
)
monkeypatch.setattr(
workflow_run_billing_mod.mps_service_key_client,
"report_platform_usage",
report_usage,
)
await report_workflow_run_platform_usage(workflow_run)
get_status.assert_awaited_once_with(organization_id=42)
report_usage.assert_not_awaited()
@pytest.mark.asyncio
async def test_report_workflow_run_platform_usage_skips_missing_duration_without_correlation(
monkeypatch,
):
workflow_run = _make_workflow_run()
workflow_run.initial_context = {}
workflow_run.usage_info = {}
get_status = AsyncMock(return_value={"billing_mode": "v2"})
report_usage = AsyncMock()
monkeypatch.setattr(workflow_run_billing_mod, "DEPLOYMENT_MODE", "saas")
monkeypatch.setattr(
workflow_run_billing_mod.mps_service_key_client,
"get_billing_account_status",
get_status,
)
monkeypatch.setattr(
workflow_run_billing_mod.mps_service_key_client,
"report_platform_usage",
report_usage,
)
await report_workflow_run_platform_usage(workflow_run)
get_status.assert_not_awaited()
report_usage.assert_not_awaited()
@pytest.mark.asyncio
async def test_report_workflow_run_platform_usage_skips_oss(monkeypatch):
workflow_run = _make_workflow_run()
report_usage = AsyncMock()
monkeypatch.setattr(workflow_run_billing_mod, "DEPLOYMENT_MODE", "oss")
monkeypatch.setattr(
workflow_run_billing_mod.mps_service_key_client,
"report_platform_usage",
report_usage,
)
await report_workflow_run_platform_usage(workflow_run)
report_usage.assert_not_awaited()
@pytest.mark.asyncio
async def test_report_workflow_run_platform_usage_skips_incomplete(monkeypatch):
workflow_run = _make_workflow_run()
workflow_run.is_completed = False
report_usage = AsyncMock()
monkeypatch.setattr(workflow_run_billing_mod, "DEPLOYMENT_MODE", "saas")
monkeypatch.setattr(
workflow_run_billing_mod.mps_service_key_client,
"report_platform_usage",
report_usage,
)
await report_workflow_run_platform_usage(workflow_run)
report_usage.assert_not_awaited()
@pytest.mark.asyncio
async def test_report_completed_workflow_run_platform_usage_loads_run(monkeypatch):
workflow_run = _make_workflow_run()
get_run = AsyncMock(return_value=workflow_run)
get_status = AsyncMock(return_value={"billing_mode": "v2"})
report_usage = AsyncMock(return_value={"metered": True})
monkeypatch.setattr(workflow_run_billing_mod, "DEPLOYMENT_MODE", "saas")
monkeypatch.setattr(
workflow_run_billing_mod.db_client,
"get_workflow_run_by_id",
get_run,
)
monkeypatch.setattr(
workflow_run_billing_mod.mps_service_key_client,
"get_billing_account_status",
get_status,
)
monkeypatch.setattr(
workflow_run_billing_mod.mps_service_key_client,
"report_platform_usage",
report_usage,
)
await report_completed_workflow_run_platform_usage(workflow_run.id)
get_run.assert_awaited_once_with(workflow_run.id)
report_usage.assert_awaited_once()