mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-22 08:38:13 +02:00
fix: disable duplicate trigger nodes in workflow builder (#402)
* fix: disable duplicate trigger nodes in workflow builder AddNodePanel: disable trigger buttons and show tooltip when a trigger already exists on the canvas, using bySpecName to identify trigger- category specs from the live node list. useWorkflowState: preflight in saveWorkflow rejects saves with multiple trigger nodes via a sonner toast before the network request is made. text_chat_session_service: include the original exception message in TextChatSessionExecutionError so the HTTP 500 detail surfaces the root cause without DB inspection. Closes #378 * style: format test_text_chat_session_service.py with ruff * chore: retrigger CI checks * fix(workflow): enforce node instance constraints --------- Co-authored-by: Abhishek Kumar <abhishek@a6k.me>
This commit is contained in:
parent
7c31dd3eec
commit
7d053320df
27 changed files with 591 additions and 91 deletions
|
|
@ -375,7 +375,7 @@ async def create_campaign(
|
|||
if workflow_def:
|
||||
try:
|
||||
dto = ReactFlowDTO(**workflow_def)
|
||||
graph = WorkflowGraph(dto)
|
||||
graph = WorkflowGraph(dto, skip_instance_constraints_for={"trigger"})
|
||||
required_vars = graph.get_required_template_variables()
|
||||
|
||||
if (
|
||||
|
|
|
|||
|
|
@ -58,7 +58,10 @@ from api.services.workflow.trigger_paths import (
|
|||
trigger_path_to_node_id,
|
||||
validate_trigger_paths,
|
||||
)
|
||||
from api.services.workflow.workflow_graph import WorkflowGraph
|
||||
from api.services.workflow.workflow_graph import (
|
||||
WorkflowGraph,
|
||||
validate_node_instance_constraints,
|
||||
)
|
||||
from api.utils.artifacts import artifact_url
|
||||
from api.utils.recording_artifacts import (
|
||||
get_recording_storage_key,
|
||||
|
|
@ -192,6 +195,27 @@ def _validation_errors_http_exception(
|
|||
)
|
||||
|
||||
|
||||
def _node_instance_validation_errors(
|
||||
workflow_definition: Optional[dict],
|
||||
) -> list[WorkflowError]:
|
||||
"""Validate spec-driven max_instances without requiring a complete draft."""
|
||||
if not workflow_definition:
|
||||
return []
|
||||
nodes = workflow_definition.get("nodes")
|
||||
if not isinstance(nodes, list):
|
||||
return []
|
||||
|
||||
node_types = [
|
||||
node.get("type")
|
||||
for node in nodes
|
||||
if isinstance(node, dict) and isinstance(node.get("type"), str)
|
||||
]
|
||||
return validate_node_instance_constraints(
|
||||
node_types,
|
||||
enforce_min_instances=False,
|
||||
)
|
||||
|
||||
|
||||
class CallDispositionCodes(BaseModel):
|
||||
disposition_codes: list[str] = []
|
||||
|
||||
|
|
@ -384,6 +408,9 @@ async def create_workflow(
|
|||
trigger_path_issues = validate_trigger_paths(workflow_definition)
|
||||
if trigger_path_issues:
|
||||
raise _trigger_path_validation_http_exception(trigger_path_issues)
|
||||
instance_errors = _node_instance_validation_errors(workflow_definition)
|
||||
if instance_errors:
|
||||
raise _validation_errors_http_exception(instance_errors)
|
||||
|
||||
# Validate trigger path uniqueness BEFORE creating the workflow so we
|
||||
# don't leave an orphaned workflow record when the trigger conflicts.
|
||||
|
|
@ -990,6 +1017,9 @@ async def update_workflow(
|
|||
trigger_path_issues = validate_trigger_paths(workflow_definition)
|
||||
if trigger_path_issues:
|
||||
raise _trigger_path_validation_http_exception(trigger_path_issues)
|
||||
instance_errors = _node_instance_validation_errors(workflow_definition)
|
||||
if instance_errors:
|
||||
raise _validation_errors_http_exception(instance_errors, status_code=409)
|
||||
if workflow_definition:
|
||||
existing_workflow = await db_client.get_workflow(
|
||||
workflow_id, organization_id=user.selected_organization_id
|
||||
|
|
|
|||
|
|
@ -40,10 +40,7 @@ from api.services.workflow.node_specs.model_spec import (
|
|||
)
|
||||
],
|
||||
graph_constraints=GraphConstraints(
|
||||
min_incoming=0,
|
||||
max_incoming=0,
|
||||
min_outgoing=0,
|
||||
max_outgoing=0,
|
||||
min_incoming=0, max_incoming=0, min_outgoing=0, max_outgoing=0, max_instances=1
|
||||
),
|
||||
property_order=(
|
||||
"name",
|
||||
|
|
|
|||
|
|
@ -470,7 +470,10 @@ async def _run_pipeline(
|
|||
workflow_run_id, initial_context=merged_call_context_vars
|
||||
)
|
||||
|
||||
workflow_graph = WorkflowGraph(ReactFlowDTO.model_validate(run_workflow_json))
|
||||
workflow_graph = WorkflowGraph(
|
||||
ReactFlowDTO.model_validate(run_workflow_json),
|
||||
skip_instance_constraints_for={"trigger"},
|
||||
)
|
||||
|
||||
# Pre-call fetch: fire early so it runs concurrently with remaining setup
|
||||
pre_call_fetch_task = None
|
||||
|
|
|
|||
|
|
@ -7,15 +7,19 @@ script in `api/services/admin_utils/local_exec.py` is the production
|
|||
consumer.
|
||||
"""
|
||||
|
||||
from collections import Counter
|
||||
|
||||
from api.services.workflow.node_specs import all_specs
|
||||
|
||||
|
||||
def _build_type_rules() -> tuple[set[str], set[str]]:
|
||||
def _build_type_rules() -> tuple[set[str], set[str], dict[str, int], dict[str, int]]:
|
||||
"""From NodeSpec.graph_constraints, derive the set of types that are
|
||||
forbidden as edge sources (max_outgoing == 0) and as targets
|
||||
(max_incoming == 0)."""
|
||||
src_forbidden: set[str] = set()
|
||||
tgt_forbidden: set[str] = set()
|
||||
min_instances: dict[str, int] = {}
|
||||
max_instances: dict[str, int] = {}
|
||||
for spec in all_specs():
|
||||
gc = spec.graph_constraints
|
||||
if gc is None:
|
||||
|
|
@ -24,7 +28,11 @@ def _build_type_rules() -> tuple[set[str], set[str]]:
|
|||
src_forbidden.add(spec.name)
|
||||
if gc.max_incoming == 0:
|
||||
tgt_forbidden.add(spec.name)
|
||||
return src_forbidden, tgt_forbidden
|
||||
if gc.min_instances is not None:
|
||||
min_instances[spec.name] = gc.min_instances
|
||||
if gc.max_instances is not None:
|
||||
max_instances[spec.name] = gc.max_instances
|
||||
return src_forbidden, tgt_forbidden, min_instances, max_instances
|
||||
|
||||
|
||||
def _empty_violation(reason: str) -> dict:
|
||||
|
|
@ -49,7 +57,7 @@ def audit_definition(nodes, edges) -> list[dict]:
|
|||
if not isinstance(nodes, list) or not isinstance(edges, list):
|
||||
return []
|
||||
|
||||
src_forbidden, tgt_forbidden = _build_type_rules()
|
||||
src_forbidden, tgt_forbidden, min_instances, max_instances = _build_type_rules()
|
||||
nodes_by_id: dict = {}
|
||||
for n in nodes:
|
||||
if isinstance(n, dict) and "id" in n:
|
||||
|
|
@ -57,14 +65,25 @@ def audit_definition(nodes, edges) -> list[dict]:
|
|||
|
||||
violations: list[dict] = []
|
||||
|
||||
# Graph-level: WorkflowGraph._assert_start_node requires exactly one
|
||||
# startCall node. The DTO doesn't enforce this, so legacy or
|
||||
# script-edited rows can land in a state that fails at runtime.
|
||||
start_count = sum(1 for t in nodes_by_id.values() if t == "startCall")
|
||||
if start_count == 0:
|
||||
violations.append(_empty_violation("no_start_node"))
|
||||
elif start_count > 1:
|
||||
violations.append(_empty_violation(f"multiple_start_nodes:{start_count}"))
|
||||
node_counts = Counter(t for t in nodes_by_id.values() if isinstance(t, str))
|
||||
for node_type, min_count in min_instances.items():
|
||||
count = node_counts.get(node_type, 0)
|
||||
if count < min_count:
|
||||
reason = (
|
||||
"no_start_node"
|
||||
if node_type == "startCall" and min_count == 1
|
||||
else f"min_instances_{min_count}:{node_type}:{count}"
|
||||
)
|
||||
violations.append(_empty_violation(reason))
|
||||
for node_type, max_count in max_instances.items():
|
||||
count = node_counts.get(node_type, 0)
|
||||
if count > max_count:
|
||||
reason = (
|
||||
f"multiple_start_nodes:{count}"
|
||||
if node_type == "startCall" and max_count == 1
|
||||
else f"max_instances_{max_count}:{node_type}:{count}"
|
||||
)
|
||||
violations.append(_empty_violation(reason))
|
||||
for e in edges:
|
||||
if not isinstance(e, dict):
|
||||
continue
|
||||
|
|
|
|||
|
|
@ -196,7 +196,12 @@ class _ToolDocumentRefsMixin(BaseModel):
|
|||
},
|
||||
)
|
||||
],
|
||||
graph_constraints=GraphConstraints(min_incoming=0, max_incoming=0),
|
||||
graph_constraints=GraphConstraints(
|
||||
min_incoming=0,
|
||||
max_incoming=0,
|
||||
min_instances=1,
|
||||
max_instances=1,
|
||||
),
|
||||
property_order=(
|
||||
"name",
|
||||
"greeting_type",
|
||||
|
|
@ -539,6 +544,7 @@ class EndCallNodeData(
|
|||
max_incoming=0,
|
||||
min_outgoing=0,
|
||||
max_outgoing=0,
|
||||
max_instances=1,
|
||||
),
|
||||
property_order=("name", "prompt"),
|
||||
field_overrides={
|
||||
|
|
@ -597,7 +603,11 @@ class GlobalNodeData(BaseNodeData, _PromptedNodeDataMixin):
|
|||
examples=[
|
||||
NodeExample(name="default", data={"name": "Inbound Trigger", "enabled": True})
|
||||
],
|
||||
graph_constraints=GraphConstraints(min_incoming=0, max_incoming=0),
|
||||
graph_constraints=GraphConstraints(
|
||||
min_incoming=0,
|
||||
max_incoming=0,
|
||||
max_instances=1,
|
||||
),
|
||||
property_order=("name", "enabled", "trigger_path"),
|
||||
field_overrides={
|
||||
"name": {
|
||||
|
|
|
|||
|
|
@ -243,6 +243,8 @@ class GraphConstraints(BaseModel):
|
|||
max_incoming: Optional[int] = None
|
||||
min_outgoing: Optional[int] = None
|
||||
max_outgoing: Optional[int] = None
|
||||
min_instances: Optional[int] = None
|
||||
max_instances: Optional[int] = None
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
|
|
|
|||
|
|
@ -458,7 +458,8 @@ async def execute_text_chat_pending_turn(
|
|||
)
|
||||
|
||||
workflow_graph = WorkflowGraph(
|
||||
ReactFlowDTO.model_validate(run_definition.workflow_json)
|
||||
ReactFlowDTO.model_validate(run_definition.workflow_json),
|
||||
skip_instance_constraints_for={"trigger"},
|
||||
)
|
||||
base_checkpoint = _resolve_checkpoint_for_pending_turn(session_data, checkpoint)
|
||||
|
||||
|
|
|
|||
|
|
@ -201,7 +201,7 @@ async def execute_pending_text_chat_turn(
|
|||
error_message=str(e),
|
||||
)
|
||||
raise TextChatSessionExecutionError(
|
||||
"Failed to execute text chat assistant turn"
|
||||
f"Failed to execute text chat assistant turn: {e}"
|
||||
) from e
|
||||
|
||||
completed_session_data = normalize_text_chat_session_data(text_session.session_data)
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ from typing import Dict, List, Set
|
|||
from api.services.workflow.dto import EdgeDataDTO, NodeType, ReactFlowDTO
|
||||
from api.services.workflow.errors import ItemKind, WorkflowError
|
||||
from api.services.workflow.node_data import BaseNodeData
|
||||
from api.services.workflow.node_specs import get_spec
|
||||
from api.services.workflow.node_specs import all_specs, get_spec
|
||||
|
||||
# Regex for matching {{ variable }} template placeholders.
|
||||
# Captures: group(1) = variable path, group(2) = filter name, group(3) = filter value.
|
||||
|
|
@ -68,10 +68,11 @@ class Node:
|
|||
self.out: Dict[str, "Node"] = {} # forward nodes
|
||||
self.out_edges: List[Edge] = [] # forward edges with properties
|
||||
|
||||
# name/is_start/is_end live on every per-type data class (base).
|
||||
# Start/end semantics are defined by node type. The persisted
|
||||
# data flags are legacy UI/runtime state and may be stale.
|
||||
self.name = data.name
|
||||
self.is_start = data.is_start
|
||||
self.is_end = data.is_end
|
||||
self.is_start = node_type == NodeType.startNode.value
|
||||
self.is_end = node_type == NodeType.endNode.value
|
||||
|
||||
# Type-specific fields — read with getattr so this works for every
|
||||
# node variant in the discriminated union.
|
||||
|
|
@ -98,13 +99,89 @@ class Node:
|
|||
self.data = data
|
||||
|
||||
|
||||
def _instance_constraint_message(
|
||||
label: str,
|
||||
count: int,
|
||||
*,
|
||||
min_count: int | None = None,
|
||||
max_count: int | None = None,
|
||||
) -> str:
|
||||
if max_count is not None and count > max_count:
|
||||
if max_count == 1:
|
||||
return f"Workflow can have at most one {label}"
|
||||
return f"Workflow can have at most {max_count} {label} nodes"
|
||||
if min_count is not None and count < min_count:
|
||||
if min_count == 1:
|
||||
return f"Workflow must have at least one {label}"
|
||||
return f"Workflow must have at least {min_count} {label} nodes"
|
||||
return ""
|
||||
|
||||
|
||||
def validate_node_instance_constraints(
|
||||
node_types: list[str],
|
||||
*,
|
||||
enforce_min_instances: bool = True,
|
||||
skip_types: Set[str] | None = None,
|
||||
) -> list[WorkflowError]:
|
||||
"""Validate workflow-level node type counts from NodeSpec.graph_constraints."""
|
||||
errors: list[WorkflowError] = []
|
||||
skip_types = skip_types or set()
|
||||
counts = Counter(node_types)
|
||||
|
||||
for spec in all_specs():
|
||||
if spec.name in skip_types:
|
||||
continue
|
||||
gc = spec.graph_constraints
|
||||
if gc is None:
|
||||
continue
|
||||
|
||||
count = counts.get(spec.name, 0)
|
||||
if gc.max_instances is not None and count > gc.max_instances:
|
||||
errors.append(
|
||||
WorkflowError(
|
||||
kind=ItemKind.workflow,
|
||||
id=None,
|
||||
field=None,
|
||||
message=_instance_constraint_message(
|
||||
spec.display_name,
|
||||
count,
|
||||
max_count=gc.max_instances,
|
||||
),
|
||||
)
|
||||
)
|
||||
if (
|
||||
enforce_min_instances
|
||||
and gc.min_instances is not None
|
||||
and count < gc.min_instances
|
||||
):
|
||||
errors.append(
|
||||
WorkflowError(
|
||||
kind=ItemKind.workflow,
|
||||
id=None,
|
||||
field=None,
|
||||
message=_instance_constraint_message(
|
||||
spec.display_name,
|
||||
count,
|
||||
min_count=gc.min_instances,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
return errors
|
||||
|
||||
|
||||
class WorkflowGraph:
|
||||
"""
|
||||
*All* business invariants (acyclic, cardinality, etc.) are verified here.
|
||||
The constructor accepts a validated ReactFlowDTO.
|
||||
"""
|
||||
|
||||
def __init__(self, dto: ReactFlowDTO):
|
||||
def __init__(
|
||||
self,
|
||||
dto: ReactFlowDTO,
|
||||
*,
|
||||
skip_instance_constraints_for: Set[str] | None = None,
|
||||
):
|
||||
# Build adjacency list from validated DTO nodes. Core node comparisons
|
||||
# still use NodeType string enums; integration nodes remain plain
|
||||
# strings and resolve constraints through node specs.
|
||||
|
|
@ -131,10 +208,12 @@ class WorkflowGraph:
|
|||
# Set up the node references for backward compatibility
|
||||
source_node.out[target_node.id] = target_node
|
||||
|
||||
self._validate_graph()
|
||||
self._validate_graph(skip_instance_constraints_for or set())
|
||||
|
||||
# Get a reference to the start node
|
||||
self.start_node_id = [n.id for n in dto.nodes if n.data.is_start][0]
|
||||
self.start_node_id = [
|
||||
n.id for n in dto.nodes if n.type == NodeType.startNode.value
|
||||
][0]
|
||||
|
||||
# Get a reference to the global node
|
||||
try:
|
||||
|
|
@ -185,7 +264,7 @@ class WorkflowGraph:
|
|||
# -----------------------------------------------------------
|
||||
# validators
|
||||
# -----------------------------------------------------------
|
||||
def _validate_graph(self) -> None:
|
||||
def _validate_graph(self, skip_instance_constraints_for: Set[str]) -> None:
|
||||
errors: list[WorkflowError] = []
|
||||
|
||||
# TODO: Figure out what kind of cyclic contraints can be applied, since there can be a cycle in the graph
|
||||
|
|
@ -198,9 +277,13 @@ class WorkflowGraph:
|
|||
# )
|
||||
# )
|
||||
|
||||
errors.extend(self._assert_start_node())
|
||||
errors.extend(
|
||||
validate_node_instance_constraints(
|
||||
[n.node_type for n in self.nodes.values()],
|
||||
skip_types=skip_instance_constraints_for,
|
||||
)
|
||||
)
|
||||
errors.extend(self._assert_connection_counts())
|
||||
errors.extend(self._assert_global_node())
|
||||
errors.extend(self._assert_node_configs())
|
||||
if errors:
|
||||
raise ValueError(errors)
|
||||
|
|
@ -220,48 +303,6 @@ class WorkflowGraph:
|
|||
for n in self.nodes.values():
|
||||
dfs(n)
|
||||
|
||||
def _assert_start_node(self):
|
||||
errors: list[WorkflowError] = []
|
||||
start_nodes = [n for n in self.nodes.values() if n.data.is_start]
|
||||
if not start_nodes:
|
||||
errors.append(
|
||||
WorkflowError(
|
||||
kind=ItemKind.workflow,
|
||||
id=None,
|
||||
field=None,
|
||||
message="Workflow has no start node — exactly one is required",
|
||||
)
|
||||
)
|
||||
elif len(start_nodes) > 1:
|
||||
errors.append(
|
||||
WorkflowError(
|
||||
kind=ItemKind.workflow,
|
||||
id=None,
|
||||
field=None,
|
||||
message=(
|
||||
f"Workflow has {len(start_nodes)} start nodes — "
|
||||
f"exactly one is required"
|
||||
),
|
||||
)
|
||||
)
|
||||
return errors
|
||||
|
||||
def _assert_global_node(self):
|
||||
errors: list[WorkflowError] = []
|
||||
global_node = [
|
||||
n for n in self.nodes.values() if n.node_type == NodeType.globalNode.value
|
||||
]
|
||||
if not len(global_node) <= 1:
|
||||
errors.append(
|
||||
WorkflowError(
|
||||
kind=ItemKind.workflow,
|
||||
id=None,
|
||||
field=None,
|
||||
message="Workflow must have at most one global node",
|
||||
)
|
||||
)
|
||||
return errors
|
||||
|
||||
def _assert_connection_counts(self):
|
||||
"""Enforce per-type incoming/outgoing edge constraints.
|
||||
|
||||
|
|
|
|||
23
api/tests/dto_fixtures/multiple_global_nodes.json
Normal file
23
api/tests/dto_fixtures/multiple_global_nodes.json
Normal file
|
|
@ -0,0 +1,23 @@
|
|||
{
|
||||
"nodes": [
|
||||
{
|
||||
"id": "s1",
|
||||
"type": "startCall",
|
||||
"position": {"x": 0, "y": 0},
|
||||
"data": {"name": "Start", "prompt": "Greet.", "is_start": true}
|
||||
},
|
||||
{
|
||||
"id": "g1",
|
||||
"type": "globalNode",
|
||||
"position": {"x": 0, "y": 200},
|
||||
"data": {"name": "Global A", "prompt": "Use a calm tone."}
|
||||
},
|
||||
{
|
||||
"id": "g2",
|
||||
"type": "globalNode",
|
||||
"position": {"x": 0, "y": 400},
|
||||
"data": {"name": "Global B", "prompt": "Keep answers short."}
|
||||
}
|
||||
],
|
||||
"edges": []
|
||||
}
|
||||
23
api/tests/dto_fixtures/multiple_trigger_nodes.json
Normal file
23
api/tests/dto_fixtures/multiple_trigger_nodes.json
Normal file
|
|
@ -0,0 +1,23 @@
|
|||
{
|
||||
"nodes": [
|
||||
{
|
||||
"id": "s1",
|
||||
"type": "startCall",
|
||||
"position": {"x": 0, "y": 0},
|
||||
"data": {"name": "Start", "prompt": "Greet.", "is_start": true}
|
||||
},
|
||||
{
|
||||
"id": "t1",
|
||||
"type": "trigger",
|
||||
"position": {"x": 0, "y": 200},
|
||||
"data": {"name": "Trigger A", "trigger_path": "trigger_a"}
|
||||
},
|
||||
{
|
||||
"id": "t2",
|
||||
"type": "trigger",
|
||||
"position": {"x": 0, "y": 400},
|
||||
"data": {"name": "Trigger B", "trigger_path": "trigger_b"}
|
||||
}
|
||||
],
|
||||
"edges": []
|
||||
}
|
||||
66
api/tests/test_mcp_create_workflow.py
Normal file
66
api/tests/test_mcp_create_workflow.py
Normal file
|
|
@ -0,0 +1,66 @@
|
|||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from api.mcp_server.tools.create_workflow import create_workflow
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_workflow_rejects_duplicate_api_triggers():
|
||||
user = MagicMock()
|
||||
user.id = 1
|
||||
user.selected_organization_id = 1
|
||||
payload = {
|
||||
"nodes": [
|
||||
{
|
||||
"id": "start-1",
|
||||
"type": "startCall",
|
||||
"position": {"x": 0, "y": 0},
|
||||
"data": {"name": "Start", "prompt": "Greet."},
|
||||
},
|
||||
{
|
||||
"id": "trigger-1",
|
||||
"type": "trigger",
|
||||
"position": {"x": 0, "y": 200},
|
||||
"data": {"name": "Trigger A", "trigger_path": "support_west"},
|
||||
},
|
||||
{
|
||||
"id": "trigger-2",
|
||||
"type": "trigger",
|
||||
"position": {"x": 0, "y": 400},
|
||||
"data": {"name": "Trigger B", "trigger_path": "support_east"},
|
||||
},
|
||||
],
|
||||
"edges": [],
|
||||
}
|
||||
|
||||
with (
|
||||
patch(
|
||||
"api.mcp_server.tools.create_workflow.authenticate_mcp_request",
|
||||
AsyncMock(return_value=user),
|
||||
),
|
||||
patch(
|
||||
"api.mcp_server.tools.create_workflow.parse_code",
|
||||
AsyncMock(
|
||||
return_value={
|
||||
"ok": True,
|
||||
"workflowName": "duplicate-trigger-test",
|
||||
"workflow": payload,
|
||||
}
|
||||
),
|
||||
),
|
||||
patch(
|
||||
"api.mcp_server.tools.create_workflow.reconcile_positions",
|
||||
return_value=payload,
|
||||
),
|
||||
patch(
|
||||
"api.mcp_server.tools.create_workflow.db_client.create_workflow",
|
||||
AsyncMock(),
|
||||
) as create_mock,
|
||||
):
|
||||
result = await create_workflow(code="ignored")
|
||||
|
||||
assert result["created"] is False
|
||||
assert result["error_code"] == "graph_validation"
|
||||
assert "at most one API Trigger" in result["error"]
|
||||
create_mock.assert_not_awaited()
|
||||
|
|
@ -244,6 +244,58 @@ const only = wf.addTyped(endCall({ name: "only", prompt: "bye" }));
|
|||
update_mock.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_graph_validation_catches_duplicate_api_triggers(mock_backends):
|
||||
save_mock, update_mock = mock_backends
|
||||
payload = {
|
||||
"nodes": [
|
||||
{
|
||||
"id": "start-1",
|
||||
"type": "startCall",
|
||||
"position": {"x": 0, "y": 0},
|
||||
"data": {"name": "Start", "prompt": "Greet."},
|
||||
},
|
||||
{
|
||||
"id": "trigger-1",
|
||||
"type": "trigger",
|
||||
"position": {"x": 0, "y": 200},
|
||||
"data": {"name": "Trigger A", "trigger_path": "support_west"},
|
||||
},
|
||||
{
|
||||
"id": "trigger-2",
|
||||
"type": "trigger",
|
||||
"position": {"x": 0, "y": 400},
|
||||
"data": {"name": "Trigger B", "trigger_path": "support_east"},
|
||||
},
|
||||
],
|
||||
"edges": [],
|
||||
}
|
||||
|
||||
with (
|
||||
patch(
|
||||
"api.mcp_server.tools.save_workflow.parse_code",
|
||||
AsyncMock(
|
||||
return_value={
|
||||
"ok": True,
|
||||
"workflowName": _FakeWorkflowModel.name,
|
||||
"workflow": payload,
|
||||
}
|
||||
),
|
||||
),
|
||||
patch(
|
||||
"api.mcp_server.tools.save_workflow.reconcile_positions",
|
||||
return_value=payload,
|
||||
),
|
||||
):
|
||||
result = await save_workflow(workflow_id=1, code="ignored")
|
||||
|
||||
assert result["saved"] is False
|
||||
assert result["error_code"] == "graph_validation"
|
||||
assert "at most one API Trigger" in result["error"]
|
||||
save_mock.assert_not_awaited()
|
||||
update_mock.assert_not_awaited()
|
||||
|
||||
|
||||
# ─── Workflow not found / unauthorized ───────────────────────────────────
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -414,4 +414,9 @@ def test_to_mcp_dict_retains_authoring_signal_startcall():
|
|||
]
|
||||
|
||||
# graph_constraints drops its null sub-fields.
|
||||
assert projected["graph_constraints"] == {"min_incoming": 0, "max_incoming": 0}
|
||||
assert projected["graph_constraints"] == {
|
||||
"min_incoming": 0,
|
||||
"max_incoming": 0,
|
||||
"min_instances": 1,
|
||||
"max_instances": 1,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -42,15 +42,15 @@ async def test_whole_call_qa_tolerates_array_llm_response():
|
|||
"resolve_llm_config",
|
||||
new=AsyncMock(return_value=("openai", "gpt-4o", "sk-test", {})),
|
||||
),
|
||||
patch.object(qa_analysis, "create_llm_service_from_provider", return_value=object()),
|
||||
patch.object(
|
||||
qa_analysis, "create_llm_service_from_provider", return_value=object()
|
||||
),
|
||||
patch.object(
|
||||
qa_analysis,
|
||||
"_run_llm_inference",
|
||||
new=AsyncMock(return_value='["tag1", "tag2"]'),
|
||||
),
|
||||
patch.object(
|
||||
qa_analysis, "setup_langfuse_parent_context", return_value=None
|
||||
),
|
||||
patch.object(qa_analysis, "setup_langfuse_parent_context", return_value=None),
|
||||
patch.object(qa_analysis, "add_qa_span_to_trace", return_value=None),
|
||||
patch.object(qa_analysis.logger, "warning", warning_mock),
|
||||
):
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ from api.services.workflow.text_chat_session_service import (
|
|||
TextChatTurnNotFoundError,
|
||||
_reload_text_chat_session,
|
||||
build_pending_text_chat_turn,
|
||||
execute_pending_text_chat_turn,
|
||||
truncate_text_chat_future_turns,
|
||||
validate_text_chat_turn_cursor,
|
||||
)
|
||||
|
|
@ -77,6 +78,36 @@ async def test_reload_text_chat_session_uses_run_id_to_resolve_organization(
|
|||
get_text_session.assert_awaited_once_with(123, organization_id=77)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_pending_turn_surfaces_original_exception_message(monkeypatch):
|
||||
session = WorkflowRunTextSessionModel(workflow_run_id=42)
|
||||
session.session_data = {
|
||||
"turns": [{"id": "turn-1", "status": "pending"}],
|
||||
"cursor_turn_id": "turn-1",
|
||||
}
|
||||
session.checkpoint = None
|
||||
|
||||
monkeypatch.setattr(
|
||||
text_chat_session_service,
|
||||
"execute_text_chat_pending_turn",
|
||||
AsyncMock(side_effect=RuntimeError("Workflow has 2 start nodes")),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
text_chat_session_service,
|
||||
"_mark_pending_turn_failed",
|
||||
AsyncMock(),
|
||||
)
|
||||
|
||||
with pytest.raises(
|
||||
TextChatSessionExecutionError, match="Workflow has 2 start nodes"
|
||||
):
|
||||
await execute_pending_text_chat_turn(
|
||||
workflow_id=1,
|
||||
run_id=42,
|
||||
text_session=session,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reload_text_chat_session_raises_when_run_organization_is_missing(
|
||||
monkeypatch,
|
||||
|
|
|
|||
|
|
@ -47,3 +47,38 @@ def test_create_workflow_rejects_invalid_trigger_path_before_db_write():
|
|||
assert detail["errors"][0]["field"] == "data.trigger_path"
|
||||
assert "single URL path segment" in detail["errors"][0]["message"]
|
||||
assert mock_db.mock_calls == []
|
||||
|
||||
|
||||
def test_create_workflow_rejects_duplicate_api_triggers_before_db_write():
|
||||
app = _make_test_app()
|
||||
client = TestClient(app)
|
||||
|
||||
with patch("api.routes.workflow.db_client") as mock_db:
|
||||
response = client.post(
|
||||
"/workflow/create/definition",
|
||||
json={
|
||||
"name": "Support Agent",
|
||||
"workflow_definition": {
|
||||
"nodes": [
|
||||
{
|
||||
"id": "trigger-1",
|
||||
"type": "trigger",
|
||||
"data": {"trigger_path": "support_west"},
|
||||
},
|
||||
{
|
||||
"id": "trigger-2",
|
||||
"type": "trigger",
|
||||
"data": {"trigger_path": "support_east"},
|
||||
},
|
||||
],
|
||||
"edges": [],
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 422
|
||||
detail = response.json()["detail"]
|
||||
assert detail["is_valid"] is False
|
||||
assert detail["errors"][0]["kind"] == "workflow"
|
||||
assert "at most one API Trigger" in detail["errors"][0]["message"]
|
||||
assert mock_db.mock_calls == []
|
||||
|
|
|
|||
|
|
@ -72,14 +72,24 @@ _SCENARIOS = [
|
|||
(
|
||||
"no_start_node",
|
||||
["no_start_node"],
|
||||
["Workflow has no start node"],
|
||||
["Workflow must have at least one Start Call"],
|
||||
),
|
||||
# Two startCall nodes — surfaced separately from no_start_node so
|
||||
# the editor can show a count-specific message.
|
||||
(
|
||||
"multiple_start_nodes",
|
||||
["multiple_start_nodes:2"],
|
||||
["Workflow has 2 start nodes"],
|
||||
["Workflow can have at most one Start Call"],
|
||||
),
|
||||
(
|
||||
"multiple_trigger_nodes",
|
||||
["max_instances_1:trigger:2"],
|
||||
["Workflow can have at most one API Trigger"],
|
||||
),
|
||||
(
|
||||
"multiple_global_nodes",
|
||||
["max_instances_1:globalNode:2"],
|
||||
["Workflow can have at most one Global Node"],
|
||||
),
|
||||
]
|
||||
|
||||
|
|
@ -122,3 +132,35 @@ def test_workflow_graph_rejects_violations(name, expected_graph_messages):
|
|||
assert any(expected in m for m in actual_messages), (
|
||||
f"Expected substring {expected!r} not found in graph errors: {actual_messages}"
|
||||
)
|
||||
|
||||
|
||||
def test_workflow_graph_can_skip_duplicate_api_trigger_check_for_runtime():
|
||||
raw, _ = _load("multiple_trigger_nodes")
|
||||
dto = ReactFlowDTO.model_validate_json(raw)
|
||||
|
||||
WorkflowGraph(dto, skip_instance_constraints_for={"trigger"})
|
||||
|
||||
|
||||
def test_workflow_graph_start_semantics_come_from_node_type_not_legacy_flag():
|
||||
dto = ReactFlowDTO.model_validate(
|
||||
{
|
||||
"nodes": [
|
||||
{
|
||||
"id": "start-1",
|
||||
"type": "startCall",
|
||||
"position": {"x": 0, "y": 0},
|
||||
"data": {
|
||||
"name": "Start",
|
||||
"prompt": "Greet.",
|
||||
"is_start": False,
|
||||
},
|
||||
}
|
||||
],
|
||||
"edges": [],
|
||||
}
|
||||
)
|
||||
|
||||
graph = WorkflowGraph(dto)
|
||||
|
||||
assert graph.start_node_id == "start-1"
|
||||
assert graph.nodes["start-1"].is_start is True
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue