fix: fix projection to TS when fetching agnet in MCP

This commit is contained in:
Abhishek Kumar 2026-05-23 14:45:50 +05:30
parent 3892b58486
commit bbb4f91a27
12 changed files with 392 additions and 63 deletions

View file

@ -0,0 +1,99 @@
from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from api.mcp_server.tools.workflows import get_workflow
@pytest.fixture
def authed_user() -> MagicMock:
user = MagicMock()
user.selected_organization_id = 1
return user
def _workflow() -> SimpleNamespace:
return SimpleNamespace(
id=7,
name="Support Agent",
status="active",
released_definition=SimpleNamespace(
workflow_json={"nodes": [{"id": "published"}], "edges": []},
version_number=3,
),
workflow_definition={"nodes": [{"id": "legacy"}], "edges": []},
)
@pytest.mark.asyncio
async def test_get_workflow_returns_draft_sdk_view(authed_user: MagicMock):
workflow = _workflow()
draft = SimpleNamespace(
workflow_json={"nodes": [{"id": "draft"}], "edges": []},
version_number=4,
)
with (
patch(
"api.mcp_server.tools.workflows.authenticate_mcp_request",
AsyncMock(return_value=authed_user),
),
patch(
"api.mcp_server.tools.workflows.db_client.get_workflow",
AsyncMock(return_value=workflow),
),
patch(
"api.mcp_server.tools._workflow_projection.db_client.get_draft_version",
AsyncMock(return_value=draft),
),
patch(
"api.mcp_server.tools._workflow_projection.generate_code",
AsyncMock(
return_value='const wf = new Workflow({ name: "Support Agent" });'
),
) as generate_code_mock,
):
result = await get_workflow(workflow_id=workflow.id)
assert result == {
"id": 7,
"name": "Support Agent",
"status": "active",
"version": "draft",
"version_number": 4,
"code": 'const wf = new Workflow({ name: "Support Agent" });',
}
generate_code_mock.assert_awaited_once_with(
draft.workflow_json, workflow_name="Support Agent"
)
@pytest.mark.asyncio
async def test_get_workflow_falls_back_to_published_sdk_view(authed_user: MagicMock):
workflow = _workflow()
with (
patch(
"api.mcp_server.tools.workflows.authenticate_mcp_request",
AsyncMock(return_value=authed_user),
),
patch(
"api.mcp_server.tools.workflows.db_client.get_workflow",
AsyncMock(return_value=workflow),
),
patch(
"api.mcp_server.tools._workflow_projection.db_client.get_draft_version",
AsyncMock(return_value=None),
),
patch(
"api.mcp_server.tools._workflow_projection.generate_code",
AsyncMock(
return_value='const wf = new Workflow({ name: "Support Agent" });'
),
),
):
result = await get_workflow(workflow_id=workflow.id)
assert result["version"] == "published"
assert result["version_number"] == 3

View file

@ -7,10 +7,19 @@ that code → JSON and JSON → code round-trip losslessly.
from __future__ import annotations
import shutil
from types import NoneType
from typing import Any, get_args
import pytest
from api.mcp_server.ts_bridge import TsBridgeError, generate_code, parse_code
from api.services.workflow.dto import EdgeDataDTO
from api.services.workflow.node_specs import (
NodeSpec,
PropertySpec,
PropertyType,
all_specs,
)
pytestmark = pytest.mark.skipif(
shutil.which("node") is None, reason="node binary not available"
@ -81,6 +90,102 @@ def _normalize(wf: dict) -> dict:
}
def _strip_optional(annotation: Any) -> Any:
args = tuple(arg for arg in get_args(annotation) if arg is not NoneType)
if len(args) == 1:
return args[0]
return annotation
def _pick_option_value(prop: PropertySpec) -> Any:
assert prop.options, f"{prop.name} has no options"
default = prop.default
for option in prop.options:
if option.value != default:
return option.value
return prop.options[0].value
def _sample_number(prop: PropertySpec) -> int | float:
candidates: list[int | float] = [1, 2, 3, 0.5, 4.5, 10]
for candidate in candidates:
if prop.min_value is not None and candidate < prop.min_value:
continue
if prop.max_value is not None and candidate > prop.max_value:
continue
if prop.default is not None and candidate == prop.default:
continue
return candidate
raise AssertionError(f"No valid sample number found for {prop.name}")
def _sample_property_value(prop: PropertySpec, *, path: str) -> Any:
slug = path.replace(".", "_")
if prop.type == PropertyType.string:
return f"{slug}_value"
if prop.type == PropertyType.mention_textarea:
return f"{slug} prompt with {{name}}"
if prop.type == PropertyType.url:
return f"https://example.com/{slug}"
if prop.type == PropertyType.recording_ref:
return f"recording_{slug}"
if prop.type == PropertyType.credential_ref:
return f"credential_{slug}"
if prop.type == PropertyType.number:
return _sample_number(prop)
if prop.type == PropertyType.boolean:
return not prop.default if isinstance(prop.default, bool) else True
if prop.type == PropertyType.options:
return _pick_option_value(prop)
if prop.type == PropertyType.multi_options:
return [_pick_option_value(prop)]
if prop.type == PropertyType.tool_refs:
return [f"tool_{slug}"]
if prop.type == PropertyType.document_refs:
return [f"document_{slug}"]
if prop.type == PropertyType.json:
return {"kind": slug, "enabled": True}
if prop.type == PropertyType.fixed_collection:
assert prop.properties, f"{prop.name} fixed_collection has no sub-properties"
return [
{
sub_prop.name: _sample_property_value(
sub_prop, path=f"{path}.{sub_prop.name}"
)
for sub_prop in prop.properties
}
]
raise AssertionError(f"Unhandled PropertyType in TS bridge test: {prop.type}")
def _sample_node_data(spec: NodeSpec) -> dict[str, Any]:
return {
prop.name: _sample_property_value(prop, path=f"{spec.name}.{prop.name}")
for prop in spec.properties
}
def _sample_edge_value(field_name: str, annotation: Any) -> Any:
inner = _strip_optional(annotation)
if inner is str:
return f"{field_name}_value"
if inner is bool:
return True
if inner in (int, float):
return 1
raise AssertionError(
f"Unhandled edge field annotation in TS bridge test: {field_name} -> {annotation!r}"
)
def _sample_edge_data() -> dict[str, Any]:
return {
field_name: _sample_edge_value(field_name, field.annotation)
for field_name, field in EdgeDataDTO.model_fields.items()
}
# ─── generate_code ───────────────────────────────────────────────────────
@ -154,6 +259,19 @@ async def test_generate_strips_unknown_edge_fields():
assert "validationMessage" not in code
@pytest.mark.asyncio
async def test_generate_preserves_all_edge_dto_fields():
wf = _minimal_workflow()
edge_data = _sample_edge_data()
wf["edges"][0]["data"] = edge_data
code = await generate_code(wf)
result = await parse_code(code)
assert result["ok"] is True, result
assert result["workflow"]["edges"][0]["data"] == edge_data
# ─── parse_code ──────────────────────────────────────────────────────────
@ -229,6 +347,21 @@ wf.edge(a, b, { label: "", condition: "c" });
assert result["stage"] == "parse"
@pytest.mark.asyncio
async def test_parse_rejects_unknown_edge_field():
code = """import { Workflow } from "@dograh/sdk";
import { startCall, endCall } from "@dograh/sdk/typed";
const wf = new Workflow({ name: "x" });
const a = wf.addTyped(startCall({ name: "g", prompt: "hi" }));
const b = wf.addTyped(endCall({ name: "d", prompt: "bye" }));
wf.edge(a, b, { label: "done", condition: "wrapped", bogus: "x" });
"""
result = await parse_code(code)
assert result["ok"] is False
assert result["stage"] == "parse"
assert any("Unknown edge field" in e["message"] for e in result["errors"])
# ─── Round-trip ──────────────────────────────────────────────────────────
@ -257,6 +390,30 @@ async def test_round_trip_minimal():
]
@pytest.mark.asyncio
@pytest.mark.parametrize("spec", all_specs(), ids=lambda spec: spec.name)
async def test_round_trip_preserves_all_node_spec_fields(spec: NodeSpec):
data = _sample_node_data(spec)
wf = {
"nodes": [
{
"id": "1",
"type": spec.name,
"position": {"x": 0, "y": 0},
"data": data,
}
],
"edges": [],
"viewport": {"x": 0, "y": 0, "zoom": 1},
}
code = await generate_code(wf, workflow_name=f"{spec.name}_rt")
result = await parse_code(code)
assert result["ok"] is True, result
assert result["workflow"]["nodes"][0]["data"] == data
@pytest.mark.asyncio
async def test_generate_fails_on_unknown_type():
bad = {