fix: validate workflow status filter to prevent 500 on invalid enum value (#450)

* Validate workflow status filter to prevent 500 on invalid enum value

The /workflow/fetch and /workflow/summary endpoints accepted a free-form
status query param and passed it straight into a query that casts to the
workflow_status PG enum (active/archived). Any other value — e.g. an
external caller passing 'published' (a workflow_definitions version state,
not a workflow status) — failed deep in Postgres as
InvalidTextRepresentationError, surfacing as an unhandled HTTP 500.

Add _validate_status_filter() to reject values outside WorkflowStatus with
a clean 422 before any DB query, for both the single and comma-separated
paths. Add route tests covering invalid, valid-single, comma-separated, and
mixed valid/invalid cases.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>

* chore: add tests

---------

Co-authored-by: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
This commit is contained in:
Abhishek 2026-06-18 08:39:59 +05:30 committed by GitHub
parent 9a1b980f91
commit d2cda85b78
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 186 additions and 26 deletions

View file

@ -15,7 +15,7 @@ from api.db import db_client
from api.db.agent_trigger_client import TriggerPathConflictError
from api.db.models import UserModel
from api.db.workflow_template_client import WorkflowTemplateClient
from api.enums import CallType, PostHogEvent, StorageBackend
from api.enums import CallType, PostHogEvent, StorageBackend, WorkflowStatus
from api.schemas.ai_model_configuration import OrganizationAIModelConfigurationV2
from api.schemas.workflow import WorkflowRunResponseSchema
from api.sdk_expose import sdk_expose
@ -578,6 +578,31 @@ async def get_workflow_count(
)
def _validate_status_filter(status: Optional[str]) -> List[str]:
"""Parse and validate a workflow ``status`` query filter.
Accepts a single value or a comma-separated list. Returns the list of
validated status values (empty when no filter was supplied). Any value
outside the ``workflow_status`` enum raises 422 so the request fails as a
clean client error instead of a 500 from the Postgres enum cast.
"""
if status is None or status == "":
return []
allowed = {s.value for s in WorkflowStatus}
requested = [s.strip() for s in status.split(",")]
invalid = sorted({s for s in requested if s not in allowed})
if invalid:
invalid_display = ["<empty>" if s == "" else s for s in invalid]
raise HTTPException(
status_code=422,
detail=(
f"Invalid workflow status filter: {invalid_display}. "
f"Allowed values: {sorted(allowed)}."
),
)
return requested
@router.get(
"/fetch",
**sdk_expose(
@ -597,21 +622,22 @@ async def get_workflows(
Returns a lightweight response with only essential fields for listing.
Use GET /workflow/fetch/{workflow_id} to get full workflow details.
"""
# Handle comma-separated status values
if status and "," in status:
# Split comma-separated values and fetch workflows for each status
status_list = [s.strip() for s in status.split(",")]
statuses = _validate_status_filter(status)
if statuses:
# Fetch workflows for each requested status and combine the results.
all_workflows = []
for status_value in status_list:
workflows = await db_client.get_all_workflows_for_listing(
organization_id=user.selected_organization_id, status=status_value
for status_value in statuses:
all_workflows.extend(
await db_client.get_all_workflows_for_listing(
organization_id=user.selected_organization_id,
status=status_value,
)
)
all_workflows.extend(workflows)
workflows = all_workflows
else:
# Single status or no status filter
# No status filter
workflows = await db_client.get_all_workflows_for_listing(
organization_id=user.selected_organization_id, status=status
organization_id=user.selected_organization_id, status=None
)
# Get run counts for all workflows in a single query
@ -820,10 +846,20 @@ async def get_workflows_summary(
),
) -> List[WorkflowSummaryResponse]:
"""Get minimal workflow information (id and name only) for all workflows"""
workflows = await db_client.get_all_workflows(
organization_id=user.selected_organization_id,
status=status,
)
statuses = _validate_status_filter(status)
if statuses:
workflows = []
for status_value in statuses:
workflows.extend(
await db_client.get_all_workflows(
organization_id=user.selected_organization_id,
status=status_value,
)
)
else:
workflows = await db_client.get_all_workflows(
organization_id=user.selected_organization_id, status=None
)
return [
WorkflowSummaryResponse(id=workflow.id, name=workflow.name)
for workflow in workflows

View file

@ -601,16 +601,15 @@ class MPSServiceKeyClient:
if response.status_code == 200:
return response.json()
should_retry = (
response.status_code == 409
and "usage_not_ready" in response.text
and attempt < max_attempts
usage_not_ready = (
response.status_code == 409 and "usage_not_ready" in response.text
)
if should_retry:
if usage_not_ready and attempt < max_attempts:
await asyncio.sleep(attempt)
continue
logger.error(
log = logger.warning if usage_not_ready else logger.error
log(
"Failed to report platform usage: "
f"{response.status_code} - {response.text}"
)

View file

@ -39,6 +39,13 @@ async def _organization_uses_mps_billing_v2(organization_id: int) -> bool:
return bool(account and account.get("billing_mode") == "v2")
def _is_usage_not_ready_error(exc: Exception) -> bool:
response = getattr(exc, "response", None)
if getattr(response, "status_code", None) != 409:
return False
return "usage_not_ready" in (getattr(response, "text", "") or "")
async def report_workflow_run_platform_usage(workflow_run) -> None:
"""Report hosted platform usage for a completed workflow run to MPS."""
if DEPLOYMENT_MODE == "oss":
@ -91,11 +98,21 @@ async def report_workflow_run_platform_usage(workflow_run) -> None:
result,
)
except Exception as e:
logger.error(
"Failed to report platform usage for workflow run {}: {}",
workflow_run.id,
e,
)
if _is_usage_not_ready_error(e):
# A run can start and receive an MPS correlation id, then fail or end
# before billable STT usage is recorded. MPS returns usage_not_ready
# for that no-platform-fee path, so keep it out of error alerts.
logger.warning(
"Failed to report platform usage for workflow run {}: {}",
workflow_run.id,
e,
)
else:
logger.error(
"Failed to report platform usage for workflow run {}: {}",
workflow_run.id,
e,
)
async def report_completed_workflow_run_platform_usage(workflow_run_id: int) -> None:

View file

@ -2,6 +2,7 @@ from datetime import datetime, timezone
from types import SimpleNamespace
from unittest.mock import AsyncMock, patch
import pytest
from fastapi import FastAPI
from fastapi.testclient import TestClient
@ -50,3 +51,99 @@ def test_workflow_fetch_list_includes_workflow_uuid():
"workflow_uuid": workflow.workflow_uuid,
}
]
def test_workflow_fetch_invalid_status_returns_422_without_db_query():
"""A status outside the workflow_status enum (e.g. 'published') must fail
as a clean 422 instead of a 500 from the Postgres enum cast."""
app = _make_test_app()
client = TestClient(app)
with patch("api.routes.workflow.db_client") as mock_db:
mock_db.get_all_workflows_for_listing = AsyncMock()
mock_db.get_workflow_run_counts = AsyncMock()
response = client.get("/workflow/fetch?status=published")
assert response.status_code == 422
assert "published" in response.json()["detail"]
# The invalid value must never reach the database layer.
mock_db.get_all_workflows_for_listing.assert_not_called()
def test_workflow_fetch_valid_single_status_passes_through():
app = _make_test_app()
client = TestClient(app)
with patch("api.routes.workflow.db_client") as mock_db:
mock_db.get_all_workflows_for_listing = AsyncMock(return_value=[])
mock_db.get_workflow_run_counts = AsyncMock(return_value={})
response = client.get("/workflow/fetch?status=active")
assert response.status_code == 200
mock_db.get_all_workflows_for_listing.assert_awaited_once_with(
organization_id=11, status="active"
)
def test_workflow_fetch_comma_separated_status_queries_each_value():
app = _make_test_app()
client = TestClient(app)
with patch("api.routes.workflow.db_client") as mock_db:
mock_db.get_all_workflows_for_listing = AsyncMock(return_value=[])
mock_db.get_workflow_run_counts = AsyncMock(return_value={})
response = client.get("/workflow/fetch?status=active,archived")
assert response.status_code == 200
assert mock_db.get_all_workflows_for_listing.await_count == 2
statuses = {
call.kwargs["status"]
for call in mock_db.get_all_workflows_for_listing.await_args_list
}
assert statuses == {"active", "archived"}
def test_workflow_fetch_mixed_valid_and_invalid_status_returns_422():
app = _make_test_app()
client = TestClient(app)
with patch("api.routes.workflow.db_client") as mock_db:
mock_db.get_all_workflows_for_listing = AsyncMock()
mock_db.get_workflow_run_counts = AsyncMock()
response = client.get("/workflow/fetch?status=active,published")
assert response.status_code == 422
mock_db.get_all_workflows_for_listing.assert_not_called()
@pytest.mark.parametrize("status", [" ", ",", "active,,archived"])
def test_workflow_fetch_blank_status_token_returns_422_without_db_query(status: str):
app = _make_test_app()
client = TestClient(app)
with patch("api.routes.workflow.db_client") as mock_db:
mock_db.get_all_workflows_for_listing = AsyncMock()
mock_db.get_workflow_run_counts = AsyncMock()
response = client.get("/workflow/fetch", params={"status": status})
assert response.status_code == 422
assert "<empty>" in response.json()["detail"]
mock_db.get_all_workflows_for_listing.assert_not_called()
def test_workflow_summary_blank_status_token_returns_422_without_db_query():
app = _make_test_app()
client = TestClient(app)
with patch("api.routes.workflow.db_client") as mock_db:
mock_db.get_all_workflows = AsyncMock()
response = client.get("/workflow/summary", params={"status": ","})
assert response.status_code == 422
mock_db.get_all_workflows.assert_not_called()

View file

@ -5,6 +5,7 @@ import pytest
from api.services import workflow_run_billing as workflow_run_billing_mod
from api.services.workflow_run_billing import (
_is_usage_not_ready_error,
report_completed_workflow_run_platform_usage,
report_workflow_run_platform_usage,
)
@ -24,6 +25,16 @@ def _make_workflow_run():
)
def test_is_usage_not_ready_error_detects_mps_409():
exc = Exception("Failed to report platform usage")
exc.response = SimpleNamespace(
status_code=409,
text='{"detail":"usage_not_ready"}',
)
assert _is_usage_not_ready_error(exc) is True
@pytest.mark.asyncio
async def test_report_workflow_run_platform_usage_reports_hosted_completion(
monkeypatch,