fix: fix projection to TS when fetching agnet in MCP

This commit is contained in:
Abhishek Kumar 2026-05-23 14:45:50 +05:30
parent 3892b58486
commit bbb4f91a27
12 changed files with 392 additions and 63 deletions

View file

@ -0,0 +1,55 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Any, Literal
from api.db import db_client
from api.mcp_server.ts_bridge import generate_code
@dataclass(frozen=True)
class WorkflowProjectionSource:
payload: dict[str, Any] | None
version: Literal["draft", "published", "legacy"]
version_number: int | None
async def select_workflow_projection_source(workflow: Any) -> WorkflowProjectionSource:
"""Choose the same working copy across read and save MCP tools.
Draft wins over published because that's what a human editor would
be mutating. Legacy `workflow_definition` is the final fallback for
older rows that predate versioned definitions.
"""
draft = await db_client.get_draft_version(workflow.id)
if draft is not None and draft.workflow_json:
return WorkflowProjectionSource(
payload=draft.workflow_json,
version="draft",
version_number=draft.version_number,
)
released = workflow.released_definition
if released is not None and released.workflow_json:
return WorkflowProjectionSource(
payload=released.workflow_json,
version="published",
version_number=released.version_number,
)
return WorkflowProjectionSource(
payload=workflow.workflow_definition or None,
version="legacy",
version_number=None,
)
async def project_workflow_to_sdk_view(workflow: Any) -> dict[str, Any]:
source = await select_workflow_projection_source(workflow)
code = await generate_code(source.payload or {}, workflow_name=workflow.name or "")
return {
"name": workflow.name or "",
"version": source.version,
"version_number": source.version_number,
"code": code,
}

View file

@ -18,8 +18,9 @@ from fastapi import HTTPException
from api.db import db_client
from api.mcp_server.auth import authenticate_mcp_request
from api.mcp_server.tools._workflow_projection import project_workflow_to_sdk_view
from api.mcp_server.tracing import traced_tool
from api.mcp_server.ts_bridge import TsBridgeError, generate_code
from api.mcp_server.ts_bridge import TsBridgeError
@traced_tool
@ -39,31 +40,14 @@ async def get_workflow_code(workflow_id: int) -> dict[str, Any]:
if not workflow:
raise HTTPException(status_code=404, detail=f"Workflow {workflow_id} not found")
# Draft wins over published — editing a draft is the normal flow.
# `current_definition` (is_current=True) is the published row, so we
# fetch the draft explicitly. If the latest draft was just published,
# no draft row exists and we fall through to `released_definition`.
draft = await db_client.get_draft_version(workflow_id)
released = workflow.released_definition
if draft is not None and draft.workflow_json:
payload = draft.workflow_json
source = "draft"
elif released is not None and released.workflow_json:
payload = released.workflow_json
source = "published"
else:
payload = workflow.workflow_definition or {}
source = "legacy"
try:
code = await generate_code(payload, workflow_name=workflow.name or "")
view = await project_workflow_to_sdk_view(workflow)
except TsBridgeError as e:
raise HTTPException(status_code=500, detail=f"Failed to generate code: {e}")
return {
"workflow_id": workflow_id,
"name": workflow.name or "",
"version": source,
"code": code,
"name": view["name"],
"version": view["version"],
"code": view["code"],
}

View file

@ -28,6 +28,9 @@ from pydantic import ValidationError as PydanticValidationError
from api.db import db_client
from api.mcp_server.auth import authenticate_mcp_request
from api.mcp_server.tools._workflow_projection import (
select_workflow_projection_source,
)
from api.mcp_server.tracing import traced_tool
from api.mcp_server.ts_bridge import TsBridgeError, parse_code
from api.services.workflow.dto import ReactFlowDTO
@ -37,20 +40,9 @@ from api.services.workflow.workflow_graph import WorkflowGraph
async def _previous_workflow_json(workflow: Any) -> dict[str, Any] | None:
"""Same selection priority as `get_workflow_code` — the version the
LLM saw is the version we reconcile against.
`current_definition` (is_current=True) is the published row, so the
draft must be fetched explicitly. If no draft exists (e.g. the last
draft was just published), fall through to `released_definition`.
"""
draft = await db_client.get_draft_version(workflow.id)
if draft is not None and draft.workflow_json:
return draft.workflow_json
released = workflow.released_definition
if released is not None and released.workflow_json:
return released.workflow_json
return workflow.workflow_definition or None
"""Match the agent-facing read tools' source selection."""
source = await select_workflow_projection_source(workflow)
return source.payload
def _error_result(code: str, message: str, **extra: Any) -> dict[str, Any]:

View file

@ -2,7 +2,9 @@ from fastapi import HTTPException
from api.db import db_client
from api.mcp_server.auth import authenticate_mcp_request
from api.mcp_server.tools._workflow_projection import project_workflow_to_sdk_view
from api.mcp_server.tracing import traced_tool
from api.mcp_server.ts_bridge import TsBridgeError
@traced_tool
@ -10,9 +12,9 @@ async def list_workflows(status: str | None = "active") -> list[dict]:
"""List agents (workflows) in the caller's organization.
Returns id, name, status, and created_at for each agent. Use
`get_workflow` to fetch a single agent's full definition. Defaults
to active agents; pass `status="archived"` to list archived agents,
or `status=None` to list all.
`get_workflow` to fetch a single agent's current SDK view and
metadata. Defaults to active agents; pass `status="archived"` to
list archived agents, or `status=None` to list all.
"""
user = await authenticate_mcp_request()
workflows = await db_client.get_all_workflows_for_listing(
@ -32,7 +34,11 @@ async def list_workflows(status: str | None = "active") -> list[dict]:
@traced_tool
async def get_workflow(workflow_id: int) -> dict:
"""Fetch a single agent by id, including its current published definition."""
"""Fetch a single agent by id, projected into the SDK code view.
Output shape:
{"id": int, "name": str, "status": str, "version": "draft" | "published" | "legacy", "version_number": int | None, "code": "<TS source>"}
"""
user = await authenticate_mcp_request()
workflow = await db_client.get_workflow(
workflow_id, organization_id=user.selected_organization_id
@ -40,11 +46,16 @@ async def get_workflow(workflow_id: int) -> dict:
if not workflow:
raise HTTPException(status_code=404, detail=f"Workflow {workflow_id} not found")
current = workflow.current_definition
try:
view = await project_workflow_to_sdk_view(workflow)
except TsBridgeError as e:
raise HTTPException(status_code=500, detail=f"Failed to generate code: {e}")
return {
"id": workflow.id,
"name": workflow.name,
"name": view["name"],
"status": workflow.status,
"definition": current.workflow_json if current else None,
"version_number": current.version_number if current else None,
"version": view["version"],
"version_number": view["version_number"],
"code": view["code"],
}

View file

@ -18,6 +18,7 @@ import json
from pathlib import Path
from typing import Any
from api.services.workflow.dto import EdgeDataDTO
from api.services.workflow.node_specs import all_specs
_VALIDATOR_ENTRY = Path(__file__).resolve().parent / "ts_validator" / "src" / "index.ts"
@ -31,6 +32,10 @@ def _specs_payload() -> list[dict[str, Any]]:
return [s.model_dump(mode="json") for s in all_specs()]
def _edge_field_names() -> list[str]:
return list(EdgeDataDTO.model_fields.keys())
async def _invoke(request: dict[str, Any]) -> dict[str, Any]:
proc = await asyncio.create_subprocess_exec(
"node",
@ -65,6 +70,7 @@ async def generate_code(workflow: dict[str, Any], *, workflow_name: str = "") ->
"command": "generate",
"workflow": workflow,
"specs": _specs_payload(),
"edgeFieldNames": _edge_field_names(),
"workflowName": workflow_name,
}
)
@ -89,5 +95,6 @@ async def parse_code(code: str) -> dict[str, Any]:
"command": "parse",
"code": code,
"specs": _specs_payload(),
"edgeFieldNames": _edge_field_names(),
}
)

View file

@ -14,9 +14,18 @@ import type {
export function generateCode(
workflow: WireWorkflow,
specs: NodeSpec[],
opts: { workflowName?: string } = {},
opts: { workflowName?: string; edgeFieldNames?: string[] } = {},
): GenerateResult {
const specByName = new Map(specs.map((s) => [s.name, s]));
const edgeFieldNames = new Set(
opts.edgeFieldNames ?? [
"label",
"condition",
"transition_speech",
"transition_speech_type",
"transition_speech_recording_id",
],
);
// Catch unknown node types up-front — otherwise we'd emit an import
// line for a factory that doesn't exist.
@ -97,7 +106,7 @@ export function generateCode(
],
};
}
const cleanedEdge = pickEdgeFields(edge.data);
const cleanedEdge = pickEdgeFields(edge.data, edgeFieldNames);
const edgeOpts = renderObject(cleanedEdge, 0);
lines.push(`wf.edge(${src}, ${tgt}, ${edgeOpts});`);
}
@ -210,22 +219,13 @@ function stripUnknown(
return out;
}
// Edge schema is fixed (no NodeSpec for edges). Mirrors the allowed
// fields on `Workflow.edge(...)` in both SDKs.
const KNOWN_EDGE_FIELDS = new Set([
"label",
"condition",
"transition_speech",
"transition_speech_type",
"transition_speech_recording_id",
]);
function pickEdgeFields(
data: Record<string, unknown>,
knownEdgeFields: Set<string>,
): Record<string, unknown> {
const out: Record<string, unknown> = {};
for (const [k, v] of Object.entries(data)) {
if (KNOWN_EDGE_FIELDS.has(k)) out[k] = v;
if (knownEdgeFields.has(k)) out[k] = v;
}
return out;
}

View file

@ -11,6 +11,7 @@ interface GenerateRequest {
command: "generate";
workflow: WireWorkflow;
specs: NodeSpec[];
edgeFieldNames: string[];
workflowName?: string;
}
@ -18,6 +19,7 @@ interface ParseRequest {
command: "parse";
code: string;
specs: NodeSpec[];
edgeFieldNames: string[];
}
type Request = GenerateRequest | ParseRequest;
@ -49,11 +51,16 @@ async function main(): Promise<void> {
}
if (req.command === "generate") {
writeResult(generateCode(req.workflow, req.specs, { workflowName: req.workflowName }));
writeResult(
generateCode(req.workflow, req.specs, {
workflowName: req.workflowName,
edgeFieldNames: req.edgeFieldNames,
}),
);
return;
}
if (req.command === "parse") {
writeResult(parseCode(req.code, req.specs));
writeResult(parseCode(req.code, req.specs, req.edgeFieldNames));
return;
}
writeResult({

View file

@ -25,8 +25,19 @@ import type {
WireNode,
} from "./types.ts";
export function parseCode(code: string, specs: NodeSpec[]): ParseResult {
export function parseCode(
code: string,
specs: NodeSpec[],
edgeFieldNames: string[] = [
"label",
"condition",
"transition_speech",
"transition_speech_type",
"transition_speech_recording_id",
],
): ParseResult {
const specByName = new Map(specs.map((s) => [s.name, s]));
const allowedEdgeFieldNames = new Set(edgeFieldNames);
const sourceFile = ts.createSourceFile(
"workflow.ts",
code,
@ -335,6 +346,12 @@ export function parseCode(code: string, specs: NodeSpec[]): ParseResult {
addError(stmt, "`edge` requires a non-empty `condition` string.");
return;
}
for (const key of Object.keys(optsObj)) {
if (!allowedEdgeFieldNames.has(key)) {
addError(stmt, `Unknown edge field: \`${key}\`.`);
return;
}
}
edges.push({
id: `${src.id}-${tgt.id}`,
source: src.id,

View file

@ -0,0 +1,99 @@
from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from api.mcp_server.tools.workflows import get_workflow
@pytest.fixture
def authed_user() -> MagicMock:
user = MagicMock()
user.selected_organization_id = 1
return user
def _workflow() -> SimpleNamespace:
return SimpleNamespace(
id=7,
name="Support Agent",
status="active",
released_definition=SimpleNamespace(
workflow_json={"nodes": [{"id": "published"}], "edges": []},
version_number=3,
),
workflow_definition={"nodes": [{"id": "legacy"}], "edges": []},
)
@pytest.mark.asyncio
async def test_get_workflow_returns_draft_sdk_view(authed_user: MagicMock):
workflow = _workflow()
draft = SimpleNamespace(
workflow_json={"nodes": [{"id": "draft"}], "edges": []},
version_number=4,
)
with (
patch(
"api.mcp_server.tools.workflows.authenticate_mcp_request",
AsyncMock(return_value=authed_user),
),
patch(
"api.mcp_server.tools.workflows.db_client.get_workflow",
AsyncMock(return_value=workflow),
),
patch(
"api.mcp_server.tools._workflow_projection.db_client.get_draft_version",
AsyncMock(return_value=draft),
),
patch(
"api.mcp_server.tools._workflow_projection.generate_code",
AsyncMock(
return_value='const wf = new Workflow({ name: "Support Agent" });'
),
) as generate_code_mock,
):
result = await get_workflow(workflow_id=workflow.id)
assert result == {
"id": 7,
"name": "Support Agent",
"status": "active",
"version": "draft",
"version_number": 4,
"code": 'const wf = new Workflow({ name: "Support Agent" });',
}
generate_code_mock.assert_awaited_once_with(
draft.workflow_json, workflow_name="Support Agent"
)
@pytest.mark.asyncio
async def test_get_workflow_falls_back_to_published_sdk_view(authed_user: MagicMock):
workflow = _workflow()
with (
patch(
"api.mcp_server.tools.workflows.authenticate_mcp_request",
AsyncMock(return_value=authed_user),
),
patch(
"api.mcp_server.tools.workflows.db_client.get_workflow",
AsyncMock(return_value=workflow),
),
patch(
"api.mcp_server.tools._workflow_projection.db_client.get_draft_version",
AsyncMock(return_value=None),
),
patch(
"api.mcp_server.tools._workflow_projection.generate_code",
AsyncMock(
return_value='const wf = new Workflow({ name: "Support Agent" });'
),
),
):
result = await get_workflow(workflow_id=workflow.id)
assert result["version"] == "published"
assert result["version_number"] == 3

View file

@ -7,10 +7,19 @@ that code → JSON and JSON → code round-trip losslessly.
from __future__ import annotations
import shutil
from types import NoneType
from typing import Any, get_args
import pytest
from api.mcp_server.ts_bridge import TsBridgeError, generate_code, parse_code
from api.services.workflow.dto import EdgeDataDTO
from api.services.workflow.node_specs import (
NodeSpec,
PropertySpec,
PropertyType,
all_specs,
)
pytestmark = pytest.mark.skipif(
shutil.which("node") is None, reason="node binary not available"
@ -81,6 +90,102 @@ def _normalize(wf: dict) -> dict:
}
def _strip_optional(annotation: Any) -> Any:
args = tuple(arg for arg in get_args(annotation) if arg is not NoneType)
if len(args) == 1:
return args[0]
return annotation
def _pick_option_value(prop: PropertySpec) -> Any:
assert prop.options, f"{prop.name} has no options"
default = prop.default
for option in prop.options:
if option.value != default:
return option.value
return prop.options[0].value
def _sample_number(prop: PropertySpec) -> int | float:
candidates: list[int | float] = [1, 2, 3, 0.5, 4.5, 10]
for candidate in candidates:
if prop.min_value is not None and candidate < prop.min_value:
continue
if prop.max_value is not None and candidate > prop.max_value:
continue
if prop.default is not None and candidate == prop.default:
continue
return candidate
raise AssertionError(f"No valid sample number found for {prop.name}")
def _sample_property_value(prop: PropertySpec, *, path: str) -> Any:
slug = path.replace(".", "_")
if prop.type == PropertyType.string:
return f"{slug}_value"
if prop.type == PropertyType.mention_textarea:
return f"{slug} prompt with {{name}}"
if prop.type == PropertyType.url:
return f"https://example.com/{slug}"
if prop.type == PropertyType.recording_ref:
return f"recording_{slug}"
if prop.type == PropertyType.credential_ref:
return f"credential_{slug}"
if prop.type == PropertyType.number:
return _sample_number(prop)
if prop.type == PropertyType.boolean:
return not prop.default if isinstance(prop.default, bool) else True
if prop.type == PropertyType.options:
return _pick_option_value(prop)
if prop.type == PropertyType.multi_options:
return [_pick_option_value(prop)]
if prop.type == PropertyType.tool_refs:
return [f"tool_{slug}"]
if prop.type == PropertyType.document_refs:
return [f"document_{slug}"]
if prop.type == PropertyType.json:
return {"kind": slug, "enabled": True}
if prop.type == PropertyType.fixed_collection:
assert prop.properties, f"{prop.name} fixed_collection has no sub-properties"
return [
{
sub_prop.name: _sample_property_value(
sub_prop, path=f"{path}.{sub_prop.name}"
)
for sub_prop in prop.properties
}
]
raise AssertionError(f"Unhandled PropertyType in TS bridge test: {prop.type}")
def _sample_node_data(spec: NodeSpec) -> dict[str, Any]:
return {
prop.name: _sample_property_value(prop, path=f"{spec.name}.{prop.name}")
for prop in spec.properties
}
def _sample_edge_value(field_name: str, annotation: Any) -> Any:
inner = _strip_optional(annotation)
if inner is str:
return f"{field_name}_value"
if inner is bool:
return True
if inner in (int, float):
return 1
raise AssertionError(
f"Unhandled edge field annotation in TS bridge test: {field_name} -> {annotation!r}"
)
def _sample_edge_data() -> dict[str, Any]:
return {
field_name: _sample_edge_value(field_name, field.annotation)
for field_name, field in EdgeDataDTO.model_fields.items()
}
# ─── generate_code ───────────────────────────────────────────────────────
@ -154,6 +259,19 @@ async def test_generate_strips_unknown_edge_fields():
assert "validationMessage" not in code
@pytest.mark.asyncio
async def test_generate_preserves_all_edge_dto_fields():
wf = _minimal_workflow()
edge_data = _sample_edge_data()
wf["edges"][0]["data"] = edge_data
code = await generate_code(wf)
result = await parse_code(code)
assert result["ok"] is True, result
assert result["workflow"]["edges"][0]["data"] == edge_data
# ─── parse_code ──────────────────────────────────────────────────────────
@ -229,6 +347,21 @@ wf.edge(a, b, { label: "", condition: "c" });
assert result["stage"] == "parse"
@pytest.mark.asyncio
async def test_parse_rejects_unknown_edge_field():
code = """import { Workflow } from "@dograh/sdk";
import { startCall, endCall } from "@dograh/sdk/typed";
const wf = new Workflow({ name: "x" });
const a = wf.addTyped(startCall({ name: "g", prompt: "hi" }));
const b = wf.addTyped(endCall({ name: "d", prompt: "bye" }));
wf.edge(a, b, { label: "done", condition: "wrapped", bogus: "x" });
"""
result = await parse_code(code)
assert result["ok"] is False
assert result["stage"] == "parse"
assert any("Unknown edge field" in e["message"] for e in result["errors"])
# ─── Round-trip ──────────────────────────────────────────────────────────
@ -257,6 +390,30 @@ async def test_round_trip_minimal():
]
@pytest.mark.asyncio
@pytest.mark.parametrize("spec", all_specs(), ids=lambda spec: spec.name)
async def test_round_trip_preserves_all_node_spec_fields(spec: NodeSpec):
data = _sample_node_data(spec)
wf = {
"nodes": [
{
"id": "1",
"type": spec.name,
"position": {"x": 0, "y": 0},
"data": data,
}
],
"edges": [],
"viewport": {"x": 0, "y": 0, "zoom": 1},
}
code = await generate_code(wf, workflow_name=f"{spec.name}_rt")
result = await parse_code(code)
assert result["ok"] is True, result
assert result["workflow"]["nodes"][0]["data"] == data
@pytest.mark.asyncio
async def test_generate_fails_on_unknown_type():
bad = {

File diff suppressed because one or more lines are too long

View file

@ -1,6 +1,6 @@
# generated by datamodel-codegen:
# filename: dograh-openapi-XXXXXX.json.ahnZ2z2E21
# timestamp: 2026-05-23T03:32:29+00:00
# filename: dograh-openapi-XXXXXX.json.N8gRI5v3bD
# timestamp: 2026-05-23T09:14:22+00:00
from __future__ import annotations