mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-07 07:55:16 +02:00
fix: fix projection to TS when fetching agnet in MCP
This commit is contained in:
parent
3892b58486
commit
bbb4f91a27
12 changed files with 392 additions and 63 deletions
55
api/mcp_server/tools/_workflow_projection.py
Normal file
55
api/mcp_server/tools/_workflow_projection.py
Normal file
|
|
@ -0,0 +1,55 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Literal
|
||||
|
||||
from api.db import db_client
|
||||
from api.mcp_server.ts_bridge import generate_code
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class WorkflowProjectionSource:
|
||||
payload: dict[str, Any] | None
|
||||
version: Literal["draft", "published", "legacy"]
|
||||
version_number: int | None
|
||||
|
||||
|
||||
async def select_workflow_projection_source(workflow: Any) -> WorkflowProjectionSource:
|
||||
"""Choose the same working copy across read and save MCP tools.
|
||||
|
||||
Draft wins over published because that's what a human editor would
|
||||
be mutating. Legacy `workflow_definition` is the final fallback for
|
||||
older rows that predate versioned definitions.
|
||||
"""
|
||||
draft = await db_client.get_draft_version(workflow.id)
|
||||
if draft is not None and draft.workflow_json:
|
||||
return WorkflowProjectionSource(
|
||||
payload=draft.workflow_json,
|
||||
version="draft",
|
||||
version_number=draft.version_number,
|
||||
)
|
||||
|
||||
released = workflow.released_definition
|
||||
if released is not None and released.workflow_json:
|
||||
return WorkflowProjectionSource(
|
||||
payload=released.workflow_json,
|
||||
version="published",
|
||||
version_number=released.version_number,
|
||||
)
|
||||
|
||||
return WorkflowProjectionSource(
|
||||
payload=workflow.workflow_definition or None,
|
||||
version="legacy",
|
||||
version_number=None,
|
||||
)
|
||||
|
||||
|
||||
async def project_workflow_to_sdk_view(workflow: Any) -> dict[str, Any]:
|
||||
source = await select_workflow_projection_source(workflow)
|
||||
code = await generate_code(source.payload or {}, workflow_name=workflow.name or "")
|
||||
return {
|
||||
"name": workflow.name or "",
|
||||
"version": source.version,
|
||||
"version_number": source.version_number,
|
||||
"code": code,
|
||||
}
|
||||
|
|
@ -18,8 +18,9 @@ from fastapi import HTTPException
|
|||
|
||||
from api.db import db_client
|
||||
from api.mcp_server.auth import authenticate_mcp_request
|
||||
from api.mcp_server.tools._workflow_projection import project_workflow_to_sdk_view
|
||||
from api.mcp_server.tracing import traced_tool
|
||||
from api.mcp_server.ts_bridge import TsBridgeError, generate_code
|
||||
from api.mcp_server.ts_bridge import TsBridgeError
|
||||
|
||||
|
||||
@traced_tool
|
||||
|
|
@ -39,31 +40,14 @@ async def get_workflow_code(workflow_id: int) -> dict[str, Any]:
|
|||
if not workflow:
|
||||
raise HTTPException(status_code=404, detail=f"Workflow {workflow_id} not found")
|
||||
|
||||
# Draft wins over published — editing a draft is the normal flow.
|
||||
# `current_definition` (is_current=True) is the published row, so we
|
||||
# fetch the draft explicitly. If the latest draft was just published,
|
||||
# no draft row exists and we fall through to `released_definition`.
|
||||
draft = await db_client.get_draft_version(workflow_id)
|
||||
released = workflow.released_definition
|
||||
|
||||
if draft is not None and draft.workflow_json:
|
||||
payload = draft.workflow_json
|
||||
source = "draft"
|
||||
elif released is not None and released.workflow_json:
|
||||
payload = released.workflow_json
|
||||
source = "published"
|
||||
else:
|
||||
payload = workflow.workflow_definition or {}
|
||||
source = "legacy"
|
||||
|
||||
try:
|
||||
code = await generate_code(payload, workflow_name=workflow.name or "")
|
||||
view = await project_workflow_to_sdk_view(workflow)
|
||||
except TsBridgeError as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to generate code: {e}")
|
||||
|
||||
return {
|
||||
"workflow_id": workflow_id,
|
||||
"name": workflow.name or "",
|
||||
"version": source,
|
||||
"code": code,
|
||||
"name": view["name"],
|
||||
"version": view["version"],
|
||||
"code": view["code"],
|
||||
}
|
||||
|
|
|
|||
|
|
@ -28,6 +28,9 @@ from pydantic import ValidationError as PydanticValidationError
|
|||
|
||||
from api.db import db_client
|
||||
from api.mcp_server.auth import authenticate_mcp_request
|
||||
from api.mcp_server.tools._workflow_projection import (
|
||||
select_workflow_projection_source,
|
||||
)
|
||||
from api.mcp_server.tracing import traced_tool
|
||||
from api.mcp_server.ts_bridge import TsBridgeError, parse_code
|
||||
from api.services.workflow.dto import ReactFlowDTO
|
||||
|
|
@ -37,20 +40,9 @@ from api.services.workflow.workflow_graph import WorkflowGraph
|
|||
|
||||
|
||||
async def _previous_workflow_json(workflow: Any) -> dict[str, Any] | None:
|
||||
"""Same selection priority as `get_workflow_code` — the version the
|
||||
LLM saw is the version we reconcile against.
|
||||
|
||||
`current_definition` (is_current=True) is the published row, so the
|
||||
draft must be fetched explicitly. If no draft exists (e.g. the last
|
||||
draft was just published), fall through to `released_definition`.
|
||||
"""
|
||||
draft = await db_client.get_draft_version(workflow.id)
|
||||
if draft is not None and draft.workflow_json:
|
||||
return draft.workflow_json
|
||||
released = workflow.released_definition
|
||||
if released is not None and released.workflow_json:
|
||||
return released.workflow_json
|
||||
return workflow.workflow_definition or None
|
||||
"""Match the agent-facing read tools' source selection."""
|
||||
source = await select_workflow_projection_source(workflow)
|
||||
return source.payload
|
||||
|
||||
|
||||
def _error_result(code: str, message: str, **extra: Any) -> dict[str, Any]:
|
||||
|
|
|
|||
|
|
@ -2,7 +2,9 @@ from fastapi import HTTPException
|
|||
|
||||
from api.db import db_client
|
||||
from api.mcp_server.auth import authenticate_mcp_request
|
||||
from api.mcp_server.tools._workflow_projection import project_workflow_to_sdk_view
|
||||
from api.mcp_server.tracing import traced_tool
|
||||
from api.mcp_server.ts_bridge import TsBridgeError
|
||||
|
||||
|
||||
@traced_tool
|
||||
|
|
@ -10,9 +12,9 @@ async def list_workflows(status: str | None = "active") -> list[dict]:
|
|||
"""List agents (workflows) in the caller's organization.
|
||||
|
||||
Returns id, name, status, and created_at for each agent. Use
|
||||
`get_workflow` to fetch a single agent's full definition. Defaults
|
||||
to active agents; pass `status="archived"` to list archived agents,
|
||||
or `status=None` to list all.
|
||||
`get_workflow` to fetch a single agent's current SDK view and
|
||||
metadata. Defaults to active agents; pass `status="archived"` to
|
||||
list archived agents, or `status=None` to list all.
|
||||
"""
|
||||
user = await authenticate_mcp_request()
|
||||
workflows = await db_client.get_all_workflows_for_listing(
|
||||
|
|
@ -32,7 +34,11 @@ async def list_workflows(status: str | None = "active") -> list[dict]:
|
|||
|
||||
@traced_tool
|
||||
async def get_workflow(workflow_id: int) -> dict:
|
||||
"""Fetch a single agent by id, including its current published definition."""
|
||||
"""Fetch a single agent by id, projected into the SDK code view.
|
||||
|
||||
Output shape:
|
||||
{"id": int, "name": str, "status": str, "version": "draft" | "published" | "legacy", "version_number": int | None, "code": "<TS source>"}
|
||||
"""
|
||||
user = await authenticate_mcp_request()
|
||||
workflow = await db_client.get_workflow(
|
||||
workflow_id, organization_id=user.selected_organization_id
|
||||
|
|
@ -40,11 +46,16 @@ async def get_workflow(workflow_id: int) -> dict:
|
|||
if not workflow:
|
||||
raise HTTPException(status_code=404, detail=f"Workflow {workflow_id} not found")
|
||||
|
||||
current = workflow.current_definition
|
||||
try:
|
||||
view = await project_workflow_to_sdk_view(workflow)
|
||||
except TsBridgeError as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to generate code: {e}")
|
||||
|
||||
return {
|
||||
"id": workflow.id,
|
||||
"name": workflow.name,
|
||||
"name": view["name"],
|
||||
"status": workflow.status,
|
||||
"definition": current.workflow_json if current else None,
|
||||
"version_number": current.version_number if current else None,
|
||||
"version": view["version"],
|
||||
"version_number": view["version_number"],
|
||||
"code": view["code"],
|
||||
}
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ import json
|
|||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from api.services.workflow.dto import EdgeDataDTO
|
||||
from api.services.workflow.node_specs import all_specs
|
||||
|
||||
_VALIDATOR_ENTRY = Path(__file__).resolve().parent / "ts_validator" / "src" / "index.ts"
|
||||
|
|
@ -31,6 +32,10 @@ def _specs_payload() -> list[dict[str, Any]]:
|
|||
return [s.model_dump(mode="json") for s in all_specs()]
|
||||
|
||||
|
||||
def _edge_field_names() -> list[str]:
|
||||
return list(EdgeDataDTO.model_fields.keys())
|
||||
|
||||
|
||||
async def _invoke(request: dict[str, Any]) -> dict[str, Any]:
|
||||
proc = await asyncio.create_subprocess_exec(
|
||||
"node",
|
||||
|
|
@ -65,6 +70,7 @@ async def generate_code(workflow: dict[str, Any], *, workflow_name: str = "") ->
|
|||
"command": "generate",
|
||||
"workflow": workflow,
|
||||
"specs": _specs_payload(),
|
||||
"edgeFieldNames": _edge_field_names(),
|
||||
"workflowName": workflow_name,
|
||||
}
|
||||
)
|
||||
|
|
@ -89,5 +95,6 @@ async def parse_code(code: str) -> dict[str, Any]:
|
|||
"command": "parse",
|
||||
"code": code,
|
||||
"specs": _specs_payload(),
|
||||
"edgeFieldNames": _edge_field_names(),
|
||||
}
|
||||
)
|
||||
|
|
|
|||
|
|
@ -14,9 +14,18 @@ import type {
|
|||
export function generateCode(
|
||||
workflow: WireWorkflow,
|
||||
specs: NodeSpec[],
|
||||
opts: { workflowName?: string } = {},
|
||||
opts: { workflowName?: string; edgeFieldNames?: string[] } = {},
|
||||
): GenerateResult {
|
||||
const specByName = new Map(specs.map((s) => [s.name, s]));
|
||||
const edgeFieldNames = new Set(
|
||||
opts.edgeFieldNames ?? [
|
||||
"label",
|
||||
"condition",
|
||||
"transition_speech",
|
||||
"transition_speech_type",
|
||||
"transition_speech_recording_id",
|
||||
],
|
||||
);
|
||||
|
||||
// Catch unknown node types up-front — otherwise we'd emit an import
|
||||
// line for a factory that doesn't exist.
|
||||
|
|
@ -97,7 +106,7 @@ export function generateCode(
|
|||
],
|
||||
};
|
||||
}
|
||||
const cleanedEdge = pickEdgeFields(edge.data);
|
||||
const cleanedEdge = pickEdgeFields(edge.data, edgeFieldNames);
|
||||
const edgeOpts = renderObject(cleanedEdge, 0);
|
||||
lines.push(`wf.edge(${src}, ${tgt}, ${edgeOpts});`);
|
||||
}
|
||||
|
|
@ -210,22 +219,13 @@ function stripUnknown(
|
|||
return out;
|
||||
}
|
||||
|
||||
// Edge schema is fixed (no NodeSpec for edges). Mirrors the allowed
|
||||
// fields on `Workflow.edge(...)` in both SDKs.
|
||||
const KNOWN_EDGE_FIELDS = new Set([
|
||||
"label",
|
||||
"condition",
|
||||
"transition_speech",
|
||||
"transition_speech_type",
|
||||
"transition_speech_recording_id",
|
||||
]);
|
||||
|
||||
function pickEdgeFields(
|
||||
data: Record<string, unknown>,
|
||||
knownEdgeFields: Set<string>,
|
||||
): Record<string, unknown> {
|
||||
const out: Record<string, unknown> = {};
|
||||
for (const [k, v] of Object.entries(data)) {
|
||||
if (KNOWN_EDGE_FIELDS.has(k)) out[k] = v;
|
||||
if (knownEdgeFields.has(k)) out[k] = v;
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ interface GenerateRequest {
|
|||
command: "generate";
|
||||
workflow: WireWorkflow;
|
||||
specs: NodeSpec[];
|
||||
edgeFieldNames: string[];
|
||||
workflowName?: string;
|
||||
}
|
||||
|
||||
|
|
@ -18,6 +19,7 @@ interface ParseRequest {
|
|||
command: "parse";
|
||||
code: string;
|
||||
specs: NodeSpec[];
|
||||
edgeFieldNames: string[];
|
||||
}
|
||||
|
||||
type Request = GenerateRequest | ParseRequest;
|
||||
|
|
@ -49,11 +51,16 @@ async function main(): Promise<void> {
|
|||
}
|
||||
|
||||
if (req.command === "generate") {
|
||||
writeResult(generateCode(req.workflow, req.specs, { workflowName: req.workflowName }));
|
||||
writeResult(
|
||||
generateCode(req.workflow, req.specs, {
|
||||
workflowName: req.workflowName,
|
||||
edgeFieldNames: req.edgeFieldNames,
|
||||
}),
|
||||
);
|
||||
return;
|
||||
}
|
||||
if (req.command === "parse") {
|
||||
writeResult(parseCode(req.code, req.specs));
|
||||
writeResult(parseCode(req.code, req.specs, req.edgeFieldNames));
|
||||
return;
|
||||
}
|
||||
writeResult({
|
||||
|
|
|
|||
|
|
@ -25,8 +25,19 @@ import type {
|
|||
WireNode,
|
||||
} from "./types.ts";
|
||||
|
||||
export function parseCode(code: string, specs: NodeSpec[]): ParseResult {
|
||||
export function parseCode(
|
||||
code: string,
|
||||
specs: NodeSpec[],
|
||||
edgeFieldNames: string[] = [
|
||||
"label",
|
||||
"condition",
|
||||
"transition_speech",
|
||||
"transition_speech_type",
|
||||
"transition_speech_recording_id",
|
||||
],
|
||||
): ParseResult {
|
||||
const specByName = new Map(specs.map((s) => [s.name, s]));
|
||||
const allowedEdgeFieldNames = new Set(edgeFieldNames);
|
||||
const sourceFile = ts.createSourceFile(
|
||||
"workflow.ts",
|
||||
code,
|
||||
|
|
@ -335,6 +346,12 @@ export function parseCode(code: string, specs: NodeSpec[]): ParseResult {
|
|||
addError(stmt, "`edge` requires a non-empty `condition` string.");
|
||||
return;
|
||||
}
|
||||
for (const key of Object.keys(optsObj)) {
|
||||
if (!allowedEdgeFieldNames.has(key)) {
|
||||
addError(stmt, `Unknown edge field: \`${key}\`.`);
|
||||
return;
|
||||
}
|
||||
}
|
||||
edges.push({
|
||||
id: `${src.id}-${tgt.id}`,
|
||||
source: src.id,
|
||||
|
|
|
|||
99
api/tests/test_mcp_get_workflow.py
Normal file
99
api/tests/test_mcp_get_workflow.py
Normal file
|
|
@ -0,0 +1,99 @@
|
|||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from api.mcp_server.tools.workflows import get_workflow
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def authed_user() -> MagicMock:
|
||||
user = MagicMock()
|
||||
user.selected_organization_id = 1
|
||||
return user
|
||||
|
||||
|
||||
def _workflow() -> SimpleNamespace:
|
||||
return SimpleNamespace(
|
||||
id=7,
|
||||
name="Support Agent",
|
||||
status="active",
|
||||
released_definition=SimpleNamespace(
|
||||
workflow_json={"nodes": [{"id": "published"}], "edges": []},
|
||||
version_number=3,
|
||||
),
|
||||
workflow_definition={"nodes": [{"id": "legacy"}], "edges": []},
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_workflow_returns_draft_sdk_view(authed_user: MagicMock):
|
||||
workflow = _workflow()
|
||||
draft = SimpleNamespace(
|
||||
workflow_json={"nodes": [{"id": "draft"}], "edges": []},
|
||||
version_number=4,
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"api.mcp_server.tools.workflows.authenticate_mcp_request",
|
||||
AsyncMock(return_value=authed_user),
|
||||
),
|
||||
patch(
|
||||
"api.mcp_server.tools.workflows.db_client.get_workflow",
|
||||
AsyncMock(return_value=workflow),
|
||||
),
|
||||
patch(
|
||||
"api.mcp_server.tools._workflow_projection.db_client.get_draft_version",
|
||||
AsyncMock(return_value=draft),
|
||||
),
|
||||
patch(
|
||||
"api.mcp_server.tools._workflow_projection.generate_code",
|
||||
AsyncMock(
|
||||
return_value='const wf = new Workflow({ name: "Support Agent" });'
|
||||
),
|
||||
) as generate_code_mock,
|
||||
):
|
||||
result = await get_workflow(workflow_id=workflow.id)
|
||||
|
||||
assert result == {
|
||||
"id": 7,
|
||||
"name": "Support Agent",
|
||||
"status": "active",
|
||||
"version": "draft",
|
||||
"version_number": 4,
|
||||
"code": 'const wf = new Workflow({ name: "Support Agent" });',
|
||||
}
|
||||
generate_code_mock.assert_awaited_once_with(
|
||||
draft.workflow_json, workflow_name="Support Agent"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_workflow_falls_back_to_published_sdk_view(authed_user: MagicMock):
|
||||
workflow = _workflow()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"api.mcp_server.tools.workflows.authenticate_mcp_request",
|
||||
AsyncMock(return_value=authed_user),
|
||||
),
|
||||
patch(
|
||||
"api.mcp_server.tools.workflows.db_client.get_workflow",
|
||||
AsyncMock(return_value=workflow),
|
||||
),
|
||||
patch(
|
||||
"api.mcp_server.tools._workflow_projection.db_client.get_draft_version",
|
||||
AsyncMock(return_value=None),
|
||||
),
|
||||
patch(
|
||||
"api.mcp_server.tools._workflow_projection.generate_code",
|
||||
AsyncMock(
|
||||
return_value='const wf = new Workflow({ name: "Support Agent" });'
|
||||
),
|
||||
),
|
||||
):
|
||||
result = await get_workflow(workflow_id=workflow.id)
|
||||
|
||||
assert result["version"] == "published"
|
||||
assert result["version_number"] == 3
|
||||
|
|
@ -7,10 +7,19 @@ that code → JSON and JSON → code round-trip losslessly.
|
|||
from __future__ import annotations
|
||||
|
||||
import shutil
|
||||
from types import NoneType
|
||||
from typing import Any, get_args
|
||||
|
||||
import pytest
|
||||
|
||||
from api.mcp_server.ts_bridge import TsBridgeError, generate_code, parse_code
|
||||
from api.services.workflow.dto import EdgeDataDTO
|
||||
from api.services.workflow.node_specs import (
|
||||
NodeSpec,
|
||||
PropertySpec,
|
||||
PropertyType,
|
||||
all_specs,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.skipif(
|
||||
shutil.which("node") is None, reason="node binary not available"
|
||||
|
|
@ -81,6 +90,102 @@ def _normalize(wf: dict) -> dict:
|
|||
}
|
||||
|
||||
|
||||
def _strip_optional(annotation: Any) -> Any:
|
||||
args = tuple(arg for arg in get_args(annotation) if arg is not NoneType)
|
||||
if len(args) == 1:
|
||||
return args[0]
|
||||
return annotation
|
||||
|
||||
|
||||
def _pick_option_value(prop: PropertySpec) -> Any:
|
||||
assert prop.options, f"{prop.name} has no options"
|
||||
default = prop.default
|
||||
for option in prop.options:
|
||||
if option.value != default:
|
||||
return option.value
|
||||
return prop.options[0].value
|
||||
|
||||
|
||||
def _sample_number(prop: PropertySpec) -> int | float:
|
||||
candidates: list[int | float] = [1, 2, 3, 0.5, 4.5, 10]
|
||||
for candidate in candidates:
|
||||
if prop.min_value is not None and candidate < prop.min_value:
|
||||
continue
|
||||
if prop.max_value is not None and candidate > prop.max_value:
|
||||
continue
|
||||
if prop.default is not None and candidate == prop.default:
|
||||
continue
|
||||
return candidate
|
||||
raise AssertionError(f"No valid sample number found for {prop.name}")
|
||||
|
||||
|
||||
def _sample_property_value(prop: PropertySpec, *, path: str) -> Any:
|
||||
slug = path.replace(".", "_")
|
||||
|
||||
if prop.type == PropertyType.string:
|
||||
return f"{slug}_value"
|
||||
if prop.type == PropertyType.mention_textarea:
|
||||
return f"{slug} prompt with {{name}}"
|
||||
if prop.type == PropertyType.url:
|
||||
return f"https://example.com/{slug}"
|
||||
if prop.type == PropertyType.recording_ref:
|
||||
return f"recording_{slug}"
|
||||
if prop.type == PropertyType.credential_ref:
|
||||
return f"credential_{slug}"
|
||||
if prop.type == PropertyType.number:
|
||||
return _sample_number(prop)
|
||||
if prop.type == PropertyType.boolean:
|
||||
return not prop.default if isinstance(prop.default, bool) else True
|
||||
if prop.type == PropertyType.options:
|
||||
return _pick_option_value(prop)
|
||||
if prop.type == PropertyType.multi_options:
|
||||
return [_pick_option_value(prop)]
|
||||
if prop.type == PropertyType.tool_refs:
|
||||
return [f"tool_{slug}"]
|
||||
if prop.type == PropertyType.document_refs:
|
||||
return [f"document_{slug}"]
|
||||
if prop.type == PropertyType.json:
|
||||
return {"kind": slug, "enabled": True}
|
||||
if prop.type == PropertyType.fixed_collection:
|
||||
assert prop.properties, f"{prop.name} fixed_collection has no sub-properties"
|
||||
return [
|
||||
{
|
||||
sub_prop.name: _sample_property_value(
|
||||
sub_prop, path=f"{path}.{sub_prop.name}"
|
||||
)
|
||||
for sub_prop in prop.properties
|
||||
}
|
||||
]
|
||||
raise AssertionError(f"Unhandled PropertyType in TS bridge test: {prop.type}")
|
||||
|
||||
|
||||
def _sample_node_data(spec: NodeSpec) -> dict[str, Any]:
|
||||
return {
|
||||
prop.name: _sample_property_value(prop, path=f"{spec.name}.{prop.name}")
|
||||
for prop in spec.properties
|
||||
}
|
||||
|
||||
|
||||
def _sample_edge_value(field_name: str, annotation: Any) -> Any:
|
||||
inner = _strip_optional(annotation)
|
||||
if inner is str:
|
||||
return f"{field_name}_value"
|
||||
if inner is bool:
|
||||
return True
|
||||
if inner in (int, float):
|
||||
return 1
|
||||
raise AssertionError(
|
||||
f"Unhandled edge field annotation in TS bridge test: {field_name} -> {annotation!r}"
|
||||
)
|
||||
|
||||
|
||||
def _sample_edge_data() -> dict[str, Any]:
|
||||
return {
|
||||
field_name: _sample_edge_value(field_name, field.annotation)
|
||||
for field_name, field in EdgeDataDTO.model_fields.items()
|
||||
}
|
||||
|
||||
|
||||
# ─── generate_code ───────────────────────────────────────────────────────
|
||||
|
||||
|
||||
|
|
@ -154,6 +259,19 @@ async def test_generate_strips_unknown_edge_fields():
|
|||
assert "validationMessage" not in code
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_preserves_all_edge_dto_fields():
|
||||
wf = _minimal_workflow()
|
||||
edge_data = _sample_edge_data()
|
||||
wf["edges"][0]["data"] = edge_data
|
||||
|
||||
code = await generate_code(wf)
|
||||
result = await parse_code(code)
|
||||
|
||||
assert result["ok"] is True, result
|
||||
assert result["workflow"]["edges"][0]["data"] == edge_data
|
||||
|
||||
|
||||
# ─── parse_code ──────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
|
|
@ -229,6 +347,21 @@ wf.edge(a, b, { label: "", condition: "c" });
|
|||
assert result["stage"] == "parse"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_parse_rejects_unknown_edge_field():
|
||||
code = """import { Workflow } from "@dograh/sdk";
|
||||
import { startCall, endCall } from "@dograh/sdk/typed";
|
||||
const wf = new Workflow({ name: "x" });
|
||||
const a = wf.addTyped(startCall({ name: "g", prompt: "hi" }));
|
||||
const b = wf.addTyped(endCall({ name: "d", prompt: "bye" }));
|
||||
wf.edge(a, b, { label: "done", condition: "wrapped", bogus: "x" });
|
||||
"""
|
||||
result = await parse_code(code)
|
||||
assert result["ok"] is False
|
||||
assert result["stage"] == "parse"
|
||||
assert any("Unknown edge field" in e["message"] for e in result["errors"])
|
||||
|
||||
|
||||
# ─── Round-trip ──────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
|
|
@ -257,6 +390,30 @@ async def test_round_trip_minimal():
|
|||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("spec", all_specs(), ids=lambda spec: spec.name)
|
||||
async def test_round_trip_preserves_all_node_spec_fields(spec: NodeSpec):
|
||||
data = _sample_node_data(spec)
|
||||
wf = {
|
||||
"nodes": [
|
||||
{
|
||||
"id": "1",
|
||||
"type": spec.name,
|
||||
"position": {"x": 0, "y": 0},
|
||||
"data": data,
|
||||
}
|
||||
],
|
||||
"edges": [],
|
||||
"viewport": {"x": 0, "y": 0, "zoom": 1},
|
||||
}
|
||||
|
||||
code = await generate_code(wf, workflow_name=f"{spec.name}_rt")
|
||||
result = await parse_code(code)
|
||||
|
||||
assert result["ok"] is True, result
|
||||
assert result["workflow"]["nodes"][0]["data"] == data
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_fails_on_unknown_type():
|
||||
bad = {
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue