diff --git a/api/routes/workflow.py b/api/routes/workflow.py index 15541674..c106a2ee 100644 --- a/api/routes/workflow.py +++ b/api/routes/workflow.py @@ -15,7 +15,7 @@ from api.db import db_client from api.db.agent_trigger_client import TriggerPathConflictError from api.db.models import UserModel from api.db.workflow_template_client import WorkflowTemplateClient -from api.enums import CallType, PostHogEvent, StorageBackend +from api.enums import CallType, PostHogEvent, StorageBackend, WorkflowStatus from api.schemas.ai_model_configuration import OrganizationAIModelConfigurationV2 from api.schemas.workflow import WorkflowRunResponseSchema from api.sdk_expose import sdk_expose @@ -578,6 +578,31 @@ async def get_workflow_count( ) +def _validate_status_filter(status: Optional[str]) -> List[str]: + """Parse and validate a workflow ``status`` query filter. + + Accepts a single value or a comma-separated list. Returns the list of + validated status values (empty when no filter was supplied). Any value + outside the ``workflow_status`` enum raises 422 so the request fails as a + clean client error instead of a 500 from the Postgres enum cast. + """ + if status is None or status == "": + return [] + allowed = {s.value for s in WorkflowStatus} + requested = [s.strip() for s in status.split(",")] + invalid = sorted({s for s in requested if s not in allowed}) + if invalid: + invalid_display = ["" if s == "" else s for s in invalid] + raise HTTPException( + status_code=422, + detail=( + f"Invalid workflow status filter: {invalid_display}. " + f"Allowed values: {sorted(allowed)}." + ), + ) + return requested + + @router.get( "/fetch", **sdk_expose( @@ -597,21 +622,22 @@ async def get_workflows( Returns a lightweight response with only essential fields for listing. Use GET /workflow/fetch/{workflow_id} to get full workflow details. """ - # Handle comma-separated status values - if status and "," in status: - # Split comma-separated values and fetch workflows for each status - status_list = [s.strip() for s in status.split(",")] + statuses = _validate_status_filter(status) + if statuses: + # Fetch workflows for each requested status and combine the results. all_workflows = [] - for status_value in status_list: - workflows = await db_client.get_all_workflows_for_listing( - organization_id=user.selected_organization_id, status=status_value + for status_value in statuses: + all_workflows.extend( + await db_client.get_all_workflows_for_listing( + organization_id=user.selected_organization_id, + status=status_value, + ) ) - all_workflows.extend(workflows) workflows = all_workflows else: - # Single status or no status filter + # No status filter workflows = await db_client.get_all_workflows_for_listing( - organization_id=user.selected_organization_id, status=status + organization_id=user.selected_organization_id, status=None ) # Get run counts for all workflows in a single query @@ -820,10 +846,20 @@ async def get_workflows_summary( ), ) -> List[WorkflowSummaryResponse]: """Get minimal workflow information (id and name only) for all workflows""" - workflows = await db_client.get_all_workflows( - organization_id=user.selected_organization_id, - status=status, - ) + statuses = _validate_status_filter(status) + if statuses: + workflows = [] + for status_value in statuses: + workflows.extend( + await db_client.get_all_workflows( + organization_id=user.selected_organization_id, + status=status_value, + ) + ) + else: + workflows = await db_client.get_all_workflows( + organization_id=user.selected_organization_id, status=None + ) return [ WorkflowSummaryResponse(id=workflow.id, name=workflow.name) for workflow in workflows diff --git a/api/services/mps_service_key_client.py b/api/services/mps_service_key_client.py index 5f90380f..87b95fde 100644 --- a/api/services/mps_service_key_client.py +++ b/api/services/mps_service_key_client.py @@ -601,16 +601,15 @@ class MPSServiceKeyClient: if response.status_code == 200: return response.json() - should_retry = ( - response.status_code == 409 - and "usage_not_ready" in response.text - and attempt < max_attempts + usage_not_ready = ( + response.status_code == 409 and "usage_not_ready" in response.text ) - if should_retry: + if usage_not_ready and attempt < max_attempts: await asyncio.sleep(attempt) continue - logger.error( + log = logger.warning if usage_not_ready else logger.error + log( "Failed to report platform usage: " f"{response.status_code} - {response.text}" ) diff --git a/api/services/workflow_run_billing.py b/api/services/workflow_run_billing.py index ab8a3121..2c61dc8b 100644 --- a/api/services/workflow_run_billing.py +++ b/api/services/workflow_run_billing.py @@ -39,6 +39,13 @@ async def _organization_uses_mps_billing_v2(organization_id: int) -> bool: return bool(account and account.get("billing_mode") == "v2") +def _is_usage_not_ready_error(exc: Exception) -> bool: + response = getattr(exc, "response", None) + if getattr(response, "status_code", None) != 409: + return False + return "usage_not_ready" in (getattr(response, "text", "") or "") + + async def report_workflow_run_platform_usage(workflow_run) -> None: """Report hosted platform usage for a completed workflow run to MPS.""" if DEPLOYMENT_MODE == "oss": @@ -91,11 +98,21 @@ async def report_workflow_run_platform_usage(workflow_run) -> None: result, ) except Exception as e: - logger.error( - "Failed to report platform usage for workflow run {}: {}", - workflow_run.id, - e, - ) + if _is_usage_not_ready_error(e): + # A run can start and receive an MPS correlation id, then fail or end + # before billable STT usage is recorded. MPS returns usage_not_ready + # for that no-platform-fee path, so keep it out of error alerts. + logger.warning( + "Failed to report platform usage for workflow run {}: {}", + workflow_run.id, + e, + ) + else: + logger.error( + "Failed to report platform usage for workflow run {}: {}", + workflow_run.id, + e, + ) async def report_completed_workflow_run_platform_usage(workflow_run_id: int) -> None: diff --git a/api/tests/test_workflow_list_route.py b/api/tests/test_workflow_list_route.py index 0f1864b4..dcc2ddd5 100644 --- a/api/tests/test_workflow_list_route.py +++ b/api/tests/test_workflow_list_route.py @@ -2,6 +2,7 @@ from datetime import datetime, timezone from types import SimpleNamespace from unittest.mock import AsyncMock, patch +import pytest from fastapi import FastAPI from fastapi.testclient import TestClient @@ -50,3 +51,99 @@ def test_workflow_fetch_list_includes_workflow_uuid(): "workflow_uuid": workflow.workflow_uuid, } ] + + +def test_workflow_fetch_invalid_status_returns_422_without_db_query(): + """A status outside the workflow_status enum (e.g. 'published') must fail + as a clean 422 instead of a 500 from the Postgres enum cast.""" + app = _make_test_app() + client = TestClient(app) + + with patch("api.routes.workflow.db_client") as mock_db: + mock_db.get_all_workflows_for_listing = AsyncMock() + mock_db.get_workflow_run_counts = AsyncMock() + + response = client.get("/workflow/fetch?status=published") + + assert response.status_code == 422 + assert "published" in response.json()["detail"] + # The invalid value must never reach the database layer. + mock_db.get_all_workflows_for_listing.assert_not_called() + + +def test_workflow_fetch_valid_single_status_passes_through(): + app = _make_test_app() + client = TestClient(app) + + with patch("api.routes.workflow.db_client") as mock_db: + mock_db.get_all_workflows_for_listing = AsyncMock(return_value=[]) + mock_db.get_workflow_run_counts = AsyncMock(return_value={}) + + response = client.get("/workflow/fetch?status=active") + + assert response.status_code == 200 + mock_db.get_all_workflows_for_listing.assert_awaited_once_with( + organization_id=11, status="active" + ) + + +def test_workflow_fetch_comma_separated_status_queries_each_value(): + app = _make_test_app() + client = TestClient(app) + + with patch("api.routes.workflow.db_client") as mock_db: + mock_db.get_all_workflows_for_listing = AsyncMock(return_value=[]) + mock_db.get_workflow_run_counts = AsyncMock(return_value={}) + + response = client.get("/workflow/fetch?status=active,archived") + + assert response.status_code == 200 + assert mock_db.get_all_workflows_for_listing.await_count == 2 + statuses = { + call.kwargs["status"] + for call in mock_db.get_all_workflows_for_listing.await_args_list + } + assert statuses == {"active", "archived"} + + +def test_workflow_fetch_mixed_valid_and_invalid_status_returns_422(): + app = _make_test_app() + client = TestClient(app) + + with patch("api.routes.workflow.db_client") as mock_db: + mock_db.get_all_workflows_for_listing = AsyncMock() + mock_db.get_workflow_run_counts = AsyncMock() + + response = client.get("/workflow/fetch?status=active,published") + + assert response.status_code == 422 + mock_db.get_all_workflows_for_listing.assert_not_called() + + +@pytest.mark.parametrize("status", [" ", ",", "active,,archived"]) +def test_workflow_fetch_blank_status_token_returns_422_without_db_query(status: str): + app = _make_test_app() + client = TestClient(app) + + with patch("api.routes.workflow.db_client") as mock_db: + mock_db.get_all_workflows_for_listing = AsyncMock() + mock_db.get_workflow_run_counts = AsyncMock() + + response = client.get("/workflow/fetch", params={"status": status}) + + assert response.status_code == 422 + assert "" in response.json()["detail"] + mock_db.get_all_workflows_for_listing.assert_not_called() + + +def test_workflow_summary_blank_status_token_returns_422_without_db_query(): + app = _make_test_app() + client = TestClient(app) + + with patch("api.routes.workflow.db_client") as mock_db: + mock_db.get_all_workflows = AsyncMock() + + response = client.get("/workflow/summary", params={"status": ","}) + + assert response.status_code == 422 + mock_db.get_all_workflows.assert_not_called() diff --git a/api/tests/test_workflow_run_billing.py b/api/tests/test_workflow_run_billing.py index 2837317f..1dbe1828 100644 --- a/api/tests/test_workflow_run_billing.py +++ b/api/tests/test_workflow_run_billing.py @@ -5,6 +5,7 @@ import pytest from api.services import workflow_run_billing as workflow_run_billing_mod from api.services.workflow_run_billing import ( + _is_usage_not_ready_error, report_completed_workflow_run_platform_usage, report_workflow_run_platform_usage, ) @@ -24,6 +25,16 @@ def _make_workflow_run(): ) +def test_is_usage_not_ready_error_detects_mps_409(): + exc = Exception("Failed to report platform usage") + exc.response = SimpleNamespace( + status_code=409, + text='{"detail":"usage_not_ready"}', + ) + + assert _is_usage_not_ready_error(exc) is True + + @pytest.mark.asyncio async def test_report_workflow_run_platform_usage_reports_hosted_completion( monkeypatch,