mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-22 08:38:13 +02:00
fix: disable duplicate trigger nodes in workflow builder (#402)
* fix: disable duplicate trigger nodes in workflow builder AddNodePanel: disable trigger buttons and show tooltip when a trigger already exists on the canvas, using bySpecName to identify trigger- category specs from the live node list. useWorkflowState: preflight in saveWorkflow rejects saves with multiple trigger nodes via a sonner toast before the network request is made. text_chat_session_service: include the original exception message in TextChatSessionExecutionError so the HTTP 500 detail surfaces the root cause without DB inspection. Closes #378 * style: format test_text_chat_session_service.py with ruff * chore: retrigger CI checks * fix(workflow): enforce node instance constraints --------- Co-authored-by: Abhishek Kumar <abhishek@a6k.me>
This commit is contained in:
parent
7c31dd3eec
commit
7d053320df
27 changed files with 591 additions and 91 deletions
|
|
@ -7,15 +7,19 @@ script in `api/services/admin_utils/local_exec.py` is the production
|
|||
consumer.
|
||||
"""
|
||||
|
||||
from collections import Counter
|
||||
|
||||
from api.services.workflow.node_specs import all_specs
|
||||
|
||||
|
||||
def _build_type_rules() -> tuple[set[str], set[str]]:
|
||||
def _build_type_rules() -> tuple[set[str], set[str], dict[str, int], dict[str, int]]:
|
||||
"""From NodeSpec.graph_constraints, derive the set of types that are
|
||||
forbidden as edge sources (max_outgoing == 0) and as targets
|
||||
(max_incoming == 0)."""
|
||||
src_forbidden: set[str] = set()
|
||||
tgt_forbidden: set[str] = set()
|
||||
min_instances: dict[str, int] = {}
|
||||
max_instances: dict[str, int] = {}
|
||||
for spec in all_specs():
|
||||
gc = spec.graph_constraints
|
||||
if gc is None:
|
||||
|
|
@ -24,7 +28,11 @@ def _build_type_rules() -> tuple[set[str], set[str]]:
|
|||
src_forbidden.add(spec.name)
|
||||
if gc.max_incoming == 0:
|
||||
tgt_forbidden.add(spec.name)
|
||||
return src_forbidden, tgt_forbidden
|
||||
if gc.min_instances is not None:
|
||||
min_instances[spec.name] = gc.min_instances
|
||||
if gc.max_instances is not None:
|
||||
max_instances[spec.name] = gc.max_instances
|
||||
return src_forbidden, tgt_forbidden, min_instances, max_instances
|
||||
|
||||
|
||||
def _empty_violation(reason: str) -> dict:
|
||||
|
|
@ -49,7 +57,7 @@ def audit_definition(nodes, edges) -> list[dict]:
|
|||
if not isinstance(nodes, list) or not isinstance(edges, list):
|
||||
return []
|
||||
|
||||
src_forbidden, tgt_forbidden = _build_type_rules()
|
||||
src_forbidden, tgt_forbidden, min_instances, max_instances = _build_type_rules()
|
||||
nodes_by_id: dict = {}
|
||||
for n in nodes:
|
||||
if isinstance(n, dict) and "id" in n:
|
||||
|
|
@ -57,14 +65,25 @@ def audit_definition(nodes, edges) -> list[dict]:
|
|||
|
||||
violations: list[dict] = []
|
||||
|
||||
# Graph-level: WorkflowGraph._assert_start_node requires exactly one
|
||||
# startCall node. The DTO doesn't enforce this, so legacy or
|
||||
# script-edited rows can land in a state that fails at runtime.
|
||||
start_count = sum(1 for t in nodes_by_id.values() if t == "startCall")
|
||||
if start_count == 0:
|
||||
violations.append(_empty_violation("no_start_node"))
|
||||
elif start_count > 1:
|
||||
violations.append(_empty_violation(f"multiple_start_nodes:{start_count}"))
|
||||
node_counts = Counter(t for t in nodes_by_id.values() if isinstance(t, str))
|
||||
for node_type, min_count in min_instances.items():
|
||||
count = node_counts.get(node_type, 0)
|
||||
if count < min_count:
|
||||
reason = (
|
||||
"no_start_node"
|
||||
if node_type == "startCall" and min_count == 1
|
||||
else f"min_instances_{min_count}:{node_type}:{count}"
|
||||
)
|
||||
violations.append(_empty_violation(reason))
|
||||
for node_type, max_count in max_instances.items():
|
||||
count = node_counts.get(node_type, 0)
|
||||
if count > max_count:
|
||||
reason = (
|
||||
f"multiple_start_nodes:{count}"
|
||||
if node_type == "startCall" and max_count == 1
|
||||
else f"max_instances_{max_count}:{node_type}:{count}"
|
||||
)
|
||||
violations.append(_empty_violation(reason))
|
||||
for e in edges:
|
||||
if not isinstance(e, dict):
|
||||
continue
|
||||
|
|
|
|||
|
|
@ -196,7 +196,12 @@ class _ToolDocumentRefsMixin(BaseModel):
|
|||
},
|
||||
)
|
||||
],
|
||||
graph_constraints=GraphConstraints(min_incoming=0, max_incoming=0),
|
||||
graph_constraints=GraphConstraints(
|
||||
min_incoming=0,
|
||||
max_incoming=0,
|
||||
min_instances=1,
|
||||
max_instances=1,
|
||||
),
|
||||
property_order=(
|
||||
"name",
|
||||
"greeting_type",
|
||||
|
|
@ -539,6 +544,7 @@ class EndCallNodeData(
|
|||
max_incoming=0,
|
||||
min_outgoing=0,
|
||||
max_outgoing=0,
|
||||
max_instances=1,
|
||||
),
|
||||
property_order=("name", "prompt"),
|
||||
field_overrides={
|
||||
|
|
@ -597,7 +603,11 @@ class GlobalNodeData(BaseNodeData, _PromptedNodeDataMixin):
|
|||
examples=[
|
||||
NodeExample(name="default", data={"name": "Inbound Trigger", "enabled": True})
|
||||
],
|
||||
graph_constraints=GraphConstraints(min_incoming=0, max_incoming=0),
|
||||
graph_constraints=GraphConstraints(
|
||||
min_incoming=0,
|
||||
max_incoming=0,
|
||||
max_instances=1,
|
||||
),
|
||||
property_order=("name", "enabled", "trigger_path"),
|
||||
field_overrides={
|
||||
"name": {
|
||||
|
|
|
|||
|
|
@ -243,6 +243,8 @@ class GraphConstraints(BaseModel):
|
|||
max_incoming: Optional[int] = None
|
||||
min_outgoing: Optional[int] = None
|
||||
max_outgoing: Optional[int] = None
|
||||
min_instances: Optional[int] = None
|
||||
max_instances: Optional[int] = None
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
|
|
|
|||
|
|
@ -458,7 +458,8 @@ async def execute_text_chat_pending_turn(
|
|||
)
|
||||
|
||||
workflow_graph = WorkflowGraph(
|
||||
ReactFlowDTO.model_validate(run_definition.workflow_json)
|
||||
ReactFlowDTO.model_validate(run_definition.workflow_json),
|
||||
skip_instance_constraints_for={"trigger"},
|
||||
)
|
||||
base_checkpoint = _resolve_checkpoint_for_pending_turn(session_data, checkpoint)
|
||||
|
||||
|
|
|
|||
|
|
@ -201,7 +201,7 @@ async def execute_pending_text_chat_turn(
|
|||
error_message=str(e),
|
||||
)
|
||||
raise TextChatSessionExecutionError(
|
||||
"Failed to execute text chat assistant turn"
|
||||
f"Failed to execute text chat assistant turn: {e}"
|
||||
) from e
|
||||
|
||||
completed_session_data = normalize_text_chat_session_data(text_session.session_data)
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ from typing import Dict, List, Set
|
|||
from api.services.workflow.dto import EdgeDataDTO, NodeType, ReactFlowDTO
|
||||
from api.services.workflow.errors import ItemKind, WorkflowError
|
||||
from api.services.workflow.node_data import BaseNodeData
|
||||
from api.services.workflow.node_specs import get_spec
|
||||
from api.services.workflow.node_specs import all_specs, get_spec
|
||||
|
||||
# Regex for matching {{ variable }} template placeholders.
|
||||
# Captures: group(1) = variable path, group(2) = filter name, group(3) = filter value.
|
||||
|
|
@ -68,10 +68,11 @@ class Node:
|
|||
self.out: Dict[str, "Node"] = {} # forward nodes
|
||||
self.out_edges: List[Edge] = [] # forward edges with properties
|
||||
|
||||
# name/is_start/is_end live on every per-type data class (base).
|
||||
# Start/end semantics are defined by node type. The persisted
|
||||
# data flags are legacy UI/runtime state and may be stale.
|
||||
self.name = data.name
|
||||
self.is_start = data.is_start
|
||||
self.is_end = data.is_end
|
||||
self.is_start = node_type == NodeType.startNode.value
|
||||
self.is_end = node_type == NodeType.endNode.value
|
||||
|
||||
# Type-specific fields — read with getattr so this works for every
|
||||
# node variant in the discriminated union.
|
||||
|
|
@ -98,13 +99,89 @@ class Node:
|
|||
self.data = data
|
||||
|
||||
|
||||
def _instance_constraint_message(
|
||||
label: str,
|
||||
count: int,
|
||||
*,
|
||||
min_count: int | None = None,
|
||||
max_count: int | None = None,
|
||||
) -> str:
|
||||
if max_count is not None and count > max_count:
|
||||
if max_count == 1:
|
||||
return f"Workflow can have at most one {label}"
|
||||
return f"Workflow can have at most {max_count} {label} nodes"
|
||||
if min_count is not None and count < min_count:
|
||||
if min_count == 1:
|
||||
return f"Workflow must have at least one {label}"
|
||||
return f"Workflow must have at least {min_count} {label} nodes"
|
||||
return ""
|
||||
|
||||
|
||||
def validate_node_instance_constraints(
|
||||
node_types: list[str],
|
||||
*,
|
||||
enforce_min_instances: bool = True,
|
||||
skip_types: Set[str] | None = None,
|
||||
) -> list[WorkflowError]:
|
||||
"""Validate workflow-level node type counts from NodeSpec.graph_constraints."""
|
||||
errors: list[WorkflowError] = []
|
||||
skip_types = skip_types or set()
|
||||
counts = Counter(node_types)
|
||||
|
||||
for spec in all_specs():
|
||||
if spec.name in skip_types:
|
||||
continue
|
||||
gc = spec.graph_constraints
|
||||
if gc is None:
|
||||
continue
|
||||
|
||||
count = counts.get(spec.name, 0)
|
||||
if gc.max_instances is not None and count > gc.max_instances:
|
||||
errors.append(
|
||||
WorkflowError(
|
||||
kind=ItemKind.workflow,
|
||||
id=None,
|
||||
field=None,
|
||||
message=_instance_constraint_message(
|
||||
spec.display_name,
|
||||
count,
|
||||
max_count=gc.max_instances,
|
||||
),
|
||||
)
|
||||
)
|
||||
if (
|
||||
enforce_min_instances
|
||||
and gc.min_instances is not None
|
||||
and count < gc.min_instances
|
||||
):
|
||||
errors.append(
|
||||
WorkflowError(
|
||||
kind=ItemKind.workflow,
|
||||
id=None,
|
||||
field=None,
|
||||
message=_instance_constraint_message(
|
||||
spec.display_name,
|
||||
count,
|
||||
min_count=gc.min_instances,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
return errors
|
||||
|
||||
|
||||
class WorkflowGraph:
|
||||
"""
|
||||
*All* business invariants (acyclic, cardinality, etc.) are verified here.
|
||||
The constructor accepts a validated ReactFlowDTO.
|
||||
"""
|
||||
|
||||
def __init__(self, dto: ReactFlowDTO):
|
||||
def __init__(
|
||||
self,
|
||||
dto: ReactFlowDTO,
|
||||
*,
|
||||
skip_instance_constraints_for: Set[str] | None = None,
|
||||
):
|
||||
# Build adjacency list from validated DTO nodes. Core node comparisons
|
||||
# still use NodeType string enums; integration nodes remain plain
|
||||
# strings and resolve constraints through node specs.
|
||||
|
|
@ -131,10 +208,12 @@ class WorkflowGraph:
|
|||
# Set up the node references for backward compatibility
|
||||
source_node.out[target_node.id] = target_node
|
||||
|
||||
self._validate_graph()
|
||||
self._validate_graph(skip_instance_constraints_for or set())
|
||||
|
||||
# Get a reference to the start node
|
||||
self.start_node_id = [n.id for n in dto.nodes if n.data.is_start][0]
|
||||
self.start_node_id = [
|
||||
n.id for n in dto.nodes if n.type == NodeType.startNode.value
|
||||
][0]
|
||||
|
||||
# Get a reference to the global node
|
||||
try:
|
||||
|
|
@ -185,7 +264,7 @@ class WorkflowGraph:
|
|||
# -----------------------------------------------------------
|
||||
# validators
|
||||
# -----------------------------------------------------------
|
||||
def _validate_graph(self) -> None:
|
||||
def _validate_graph(self, skip_instance_constraints_for: Set[str]) -> None:
|
||||
errors: list[WorkflowError] = []
|
||||
|
||||
# TODO: Figure out what kind of cyclic contraints can be applied, since there can be a cycle in the graph
|
||||
|
|
@ -198,9 +277,13 @@ class WorkflowGraph:
|
|||
# )
|
||||
# )
|
||||
|
||||
errors.extend(self._assert_start_node())
|
||||
errors.extend(
|
||||
validate_node_instance_constraints(
|
||||
[n.node_type for n in self.nodes.values()],
|
||||
skip_types=skip_instance_constraints_for,
|
||||
)
|
||||
)
|
||||
errors.extend(self._assert_connection_counts())
|
||||
errors.extend(self._assert_global_node())
|
||||
errors.extend(self._assert_node_configs())
|
||||
if errors:
|
||||
raise ValueError(errors)
|
||||
|
|
@ -220,48 +303,6 @@ class WorkflowGraph:
|
|||
for n in self.nodes.values():
|
||||
dfs(n)
|
||||
|
||||
def _assert_start_node(self):
|
||||
errors: list[WorkflowError] = []
|
||||
start_nodes = [n for n in self.nodes.values() if n.data.is_start]
|
||||
if not start_nodes:
|
||||
errors.append(
|
||||
WorkflowError(
|
||||
kind=ItemKind.workflow,
|
||||
id=None,
|
||||
field=None,
|
||||
message="Workflow has no start node — exactly one is required",
|
||||
)
|
||||
)
|
||||
elif len(start_nodes) > 1:
|
||||
errors.append(
|
||||
WorkflowError(
|
||||
kind=ItemKind.workflow,
|
||||
id=None,
|
||||
field=None,
|
||||
message=(
|
||||
f"Workflow has {len(start_nodes)} start nodes — "
|
||||
f"exactly one is required"
|
||||
),
|
||||
)
|
||||
)
|
||||
return errors
|
||||
|
||||
def _assert_global_node(self):
|
||||
errors: list[WorkflowError] = []
|
||||
global_node = [
|
||||
n for n in self.nodes.values() if n.node_type == NodeType.globalNode.value
|
||||
]
|
||||
if not len(global_node) <= 1:
|
||||
errors.append(
|
||||
WorkflowError(
|
||||
kind=ItemKind.workflow,
|
||||
id=None,
|
||||
field=None,
|
||||
message="Workflow must have at most one global node",
|
||||
)
|
||||
)
|
||||
return errors
|
||||
|
||||
def _assert_connection_counts(self):
|
||||
"""Enforce per-type incoming/outgoing edge constraints.
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue