mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-19 08:28:10 +02:00
fix: validate workflow status filter to prevent 500 on invalid enum value (#450)
* Validate workflow status filter to prevent 500 on invalid enum value The /workflow/fetch and /workflow/summary endpoints accepted a free-form status query param and passed it straight into a query that casts to the workflow_status PG enum (active/archived). Any other value — e.g. an external caller passing 'published' (a workflow_definitions version state, not a workflow status) — failed deep in Postgres as InvalidTextRepresentationError, surfacing as an unhandled HTTP 500. Add _validate_status_filter() to reject values outside WorkflowStatus with a clean 422 before any DB query, for both the single and comma-separated paths. Add route tests covering invalid, valid-single, comma-separated, and mixed valid/invalid cases. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> * chore: add tests --------- Co-authored-by: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
This commit is contained in:
parent
9a1b980f91
commit
d2cda85b78
5 changed files with 186 additions and 26 deletions
|
|
@ -15,7 +15,7 @@ from api.db import db_client
|
|||
from api.db.agent_trigger_client import TriggerPathConflictError
|
||||
from api.db.models import UserModel
|
||||
from api.db.workflow_template_client import WorkflowTemplateClient
|
||||
from api.enums import CallType, PostHogEvent, StorageBackend
|
||||
from api.enums import CallType, PostHogEvent, StorageBackend, WorkflowStatus
|
||||
from api.schemas.ai_model_configuration import OrganizationAIModelConfigurationV2
|
||||
from api.schemas.workflow import WorkflowRunResponseSchema
|
||||
from api.sdk_expose import sdk_expose
|
||||
|
|
@ -578,6 +578,31 @@ async def get_workflow_count(
|
|||
)
|
||||
|
||||
|
||||
def _validate_status_filter(status: Optional[str]) -> List[str]:
|
||||
"""Parse and validate a workflow ``status`` query filter.
|
||||
|
||||
Accepts a single value or a comma-separated list. Returns the list of
|
||||
validated status values (empty when no filter was supplied). Any value
|
||||
outside the ``workflow_status`` enum raises 422 so the request fails as a
|
||||
clean client error instead of a 500 from the Postgres enum cast.
|
||||
"""
|
||||
if status is None or status == "":
|
||||
return []
|
||||
allowed = {s.value for s in WorkflowStatus}
|
||||
requested = [s.strip() for s in status.split(",")]
|
||||
invalid = sorted({s for s in requested if s not in allowed})
|
||||
if invalid:
|
||||
invalid_display = ["<empty>" if s == "" else s for s in invalid]
|
||||
raise HTTPException(
|
||||
status_code=422,
|
||||
detail=(
|
||||
f"Invalid workflow status filter: {invalid_display}. "
|
||||
f"Allowed values: {sorted(allowed)}."
|
||||
),
|
||||
)
|
||||
return requested
|
||||
|
||||
|
||||
@router.get(
|
||||
"/fetch",
|
||||
**sdk_expose(
|
||||
|
|
@ -597,21 +622,22 @@ async def get_workflows(
|
|||
Returns a lightweight response with only essential fields for listing.
|
||||
Use GET /workflow/fetch/{workflow_id} to get full workflow details.
|
||||
"""
|
||||
# Handle comma-separated status values
|
||||
if status and "," in status:
|
||||
# Split comma-separated values and fetch workflows for each status
|
||||
status_list = [s.strip() for s in status.split(",")]
|
||||
statuses = _validate_status_filter(status)
|
||||
if statuses:
|
||||
# Fetch workflows for each requested status and combine the results.
|
||||
all_workflows = []
|
||||
for status_value in status_list:
|
||||
workflows = await db_client.get_all_workflows_for_listing(
|
||||
organization_id=user.selected_organization_id, status=status_value
|
||||
for status_value in statuses:
|
||||
all_workflows.extend(
|
||||
await db_client.get_all_workflows_for_listing(
|
||||
organization_id=user.selected_organization_id,
|
||||
status=status_value,
|
||||
)
|
||||
)
|
||||
all_workflows.extend(workflows)
|
||||
workflows = all_workflows
|
||||
else:
|
||||
# Single status or no status filter
|
||||
# No status filter
|
||||
workflows = await db_client.get_all_workflows_for_listing(
|
||||
organization_id=user.selected_organization_id, status=status
|
||||
organization_id=user.selected_organization_id, status=None
|
||||
)
|
||||
|
||||
# Get run counts for all workflows in a single query
|
||||
|
|
@ -820,10 +846,20 @@ async def get_workflows_summary(
|
|||
),
|
||||
) -> List[WorkflowSummaryResponse]:
|
||||
"""Get minimal workflow information (id and name only) for all workflows"""
|
||||
workflows = await db_client.get_all_workflows(
|
||||
organization_id=user.selected_organization_id,
|
||||
status=status,
|
||||
)
|
||||
statuses = _validate_status_filter(status)
|
||||
if statuses:
|
||||
workflows = []
|
||||
for status_value in statuses:
|
||||
workflows.extend(
|
||||
await db_client.get_all_workflows(
|
||||
organization_id=user.selected_organization_id,
|
||||
status=status_value,
|
||||
)
|
||||
)
|
||||
else:
|
||||
workflows = await db_client.get_all_workflows(
|
||||
organization_id=user.selected_organization_id, status=None
|
||||
)
|
||||
return [
|
||||
WorkflowSummaryResponse(id=workflow.id, name=workflow.name)
|
||||
for workflow in workflows
|
||||
|
|
|
|||
|
|
@ -601,16 +601,15 @@ class MPSServiceKeyClient:
|
|||
if response.status_code == 200:
|
||||
return response.json()
|
||||
|
||||
should_retry = (
|
||||
response.status_code == 409
|
||||
and "usage_not_ready" in response.text
|
||||
and attempt < max_attempts
|
||||
usage_not_ready = (
|
||||
response.status_code == 409 and "usage_not_ready" in response.text
|
||||
)
|
||||
if should_retry:
|
||||
if usage_not_ready and attempt < max_attempts:
|
||||
await asyncio.sleep(attempt)
|
||||
continue
|
||||
|
||||
logger.error(
|
||||
log = logger.warning if usage_not_ready else logger.error
|
||||
log(
|
||||
"Failed to report platform usage: "
|
||||
f"{response.status_code} - {response.text}"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -39,6 +39,13 @@ async def _organization_uses_mps_billing_v2(organization_id: int) -> bool:
|
|||
return bool(account and account.get("billing_mode") == "v2")
|
||||
|
||||
|
||||
def _is_usage_not_ready_error(exc: Exception) -> bool:
|
||||
response = getattr(exc, "response", None)
|
||||
if getattr(response, "status_code", None) != 409:
|
||||
return False
|
||||
return "usage_not_ready" in (getattr(response, "text", "") or "")
|
||||
|
||||
|
||||
async def report_workflow_run_platform_usage(workflow_run) -> None:
|
||||
"""Report hosted platform usage for a completed workflow run to MPS."""
|
||||
if DEPLOYMENT_MODE == "oss":
|
||||
|
|
@ -91,11 +98,21 @@ async def report_workflow_run_platform_usage(workflow_run) -> None:
|
|||
result,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to report platform usage for workflow run {}: {}",
|
||||
workflow_run.id,
|
||||
e,
|
||||
)
|
||||
if _is_usage_not_ready_error(e):
|
||||
# A run can start and receive an MPS correlation id, then fail or end
|
||||
# before billable STT usage is recorded. MPS returns usage_not_ready
|
||||
# for that no-platform-fee path, so keep it out of error alerts.
|
||||
logger.warning(
|
||||
"Failed to report platform usage for workflow run {}: {}",
|
||||
workflow_run.id,
|
||||
e,
|
||||
)
|
||||
else:
|
||||
logger.error(
|
||||
"Failed to report platform usage for workflow run {}: {}",
|
||||
workflow_run.id,
|
||||
e,
|
||||
)
|
||||
|
||||
|
||||
async def report_completed_workflow_run_platform_usage(workflow_run_id: int) -> None:
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ from datetime import datetime, timezone
|
|||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
|
|
@ -50,3 +51,99 @@ def test_workflow_fetch_list_includes_workflow_uuid():
|
|||
"workflow_uuid": workflow.workflow_uuid,
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
def test_workflow_fetch_invalid_status_returns_422_without_db_query():
|
||||
"""A status outside the workflow_status enum (e.g. 'published') must fail
|
||||
as a clean 422 instead of a 500 from the Postgres enum cast."""
|
||||
app = _make_test_app()
|
||||
client = TestClient(app)
|
||||
|
||||
with patch("api.routes.workflow.db_client") as mock_db:
|
||||
mock_db.get_all_workflows_for_listing = AsyncMock()
|
||||
mock_db.get_workflow_run_counts = AsyncMock()
|
||||
|
||||
response = client.get("/workflow/fetch?status=published")
|
||||
|
||||
assert response.status_code == 422
|
||||
assert "published" in response.json()["detail"]
|
||||
# The invalid value must never reach the database layer.
|
||||
mock_db.get_all_workflows_for_listing.assert_not_called()
|
||||
|
||||
|
||||
def test_workflow_fetch_valid_single_status_passes_through():
|
||||
app = _make_test_app()
|
||||
client = TestClient(app)
|
||||
|
||||
with patch("api.routes.workflow.db_client") as mock_db:
|
||||
mock_db.get_all_workflows_for_listing = AsyncMock(return_value=[])
|
||||
mock_db.get_workflow_run_counts = AsyncMock(return_value={})
|
||||
|
||||
response = client.get("/workflow/fetch?status=active")
|
||||
|
||||
assert response.status_code == 200
|
||||
mock_db.get_all_workflows_for_listing.assert_awaited_once_with(
|
||||
organization_id=11, status="active"
|
||||
)
|
||||
|
||||
|
||||
def test_workflow_fetch_comma_separated_status_queries_each_value():
|
||||
app = _make_test_app()
|
||||
client = TestClient(app)
|
||||
|
||||
with patch("api.routes.workflow.db_client") as mock_db:
|
||||
mock_db.get_all_workflows_for_listing = AsyncMock(return_value=[])
|
||||
mock_db.get_workflow_run_counts = AsyncMock(return_value={})
|
||||
|
||||
response = client.get("/workflow/fetch?status=active,archived")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert mock_db.get_all_workflows_for_listing.await_count == 2
|
||||
statuses = {
|
||||
call.kwargs["status"]
|
||||
for call in mock_db.get_all_workflows_for_listing.await_args_list
|
||||
}
|
||||
assert statuses == {"active", "archived"}
|
||||
|
||||
|
||||
def test_workflow_fetch_mixed_valid_and_invalid_status_returns_422():
|
||||
app = _make_test_app()
|
||||
client = TestClient(app)
|
||||
|
||||
with patch("api.routes.workflow.db_client") as mock_db:
|
||||
mock_db.get_all_workflows_for_listing = AsyncMock()
|
||||
mock_db.get_workflow_run_counts = AsyncMock()
|
||||
|
||||
response = client.get("/workflow/fetch?status=active,published")
|
||||
|
||||
assert response.status_code == 422
|
||||
mock_db.get_all_workflows_for_listing.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("status", [" ", ",", "active,,archived"])
|
||||
def test_workflow_fetch_blank_status_token_returns_422_without_db_query(status: str):
|
||||
app = _make_test_app()
|
||||
client = TestClient(app)
|
||||
|
||||
with patch("api.routes.workflow.db_client") as mock_db:
|
||||
mock_db.get_all_workflows_for_listing = AsyncMock()
|
||||
mock_db.get_workflow_run_counts = AsyncMock()
|
||||
|
||||
response = client.get("/workflow/fetch", params={"status": status})
|
||||
|
||||
assert response.status_code == 422
|
||||
assert "<empty>" in response.json()["detail"]
|
||||
mock_db.get_all_workflows_for_listing.assert_not_called()
|
||||
|
||||
|
||||
def test_workflow_summary_blank_status_token_returns_422_without_db_query():
|
||||
app = _make_test_app()
|
||||
client = TestClient(app)
|
||||
|
||||
with patch("api.routes.workflow.db_client") as mock_db:
|
||||
mock_db.get_all_workflows = AsyncMock()
|
||||
|
||||
response = client.get("/workflow/summary", params={"status": ","})
|
||||
|
||||
assert response.status_code == 422
|
||||
mock_db.get_all_workflows.assert_not_called()
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ import pytest
|
|||
|
||||
from api.services import workflow_run_billing as workflow_run_billing_mod
|
||||
from api.services.workflow_run_billing import (
|
||||
_is_usage_not_ready_error,
|
||||
report_completed_workflow_run_platform_usage,
|
||||
report_workflow_run_platform_usage,
|
||||
)
|
||||
|
|
@ -24,6 +25,16 @@ def _make_workflow_run():
|
|||
)
|
||||
|
||||
|
||||
def test_is_usage_not_ready_error_detects_mps_409():
|
||||
exc = Exception("Failed to report platform usage")
|
||||
exc.response = SimpleNamespace(
|
||||
status_code=409,
|
||||
text='{"detail":"usage_not_ready"}',
|
||||
)
|
||||
|
||||
assert _is_usage_not_ready_error(exc) is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_report_workflow_run_platform_usage_reports_hosted_completion(
|
||||
monkeypatch,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue