fix: disable duplicate trigger nodes in workflow builder (#402)

* fix: disable duplicate trigger nodes in workflow builder

AddNodePanel: disable trigger buttons and show tooltip when a trigger
already exists on the canvas, using bySpecName to identify trigger-
category specs from the live node list.
useWorkflowState: preflight in saveWorkflow rejects saves with multiple
trigger nodes via a sonner toast before the network request is made.
text_chat_session_service: include the original exception message in
TextChatSessionExecutionError so the HTTP 500 detail surfaces the root
cause without DB inspection.

Closes #378

* style: format test_text_chat_session_service.py with ruff

* chore: retrigger CI checks

* fix(workflow): enforce node instance constraints

---------

Co-authored-by: Abhishek Kumar <abhishek@a6k.me>
This commit is contained in:
nuthalapativarun 2026-06-19 03:29:30 -07:00 committed by GitHub
parent 7c31dd3eec
commit 7d053320df
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
27 changed files with 591 additions and 91 deletions

View file

@ -40,10 +40,7 @@ from api.services.workflow.node_specs.model_spec import (
)
],
graph_constraints=GraphConstraints(
min_incoming=0,
max_incoming=0,
min_outgoing=0,
max_outgoing=0,
min_incoming=0, max_incoming=0, min_outgoing=0, max_outgoing=0, max_instances=1
),
property_order=(
"name",

View file

@ -470,7 +470,10 @@ async def _run_pipeline(
workflow_run_id, initial_context=merged_call_context_vars
)
workflow_graph = WorkflowGraph(ReactFlowDTO.model_validate(run_workflow_json))
workflow_graph = WorkflowGraph(
ReactFlowDTO.model_validate(run_workflow_json),
skip_instance_constraints_for={"trigger"},
)
# Pre-call fetch: fire early so it runs concurrently with remaining setup
pre_call_fetch_task = None

View file

@ -7,15 +7,19 @@ script in `api/services/admin_utils/local_exec.py` is the production
consumer.
"""
from collections import Counter
from api.services.workflow.node_specs import all_specs
def _build_type_rules() -> tuple[set[str], set[str]]:
def _build_type_rules() -> tuple[set[str], set[str], dict[str, int], dict[str, int]]:
"""From NodeSpec.graph_constraints, derive the set of types that are
forbidden as edge sources (max_outgoing == 0) and as targets
(max_incoming == 0)."""
src_forbidden: set[str] = set()
tgt_forbidden: set[str] = set()
min_instances: dict[str, int] = {}
max_instances: dict[str, int] = {}
for spec in all_specs():
gc = spec.graph_constraints
if gc is None:
@ -24,7 +28,11 @@ def _build_type_rules() -> tuple[set[str], set[str]]:
src_forbidden.add(spec.name)
if gc.max_incoming == 0:
tgt_forbidden.add(spec.name)
return src_forbidden, tgt_forbidden
if gc.min_instances is not None:
min_instances[spec.name] = gc.min_instances
if gc.max_instances is not None:
max_instances[spec.name] = gc.max_instances
return src_forbidden, tgt_forbidden, min_instances, max_instances
def _empty_violation(reason: str) -> dict:
@ -49,7 +57,7 @@ def audit_definition(nodes, edges) -> list[dict]:
if not isinstance(nodes, list) or not isinstance(edges, list):
return []
src_forbidden, tgt_forbidden = _build_type_rules()
src_forbidden, tgt_forbidden, min_instances, max_instances = _build_type_rules()
nodes_by_id: dict = {}
for n in nodes:
if isinstance(n, dict) and "id" in n:
@ -57,14 +65,25 @@ def audit_definition(nodes, edges) -> list[dict]:
violations: list[dict] = []
# Graph-level: WorkflowGraph._assert_start_node requires exactly one
# startCall node. The DTO doesn't enforce this, so legacy or
# script-edited rows can land in a state that fails at runtime.
start_count = sum(1 for t in nodes_by_id.values() if t == "startCall")
if start_count == 0:
violations.append(_empty_violation("no_start_node"))
elif start_count > 1:
violations.append(_empty_violation(f"multiple_start_nodes:{start_count}"))
node_counts = Counter(t for t in nodes_by_id.values() if isinstance(t, str))
for node_type, min_count in min_instances.items():
count = node_counts.get(node_type, 0)
if count < min_count:
reason = (
"no_start_node"
if node_type == "startCall" and min_count == 1
else f"min_instances_{min_count}:{node_type}:{count}"
)
violations.append(_empty_violation(reason))
for node_type, max_count in max_instances.items():
count = node_counts.get(node_type, 0)
if count > max_count:
reason = (
f"multiple_start_nodes:{count}"
if node_type == "startCall" and max_count == 1
else f"max_instances_{max_count}:{node_type}:{count}"
)
violations.append(_empty_violation(reason))
for e in edges:
if not isinstance(e, dict):
continue

View file

@ -196,7 +196,12 @@ class _ToolDocumentRefsMixin(BaseModel):
},
)
],
graph_constraints=GraphConstraints(min_incoming=0, max_incoming=0),
graph_constraints=GraphConstraints(
min_incoming=0,
max_incoming=0,
min_instances=1,
max_instances=1,
),
property_order=(
"name",
"greeting_type",
@ -539,6 +544,7 @@ class EndCallNodeData(
max_incoming=0,
min_outgoing=0,
max_outgoing=0,
max_instances=1,
),
property_order=("name", "prompt"),
field_overrides={
@ -597,7 +603,11 @@ class GlobalNodeData(BaseNodeData, _PromptedNodeDataMixin):
examples=[
NodeExample(name="default", data={"name": "Inbound Trigger", "enabled": True})
],
graph_constraints=GraphConstraints(min_incoming=0, max_incoming=0),
graph_constraints=GraphConstraints(
min_incoming=0,
max_incoming=0,
max_instances=1,
),
property_order=("name", "enabled", "trigger_path"),
field_overrides={
"name": {

View file

@ -243,6 +243,8 @@ class GraphConstraints(BaseModel):
max_incoming: Optional[int] = None
min_outgoing: Optional[int] = None
max_outgoing: Optional[int] = None
min_instances: Optional[int] = None
max_instances: Optional[int] = None
model_config = ConfigDict(extra="forbid")

View file

@ -458,7 +458,8 @@ async def execute_text_chat_pending_turn(
)
workflow_graph = WorkflowGraph(
ReactFlowDTO.model_validate(run_definition.workflow_json)
ReactFlowDTO.model_validate(run_definition.workflow_json),
skip_instance_constraints_for={"trigger"},
)
base_checkpoint = _resolve_checkpoint_for_pending_turn(session_data, checkpoint)

View file

@ -201,7 +201,7 @@ async def execute_pending_text_chat_turn(
error_message=str(e),
)
raise TextChatSessionExecutionError(
"Failed to execute text chat assistant turn"
f"Failed to execute text chat assistant turn: {e}"
) from e
completed_session_data = normalize_text_chat_session_data(text_session.session_data)

View file

@ -5,7 +5,7 @@ from typing import Dict, List, Set
from api.services.workflow.dto import EdgeDataDTO, NodeType, ReactFlowDTO
from api.services.workflow.errors import ItemKind, WorkflowError
from api.services.workflow.node_data import BaseNodeData
from api.services.workflow.node_specs import get_spec
from api.services.workflow.node_specs import all_specs, get_spec
# Regex for matching {{ variable }} template placeholders.
# Captures: group(1) = variable path, group(2) = filter name, group(3) = filter value.
@ -68,10 +68,11 @@ class Node:
self.out: Dict[str, "Node"] = {} # forward nodes
self.out_edges: List[Edge] = [] # forward edges with properties
# name/is_start/is_end live on every per-type data class (base).
# Start/end semantics are defined by node type. The persisted
# data flags are legacy UI/runtime state and may be stale.
self.name = data.name
self.is_start = data.is_start
self.is_end = data.is_end
self.is_start = node_type == NodeType.startNode.value
self.is_end = node_type == NodeType.endNode.value
# Type-specific fields — read with getattr so this works for every
# node variant in the discriminated union.
@ -98,13 +99,89 @@ class Node:
self.data = data
def _instance_constraint_message(
label: str,
count: int,
*,
min_count: int | None = None,
max_count: int | None = None,
) -> str:
if max_count is not None and count > max_count:
if max_count == 1:
return f"Workflow can have at most one {label}"
return f"Workflow can have at most {max_count} {label} nodes"
if min_count is not None and count < min_count:
if min_count == 1:
return f"Workflow must have at least one {label}"
return f"Workflow must have at least {min_count} {label} nodes"
return ""
def validate_node_instance_constraints(
node_types: list[str],
*,
enforce_min_instances: bool = True,
skip_types: Set[str] | None = None,
) -> list[WorkflowError]:
"""Validate workflow-level node type counts from NodeSpec.graph_constraints."""
errors: list[WorkflowError] = []
skip_types = skip_types or set()
counts = Counter(node_types)
for spec in all_specs():
if spec.name in skip_types:
continue
gc = spec.graph_constraints
if gc is None:
continue
count = counts.get(spec.name, 0)
if gc.max_instances is not None and count > gc.max_instances:
errors.append(
WorkflowError(
kind=ItemKind.workflow,
id=None,
field=None,
message=_instance_constraint_message(
spec.display_name,
count,
max_count=gc.max_instances,
),
)
)
if (
enforce_min_instances
and gc.min_instances is not None
and count < gc.min_instances
):
errors.append(
WorkflowError(
kind=ItemKind.workflow,
id=None,
field=None,
message=_instance_constraint_message(
spec.display_name,
count,
min_count=gc.min_instances,
),
)
)
return errors
class WorkflowGraph:
"""
*All* business invariants (acyclic, cardinality, etc.) are verified here.
The constructor accepts a validated ReactFlowDTO.
"""
def __init__(self, dto: ReactFlowDTO):
def __init__(
self,
dto: ReactFlowDTO,
*,
skip_instance_constraints_for: Set[str] | None = None,
):
# Build adjacency list from validated DTO nodes. Core node comparisons
# still use NodeType string enums; integration nodes remain plain
# strings and resolve constraints through node specs.
@ -131,10 +208,12 @@ class WorkflowGraph:
# Set up the node references for backward compatibility
source_node.out[target_node.id] = target_node
self._validate_graph()
self._validate_graph(skip_instance_constraints_for or set())
# Get a reference to the start node
self.start_node_id = [n.id for n in dto.nodes if n.data.is_start][0]
self.start_node_id = [
n.id for n in dto.nodes if n.type == NodeType.startNode.value
][0]
# Get a reference to the global node
try:
@ -185,7 +264,7 @@ class WorkflowGraph:
# -----------------------------------------------------------
# validators
# -----------------------------------------------------------
def _validate_graph(self) -> None:
def _validate_graph(self, skip_instance_constraints_for: Set[str]) -> None:
errors: list[WorkflowError] = []
# TODO: Figure out what kind of cyclic contraints can be applied, since there can be a cycle in the graph
@ -198,9 +277,13 @@ class WorkflowGraph:
# )
# )
errors.extend(self._assert_start_node())
errors.extend(
validate_node_instance_constraints(
[n.node_type for n in self.nodes.values()],
skip_types=skip_instance_constraints_for,
)
)
errors.extend(self._assert_connection_counts())
errors.extend(self._assert_global_node())
errors.extend(self._assert_node_configs())
if errors:
raise ValueError(errors)
@ -220,48 +303,6 @@ class WorkflowGraph:
for n in self.nodes.values():
dfs(n)
def _assert_start_node(self):
errors: list[WorkflowError] = []
start_nodes = [n for n in self.nodes.values() if n.data.is_start]
if not start_nodes:
errors.append(
WorkflowError(
kind=ItemKind.workflow,
id=None,
field=None,
message="Workflow has no start node — exactly one is required",
)
)
elif len(start_nodes) > 1:
errors.append(
WorkflowError(
kind=ItemKind.workflow,
id=None,
field=None,
message=(
f"Workflow has {len(start_nodes)} start nodes — "
f"exactly one is required"
),
)
)
return errors
def _assert_global_node(self):
errors: list[WorkflowError] = []
global_node = [
n for n in self.nodes.values() if n.node_type == NodeType.globalNode.value
]
if not len(global_node) <= 1:
errors.append(
WorkflowError(
kind=ItemKind.workflow,
id=None,
field=None,
message="Workflow must have at most one global node",
)
)
return errors
def _assert_connection_counts(self):
"""Enforce per-type incoming/outgoing edge constraints.