mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-07 07:55:16 +02:00
feat: refactor node spec and add mcp tools (#244)
* refactor: carve out extraction panel * refactor: create spec versions for node types * refactor: create a GenericNode and remove custom nodes * feat: add python and typescript sdk * add dograh sdk * fix: fetch draft workflow definition over published one * fix: fix routes of SDKs to use code gen * chore: remove doclink dependency to reduce image size * chore: format files * chore: bump pipecat * feat: let mcp fetch archived workflows on demand * chore: fix tests * feat: add sdk documentation * chore: change banner and add badge
This commit is contained in:
parent
0a61ef295f
commit
00a1a22b74
162 changed files with 14355 additions and 3554 deletions
|
|
@ -13,11 +13,6 @@ RUN apt-get update && apt-get install -y \
|
|||
# Copy and install requirements
|
||||
COPY api/requirements.txt .
|
||||
|
||||
# Install CPU-only PyTorch FIRST to prevent CUDA/NVIDIA dependencies
|
||||
# This satisfies torch dependency before other packages try to pull GPU version
|
||||
RUN pip install --user --no-cache-dir torch --index-url https://download.pytorch.org/whl/cpu && \
|
||||
rm -rf /root/.cache/pip
|
||||
|
||||
# Install dependencies to user directory for easy copying
|
||||
RUN pip install --user --no-cache-dir -r requirements.txt && \
|
||||
# Clean up pip cache after installation
|
||||
|
|
@ -25,27 +20,54 @@ RUN pip install --user --no-cache-dir -r requirements.txt && \
|
|||
|
||||
# Copy and install pipecat from local submodule
|
||||
COPY pipecat /tmp/pipecat
|
||||
RUN pip install --user --no-cache-dir '/tmp/pipecat[cartesia,deepgram,openai,elevenlabs,groq,google,azure,sarvam,soundfile,silero,webrtc,local-smart-turn-v3,speechmatics,openrouter,camb]' && \
|
||||
RUN pip install --user --no-cache-dir '/tmp/pipecat[cartesia,deepgram,openai,elevenlabs,groq,google,azure,sarvam,soundfile,silero,webrtc,speechmatics,openrouter,camb]' && \
|
||||
# Swap opencv-python (pulled by pipecat[webrtc]) for opencv-python-headless
|
||||
# to drop X11/Qt dependencies that otherwise require libxcb etc. in runner.
|
||||
pip uninstall -y opencv-python && \
|
||||
pip install --user --no-cache-dir opencv-python-headless && \
|
||||
# Pre-download NLTK punkt_tab tokenizer data (required by pipecat at runtime)
|
||||
python -c "import nltk; nltk.download('punkt_tab', quiet=True)" && \
|
||||
# Clean up pip cache and temporary pipecat directory
|
||||
rm -rf /root/.cache/pip /tmp/pipecat
|
||||
|
||||
# Remove unnecessary Python cache files from installed packages
|
||||
# Strip cache files, test/example dirs, and type stubs from installed packages
|
||||
RUN find /root/.local -type f -name '*.pyc' -delete && \
|
||||
find /root/.local -type d -name '__pycache__' -delete && \
|
||||
find /root/.local -type f -name '*.pyo' -delete
|
||||
find /root/.local -type d -name '__pycache__' -prune -exec rm -rf {} + && \
|
||||
find /root/.local -type f -name '*.pyo' -delete && \
|
||||
find /root/.local -type d \( -name tests -o -name test -o -name examples \) -prune -exec rm -rf {} + && \
|
||||
find /root/.local -name '*.pyi' -delete
|
||||
|
||||
# Stage 2: Runtime - Minimal image with only runtime dependencies
|
||||
# Stage 2: Node deps for ts_validator (built with full node:22-slim, only
|
||||
# node_modules is copied into the runner).
|
||||
FROM node:22-slim AS ts-deps
|
||||
WORKDIR /ts_validator
|
||||
COPY api/mcp_server/ts_validator/package*.json ./
|
||||
RUN npm ci --omit=dev && npm cache clean --force
|
||||
|
||||
# Stage 3: Static ffmpeg binary (avoids apt ffmpeg pulling mesa/libllvm for
|
||||
# hardware acceleration we don't use server-side).
|
||||
FROM debian:trixie-slim AS ffmpeg-static
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
curl ca-certificates xz-utils \
|
||||
&& curl -fsSL -o /tmp/ffmpeg.tar.xz https://johnvansickle.com/ffmpeg/releases/ffmpeg-release-amd64-static.tar.xz \
|
||||
&& mkdir -p /tmp/ffmpeg \
|
||||
&& tar -xJf /tmp/ffmpeg.tar.xz -C /tmp/ffmpeg --strip-components=1 \
|
||||
&& mv /tmp/ffmpeg/ffmpeg /tmp/ffmpeg/ffprobe /usr/local/bin/ \
|
||||
&& chmod +x /usr/local/bin/ffmpeg /usr/local/bin/ffprobe
|
||||
|
||||
# Stage 4: Runtime - Minimal image with only runtime dependencies
|
||||
FROM python:3.12-slim AS runner
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Only install ffmpeg (runtime dependency)
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
ffmpeg \
|
||||
&& apt-get clean \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
# Static ffmpeg + ffprobe (used by audio_converter, audio_file_cache, etc.)
|
||||
COPY --from=ffmpeg-static /usr/local/bin/ffmpeg /usr/local/bin/ffmpeg
|
||||
COPY --from=ffmpeg-static /usr/local/bin/ffprobe /usr/local/bin/ffprobe
|
||||
|
||||
# Node.js 22 binary only (ts_validator subprocess needs node >=22.6 for
|
||||
# native TypeScript stripping; see api/mcp_server/ts_bridge.py). python:3.12-slim
|
||||
# already provides libstdc++6, libgcc-s1, and ca-certificates that node needs.
|
||||
COPY --from=node:22-slim /usr/local/bin/node /usr/local/bin/node
|
||||
|
||||
# Copy Python packages from builder stage
|
||||
COPY --from=builder /root/.local /root/.local
|
||||
|
|
@ -65,6 +87,10 @@ ENV PYTHONUNBUFFERED=1
|
|||
COPY ./api ./api
|
||||
COPY ./scripts/start_services_dev.sh ./scripts/start_services_dev.sh
|
||||
|
||||
# ts_validator Node deps (built in ts-deps stage with full node:22-slim image).
|
||||
# The validator runs as a short-lived subprocess from api/mcp_server/ts_bridge.py.
|
||||
COPY --from=ts-deps /ts_validator/node_modules ./api/mcp_server/ts_validator/node_modules
|
||||
|
||||
# Product documentation — read at runtime by the MCP docs tools
|
||||
# (search_dograh_docs / fetch_dograh_doc) so agents can learn Dograh.
|
||||
COPY ./docs ./docs
|
||||
|
|
|
|||
|
|
@ -27,7 +27,7 @@ from fastapi.middleware.cors import CORSMiddleware
|
|||
from loguru import logger
|
||||
|
||||
from api.constants import REDIS_URL
|
||||
from api.mcp import mcp
|
||||
from api.mcp_server import mcp
|
||||
from api.routes.main import router as main_router
|
||||
from api.services.pipecat.tracing_config import (
|
||||
handle_langfuse_sync,
|
||||
|
|
|
|||
|
|
@ -1,3 +0,0 @@
|
|||
from api.mcp.server import mcp
|
||||
|
||||
__all__ = ["mcp"]
|
||||
|
|
@ -1,25 +0,0 @@
|
|||
from fastapi import HTTPException
|
||||
from fastmcp.server.dependencies import get_http_headers
|
||||
|
||||
from api.db.models import UserModel
|
||||
from api.services.auth.depends import _handle_api_key_auth
|
||||
|
||||
|
||||
async def authenticate_mcp_request() -> UserModel:
|
||||
"""Resolve the authenticated Dograh user for an MCP tool invocation.
|
||||
|
||||
Accepts either `X-API-Key: <key>` or `Authorization: Bearer <key>`,
|
||||
reusing the API-key flow from `api.services.auth.depends`.
|
||||
"""
|
||||
headers = get_http_headers()
|
||||
api_key = headers.get("x-api-key")
|
||||
if not api_key:
|
||||
auth = headers.get("authorization", "")
|
||||
if auth.lower().startswith("bearer "):
|
||||
api_key = auth.split(" ", 1)[1].strip()
|
||||
if not api_key:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="Missing API key — send X-API-Key or Authorization: Bearer <key>",
|
||||
)
|
||||
return await _handle_api_key_auth(api_key)
|
||||
|
|
@ -1,6 +0,0 @@
|
|||
from fastmcp import FastMCP
|
||||
|
||||
mcp = FastMCP("dograh")
|
||||
|
||||
from api.mcp.tools import docs as _docs # noqa: E402, F401
|
||||
from api.mcp.tools import workflows as _workflows # noqa: E402, F401
|
||||
|
|
@ -1,115 +0,0 @@
|
|||
import re
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
|
||||
from fastapi import HTTPException
|
||||
from rank_bm25 import BM25Okapi
|
||||
|
||||
from api.mcp.server import mcp
|
||||
|
||||
DOCS_ROOT = Path(__file__).resolve().parents[3] / "docs"
|
||||
|
||||
_TOKEN_RE = re.compile(r"[A-Za-z0-9_]+")
|
||||
_FRONTMATTER_RE = re.compile(r"^---\n(.*?)\n---\n", re.DOTALL)
|
||||
_TITLE_RE = re.compile(r"^title:\s*['\"]?(.+?)['\"]?\s*$", re.MULTILINE)
|
||||
_H1_RE = re.compile(r"^#\s+(.+?)\s*$", re.MULTILINE)
|
||||
|
||||
|
||||
def _tokenize(text: str) -> list[str]:
|
||||
return [t.lower() for t in _TOKEN_RE.findall(text)]
|
||||
|
||||
|
||||
def _extract_title(path: Path, body: str) -> str:
|
||||
fm_match = _FRONTMATTER_RE.match(body)
|
||||
if fm_match:
|
||||
title_match = _TITLE_RE.search(fm_match.group(1))
|
||||
if title_match:
|
||||
return title_match.group(1).strip()
|
||||
h1_match = _H1_RE.search(body)
|
||||
if h1_match:
|
||||
return h1_match.group(1).strip()
|
||||
return path.stem.replace("-", " ").title()
|
||||
|
||||
|
||||
def _strip_frontmatter(body: str) -> str:
|
||||
return _FRONTMATTER_RE.sub("", body, count=1)
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def _load_index() -> tuple[list[dict], BM25Okapi]:
|
||||
"""Read every docs/**/*.mdx file once and build a BM25 index.
|
||||
|
||||
Cached for the process lifetime — docs rarely change between restarts.
|
||||
"""
|
||||
docs: list[dict] = []
|
||||
corpus: list[list[str]] = []
|
||||
|
||||
for path in sorted(DOCS_ROOT.rglob("*.mdx")):
|
||||
body = path.read_text(encoding="utf-8")
|
||||
rel = path.relative_to(DOCS_ROOT).as_posix()
|
||||
title = _extract_title(path, body)
|
||||
content = _strip_frontmatter(body)
|
||||
docs.append({"path": rel, "title": title, "content": content})
|
||||
corpus.append(_tokenize(f"{title} {content}"))
|
||||
|
||||
return docs, BM25Okapi(corpus)
|
||||
|
||||
|
||||
def _snippet(content: str, query_tokens: list[str], width: int = 240) -> str:
|
||||
lowered = content.lower()
|
||||
for tok in query_tokens:
|
||||
idx = lowered.find(tok)
|
||||
if idx >= 0:
|
||||
start = max(0, idx - width // 2)
|
||||
end = min(len(content), start + width)
|
||||
return (
|
||||
("…" if start > 0 else "")
|
||||
+ content[start:end].strip()
|
||||
+ ("…" if end < len(content) else "")
|
||||
)
|
||||
return content[:width].strip() + ("…" if len(content) > width else "")
|
||||
|
||||
|
||||
@mcp.tool
|
||||
async def search_dograh_docs(query: str, limit: int = 5) -> list[dict]:
|
||||
"""Search Dograh's product documentation.
|
||||
|
||||
Returns the top matches as {path, title, snippet}. Pass the returned
|
||||
`path` to `fetch_dograh_doc` to read the full page. Use this first
|
||||
when you need to learn how a Dograh feature works before building
|
||||
against it.
|
||||
"""
|
||||
docs, bm25 = _load_index()
|
||||
tokens = _tokenize(query)
|
||||
if not tokens:
|
||||
return []
|
||||
|
||||
scores = bm25.get_scores(tokens)
|
||||
ranked = sorted(zip(scores, docs), key=lambda pair: pair[0], reverse=True)[:limit]
|
||||
|
||||
return [
|
||||
{
|
||||
"path": doc["path"],
|
||||
"title": doc["title"],
|
||||
"snippet": _snippet(doc["content"], tokens),
|
||||
"score": round(float(score), 3),
|
||||
}
|
||||
for score, doc in ranked
|
||||
if score > 0
|
||||
]
|
||||
|
||||
|
||||
@mcp.tool
|
||||
async def fetch_dograh_doc(path: str) -> dict:
|
||||
"""Fetch the full content of a Dograh docs page by its path
|
||||
(e.g. `core-concepts/workflows.mdx`), as returned by `search_dograh_docs`.
|
||||
"""
|
||||
docs, _ = _load_index()
|
||||
for doc in docs:
|
||||
if doc["path"] == path:
|
||||
return {
|
||||
"path": doc["path"],
|
||||
"title": doc["title"],
|
||||
"content": doc["content"],
|
||||
}
|
||||
raise HTTPException(status_code=404, detail=f"Doc not found: {path}")
|
||||
3
api/mcp_server/__init__.py
Normal file
3
api/mcp_server/__init__.py
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
from api.mcp_server.server import mcp
|
||||
|
||||
__all__ = ["mcp"]
|
||||
46
api/mcp_server/auth.py
Normal file
46
api/mcp_server/auth.py
Normal file
|
|
@ -0,0 +1,46 @@
|
|||
from fastapi import HTTPException
|
||||
from fastmcp.server.dependencies import get_http_headers
|
||||
from opentelemetry import trace
|
||||
|
||||
from api.db.models import UserModel
|
||||
from api.services.auth.depends import _handle_api_key_auth
|
||||
|
||||
|
||||
async def authenticate_mcp_request() -> UserModel:
|
||||
"""Resolve the authenticated Dograh user for an MCP tool invocation.
|
||||
|
||||
Accepts either `X-API-Key: <key>` or `Authorization: Bearer <key>`,
|
||||
reusing the API-key flow from `api.services.auth.depends`.
|
||||
|
||||
Tags the currently-active OTel span with the resolved organization
|
||||
and user identifiers. `_OrgRoutingExporter` reads `dograh.org_id`
|
||||
at export time to dispatch the span to the right Langfuse project;
|
||||
the `langfuse.user.id` / `langfuse.session.id` attributes make the
|
||||
span filterable in the Langfuse UI.
|
||||
"""
|
||||
headers = get_http_headers()
|
||||
api_key = headers.get("x-api-key")
|
||||
if not api_key:
|
||||
auth = headers.get("authorization", "")
|
||||
if auth.lower().startswith("bearer "):
|
||||
api_key = auth.split(" ", 1)[1].strip()
|
||||
if not api_key:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="Missing API key — send X-API-Key or Authorization: Bearer <key>",
|
||||
)
|
||||
user = await _handle_api_key_auth(api_key)
|
||||
|
||||
span = trace.get_current_span()
|
||||
if span.is_recording():
|
||||
org_id = user.selected_organization_id
|
||||
# Intentionally NOT `dograh.org_id` — that attribute triggers the
|
||||
# per-org Langfuse routing for pipeline spans, and MCP traffic
|
||||
# should land in the default (developer-facing) project only.
|
||||
# Exposed under `mcp.org_id` for Langfuse UI filtering without
|
||||
# affecting the router.
|
||||
span.set_attribute("mcp.org_id", str(org_id))
|
||||
span.set_attribute("mcp.user_id", str(user.id))
|
||||
span.set_attribute("langfuse.user.id", str(user.id))
|
||||
|
||||
return user
|
||||
124
api/mcp_server/instructions.py
Normal file
124
api/mcp_server/instructions.py
Normal file
|
|
@ -0,0 +1,124 @@
|
|||
"""Top-level orchestration guide surfaced to every MCP session.
|
||||
|
||||
Sent to the client via `FastMCP(instructions=...)` — the client bakes
|
||||
this into its system prompt, so every LLM session sees it before the
|
||||
first tool call. Prefer procedural orchestration here (call order, error
|
||||
handling, hard constraints). Design-level per-field guidance belongs in
|
||||
each `PropertySpec.llm_hint`; it flows out through `get_node_type` and
|
||||
doesn't need to be repeated here.
|
||||
|
||||
Extend based on real LLM failures — every bullet below ideally maps to a
|
||||
mistake the system has seen at least once.
|
||||
"""
|
||||
|
||||
DOGRAH_MCP_INSTRUCTIONS = """\
|
||||
You build and edit Dograh voice-AI workflows by emitting TypeScript that
|
||||
uses the `@dograh/sdk` package. Workflows are stored as JSON; this server
|
||||
projects them to TypeScript for editing and parses them back on save.
|
||||
|
||||
## Call order
|
||||
|
||||
1. `list_workflows` — locate the target workflow.
|
||||
2. `get_workflow_code(workflow_id)` — fetch the current source (draft if
|
||||
one exists, otherwise published).
|
||||
3. (optional) `list_node_types` / `get_node_type(name)` — consult before
|
||||
adding or editing a node type whose fields aren't already visible in
|
||||
the current code.
|
||||
4. Mutate the code in place. Preserve existing nodes, edges, and variable
|
||||
names unless the task requires removing or renaming them.
|
||||
5. `save_workflow(workflow_id, code)` — persist as a new draft. The
|
||||
published version is untouched.
|
||||
|
||||
## Allowed source shape
|
||||
|
||||
The parser is AST-only and rejects anything outside this grammar. At the
|
||||
top level, only three statement forms are accepted:
|
||||
|
||||
import ... from "..."; // any import
|
||||
const <var> = <initializer>; // bindings (see below)
|
||||
wf.edge(<src>, <tgt>, { label, condition }); // bare edge calls
|
||||
|
||||
`<initializer>` is one of:
|
||||
new Workflow({ name: "..." })
|
||||
wf.addTyped(<factory>({ ...fields }) [, { position: [x, y] }])
|
||||
wf.add({ type: "<nodeType>", ...fields [, position: [x, y]] })
|
||||
|
||||
No functions, arrow fns, loops, conditionals, ternaries, spreads,
|
||||
destructuring, template interpolation, `export`, or `.map`/`.forEach`.
|
||||
Data-position values must be plain literals (strings, numbers, booleans,
|
||||
null, arrays/objects of same). A single `new Workflow(...)` per file —
|
||||
the `name` you pass there is the workflow's display name and is applied
|
||||
on save (renames propagate immediately; definition changes go to draft).
|
||||
|
||||
## Adding edges — explicit syntax
|
||||
|
||||
wf.edge(source, target, { label: "...", condition: "..." });
|
||||
|
||||
Rules:
|
||||
- `source` and `target` are the **bare variable identifiers** bound by
|
||||
`wf.addTyped(...)` / `wf.add(...)` — not strings, not `.id`, not inline
|
||||
factories. Both must be declared earlier in the file.
|
||||
- `label` is a short tag (≤4 words) shown in call logs to identify the
|
||||
branch: `"qualified"`, `"wrap up"`, `"retry"`.
|
||||
- `condition` is a full natural-language predicate the runtime evaluates
|
||||
against the live conversation: `"caller confirmed interest in a demo"`,
|
||||
not `"interested"`. Condition clarity determines routing accuracy.
|
||||
- Both fields are required and must be non-empty strings.
|
||||
- Edges are directional; emit one `wf.edge(...)` per outgoing branch.
|
||||
- Place all edges after all node bindings; group by source node.
|
||||
|
||||
Example:
|
||||
|
||||
const greet = wf.addTyped(startCall({ name: "Greet", prompt: "Hi!" }));
|
||||
const done = wf.addTyped(endCall({ name: "Done", prompt: "Bye." }));
|
||||
wf.edge(greet, done, {
|
||||
label: "wrap up",
|
||||
condition: "user acknowledged the greeting and is ready to end"
|
||||
});
|
||||
|
||||
## Hard graph constraints
|
||||
|
||||
- Exactly one `startCall` node per workflow; no incoming edges.
|
||||
- `endCall` nodes have no outgoing edges.
|
||||
- `globalNode` has no incoming or outgoing edges; its prompt is prepended
|
||||
to every other node's prompt at runtime when that node sets
|
||||
`add_global_prompt=true`.
|
||||
- Every non-global node must be reachable from `startCall`.
|
||||
|
||||
## Iterating on errors
|
||||
|
||||
`save_workflow` returns one of:
|
||||
- `parse_error` Disallowed construct (see grammar above) or
|
||||
malformed TypeScript.
|
||||
- `validation_error` Node data failed spec validation (unknown field,
|
||||
missing required, wrong type, bad `options` value).
|
||||
- `graph_validation` Structural rule broken (missing startCall,
|
||||
unreachable node, edge to/from wrong node type).
|
||||
- `bridge_error` Internal — retry once, then surface to the user.
|
||||
|
||||
Every error carries `line` and `column`. Fix at that location and
|
||||
resubmit the **complete source** — this tool does not accept patches.
|
||||
|
||||
## Field conventions
|
||||
|
||||
- `data.name` is the canonical identifier. Pick a descriptive name
|
||||
(`"Qualify Budget"`, not `"Node1"`) — the generated code uses it as
|
||||
the variable name and call logs reference it.
|
||||
- Reference fields take UUIDs, not human names:
|
||||
`tool_refs`, `document_refs` → from `list_tools`, `list_documents`
|
||||
`credential_ref` → from `list_credentials`
|
||||
`recording_ref` → from `list_recordings`
|
||||
- `mention_textarea` fields (prompts, greetings, etc.) accept
|
||||
`{{template_variables}}` — values resolved at runtime from
|
||||
`pre_call_fetch`, caller context, or earlier extraction passes.
|
||||
|
||||
## Style
|
||||
|
||||
- Prefer `wf.addTyped(factory({ ... }))` over `wf.add({ type, ... })`.
|
||||
- Only include fields whose values differ from the spec default — the
|
||||
parser re-applies defaults on save, so extras are noise.
|
||||
- Omit `position`; the server reconciles positions against the previous
|
||||
saved workflow and lays out new nodes automatically.
|
||||
- Add nodes in call-flow order (start → intermediate → end) so the
|
||||
generated code reads top-to-bottom, with all edges after all nodes.
|
||||
"""
|
||||
11
api/mcp_server/server.py
Normal file
11
api/mcp_server/server.py
Normal file
|
|
@ -0,0 +1,11 @@
|
|||
from fastmcp import FastMCP
|
||||
|
||||
from api.mcp_server.instructions import DOGRAH_MCP_INSTRUCTIONS
|
||||
|
||||
mcp = FastMCP("dograh", instructions=DOGRAH_MCP_INSTRUCTIONS)
|
||||
|
||||
from api.mcp_server.tools import catalog as _catalog # noqa: E402, F401
|
||||
from api.mcp_server.tools import get_workflow_code as _get_workflow_code # noqa: E402, F401
|
||||
from api.mcp_server.tools import node_types as _node_types # noqa: E402, F401
|
||||
from api.mcp_server.tools import save_workflow as _save_workflow # noqa: E402, F401
|
||||
from api.mcp_server.tools import workflows as _workflows # noqa: E402, F401
|
||||
113
api/mcp_server/tools/catalog.py
Normal file
113
api/mcp_server/tools/catalog.py
Normal file
|
|
@ -0,0 +1,113 @@
|
|||
"""MCP discovery tools for the reference catalogs.
|
||||
|
||||
Node properties of type `tool_refs`, `document_refs`, `recording_ref`, and
|
||||
`credential_ref` carry UUIDs that resolve against these catalogs. LLMs must
|
||||
list the catalog before populating those fields with real UUIDs.
|
||||
"""
|
||||
|
||||
from api.db import db_client
|
||||
from api.mcp_server.auth import authenticate_mcp_request
|
||||
from api.mcp_server.server import mcp
|
||||
from api.mcp_server.tracing import traced_tool
|
||||
|
||||
|
||||
@mcp.tool
|
||||
@traced_tool
|
||||
async def list_tools(status: str | None = "active") -> list[dict]:
|
||||
"""List tools the agent can invoke during a call.
|
||||
|
||||
Returns each tool's `tool_uuid` (use this in node `tool_uuids` properties),
|
||||
`name`, `description`, and `category`. Pass `status=None` to include
|
||||
archived tools.
|
||||
"""
|
||||
user = await authenticate_mcp_request()
|
||||
tools = await db_client.get_tools_for_organization(
|
||||
organization_id=user.selected_organization_id,
|
||||
status=status,
|
||||
)
|
||||
return [
|
||||
{
|
||||
"tool_uuid": t.tool_uuid,
|
||||
"name": t.name,
|
||||
"description": t.description or "",
|
||||
"category": t.category,
|
||||
}
|
||||
for t in tools
|
||||
]
|
||||
|
||||
|
||||
@mcp.tool
|
||||
@traced_tool
|
||||
async def list_documents() -> list[dict]:
|
||||
"""List knowledge-base documents the agent can reference during a call.
|
||||
|
||||
Returns each document's `document_uuid` (use this in node
|
||||
`document_uuids` properties), `filename`, and `processing_status`.
|
||||
"""
|
||||
user = await authenticate_mcp_request()
|
||||
documents = await db_client.get_documents_for_organization(
|
||||
organization_id=user.selected_organization_id,
|
||||
)
|
||||
return [
|
||||
{
|
||||
"document_uuid": d.document_uuid,
|
||||
"filename": d.filename,
|
||||
"processing_status": d.processing_status,
|
||||
"total_chunks": d.total_chunks,
|
||||
}
|
||||
for d in documents
|
||||
]
|
||||
|
||||
|
||||
@mcp.tool
|
||||
@traced_tool
|
||||
async def list_credentials() -> list[dict]:
|
||||
"""List external credentials available for webhook auth and pre-call fetch.
|
||||
|
||||
Returns each credential's `credential_uuid` (use this in node
|
||||
`credential_uuid` / `pre_call_fetch_credential_uuid` properties), `name`,
|
||||
`description`, and `credential_type`.
|
||||
"""
|
||||
user = await authenticate_mcp_request()
|
||||
credentials = await db_client.get_credentials_for_organization(
|
||||
organization_id=user.selected_organization_id,
|
||||
)
|
||||
return [
|
||||
{
|
||||
"credential_uuid": c.credential_uuid,
|
||||
"name": c.name,
|
||||
"description": c.description or "",
|
||||
"credential_type": c.credential_type,
|
||||
}
|
||||
for c in credentials
|
||||
]
|
||||
|
||||
|
||||
@mcp.tool
|
||||
@traced_tool
|
||||
async def list_recordings(workflow_id: int | None = None) -> list[dict]:
|
||||
"""List pre-recorded audio files available for greetings and edge
|
||||
transition speech.
|
||||
|
||||
Returns each recording's `recording_id` (use this in
|
||||
`greeting_recording_id` / `transition_speech_recording_id` properties),
|
||||
`transcript`, and TTS metadata. Pass `workflow_id` to filter to one
|
||||
workflow's recordings.
|
||||
"""
|
||||
user = await authenticate_mcp_request()
|
||||
recordings = await db_client.get_recordings(
|
||||
organization_id=user.selected_organization_id,
|
||||
workflow_id=workflow_id,
|
||||
)
|
||||
return [
|
||||
{
|
||||
"id": r.id,
|
||||
"recording_id": r.recording_id,
|
||||
"workflow_id": r.workflow_id,
|
||||
"transcript": r.transcript,
|
||||
"tts_provider": r.tts_provider,
|
||||
"tts_model": r.tts_model,
|
||||
"tts_voice_id": r.tts_voice_id,
|
||||
}
|
||||
for r in recordings
|
||||
]
|
||||
71
api/mcp_server/tools/get_workflow_code.py
Normal file
71
api/mcp_server/tools/get_workflow_code.py
Normal file
|
|
@ -0,0 +1,71 @@
|
|||
"""MCP tool that returns a workflow as SDK TypeScript code.
|
||||
|
||||
Companion to `save_workflow`: the LLM calls `get_workflow_code` to see
|
||||
the current state of a workflow as editable code, mutates it, and calls
|
||||
`save_workflow` with the new code. Storage stays JSON; the TS form is
|
||||
an ephemeral projection for the LLM edit loop.
|
||||
|
||||
Selection priority: latest draft → latest published → legacy
|
||||
`workflow.workflow_definition`. That matches the UI's "whichever is the
|
||||
working copy" behavior so the LLM sees what a human editor would see.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
from api.db import db_client
|
||||
from api.mcp_server.auth import authenticate_mcp_request
|
||||
from api.mcp_server.server import mcp
|
||||
from api.mcp_server.tracing import traced_tool
|
||||
from api.mcp_server.ts_bridge import TsBridgeError, generate_code
|
||||
|
||||
|
||||
@mcp.tool
|
||||
@traced_tool
|
||||
async def get_workflow_code(workflow_id: int) -> dict[str, Any]:
|
||||
"""Return the workflow as SDK TypeScript code the LLM can edit.
|
||||
|
||||
Output shape:
|
||||
{"code": "<TS source>", "workflow_id": int, "version": "draft" | "published" | "legacy"}
|
||||
|
||||
The LLM edits `code`, then calls `save_workflow(workflow_id, code)`.
|
||||
"""
|
||||
user = await authenticate_mcp_request()
|
||||
|
||||
workflow = await db_client.get_workflow(
|
||||
workflow_id, organization_id=user.selected_organization_id
|
||||
)
|
||||
if not workflow:
|
||||
raise HTTPException(status_code=404, detail=f"Workflow {workflow_id} not found")
|
||||
|
||||
# Draft wins over published — editing a draft is the normal flow.
|
||||
# `current_definition` (is_current=True) is the published row, so we
|
||||
# fetch the draft explicitly. If the latest draft was just published,
|
||||
# no draft row exists and we fall through to `released_definition`.
|
||||
draft = await db_client.get_draft_version(workflow_id)
|
||||
released = workflow.released_definition
|
||||
|
||||
if draft is not None and draft.workflow_json:
|
||||
payload = draft.workflow_json
|
||||
source = "draft"
|
||||
elif released is not None and released.workflow_json:
|
||||
payload = released.workflow_json
|
||||
source = "published"
|
||||
else:
|
||||
payload = workflow.workflow_definition or {}
|
||||
source = "legacy"
|
||||
|
||||
try:
|
||||
code = await generate_code(payload, workflow_name=workflow.name or "")
|
||||
except TsBridgeError as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to generate code: {e}")
|
||||
|
||||
return {
|
||||
"workflow_id": workflow_id,
|
||||
"name": workflow.name or "",
|
||||
"version": source,
|
||||
"code": code,
|
||||
}
|
||||
57
api/mcp_server/tools/node_types.py
Normal file
57
api/mcp_server/tools/node_types.py
Normal file
|
|
@ -0,0 +1,57 @@
|
|||
"""MCP discovery tools for node specifications.
|
||||
|
||||
LLMs call these tools first to learn the available node-type catalog and
|
||||
each node's property schema before composing or modifying a workflow.
|
||||
"""
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
from api.mcp_server.auth import authenticate_mcp_request
|
||||
from api.mcp_server.server import mcp
|
||||
from api.mcp_server.tracing import traced_tool
|
||||
from api.services.workflow.node_specs import SPEC_VERSION, all_specs, get_spec
|
||||
|
||||
|
||||
@mcp.tool
|
||||
@traced_tool
|
||||
async def list_node_types() -> dict:
|
||||
"""List every available node type with a brief summary.
|
||||
|
||||
Use this first to discover what nodes exist, then call `get_node_type`
|
||||
for the full schema of any node you intend to use.
|
||||
|
||||
Returns:
|
||||
A dict with `spec_version` (pin against this in any generated workflow
|
||||
code) and `node_types` (list of {name, display_name, description,
|
||||
category}).
|
||||
"""
|
||||
await authenticate_mcp_request()
|
||||
return {
|
||||
"spec_version": SPEC_VERSION,
|
||||
"node_types": [
|
||||
{
|
||||
"name": spec.name,
|
||||
"display_name": spec.display_name,
|
||||
"description": spec.description,
|
||||
"category": spec.category.value,
|
||||
}
|
||||
for spec in all_specs()
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
@mcp.tool
|
||||
@traced_tool
|
||||
async def get_node_type(name: str) -> dict:
|
||||
"""Fetch the full schema for a node type, including every property's
|
||||
type, default, conditional visibility rules, and LLM-readable
|
||||
description, plus worked examples.
|
||||
|
||||
Use the property `description` and the `examples` list to understand
|
||||
semantics — types alone are not enough.
|
||||
"""
|
||||
await authenticate_mcp_request()
|
||||
spec = get_spec(name)
|
||||
if spec is None:
|
||||
raise HTTPException(status_code=404, detail=f"Unknown node type: {name!r}")
|
||||
return spec.model_dump(mode="json")
|
||||
168
api/mcp_server/tools/save_workflow.py
Normal file
168
api/mcp_server/tools/save_workflow.py
Normal file
|
|
@ -0,0 +1,168 @@
|
|||
"""MCP tool that accepts LLM-authored SDK TypeScript and saves it as a draft.
|
||||
|
||||
Execution flow:
|
||||
1. Parse via the Node TS validator — AST-only, never executes the code.
|
||||
Returns either a workflow JSON or per-location parse/validate errors.
|
||||
2. Pydantic validation via `ReactFlowDTO.model_validate` (defence in
|
||||
depth; the parser is already spec-driven, but the DTO layer is the
|
||||
authoritative wire-format gate).
|
||||
3. Graph validation via `WorkflowGraph`.
|
||||
4. Save as a new draft via `db_client.save_workflow_draft` — the
|
||||
published version stays intact, so edits are rollback-safe.
|
||||
|
||||
Error codes surfaced to the LLM:
|
||||
parse_error — TS parse failed or a disallowed construct was used
|
||||
validation_error — node data failed spec validation (unknown field,
|
||||
missing required, wrong type, option out of range)
|
||||
schema_validation — ReactFlowDTO Pydantic rejection (rare; parser bug)
|
||||
graph_validation — semantic graph rule broken (e.g. no start node)
|
||||
bridge_error — Node subprocess failed before returning JSON
|
||||
|
||||
All LLM-facing errors include file:line:column where available so the
|
||||
LLM can correct its code directly.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from fastapi import HTTPException
|
||||
from loguru import logger
|
||||
from pydantic import ValidationError as PydanticValidationError
|
||||
|
||||
from api.db import db_client
|
||||
from api.mcp_server.auth import authenticate_mcp_request
|
||||
from api.mcp_server.server import mcp
|
||||
from api.mcp_server.tracing import traced_tool
|
||||
from api.mcp_server.ts_bridge import TsBridgeError, parse_code
|
||||
from api.services.workflow.dto import ReactFlowDTO
|
||||
from api.services.workflow.layout import reconcile_positions
|
||||
from api.services.workflow.workflow import WorkflowGraph
|
||||
|
||||
|
||||
async def _previous_workflow_json(workflow: Any) -> dict[str, Any] | None:
|
||||
"""Same selection priority as `get_workflow_code` — the version the
|
||||
LLM saw is the version we reconcile against.
|
||||
|
||||
`current_definition` (is_current=True) is the published row, so the
|
||||
draft must be fetched explicitly. If no draft exists (e.g. the last
|
||||
draft was just published), fall through to `released_definition`.
|
||||
"""
|
||||
draft = await db_client.get_draft_version(workflow.id)
|
||||
if draft is not None and draft.workflow_json:
|
||||
return draft.workflow_json
|
||||
released = workflow.released_definition
|
||||
if released is not None and released.workflow_json:
|
||||
return released.workflow_json
|
||||
return workflow.workflow_definition or None
|
||||
|
||||
|
||||
def _error_result(code: str, message: str, **extra: Any) -> dict[str, Any]:
|
||||
return {"saved": False, "error_code": code, "error": message, **extra}
|
||||
|
||||
|
||||
def _format_errors(errors: list[dict[str, Any]]) -> str:
|
||||
parts: list[str] = []
|
||||
for e in errors:
|
||||
loc = ""
|
||||
line = e.get("line")
|
||||
col = e.get("column")
|
||||
if line is not None:
|
||||
loc = f" (line {line}" + (f", col {col}" if col is not None else "") + ")"
|
||||
parts.append(f"{e.get('message', '')}{loc}")
|
||||
return "\n".join(parts)
|
||||
|
||||
|
||||
@mcp.tool
|
||||
@traced_tool
|
||||
async def save_workflow(workflow_id: int, code: str) -> dict[str, Any]:
|
||||
"""Parse SDK TypeScript and save the resulting workflow as a draft.
|
||||
|
||||
`code` is TypeScript source using `@dograh/sdk`. Fetch the current
|
||||
code first via `get_workflow_code(workflow_id)`, edit it, then pass
|
||||
the full updated source here.
|
||||
|
||||
Example code:
|
||||
import { Workflow } from "@dograh/sdk";
|
||||
import { startCall, endCall } from "@dograh/sdk/typed";
|
||||
|
||||
const wf = new Workflow({ name: "lead_qualification" });
|
||||
const greeting = wf.addTyped(startCall({ name: "Greeting", prompt: "Hi!" }));
|
||||
const done = wf.addTyped(endCall({ name: "Done", prompt: "Bye." }));
|
||||
wf.edge(greeting, done, { label: "done", condition: "conversation complete" });
|
||||
|
||||
On success the draft version is saved; the published version is
|
||||
untouched.
|
||||
"""
|
||||
user = await authenticate_mcp_request()
|
||||
|
||||
workflow = await db_client.get_workflow(
|
||||
workflow_id, organization_id=user.selected_organization_id
|
||||
)
|
||||
if not workflow:
|
||||
raise HTTPException(status_code=404, detail=f"Workflow {workflow_id} not found")
|
||||
|
||||
# 1. Parse + spec-validate via the Node TS validator.
|
||||
try:
|
||||
parsed = await parse_code(code)
|
||||
except TsBridgeError as e:
|
||||
logger.warning(f"ts_bridge failure: {e}")
|
||||
return _error_result("bridge_error", str(e))
|
||||
|
||||
if not parsed.get("ok"):
|
||||
stage = parsed.get("stage", "parse")
|
||||
errs = parsed.get("errors") or []
|
||||
code_key = "parse_error" if stage == "parse" else "validation_error"
|
||||
return _error_result(code_key, _format_errors(errs), errors=errs)
|
||||
|
||||
payload = parsed["workflow"]
|
||||
new_name = (parsed.get("workflowName") or "").strip()
|
||||
|
||||
# 1b. Reconcile node positions against the previously-stored workflow.
|
||||
# The parser drops positions by design (LLMs don't place nodes well);
|
||||
# here we fill them back in from what was there before, and pick
|
||||
# approximate placements for newly-introduced nodes.
|
||||
payload = reconcile_positions(payload, await _previous_workflow_json(workflow))
|
||||
|
||||
# 2. Pydantic shape check (defence in depth — parser is spec-driven).
|
||||
try:
|
||||
dto = ReactFlowDTO.model_validate(payload)
|
||||
except PydanticValidationError as e:
|
||||
return _error_result("schema_validation", str(e))
|
||||
|
||||
# 3. Graph-level semantic validation (start-node count, edge shape).
|
||||
try:
|
||||
WorkflowGraph(dto)
|
||||
except (ValueError, Exception) as e: # WorkflowGraph raises ValueError
|
||||
return _error_result("graph_validation", str(e))
|
||||
|
||||
# 4a. If the `new Workflow({ name })` in the edited source differs from
|
||||
# the stored name, rename the workflow. Name is a workflow-level field
|
||||
# (not versioned), so this takes effect immediately.
|
||||
name_changed = bool(new_name) and new_name != workflow.name
|
||||
if name_changed:
|
||||
await db_client.update_workflow(
|
||||
workflow_id=workflow_id,
|
||||
name=new_name,
|
||||
workflow_definition=None,
|
||||
template_context_variables=None,
|
||||
workflow_configurations=None,
|
||||
organization_id=user.selected_organization_id,
|
||||
)
|
||||
|
||||
# 4b. Save as a new draft (existing published version stays intact).
|
||||
draft = await db_client.save_workflow_draft(
|
||||
workflow_id=workflow_id,
|
||||
workflow_definition=payload,
|
||||
)
|
||||
|
||||
return {
|
||||
"saved": True,
|
||||
"workflow_id": workflow_id,
|
||||
"version_number": draft.version_number,
|
||||
"status": draft.status,
|
||||
"node_count": len(payload["nodes"]),
|
||||
"edge_count": len(payload["edges"]),
|
||||
"name": new_name or workflow.name,
|
||||
"renamed": name_changed,
|
||||
}
|
||||
|
|
@ -1,17 +1,20 @@
|
|||
from fastapi import HTTPException
|
||||
|
||||
from api.db import db_client
|
||||
from api.mcp.auth import authenticate_mcp_request
|
||||
from api.mcp.server import mcp
|
||||
from api.mcp_server.auth import authenticate_mcp_request
|
||||
from api.mcp_server.server import mcp
|
||||
from api.mcp_server.tracing import traced_tool
|
||||
|
||||
|
||||
@mcp.tool
|
||||
async def list_workflows(status: str | None = None) -> list[dict]:
|
||||
@traced_tool
|
||||
async def list_workflows(status: str | None = "active") -> list[dict]:
|
||||
"""List agents (workflows) in the caller's organization.
|
||||
|
||||
Returns id, name, status, and created_at for each agent. Use
|
||||
`get_workflow` to fetch a single agent's full definition. Pass
|
||||
`status="active"` or `status="archived"` to filter.
|
||||
`get_workflow` to fetch a single agent's full definition. Defaults
|
||||
to active agents; pass `status="archived"` to list archived agents,
|
||||
or `status=None` to list all.
|
||||
"""
|
||||
user = await authenticate_mcp_request()
|
||||
workflows = await db_client.get_all_workflows_for_listing(
|
||||
|
|
@ -30,6 +33,7 @@ async def list_workflows(status: str | None = None) -> list[dict]:
|
|||
|
||||
|
||||
@mcp.tool
|
||||
@traced_tool
|
||||
async def get_workflow(workflow_id: int) -> dict:
|
||||
"""Fetch a single agent by id, including its current published definition."""
|
||||
user = await authenticate_mcp_request()
|
||||
87
api/mcp_server/tracing.py
Normal file
87
api/mcp_server/tracing.py
Normal file
|
|
@ -0,0 +1,87 @@
|
|||
"""OTel tracing for MCP tool invocations.
|
||||
|
||||
The project-wide tracing setup in
|
||||
`api/services/pipecat/tracing_config.py` already routes spans to
|
||||
per-organization Langfuse projects based on the `dograh.org_id` span
|
||||
attribute. This module plugs MCP tool calls into that pipeline:
|
||||
|
||||
@mcp.tool
|
||||
@traced_tool
|
||||
async def my_tool(...): ...
|
||||
|
||||
Each decorated invocation produces one span named `mcp.<tool_name>` with
|
||||
Langfuse-rendered input/output. Organization and user attributes are
|
||||
stamped separately by `authenticate_mcp_request` when it runs inside
|
||||
the tool body — the decorator's span is the `current_span` at that
|
||||
point, so the attributes land on the right span and the router export
|
||||
dispatches to the correct Langfuse project.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from functools import wraps
|
||||
from typing import Any, Awaitable, Callable, TypeVar
|
||||
|
||||
from opentelemetry import trace
|
||||
from opentelemetry.context import Context
|
||||
from opentelemetry.trace import Status, StatusCode
|
||||
|
||||
R = TypeVar("R")
|
||||
|
||||
_TRACER = trace.get_tracer("dograh.mcp")
|
||||
# Langfuse truncates long payloads anyway; cap here to keep span size
|
||||
# bounded. Tune up if you find tool outputs consistently clipped.
|
||||
_MAX_ATTR_LEN = 8000
|
||||
|
||||
|
||||
def _safe_json(value: Any) -> str:
|
||||
try:
|
||||
return json.dumps(value, default=str, ensure_ascii=False)
|
||||
except Exception: # noqa: BLE001
|
||||
return str(value)
|
||||
|
||||
|
||||
def traced_tool(fn: Callable[..., Awaitable[R]]) -> Callable[..., Awaitable[R]]:
|
||||
"""Wrap an MCP tool so each invocation produces a span.
|
||||
|
||||
Captures tool name, input kwargs, output, and exceptions. Stacks
|
||||
below `@mcp.tool` so FastMCP sees the wrapped function when
|
||||
introspecting the tool schema (`functools.wraps` preserves the
|
||||
signature the framework reads).
|
||||
"""
|
||||
|
||||
@wraps(fn)
|
||||
async def wrapper(*args: Any, **kwargs: Any) -> R:
|
||||
# Each MCP tool call is its own root trace. Passing an empty
|
||||
# `Context()` severs the inherited parent so the span doesn't
|
||||
# graft onto whatever other trace happens to be active (e.g.
|
||||
# the FastAPI request span, or a client-propagated context).
|
||||
# One trace per tool invocation makes Langfuse diffing and
|
||||
# per-org filtering clean.
|
||||
with _TRACER.start_as_current_span(
|
||||
f"mcp.{fn.__name__}",
|
||||
context=Context(),
|
||||
) as span:
|
||||
span.set_attribute("mcp.tool.name", fn.__name__)
|
||||
# Explicit trace-name override so the Langfuse UI shows
|
||||
# `mcp.<tool>` at the top of the trace instead of whatever
|
||||
# the framework happens to name the root span.
|
||||
span.set_attribute("langfuse.trace.name", f"mcp.{fn.__name__}")
|
||||
span.set_attribute(
|
||||
"langfuse.observation.input",
|
||||
_safe_json(kwargs)[:_MAX_ATTR_LEN],
|
||||
)
|
||||
try:
|
||||
result = await fn(*args, **kwargs)
|
||||
except Exception as e:
|
||||
span.record_exception(e)
|
||||
span.set_status(Status(StatusCode.ERROR, str(e)))
|
||||
raise
|
||||
span.set_attribute(
|
||||
"langfuse.observation.output",
|
||||
_safe_json(result)[:_MAX_ATTR_LEN],
|
||||
)
|
||||
return result
|
||||
|
||||
return wrapper
|
||||
93
api/mcp_server/ts_bridge.py
Normal file
93
api/mcp_server/ts_bridge.py
Normal file
|
|
@ -0,0 +1,93 @@
|
|||
"""Python-side bridge to the Node TS validator.
|
||||
|
||||
Spawns `node api/mcp_server/ts_validator/src/index.ts` as a short-lived
|
||||
subprocess per call, streams a JSON request on stdin, reads a JSON
|
||||
response from stdout. The validator never executes LLM code — it either
|
||||
emits TypeScript from a workflow JSON (`generate`) or parses LLM-authored
|
||||
TS back into a workflow JSON via AST walking (`parse`).
|
||||
|
||||
The subprocess startup cost is ~100-200ms per call. Fine for MCP tool
|
||||
rates; if it ever matters, the validator can be promoted to a long-lived
|
||||
worker over a unix socket without changing this interface.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from api.services.workflow.node_specs import all_specs
|
||||
|
||||
_VALIDATOR_ENTRY = Path(__file__).resolve().parent / "ts_validator" / "src" / "index.ts"
|
||||
|
||||
|
||||
class TsBridgeError(Exception):
|
||||
"""The Node subprocess failed before producing a JSON response."""
|
||||
|
||||
|
||||
def _specs_payload() -> list[dict[str, Any]]:
|
||||
return [s.model_dump(mode="json") for s in all_specs()]
|
||||
|
||||
|
||||
async def _invoke(request: dict[str, Any]) -> dict[str, Any]:
|
||||
proc = await asyncio.create_subprocess_exec(
|
||||
"node",
|
||||
str(_VALIDATOR_ENTRY),
|
||||
stdin=asyncio.subprocess.PIPE,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
)
|
||||
stdout, stderr = await proc.communicate(json.dumps(request).encode("utf-8"))
|
||||
if proc.returncode != 0 and not stdout:
|
||||
raise TsBridgeError(
|
||||
f"ts_validator exited {proc.returncode}: "
|
||||
f"{stderr.decode('utf-8', errors='replace')}"
|
||||
)
|
||||
try:
|
||||
return json.loads(stdout.decode("utf-8"))
|
||||
except json.JSONDecodeError as e:
|
||||
raise TsBridgeError(
|
||||
f"ts_validator emitted non-JSON: {stdout!r} (stderr: {stderr!r})"
|
||||
) from e
|
||||
|
||||
|
||||
async def generate_code(workflow: dict[str, Any], *, workflow_name: str = "") -> str:
|
||||
"""Emit SDK TypeScript source from a workflow JSON payload.
|
||||
|
||||
Raises `TsBridgeError` if the validator can't produce code (unknown
|
||||
node type, dangling edge reference, etc.) — these are bugs at the
|
||||
caller layer, not user input, so we fail loudly.
|
||||
"""
|
||||
result = await _invoke(
|
||||
{
|
||||
"command": "generate",
|
||||
"workflow": workflow,
|
||||
"specs": _specs_payload(),
|
||||
"workflowName": workflow_name,
|
||||
}
|
||||
)
|
||||
if not result.get("ok"):
|
||||
errs = result.get("errors") or [{"message": "unknown failure"}]
|
||||
raise TsBridgeError(
|
||||
"generate_code failed: " + "; ".join(e.get("message", "") for e in errs)
|
||||
)
|
||||
return result["code"]
|
||||
|
||||
|
||||
async def parse_code(code: str) -> dict[str, Any]:
|
||||
"""Parse LLM-authored TS back into a workflow JSON.
|
||||
|
||||
Returns the raw validator response — `{"ok": True, "workflow": {...}}`
|
||||
on success, `{"ok": False, "stage": "parse" | "validate", "errors": [...]}`
|
||||
on author-side failure. Author-side failures are surfaced to the LLM
|
||||
verbatim so it can iterate; callers should not re-wrap them.
|
||||
"""
|
||||
return await _invoke(
|
||||
{
|
||||
"command": "parse",
|
||||
"code": code,
|
||||
"specs": _specs_payload(),
|
||||
}
|
||||
)
|
||||
1
api/mcp_server/ts_validator/.gitignore
vendored
Normal file
1
api/mcp_server/ts_validator/.gitignore
vendored
Normal file
|
|
@ -0,0 +1 @@
|
|||
node_modules/
|
||||
31
api/mcp_server/ts_validator/package-lock.json
generated
Normal file
31
api/mcp_server/ts_validator/package-lock.json
generated
Normal file
|
|
@ -0,0 +1,31 @@
|
|||
{
|
||||
"name": "dograh-ts-validator",
|
||||
"version": "0.0.0",
|
||||
"lockfileVersion": 3,
|
||||
"requires": true,
|
||||
"packages": {
|
||||
"": {
|
||||
"name": "dograh-ts-validator",
|
||||
"version": "0.0.0",
|
||||
"dependencies": {
|
||||
"typescript": "^5.6.0"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">=22.6"
|
||||
}
|
||||
},
|
||||
"node_modules/typescript": {
|
||||
"version": "5.9.3",
|
||||
"resolved": "https://registry.npmjs.org/typescript/-/typescript-5.9.3.tgz",
|
||||
"integrity": "sha512-jl1vZzPDinLr9eUt3J/t7V6FgNEw9QjvBPdysz9KfQDD41fQrC2Y4vKQdiaUpFT4bXlb1RHhLpp8wtm6M5TgSw==",
|
||||
"license": "Apache-2.0",
|
||||
"bin": {
|
||||
"tsc": "bin/tsc",
|
||||
"tsserver": "bin/tsserver"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">=14.17"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
13
api/mcp_server/ts_validator/package.json
Normal file
13
api/mcp_server/ts_validator/package.json
Normal file
|
|
@ -0,0 +1,13 @@
|
|||
{
|
||||
"name": "dograh-ts-validator",
|
||||
"private": true,
|
||||
"version": "0.0.0",
|
||||
"type": "module",
|
||||
"description": "Node helper invoked by the Python MCP server. Converts workflow JSON to SDK TypeScript code (generate) and parses LLM-authored TS back into a validated workflow JSON (parse). Runs as a short-lived subprocess over stdin/stdout.",
|
||||
"dependencies": {
|
||||
"typescript": "^5.6.0"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">=22.6"
|
||||
}
|
||||
}
|
||||
304
api/mcp_server/ts_validator/src/generate.ts
Normal file
304
api/mcp_server/ts_validator/src/generate.ts
Normal file
|
|
@ -0,0 +1,304 @@
|
|||
// JSON → TypeScript source. Emits flat code the LLM can read and edit:
|
||||
// imports, a `Workflow` construction, one `addTyped` per node, one `edge`
|
||||
// per edge. Variable names are derived from `data.name` (falling back to
|
||||
// the node id) and deduplicated so the AST round-trips back through
|
||||
// `parse.ts` into the same workflow JSON.
|
||||
|
||||
import type {
|
||||
GenerateResult,
|
||||
NodeSpec,
|
||||
PropertySpec,
|
||||
WireWorkflow,
|
||||
} from "./types.ts";
|
||||
|
||||
export function generateCode(
|
||||
workflow: WireWorkflow,
|
||||
specs: NodeSpec[],
|
||||
opts: { workflowName?: string } = {},
|
||||
): GenerateResult {
|
||||
const specByName = new Map(specs.map((s) => [s.name, s]));
|
||||
|
||||
// Catch unknown node types up-front — otherwise we'd emit an import
|
||||
// line for a factory that doesn't exist.
|
||||
for (const n of workflow.nodes) {
|
||||
if (!specByName.has(n.type)) {
|
||||
return {
|
||||
ok: false,
|
||||
errors: [
|
||||
{
|
||||
message: `Unknown node type in workflow: "${n.type}"`,
|
||||
},
|
||||
],
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
const factoryNames = [
|
||||
...new Set(workflow.nodes.map((n) => n.type)),
|
||||
].sort();
|
||||
const nodeVarById = new Map<string, string>();
|
||||
const usedNames = new Set<string>();
|
||||
|
||||
const lines: string[] = [];
|
||||
lines.push(`import { Workflow } from "@dograh/sdk";`);
|
||||
if (factoryNames.length > 0) {
|
||||
lines.push(
|
||||
`import { ${factoryNames.join(", ")} } from "@dograh/sdk/typed";`,
|
||||
);
|
||||
}
|
||||
lines.push("");
|
||||
const wfName = opts.workflowName ?? "";
|
||||
lines.push(
|
||||
`const wf = new Workflow(${renderObject({ name: wfName }, 0)});`,
|
||||
);
|
||||
lines.push("");
|
||||
|
||||
for (const node of workflow.nodes) {
|
||||
const varName = pickVarName(node, usedNames);
|
||||
nodeVarById.set(node.id, varName);
|
||||
|
||||
const spec = specByName.get(node.type)!;
|
||||
// Strip legacy/UI-state fields the spec doesn't know about
|
||||
// (e.g. `invalid`, `selected`, `dragging`, `is_start`,
|
||||
// `validationMessage`). They accumulated in stored workflow
|
||||
// data before the parser enforced spec validation, and are
|
||||
// pure noise from the LLM's perspective — dropping them keeps
|
||||
// the editing surface clean and avoids a pointless save-time
|
||||
// rejection round-trip.
|
||||
const knownOnly = stripUnknown(node.data, spec);
|
||||
const data = stripDefaults(knownOnly, spec);
|
||||
const factoryArg = renderObject(data, 0);
|
||||
|
||||
// Positions are intentionally omitted — LLMs don't place nodes
|
||||
// sensibly, so we let a downstream auto-layout pass (future
|
||||
// enhancement) assign coordinates on save. Existing positions
|
||||
// in the DB are preserved by `parse.ts` defaulting to {0,0}
|
||||
// and the save path leaving pre-existing node positions alone.
|
||||
lines.push(
|
||||
`const ${varName} = wf.addTyped(${node.type}(${factoryArg}));`,
|
||||
);
|
||||
}
|
||||
|
||||
if (workflow.edges.length > 0) {
|
||||
lines.push("");
|
||||
}
|
||||
for (const edge of workflow.edges) {
|
||||
const src = nodeVarById.get(edge.source);
|
||||
const tgt = nodeVarById.get(edge.target);
|
||||
if (!src || !tgt) {
|
||||
return {
|
||||
ok: false,
|
||||
errors: [
|
||||
{
|
||||
message:
|
||||
`Edge ${edge.id} references unknown node ` +
|
||||
`(source=${edge.source}, target=${edge.target}).`,
|
||||
},
|
||||
],
|
||||
};
|
||||
}
|
||||
const cleanedEdge = pickEdgeFields(edge.data);
|
||||
const edgeOpts = renderObject(cleanedEdge, 0);
|
||||
lines.push(`wf.edge(${src}, ${tgt}, ${edgeOpts});`);
|
||||
}
|
||||
|
||||
return { ok: true, code: lines.join("\n") + "\n" };
|
||||
}
|
||||
|
||||
// ─── helpers ──────────────────────────────────────────────────────────
|
||||
|
||||
function pickVarName(
|
||||
node: { id: string; data: Record<string, unknown> },
|
||||
used: Set<string>,
|
||||
): string {
|
||||
const seed =
|
||||
typeof node.data["name"] === "string" && node.data["name"].trim()
|
||||
? (node.data["name"] as string)
|
||||
: `node_${node.id}`;
|
||||
const base = sanitizeIdentifier(seed);
|
||||
let candidate = base;
|
||||
let i = 2;
|
||||
while (used.has(candidate) || RESERVED.has(candidate)) {
|
||||
candidate = `${base}_${i++}`;
|
||||
}
|
||||
used.add(candidate);
|
||||
return candidate;
|
||||
}
|
||||
|
||||
function sanitizeIdentifier(raw: string): string {
|
||||
const cleaned = raw
|
||||
.trim()
|
||||
.replace(/[^a-zA-Z0-9_]+/g, "_")
|
||||
.replace(/^_+|_+$/g, "")
|
||||
.toLowerCase();
|
||||
if (!cleaned) return "node";
|
||||
if (/^[0-9]/.test(cleaned)) return `n_${cleaned}`;
|
||||
return cleaned;
|
||||
}
|
||||
|
||||
const RESERVED = new Set([
|
||||
"wf",
|
||||
"const",
|
||||
"let",
|
||||
"var",
|
||||
"new",
|
||||
"function",
|
||||
"class",
|
||||
"import",
|
||||
"export",
|
||||
"return",
|
||||
"if",
|
||||
"else",
|
||||
"for",
|
||||
"while",
|
||||
"do",
|
||||
"switch",
|
||||
"case",
|
||||
"break",
|
||||
"continue",
|
||||
"default",
|
||||
"throw",
|
||||
"try",
|
||||
"catch",
|
||||
"finally",
|
||||
"await",
|
||||
"async",
|
||||
"true",
|
||||
"false",
|
||||
"null",
|
||||
"undefined",
|
||||
"this",
|
||||
"super",
|
||||
"in",
|
||||
"of",
|
||||
"typeof",
|
||||
"instanceof",
|
||||
"delete",
|
||||
"void",
|
||||
"yield",
|
||||
"Workflow",
|
||||
]);
|
||||
|
||||
// Drop keys not declared in the spec. Handles nested `fixed_collection`
|
||||
// rows by recursing through sub-property specs. Anything that isn't in
|
||||
// the spec is legacy/UI state and should never reach the LLM.
|
||||
function stripUnknown(
|
||||
data: Record<string, unknown>,
|
||||
spec: NodeSpec,
|
||||
): Record<string, unknown> {
|
||||
const known = new Map<string, PropertySpec>();
|
||||
for (const p of spec.properties ?? []) known.set(p.name, p);
|
||||
|
||||
const out: Record<string, unknown> = {};
|
||||
for (const [k, v] of Object.entries(data)) {
|
||||
const prop = known.get(k);
|
||||
if (!prop) continue; // drop unknown
|
||||
if (prop.type === "fixed_collection" && Array.isArray(v)) {
|
||||
const rowSpec: NodeSpec = {
|
||||
name: prop.name,
|
||||
properties: prop.properties ?? [],
|
||||
};
|
||||
out[k] = v.map((row) =>
|
||||
row && typeof row === "object" && !Array.isArray(row)
|
||||
? stripUnknown(row as Record<string, unknown>, rowSpec)
|
||||
: row,
|
||||
);
|
||||
} else {
|
||||
out[k] = v;
|
||||
}
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
// Edge schema is fixed (no NodeSpec for edges). Mirrors the allowed
|
||||
// fields on `Workflow.edge(...)` in both SDKs.
|
||||
const KNOWN_EDGE_FIELDS = new Set([
|
||||
"label",
|
||||
"condition",
|
||||
"transition_speech",
|
||||
"transition_speech_type",
|
||||
"transition_speech_recording_id",
|
||||
]);
|
||||
|
||||
function pickEdgeFields(
|
||||
data: Record<string, unknown>,
|
||||
): Record<string, unknown> {
|
||||
const out: Record<string, unknown> = {};
|
||||
for (const [k, v] of Object.entries(data)) {
|
||||
if (KNOWN_EDGE_FIELDS.has(k)) out[k] = v;
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
// Drop keys whose value equals the spec default — keeps emitted code tight.
|
||||
function stripDefaults(
|
||||
data: Record<string, unknown>,
|
||||
spec: NodeSpec,
|
||||
): Record<string, unknown> {
|
||||
const out: Record<string, unknown> = {};
|
||||
const defaults = new Map<string, unknown>();
|
||||
for (const prop of spec.properties ?? []) {
|
||||
if (prop.default !== undefined && prop.default !== null) {
|
||||
defaults.set(prop.name, prop.default);
|
||||
}
|
||||
}
|
||||
for (const [k, v] of Object.entries(data)) {
|
||||
if (defaults.has(k) && deepEqual(defaults.get(k), v)) continue;
|
||||
out[k] = v;
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
function deepEqual(a: unknown, b: unknown): boolean {
|
||||
if (a === b) return true;
|
||||
if (typeof a !== typeof b) return false;
|
||||
if (a === null || b === null) return false;
|
||||
if (Array.isArray(a) && Array.isArray(b)) {
|
||||
if (a.length !== b.length) return false;
|
||||
return a.every((el, i) => deepEqual(el, b[i]));
|
||||
}
|
||||
if (typeof a === "object" && typeof b === "object") {
|
||||
const ak = Object.keys(a as object).sort();
|
||||
const bk = Object.keys(b as object).sort();
|
||||
if (ak.length !== bk.length) return false;
|
||||
if (ak.some((k, i) => k !== bk[i])) return false;
|
||||
return ak.every((k) =>
|
||||
deepEqual(
|
||||
(a as Record<string, unknown>)[k],
|
||||
(b as Record<string, unknown>)[k],
|
||||
),
|
||||
);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// Object renderer biased for readability — strings use single-line JSON,
|
||||
// nested objects/arrays indent one level per depth.
|
||||
function renderObject(obj: Record<string, unknown>, depth: number): string {
|
||||
const keys = Object.keys(obj);
|
||||
if (keys.length === 0) return "{}";
|
||||
const pad = " ".repeat(depth + 1);
|
||||
const closingPad = " ".repeat(depth);
|
||||
const parts = keys.map((k) => {
|
||||
const v = renderValue(obj[k], depth + 1);
|
||||
return `${pad}${k}: ${v}`;
|
||||
});
|
||||
return `{\n${parts.join(",\n")},\n${closingPad}}`;
|
||||
}
|
||||
|
||||
function renderValue(v: unknown, depth: number): string {
|
||||
if (v === null || v === undefined) return "null";
|
||||
if (typeof v === "string") return JSON.stringify(v);
|
||||
if (typeof v === "number" || typeof v === "boolean") return String(v);
|
||||
if (Array.isArray(v)) {
|
||||
if (v.length === 0) return "[]";
|
||||
const pad = " ".repeat(depth + 1);
|
||||
const closingPad = " ".repeat(depth);
|
||||
const items = v.map((el) => `${pad}${renderValue(el, depth + 1)}`);
|
||||
return `[\n${items.join(",\n")},\n${closingPad}]`;
|
||||
}
|
||||
if (typeof v === "object") {
|
||||
return renderObject(v as Record<string, unknown>, depth);
|
||||
}
|
||||
return JSON.stringify(v);
|
||||
}
|
||||
74
api/mcp_server/ts_validator/src/index.ts
Normal file
74
api/mcp_server/ts_validator/src/index.ts
Normal file
|
|
@ -0,0 +1,74 @@
|
|||
// Stdin/stdout dispatch. Reads a single JSON request, routes to
|
||||
// generate or parse, writes a single JSON response. Exits 0 on request
|
||||
// success (including validation failures — those are in the JSON), and
|
||||
// exits 1 only on internal errors (bad input JSON, unhandled exception).
|
||||
|
||||
import { generateCode } from "./generate.ts";
|
||||
import { parseCode } from "./parse.ts";
|
||||
import type { NodeSpec, WireWorkflow } from "./types.ts";
|
||||
|
||||
interface GenerateRequest {
|
||||
command: "generate";
|
||||
workflow: WireWorkflow;
|
||||
specs: NodeSpec[];
|
||||
workflowName?: string;
|
||||
}
|
||||
|
||||
interface ParseRequest {
|
||||
command: "parse";
|
||||
code: string;
|
||||
specs: NodeSpec[];
|
||||
}
|
||||
|
||||
type Request = GenerateRequest | ParseRequest;
|
||||
|
||||
async function readStdin(): Promise<string> {
|
||||
const chunks: Buffer[] = [];
|
||||
for await (const chunk of process.stdin) {
|
||||
chunks.push(chunk as Buffer);
|
||||
}
|
||||
return Buffer.concat(chunks).toString("utf-8");
|
||||
}
|
||||
|
||||
function writeResult(payload: unknown): void {
|
||||
process.stdout.write(JSON.stringify(payload));
|
||||
}
|
||||
|
||||
async function main(): Promise<void> {
|
||||
const input = await readStdin();
|
||||
let req: Request;
|
||||
try {
|
||||
req = JSON.parse(input) as Request;
|
||||
} catch (e) {
|
||||
writeResult({
|
||||
ok: false,
|
||||
stage: "internal",
|
||||
errors: [{ message: `Invalid JSON on stdin: ${(e as Error).message}` }],
|
||||
});
|
||||
process.exit(1);
|
||||
}
|
||||
|
||||
if (req.command === "generate") {
|
||||
writeResult(generateCode(req.workflow, req.specs, { workflowName: req.workflowName }));
|
||||
return;
|
||||
}
|
||||
if (req.command === "parse") {
|
||||
writeResult(parseCode(req.code, req.specs));
|
||||
return;
|
||||
}
|
||||
writeResult({
|
||||
ok: false,
|
||||
stage: "internal",
|
||||
errors: [{ message: `Unknown command: ${(req as { command?: unknown }).command}` }],
|
||||
});
|
||||
process.exit(1);
|
||||
}
|
||||
|
||||
main().catch((err: unknown) => {
|
||||
writeResult({
|
||||
ok: false,
|
||||
stage: "internal",
|
||||
errors: [{ message: (err as Error).stack ?? String(err) }],
|
||||
});
|
||||
process.exit(1);
|
||||
});
|
||||
612
api/mcp_server/ts_validator/src/parse.ts
Normal file
612
api/mcp_server/ts_validator/src/parse.ts
Normal file
|
|
@ -0,0 +1,612 @@
|
|||
// TypeScript → workflow JSON.
|
||||
//
|
||||
// Parses LLM-authored SDK code with the TypeScript compiler, walks the
|
||||
// AST statement by statement, and builds up a workflow JSON from the
|
||||
// recognized SDK patterns:
|
||||
//
|
||||
// const wf = new Workflow({ name: "..." });
|
||||
// const X = wf.addTyped(startCall({ ...fields }));
|
||||
// const Y = wf.add({ type: "endCall", ...fields });
|
||||
// wf.edge(X, Y, { label: "...", condition: "..." });
|
||||
//
|
||||
// No code is executed. Any top-level statement that doesn't match one
|
||||
// of the recognized shapes is a parse error with a file:line:col pointer
|
||||
// so the LLM can iterate. Node data is validated against the spec
|
||||
// catalog before returning.
|
||||
|
||||
import ts from "typescript";
|
||||
|
||||
import type {
|
||||
NodeSpec,
|
||||
ParseErrorItem,
|
||||
ParseResult,
|
||||
PropertySpec,
|
||||
WireEdge,
|
||||
WireNode,
|
||||
} from "./types.ts";
|
||||
|
||||
export function parseCode(code: string, specs: NodeSpec[]): ParseResult {
|
||||
const specByName = new Map(specs.map((s) => [s.name, s]));
|
||||
const sourceFile = ts.createSourceFile(
|
||||
"workflow.ts",
|
||||
code,
|
||||
ts.ScriptTarget.ESNext,
|
||||
true,
|
||||
ts.ScriptKind.TS,
|
||||
);
|
||||
|
||||
const errors: ParseErrorItem[] = [];
|
||||
const nodes: WireNode[] = [];
|
||||
const edges: WireEdge[] = [];
|
||||
const nodeRefs = new Map<string, WireNode>();
|
||||
let workflowVar: string | null = null;
|
||||
let workflowName = "";
|
||||
let nextId = 1;
|
||||
|
||||
const addError = (node: ts.Node, message: string): void => {
|
||||
const pos = sourceFile.getLineAndCharacterOfPosition(node.getStart());
|
||||
errors.push({
|
||||
message,
|
||||
line: pos.line + 1,
|
||||
column: pos.character + 1,
|
||||
});
|
||||
};
|
||||
|
||||
for (const stmt of sourceFile.statements) {
|
||||
if (ts.isImportDeclaration(stmt)) continue; // imports are harmless
|
||||
if (
|
||||
ts.isExportAssignment(stmt) ||
|
||||
stmt.kind === ts.SyntaxKind.EmptyStatement
|
||||
) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// `const X = ...;` or `wf.edge(...);`
|
||||
if (ts.isVariableStatement(stmt)) {
|
||||
handleVariableStatement(stmt);
|
||||
continue;
|
||||
}
|
||||
if (ts.isExpressionStatement(stmt)) {
|
||||
handleExpressionStatement(stmt);
|
||||
continue;
|
||||
}
|
||||
addError(
|
||||
stmt,
|
||||
`Only imports, \`const X = ...\` bindings, and \`wf.edge(...)\` calls are allowed at the top level. Found: ${ts.SyntaxKind[stmt.kind]}.`,
|
||||
);
|
||||
}
|
||||
|
||||
function handleVariableStatement(stmt: ts.VariableStatement): void {
|
||||
const modifiers = ts.getModifiers(stmt);
|
||||
if (modifiers && modifiers.some((m) => m.kind === ts.SyntaxKind.ExportKeyword)) {
|
||||
addError(stmt, "`export` is not allowed on workflow bindings.");
|
||||
return;
|
||||
}
|
||||
if ((stmt.declarationList.flags & ts.NodeFlags.Const) === 0) {
|
||||
addError(stmt, "Use `const` for all bindings.");
|
||||
return;
|
||||
}
|
||||
for (const decl of stmt.declarationList.declarations) {
|
||||
if (!ts.isIdentifier(decl.name)) {
|
||||
addError(decl, "Destructuring is not allowed — use a single identifier.");
|
||||
continue;
|
||||
}
|
||||
if (!decl.initializer) {
|
||||
addError(decl, "Bindings must have an initializer.");
|
||||
continue;
|
||||
}
|
||||
const varName = decl.name.text;
|
||||
handleBinding(varName, decl.initializer, decl);
|
||||
}
|
||||
}
|
||||
|
||||
function handleBinding(
|
||||
varName: string,
|
||||
initializer: ts.Expression,
|
||||
origin: ts.Node,
|
||||
): void {
|
||||
const expr = unwrapAwait(initializer);
|
||||
|
||||
// `const wf = new Workflow({ name: "..." })`
|
||||
if (ts.isNewExpression(expr)) {
|
||||
if (!ts.isIdentifier(expr.expression) || expr.expression.text !== "Workflow") {
|
||||
addError(origin, "Only `new Workflow(...)` is supported for object construction.");
|
||||
return;
|
||||
}
|
||||
if (workflowVar) {
|
||||
addError(origin, `A Workflow is already bound (as \`${workflowVar}\`). Only one Workflow is allowed.`);
|
||||
return;
|
||||
}
|
||||
const args = expr.arguments ?? ts.factory.createNodeArray();
|
||||
if (args.length > 0) {
|
||||
const val = literalToJs(args[0]!, addError);
|
||||
if (
|
||||
val &&
|
||||
typeof val === "object" &&
|
||||
!Array.isArray(val) &&
|
||||
typeof (val as Record<string, unknown>)["name"] === "string"
|
||||
) {
|
||||
workflowName = (val as Record<string, unknown>)["name"] as string;
|
||||
}
|
||||
}
|
||||
workflowVar = varName;
|
||||
return;
|
||||
}
|
||||
|
||||
// `const X = wf.addTyped(factory({...}))` or `wf.add({ type: "...", ... })`
|
||||
if (ts.isCallExpression(expr)) {
|
||||
const call = expr;
|
||||
const callee = call.expression;
|
||||
|
||||
// Must be `wf.XYZ(...)` — property access off the workflow var
|
||||
if (
|
||||
!ts.isPropertyAccessExpression(callee) ||
|
||||
!ts.isIdentifier(callee.expression) ||
|
||||
(workflowVar !== null && callee.expression.text !== workflowVar)
|
||||
) {
|
||||
addError(
|
||||
origin,
|
||||
`Expected \`${workflowVar ?? "wf"}.addTyped(...)\` or \`${workflowVar ?? "wf"}.add(...)\`.`,
|
||||
);
|
||||
return;
|
||||
}
|
||||
if (!workflowVar) {
|
||||
addError(origin, "Workflow must be constructed before adding nodes.");
|
||||
return;
|
||||
}
|
||||
|
||||
const method = callee.name.text;
|
||||
if (method === "addTyped") {
|
||||
handleAddTyped(varName, call, origin);
|
||||
} else if (method === "add") {
|
||||
handleAddGeneric(varName, call, origin);
|
||||
} else {
|
||||
addError(
|
||||
origin,
|
||||
`Unsupported method \`${method}\`. Use \`addTyped\` or \`add\`.`,
|
||||
);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
addError(
|
||||
origin,
|
||||
"Only `new Workflow(...)`, `wf.addTyped(...)`, and `wf.add(...)` are allowed as binding initializers.",
|
||||
);
|
||||
}
|
||||
|
||||
function handleAddTyped(
|
||||
varName: string,
|
||||
call: ts.CallExpression,
|
||||
origin: ts.Node,
|
||||
): void {
|
||||
if (call.arguments.length < 1 || call.arguments.length > 2) {
|
||||
addError(origin, "`addTyped` takes 1 or 2 arguments.");
|
||||
return;
|
||||
}
|
||||
const inner = call.arguments[0]!;
|
||||
if (!ts.isCallExpression(inner) || !ts.isIdentifier(inner.expression)) {
|
||||
addError(
|
||||
origin,
|
||||
"`addTyped` must be called with a factory invocation, e.g. `addTyped(startCall({ ... }))`.",
|
||||
);
|
||||
return;
|
||||
}
|
||||
const factoryName = inner.expression.text;
|
||||
if (!specByName.has(factoryName)) {
|
||||
addError(
|
||||
origin,
|
||||
`Unknown node type: \`${factoryName}\`. Check the list of registered node types.`,
|
||||
);
|
||||
return;
|
||||
}
|
||||
const factoryArgs = inner.arguments;
|
||||
let data: Record<string, unknown> = {};
|
||||
if (factoryArgs.length > 0) {
|
||||
const parsed = literalToJs(factoryArgs[0]!, addError);
|
||||
if (parsed !== undefined) {
|
||||
if (typeof parsed !== "object" || parsed === null || Array.isArray(parsed)) {
|
||||
addError(inner, "Factory argument must be an object literal.");
|
||||
return;
|
||||
}
|
||||
data = parsed as Record<string, unknown>;
|
||||
}
|
||||
}
|
||||
// Optional position arg
|
||||
const position = extractPositionArg(call.arguments[1], addError);
|
||||
bindNode(varName, factoryName, data, position, origin);
|
||||
}
|
||||
|
||||
function handleAddGeneric(
|
||||
varName: string,
|
||||
call: ts.CallExpression,
|
||||
origin: ts.Node,
|
||||
): void {
|
||||
if (call.arguments.length !== 1) {
|
||||
addError(origin, "`add` takes exactly 1 object argument.");
|
||||
return;
|
||||
}
|
||||
const parsed = literalToJs(call.arguments[0]!, addError);
|
||||
if (parsed === undefined) return;
|
||||
if (typeof parsed !== "object" || parsed === null || Array.isArray(parsed)) {
|
||||
addError(origin, "`add` argument must be an object literal.");
|
||||
return;
|
||||
}
|
||||
const obj = parsed as Record<string, unknown>;
|
||||
const typeValue = obj["type"];
|
||||
if (typeof typeValue !== "string") {
|
||||
addError(origin, "`add({ type, ... })` requires a string `type` field.");
|
||||
return;
|
||||
}
|
||||
if (!specByName.has(typeValue)) {
|
||||
addError(origin, `Unknown node type: \`${typeValue}\`.`);
|
||||
return;
|
||||
}
|
||||
let position: { x: number; y: number } | undefined;
|
||||
if (obj["position"] !== undefined) {
|
||||
const p = obj["position"];
|
||||
if (
|
||||
Array.isArray(p) &&
|
||||
p.length === 2 &&
|
||||
typeof p[0] === "number" &&
|
||||
typeof p[1] === "number"
|
||||
) {
|
||||
position = { x: p[0], y: p[1] };
|
||||
} else {
|
||||
addError(
|
||||
origin,
|
||||
"`position` must be a [x, y] tuple of numbers.",
|
||||
);
|
||||
return;
|
||||
}
|
||||
}
|
||||
const { type: _ignored, position: _ignored2, ...rest } = obj;
|
||||
bindNode(varName, typeValue, rest, position, origin);
|
||||
}
|
||||
|
||||
function bindNode(
|
||||
varName: string,
|
||||
type: string,
|
||||
data: Record<string, unknown>,
|
||||
position: { x: number; y: number } | undefined,
|
||||
origin: ts.Node,
|
||||
): void {
|
||||
if (nodeRefs.has(varName)) {
|
||||
addError(origin, `Variable \`${varName}\` is already bound.`);
|
||||
return;
|
||||
}
|
||||
const node: WireNode = {
|
||||
id: String(nextId++),
|
||||
type,
|
||||
position: position ?? { x: 0, y: 0 },
|
||||
data,
|
||||
};
|
||||
nodes.push(node);
|
||||
nodeRefs.set(varName, node);
|
||||
}
|
||||
|
||||
function handleExpressionStatement(stmt: ts.ExpressionStatement): void {
|
||||
const expr = unwrapAwait(stmt.expression);
|
||||
if (!ts.isCallExpression(expr)) {
|
||||
addError(stmt, "Only `wf.edge(...)` calls are allowed as bare statements.");
|
||||
return;
|
||||
}
|
||||
const callee = expr.expression;
|
||||
if (
|
||||
!ts.isPropertyAccessExpression(callee) ||
|
||||
!ts.isIdentifier(callee.expression) ||
|
||||
(workflowVar !== null && callee.expression.text !== workflowVar) ||
|
||||
callee.name.text !== "edge"
|
||||
) {
|
||||
addError(stmt, "Only `wf.edge(source, target, opts)` is allowed as a bare statement.");
|
||||
return;
|
||||
}
|
||||
if (expr.arguments.length !== 3) {
|
||||
addError(stmt, "`edge` takes exactly 3 arguments: (source, target, opts).");
|
||||
return;
|
||||
}
|
||||
const [srcArg, tgtArg, optsArg] = expr.arguments;
|
||||
if (!ts.isIdentifier(srcArg!) || !ts.isIdentifier(tgtArg!)) {
|
||||
addError(stmt, "`edge` source and target must be variable identifiers bound by `addTyped`/`add`.");
|
||||
return;
|
||||
}
|
||||
const src = nodeRefs.get(srcArg.text);
|
||||
const tgt = nodeRefs.get(tgtArg.text);
|
||||
if (!src) {
|
||||
addError(srcArg, `Unknown node variable: \`${srcArg.text}\`.`);
|
||||
return;
|
||||
}
|
||||
if (!tgt) {
|
||||
addError(tgtArg, `Unknown node variable: \`${tgtArg.text}\`.`);
|
||||
return;
|
||||
}
|
||||
const opts = literalToJs(optsArg!, addError);
|
||||
if (opts === undefined) return;
|
||||
if (typeof opts !== "object" || opts === null || Array.isArray(opts)) {
|
||||
addError(stmt, "`edge` options must be an object literal.");
|
||||
return;
|
||||
}
|
||||
const optsObj = opts as Record<string, unknown>;
|
||||
if (typeof optsObj["label"] !== "string" || (optsObj["label"] as string).trim() === "") {
|
||||
addError(stmt, "`edge` requires a non-empty `label` string.");
|
||||
return;
|
||||
}
|
||||
if (typeof optsObj["condition"] !== "string" || (optsObj["condition"] as string).trim() === "") {
|
||||
addError(stmt, "`edge` requires a non-empty `condition` string.");
|
||||
return;
|
||||
}
|
||||
edges.push({
|
||||
id: `${src.id}-${tgt.id}`,
|
||||
source: src.id,
|
||||
target: tgt.id,
|
||||
data: optsObj,
|
||||
});
|
||||
}
|
||||
|
||||
// ── terminate early on parse errors ──────────────────────────────
|
||||
if (errors.length > 0) {
|
||||
return { ok: false, stage: "parse", errors };
|
||||
}
|
||||
|
||||
if (!workflowVar) {
|
||||
return {
|
||||
ok: false,
|
||||
stage: "parse",
|
||||
errors: [
|
||||
{
|
||||
message:
|
||||
"No Workflow construction found. Expected `const wf = new Workflow({ name: \"...\" });`.",
|
||||
},
|
||||
],
|
||||
};
|
||||
}
|
||||
|
||||
// ── spec-driven node validation ─────────────────────────────────
|
||||
const validationErrors: ParseErrorItem[] = [];
|
||||
for (const node of nodes) {
|
||||
const spec = specByName.get(node.type)!;
|
||||
const validated = validateNodeData(
|
||||
spec,
|
||||
node.data,
|
||||
(msg) => validationErrors.push({ message: `[${node.type}] ${msg}` }),
|
||||
);
|
||||
if (validated !== null) node.data = validated;
|
||||
}
|
||||
if (validationErrors.length > 0) {
|
||||
return { ok: false, stage: "validate", errors: validationErrors };
|
||||
}
|
||||
|
||||
return {
|
||||
ok: true,
|
||||
workflow: {
|
||||
nodes,
|
||||
edges,
|
||||
viewport: { x: 0, y: 0, zoom: 1 },
|
||||
},
|
||||
workflowName,
|
||||
};
|
||||
}
|
||||
|
||||
// ─── helpers ──────────────────────────────────────────────────────────
|
||||
|
||||
function unwrapAwait(expr: ts.Expression): ts.Expression {
|
||||
return ts.isAwaitExpression(expr) ? expr.expression : expr;
|
||||
}
|
||||
|
||||
function extractPositionArg(
|
||||
arg: ts.Expression | undefined,
|
||||
addError: (n: ts.Node, m: string) => void,
|
||||
): { x: number; y: number } | undefined {
|
||||
if (!arg) return undefined;
|
||||
const parsed = literalToJs(arg, addError);
|
||||
if (parsed === undefined || parsed === null) return undefined;
|
||||
if (
|
||||
typeof parsed === "object" &&
|
||||
!Array.isArray(parsed) &&
|
||||
Array.isArray((parsed as Record<string, unknown>)["position"])
|
||||
) {
|
||||
const p = (parsed as Record<string, unknown>)["position"] as unknown[];
|
||||
if (p.length === 2 && typeof p[0] === "number" && typeof p[1] === "number") {
|
||||
return { x: p[0], y: p[1] };
|
||||
}
|
||||
}
|
||||
addError(arg, "Optional second arg must be `{ position: [x, y] }`.");
|
||||
return undefined;
|
||||
}
|
||||
|
||||
// Convert an expression to a plain JS value. Accepts: string, number,
|
||||
// boolean, null, undefined (→ undefined), array/object literals of the
|
||||
// same. Rejects any expression with runtime semantics (identifiers other
|
||||
// than `true/false/null/undefined`, function calls, arrow fns, etc.).
|
||||
function literalToJs(
|
||||
expr: ts.Expression,
|
||||
addError: (n: ts.Node, m: string) => void,
|
||||
): unknown {
|
||||
if (ts.isStringLiteral(expr) || ts.isNoSubstitutionTemplateLiteral(expr)) {
|
||||
return expr.text;
|
||||
}
|
||||
if (ts.isNumericLiteral(expr)) return Number(expr.text);
|
||||
if (expr.kind === ts.SyntaxKind.TrueKeyword) return true;
|
||||
if (expr.kind === ts.SyntaxKind.FalseKeyword) return false;
|
||||
if (expr.kind === ts.SyntaxKind.NullKeyword) return null;
|
||||
if (ts.isIdentifier(expr) && expr.text === "undefined") return undefined;
|
||||
if (ts.isPrefixUnaryExpression(expr)) {
|
||||
if (expr.operator === ts.SyntaxKind.MinusToken) {
|
||||
const inner = literalToJs(expr.operand, addError);
|
||||
if (typeof inner === "number") return -inner;
|
||||
}
|
||||
if (expr.operator === ts.SyntaxKind.PlusToken) {
|
||||
const inner = literalToJs(expr.operand, addError);
|
||||
if (typeof inner === "number") return inner;
|
||||
}
|
||||
addError(expr, "Unsupported unary operator; only numeric negation is allowed.");
|
||||
return undefined;
|
||||
}
|
||||
if (ts.isArrayLiteralExpression(expr)) {
|
||||
const out: unknown[] = [];
|
||||
for (const el of expr.elements) {
|
||||
if (el.kind === ts.SyntaxKind.OmittedExpression) {
|
||||
addError(el, "Sparse arrays are not allowed.");
|
||||
return undefined;
|
||||
}
|
||||
if (ts.isSpreadElement(el)) {
|
||||
addError(el, "Spread elements are not allowed in array literals.");
|
||||
return undefined;
|
||||
}
|
||||
const v = literalToJs(el, addError);
|
||||
if (v === undefined && el.kind !== ts.SyntaxKind.Identifier) {
|
||||
return undefined;
|
||||
}
|
||||
out.push(v);
|
||||
}
|
||||
return out;
|
||||
}
|
||||
if (ts.isObjectLiteralExpression(expr)) {
|
||||
const out: Record<string, unknown> = {};
|
||||
for (const prop of expr.properties) {
|
||||
if (!ts.isPropertyAssignment(prop)) {
|
||||
addError(prop, "Only plain `key: value` properties are allowed (no methods, shorthand, spread, or computed keys).");
|
||||
return undefined;
|
||||
}
|
||||
let key: string;
|
||||
if (ts.isIdentifier(prop.name) || ts.isStringLiteral(prop.name)) {
|
||||
key = prop.name.text;
|
||||
} else {
|
||||
addError(prop.name, "Property keys must be identifiers or string literals.");
|
||||
return undefined;
|
||||
}
|
||||
const val = literalToJs(prop.initializer, addError);
|
||||
if (val === undefined && prop.initializer.kind !== ts.SyntaxKind.Identifier) {
|
||||
// treat explicit `undefined` as omission
|
||||
continue;
|
||||
}
|
||||
out[key] = val;
|
||||
}
|
||||
return out;
|
||||
}
|
||||
if (ts.isTemplateExpression(expr)) {
|
||||
addError(expr, "Template literals with interpolation are not allowed — use plain strings.");
|
||||
return undefined;
|
||||
}
|
||||
addError(expr, `Unsupported expression (${ts.SyntaxKind[expr.kind]}). Only literals are allowed in data positions.`);
|
||||
return undefined;
|
||||
}
|
||||
|
||||
// Spec-driven validation, mirrors the shape of
|
||||
// `sdk/python/src/dograh_sdk/_validation.py` but lightweight — applies
|
||||
// defaults for missing optionals, catches unknown keys, enforces `options`
|
||||
// membership, and type-shapes the scalar and `fixed_collection` cases.
|
||||
function validateNodeData(
|
||||
spec: NodeSpec,
|
||||
data: Record<string, unknown>,
|
||||
addError: (message: string) => void,
|
||||
): Record<string, unknown> | null {
|
||||
const out: Record<string, unknown> = {};
|
||||
const known = new Map<string, PropertySpec>();
|
||||
for (const p of spec.properties ?? []) known.set(p.name, p);
|
||||
|
||||
for (const key of Object.keys(data)) {
|
||||
if (!known.has(key)) {
|
||||
addError(`Unknown field: \`${key}\`.`);
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
for (const [key, prop] of known) {
|
||||
if (key in data) {
|
||||
out[key] = data[key];
|
||||
} else if (prop.default !== undefined && prop.default !== null) {
|
||||
out[key] = prop.default;
|
||||
} else if (prop.required) {
|
||||
addError(`Missing required field: \`${key}\`.`);
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
for (const [key, prop] of known) {
|
||||
if (!(key in out)) continue;
|
||||
const value = out[key];
|
||||
const err = checkPropertyShape(prop, value);
|
||||
if (err) {
|
||||
addError(`Field \`${key}\`: ${err}`);
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
function checkPropertyShape(prop: PropertySpec, value: unknown): string | null {
|
||||
switch (prop.type) {
|
||||
case "string":
|
||||
case "mention_textarea":
|
||||
case "url":
|
||||
case "recording_ref":
|
||||
case "credential_ref":
|
||||
if (typeof value !== "string") return `expected string, got ${jsTypeOf(value)}.`;
|
||||
return null;
|
||||
case "number":
|
||||
if (typeof value !== "number") return `expected number, got ${jsTypeOf(value)}.`;
|
||||
return null;
|
||||
case "boolean":
|
||||
if (typeof value !== "boolean") return `expected boolean, got ${jsTypeOf(value)}.`;
|
||||
return null;
|
||||
case "tool_refs":
|
||||
case "document_refs":
|
||||
case "multi_options":
|
||||
if (!Array.isArray(value)) return `expected array, got ${jsTypeOf(value)}.`;
|
||||
for (const el of value) {
|
||||
if (prop.type === "multi_options") {
|
||||
if (!isInOptions(prop, el)) {
|
||||
return `value \`${JSON.stringify(el)}\` is not in the allowed options.`;
|
||||
}
|
||||
} else if (typeof el !== "string") {
|
||||
return `array elements must be strings.`;
|
||||
}
|
||||
}
|
||||
return null;
|
||||
case "options":
|
||||
if (!isInOptions(prop, value)) {
|
||||
return `value \`${JSON.stringify(value)}\` is not in the allowed options.`;
|
||||
}
|
||||
return null;
|
||||
case "json":
|
||||
if (typeof value !== "object" || value === null || Array.isArray(value)) {
|
||||
return `expected JSON object, got ${jsTypeOf(value)}.`;
|
||||
}
|
||||
return null;
|
||||
case "fixed_collection":
|
||||
if (!Array.isArray(value)) return `expected array of rows, got ${jsTypeOf(value)}.`;
|
||||
for (let i = 0; i < value.length; i++) {
|
||||
const row = value[i];
|
||||
if (typeof row !== "object" || row === null || Array.isArray(row)) {
|
||||
return `row ${i}: expected object, got ${jsTypeOf(row)}.`;
|
||||
}
|
||||
for (const sub of prop.properties ?? []) {
|
||||
const subVal = (row as Record<string, unknown>)[sub.name];
|
||||
if (subVal === undefined) {
|
||||
if (sub.required && (sub.default === undefined || sub.default === null)) {
|
||||
return `row ${i}: missing required field \`${sub.name}\`.`;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
const subErr = checkPropertyShape(sub, subVal);
|
||||
if (subErr) return `row ${i}, \`${sub.name}\`: ${subErr}`;
|
||||
}
|
||||
}
|
||||
return null;
|
||||
default:
|
||||
return null; // Unknown types pass — forward compat.
|
||||
}
|
||||
}
|
||||
|
||||
function isInOptions(prop: PropertySpec, value: unknown): boolean {
|
||||
if (!prop.options) return true;
|
||||
return prop.options.some((o) => o.value === value);
|
||||
}
|
||||
|
||||
function jsTypeOf(v: unknown): string {
|
||||
if (v === null) return "null";
|
||||
if (Array.isArray(v)) return "array";
|
||||
return typeof v;
|
||||
}
|
||||
57
api/mcp_server/ts_validator/src/types.ts
Normal file
57
api/mcp_server/ts_validator/src/types.ts
Normal file
|
|
@ -0,0 +1,57 @@
|
|||
// Shared shapes used by both generate and parse. Mirror the `ReactFlowDTO`
|
||||
// wire format on the Python side (`api/services/workflow/dto.py`) and the
|
||||
// node-spec JSON served by `/api/v1/node-types` / dumped by
|
||||
// `python -m api.services.workflow.node_specs`.
|
||||
|
||||
export interface PropertyOption {
|
||||
value: string | number | boolean;
|
||||
label: string;
|
||||
}
|
||||
|
||||
export interface PropertySpec {
|
||||
name: string;
|
||||
type: string;
|
||||
required?: boolean;
|
||||
default?: unknown;
|
||||
options?: PropertyOption[];
|
||||
properties?: PropertySpec[];
|
||||
}
|
||||
|
||||
export interface NodeSpec {
|
||||
name: string;
|
||||
properties: PropertySpec[];
|
||||
}
|
||||
|
||||
export interface WireNode {
|
||||
id: string;
|
||||
type: string;
|
||||
position: { x: number; y: number };
|
||||
data: Record<string, unknown>;
|
||||
}
|
||||
|
||||
export interface WireEdge {
|
||||
id: string;
|
||||
source: string;
|
||||
target: string;
|
||||
data: Record<string, unknown>;
|
||||
}
|
||||
|
||||
export interface WireWorkflow {
|
||||
nodes: WireNode[];
|
||||
edges: WireEdge[];
|
||||
viewport: { x: number; y: number; zoom: number };
|
||||
}
|
||||
|
||||
export interface ParseErrorItem {
|
||||
message: string;
|
||||
line?: number;
|
||||
column?: number;
|
||||
}
|
||||
|
||||
export type GenerateResult =
|
||||
| { ok: true; code: string }
|
||||
| { ok: false; errors: ParseErrorItem[] };
|
||||
|
||||
export type ParseResult =
|
||||
| { ok: true; workflow: WireWorkflow; workflowName: string }
|
||||
| { ok: false; stage: "parse" | "validate"; errors: ParseErrorItem[] };
|
||||
14
api/mcp_server/ts_validator/tsconfig.json
Normal file
14
api/mcp_server/ts_validator/tsconfig.json
Normal file
|
|
@ -0,0 +1,14 @@
|
|||
{
|
||||
"compilerOptions": {
|
||||
"target": "ES2022",
|
||||
"module": "ESNext",
|
||||
"moduleResolution": "bundler",
|
||||
"strict": true,
|
||||
"esModuleInterop": true,
|
||||
"allowImportingTsExtensions": true,
|
||||
"skipLibCheck": true,
|
||||
"noEmit": true,
|
||||
"isolatedModules": true
|
||||
},
|
||||
"include": ["src/**/*.ts"]
|
||||
}
|
||||
|
|
@ -6,7 +6,7 @@ testpaths = tests
|
|||
python_files = test_*.py *_test.py
|
||||
python_classes = Test*
|
||||
python_functions = test_*
|
||||
addopts = -v --tb=short -s
|
||||
addopts = -v --tb=short -s --import-mode=importlib
|
||||
markers =
|
||||
asyncio: mark test as an async test
|
||||
slow: mark test as slow running
|
||||
|
|
@ -4,4 +4,6 @@ pytest==8.3.5
|
|||
pytest-asyncio==0.26.0
|
||||
pre-commit==4.2.0
|
||||
watchfiles==1.1.0
|
||||
python-dotenv==1.2.1
|
||||
python-dotenv==1.2.1
|
||||
datamodel-code-generator==0.56.1
|
||||
-e ./sdk/python
|
||||
|
|
|
|||
|
|
@ -13,10 +13,8 @@ python-multipart==0.0.20
|
|||
sentry-sdk[fastapi]==2.38.0
|
||||
sqlalchemy[asyncio]==2.0.43
|
||||
msgpack==1.1.2
|
||||
docling[rapidocr]==2.68.0
|
||||
pgvector==0.4.2
|
||||
bcrypt==5.0.0
|
||||
email-validator==2.3.0
|
||||
posthog==7.11.1
|
||||
fastmcp==3.2.4
|
||||
rank-bm25==0.2.2
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ from pydantic import BaseModel
|
|||
from api.db import db_client
|
||||
from api.db.models import UserModel
|
||||
from api.enums import WebhookCredentialType
|
||||
from api.sdk_expose import sdk_expose
|
||||
from api.services.auth.depends import get_user
|
||||
|
||||
router = APIRouter(prefix="/credentials")
|
||||
|
|
@ -107,7 +108,13 @@ def build_credential_response(credential) -> CredentialResponse:
|
|||
)
|
||||
|
||||
|
||||
@router.get("/")
|
||||
@router.get(
|
||||
"/",
|
||||
**sdk_expose(
|
||||
method="list_credentials",
|
||||
description="List webhook credentials available to the authenticated organization.",
|
||||
),
|
||||
)
|
||||
async def list_credentials(
|
||||
user: UserModel = Depends(get_user),
|
||||
) -> List[CredentialResponse]:
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@ from api.schemas.knowledge_base import (
|
|||
DocumentUploadResponseSchema,
|
||||
ProcessDocumentRequestSchema,
|
||||
)
|
||||
from api.sdk_expose import sdk_expose
|
||||
from api.services.auth.depends import get_user
|
||||
from api.services.posthog_client import capture_event
|
||||
from api.services.storage import storage_fs
|
||||
|
|
@ -135,6 +136,7 @@ async def process_document(
|
|||
document.id,
|
||||
request.s3_key,
|
||||
user.selected_organization_id,
|
||||
str(user.provider_id),
|
||||
128, # max_tokens (default)
|
||||
request.retrieval_mode,
|
||||
)
|
||||
|
|
@ -190,6 +192,10 @@ async def process_document(
|
|||
"/documents",
|
||||
response_model=DocumentListResponseSchema,
|
||||
summary="List documents",
|
||||
**sdk_expose(
|
||||
method="list_documents",
|
||||
description="List knowledge base documents available to the authenticated organization.",
|
||||
),
|
||||
)
|
||||
async def list_documents(
|
||||
status: Annotated[
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ from api.routes.credentials import router as credentials_router
|
|||
from api.routes.integration import router as integration_router
|
||||
from api.routes.knowledge_base import router as knowledge_base_router
|
||||
from api.routes.looptalk import router as looptalk_router
|
||||
from api.routes.node_types import router as node_types_router
|
||||
from api.routes.organization import router as organization_router
|
||||
from api.routes.organization_usage import router as organization_usage_router
|
||||
from api.routes.public_agent import router as public_agent_router
|
||||
|
|
@ -54,6 +55,7 @@ router.include_router(workflow_embed_router)
|
|||
router.include_router(knowledge_base_router)
|
||||
router.include_router(workflow_recording_router)
|
||||
router.include_router(auth_router)
|
||||
router.include_router(node_types_router)
|
||||
|
||||
|
||||
class HealthResponse(BaseModel):
|
||||
|
|
|
|||
67
api/routes/node_types.py
Normal file
67
api/routes/node_types.py
Normal file
|
|
@ -0,0 +1,67 @@
|
|||
"""API for the node-spec catalog.
|
||||
|
||||
Exposes the registered NodeSpecs (one per node type) so frontend renderers
|
||||
and the LLM SDK can build forms / typed constructors from a single source
|
||||
of truth.
|
||||
|
||||
Endpoints:
|
||||
GET /node-types → list every registered NodeSpec
|
||||
GET /node-types/{name} → single NodeSpec by name
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from pydantic import BaseModel
|
||||
|
||||
from api.db.models import UserModel
|
||||
from api.sdk_expose import sdk_expose
|
||||
from api.services.auth.depends import get_user
|
||||
from api.services.workflow.node_specs import (
|
||||
SPEC_VERSION,
|
||||
NodeSpec,
|
||||
all_specs,
|
||||
get_spec,
|
||||
)
|
||||
|
||||
router = APIRouter(prefix="/node-types")
|
||||
|
||||
|
||||
class NodeTypesResponse(BaseModel):
|
||||
spec_version: str
|
||||
node_types: list[NodeSpec]
|
||||
|
||||
|
||||
@router.get(
|
||||
"",
|
||||
response_model=NodeTypesResponse,
|
||||
**sdk_expose(
|
||||
method="list_node_types",
|
||||
description="List every registered node type with its spec. Pinned to spec_version.",
|
||||
),
|
||||
)
|
||||
async def list_node_types(
|
||||
_user: UserModel = Depends(get_user),
|
||||
) -> NodeTypesResponse:
|
||||
"""List every registered NodeSpec.
|
||||
|
||||
SDK clients should pin to `spec_version` and warn if the server reports
|
||||
a higher version than what they were generated against.
|
||||
"""
|
||||
return NodeTypesResponse(spec_version=SPEC_VERSION, node_types=all_specs())
|
||||
|
||||
|
||||
@router.get(
|
||||
"/{name}",
|
||||
response_model=NodeSpec,
|
||||
**sdk_expose(
|
||||
method="get_node_type",
|
||||
description="Fetch a single node spec by name.",
|
||||
),
|
||||
)
|
||||
async def get_node_type(
|
||||
name: str,
|
||||
_user: UserModel = Depends(get_user),
|
||||
) -> NodeSpec:
|
||||
spec = get_spec(name)
|
||||
if spec is None:
|
||||
raise HTTPException(status_code=404, detail=f"Unknown node type: {name!r}")
|
||||
return spec
|
||||
|
|
@ -29,6 +29,7 @@ from api.db.workflow_client import WorkflowClient
|
|||
from api.db.workflow_run_client import WorkflowRunClient
|
||||
from api.enums import CallType, OrganizationConfigurationKey, WorkflowRunState
|
||||
from api.errors.telephony_errors import TelephonyError
|
||||
from api.sdk_expose import sdk_expose
|
||||
from api.services.auth.depends import get_user
|
||||
from api.services.campaign.campaign_call_dispatcher import campaign_call_dispatcher
|
||||
from api.services.campaign.campaign_event_publisher import get_campaign_event_publisher
|
||||
|
|
@ -139,7 +140,13 @@ class StatusCallbackRequest(BaseModel):
|
|||
)
|
||||
|
||||
|
||||
@router.post("/initiate-call")
|
||||
@router.post(
|
||||
"/initiate-call",
|
||||
**sdk_expose(
|
||||
method="test_phone_call",
|
||||
description="Place a test call from a workflow to a phone number.",
|
||||
),
|
||||
)
|
||||
async def initiate_call(
|
||||
request: InitiateCallRequest, user: UserModel = Depends(get_user)
|
||||
):
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ from pydantic import BaseModel, Field, field_validator
|
|||
from api.db import db_client
|
||||
from api.db.models import UserModel
|
||||
from api.enums import PostHogEvent, ToolCategory, ToolStatus
|
||||
from api.sdk_expose import sdk_expose
|
||||
from api.services.auth.depends import get_user
|
||||
from api.services.posthog_client import capture_event
|
||||
|
||||
|
|
@ -276,7 +277,13 @@ def validate_status(status: str) -> None:
|
|||
)
|
||||
|
||||
|
||||
@router.get("/")
|
||||
@router.get(
|
||||
"/",
|
||||
**sdk_expose(
|
||||
method="list_tools",
|
||||
description="List tools available to the authenticated organization.",
|
||||
),
|
||||
)
|
||||
async def list_tools(
|
||||
status: Optional[str] = None,
|
||||
category: Optional[str] = None,
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ from api.db.models import UserModel
|
|||
from api.db.workflow_template_client import WorkflowTemplateClient
|
||||
from api.enums import CallType, PostHogEvent, StorageBackend
|
||||
from api.schemas.workflow import WorkflowRunResponseSchema
|
||||
from api.sdk_expose import sdk_expose
|
||||
from api.services.auth.depends import get_user
|
||||
from api.services.campaign.report import generate_workflow_report_csv
|
||||
from api.services.configuration.check_validity import UserConfigurationValidator
|
||||
|
|
@ -27,7 +28,7 @@ from api.services.configuration.resolve import resolve_effective_config
|
|||
from api.services.mps_service_key_client import mps_service_key_client
|
||||
from api.services.posthog_client import capture_event
|
||||
from api.services.storage import storage_fs
|
||||
from api.services.workflow.dto import ReactFlowDTO
|
||||
from api.services.workflow.dto import ReactFlowDTO, sanitize_workflow_definition
|
||||
from api.services.workflow.duplicate import duplicate_workflow
|
||||
from api.services.workflow.errors import ItemKind, WorkflowError
|
||||
from api.services.workflow.workflow import WorkflowGraph
|
||||
|
|
@ -453,7 +454,13 @@ async def get_workflow_count(
|
|||
)
|
||||
|
||||
|
||||
@router.get("/fetch")
|
||||
@router.get(
|
||||
"/fetch",
|
||||
**sdk_expose(
|
||||
method="list_workflows",
|
||||
description="List all workflows in the authenticated organization.",
|
||||
),
|
||||
)
|
||||
async def get_workflows(
|
||||
user: UserModel = Depends(get_user),
|
||||
status: Optional[str] = Query(
|
||||
|
|
@ -499,7 +506,13 @@ async def get_workflows(
|
|||
]
|
||||
|
||||
|
||||
@router.get("/fetch/{workflow_id}")
|
||||
@router.get(
|
||||
"/fetch/{workflow_id}",
|
||||
**sdk_expose(
|
||||
method="get_workflow",
|
||||
description="Get a single workflow by ID (returns draft if one exists, else published).",
|
||||
),
|
||||
)
|
||||
async def get_workflow(
|
||||
workflow_id: int,
|
||||
user: UserModel = Depends(get_user),
|
||||
|
|
@ -701,7 +714,13 @@ async def update_workflow_status(
|
|||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.put("/{workflow_id}")
|
||||
@router.put(
|
||||
"/{workflow_id}",
|
||||
**sdk_expose(
|
||||
method="update_workflow",
|
||||
description="Update a workflow's name and/or definition. Saves as a new draft.",
|
||||
),
|
||||
)
|
||||
async def update_workflow(
|
||||
workflow_id: int,
|
||||
request: UpdateWorkflowRequest,
|
||||
|
|
@ -721,8 +740,10 @@ async def update_workflow(
|
|||
HTTPException: If the workflow is not found or if there's a database error
|
||||
"""
|
||||
try:
|
||||
# Restore real API keys where the incoming definition has masked placeholders
|
||||
workflow_definition = request.workflow_definition
|
||||
# Strip UI runtime-only fields (invalid, validationMessage, etc.) from
|
||||
# node.data / edge.data before anything touches the DB — the UI sends
|
||||
# nodes wholesale from the React Flow store, which carries those.
|
||||
workflow_definition = sanitize_workflow_definition(request.workflow_definition)
|
||||
if workflow_definition:
|
||||
existing_workflow = await db_client.get_workflow(
|
||||
workflow_id, organization_id=user.selected_organization_id
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ from api.schemas.workflow_recording import (
|
|||
RecordingUpdateRequestSchema,
|
||||
RecordingUploadResponseSchema,
|
||||
)
|
||||
from api.sdk_expose import sdk_expose
|
||||
from api.services.auth.depends import get_user
|
||||
from api.services.mps_service_key_client import mps_service_key_client
|
||||
from api.services.storage import storage_fs
|
||||
|
|
@ -165,6 +166,10 @@ async def create_recordings(
|
|||
"/",
|
||||
response_model=RecordingListResponseSchema,
|
||||
summary="List recordings",
|
||||
**sdk_expose(
|
||||
method="list_recordings",
|
||||
description="List workflow recordings available to the authenticated organization.",
|
||||
),
|
||||
)
|
||||
async def list_recordings(
|
||||
workflow_id: Annotated[
|
||||
|
|
|
|||
38
api/sdk_expose.py
Normal file
38
api/sdk_expose.py
Normal file
|
|
@ -0,0 +1,38 @@
|
|||
"""Opt-in marker for exposing a FastAPI route through the Dograh SDK.
|
||||
|
||||
The generated SDK client (`sdk/python/src/dograh_sdk/_generated_client.py`
|
||||
and the TypeScript equivalent) is built by walking the backend's OpenAPI
|
||||
schema and picking up any operation tagged with `x-sdk-method`. That
|
||||
means `generate_sdk.sh` stays in sync with the real HTTP paths — no more
|
||||
hand-typed URL strings drifting out of date.
|
||||
|
||||
Usage:
|
||||
|
||||
from api.sdk_expose import sdk_expose
|
||||
|
||||
@router.post("/initiate-call", **sdk_expose(
|
||||
method="test_phone_call",
|
||||
description="Place a test call from a workflow to a phone number.",
|
||||
))
|
||||
async def initiate_call(...): ...
|
||||
|
||||
Anything not wrapped in `sdk_expose` is invisible to the SDK — deliberate,
|
||||
so the SDK surface stays small and auditable.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
|
||||
def sdk_expose(*, method: str, description: str = "") -> dict[str, Any]:
|
||||
"""Return FastAPI route kwargs that tag the operation for SDK codegen.
|
||||
|
||||
`method` becomes the SDK method name in both Python and TypeScript
|
||||
(converted to snake_case / camelCase as appropriate by the codegen).
|
||||
`description` is emitted as the method docstring.
|
||||
"""
|
||||
extra: dict[str, Any] = {"x-sdk-method": method}
|
||||
if description:
|
||||
extra["x-sdk-description"] = description
|
||||
return {"openapi_extra": extra}
|
||||
|
|
@ -1,35 +1,22 @@
|
|||
"""OpenAI embedding service.
|
||||
|
||||
This module provides document processing capabilities using:
|
||||
- OpenAI's text-embedding-3-small for embeddings (1536 dimensions)
|
||||
- Docling for document conversion and chunking
|
||||
- pgvector for vector similarity search
|
||||
Embeds text and performs vector similarity search via the local database.
|
||||
Document conversion and chunking now live in the Model Proxy Service (MPS);
|
||||
this file no longer pulls docling/transformers.
|
||||
"""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from docling.chunking import HybridChunker
|
||||
from docling.document_converter import DocumentConverter
|
||||
from docling_core.transforms.chunker.tokenizer.huggingface import HuggingFaceTokenizer
|
||||
from loguru import logger
|
||||
from openai import AsyncOpenAI
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from api.db.db_client import DBClient
|
||||
from api.db.models import KnowledgeBaseChunkModel
|
||||
|
||||
from .base import BaseEmbeddingService
|
||||
|
||||
# Model configuration
|
||||
DEFAULT_MODEL_ID = "text-embedding-3-small"
|
||||
EMBEDDING_DIMENSION = 1536 # Dimension for text-embedding-3-small
|
||||
|
||||
# For chunking, we'll use the same tokenizer as SentenceTransformer
|
||||
# since OpenAI uses similar tokenization
|
||||
TOKENIZER_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
|
||||
|
||||
|
||||
class EmbeddingAPIKeyNotConfiguredError(Exception):
|
||||
"""Raised when OpenAI API key is not configured for embeddings."""
|
||||
|
|
@ -49,24 +36,20 @@ class OpenAIEmbeddingService(BaseEmbeddingService):
|
|||
db_client: DBClient,
|
||||
api_key: Optional[str] = None,
|
||||
model_id: str = DEFAULT_MODEL_ID,
|
||||
max_tokens: int = 512,
|
||||
base_url: Optional[str] = None,
|
||||
):
|
||||
"""Initialize the OpenAI embedding service.
|
||||
|
||||
Args:
|
||||
db_client: Database client for storing documents and chunks
|
||||
db_client: Database client for vector similarity search.
|
||||
api_key: OpenAI API key. If not provided, the client will not be
|
||||
initialized and operations will fail with a clear error.
|
||||
model_id: OpenAI embedding model ID (default: text-embedding-3-small)
|
||||
max_tokens: Maximum number of tokens per chunk (default: 512)
|
||||
base_url: Optional base URL for the API (e.g. for OpenRouter)
|
||||
initialized and operations will fail with a clear error.
|
||||
model_id: OpenAI embedding model ID (default: text-embedding-3-small).
|
||||
base_url: Optional base URL for the API (e.g. for OpenRouter).
|
||||
"""
|
||||
self.db = db_client
|
||||
self.model_id = model_id
|
||||
self.max_tokens = max_tokens
|
||||
|
||||
# Only initialize OpenAI client if API key is provided
|
||||
self._api_key_configured = bool(api_key)
|
||||
if self._api_key_configured:
|
||||
client_kwargs = {"api_key": api_key}
|
||||
|
|
@ -81,35 +64,6 @@ class OpenAIEmbeddingService(BaseEmbeddingService):
|
|||
"Operations will fail until API key is configured in Model Configurations."
|
||||
)
|
||||
|
||||
# Initialize tokenizer for chunking
|
||||
# We use a HuggingFace tokenizer for consistent chunking
|
||||
logger.info(
|
||||
f"Loading tokenizer for chunking: {TOKENIZER_MODEL} with max_tokens={max_tokens}"
|
||||
)
|
||||
try:
|
||||
self.tokenizer = HuggingFaceTokenizer(
|
||||
tokenizer=AutoTokenizer.from_pretrained(
|
||||
TOKENIZER_MODEL,
|
||||
local_files_only=True,
|
||||
),
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
logger.info("Loaded tokenizer from cache")
|
||||
except Exception as e:
|
||||
logger.warning(f"Tokenizer not in cache, downloading: {e}")
|
||||
self.tokenizer = HuggingFaceTokenizer(
|
||||
tokenizer=AutoTokenizer.from_pretrained(TOKENIZER_MODEL),
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
logger.info("Tokenizer downloaded and cached")
|
||||
|
||||
# Initialize chunker
|
||||
logger.info(f"Initializing HybridChunker with max_tokens={max_tokens}")
|
||||
self.chunker = HybridChunker(tokenizer=self.tokenizer)
|
||||
|
||||
# Initialize document converter
|
||||
self.converter = DocumentConverter()
|
||||
|
||||
def get_model_id(self) -> str:
|
||||
"""Return the model identifier."""
|
||||
return self.model_id
|
||||
|
|
@ -126,28 +80,17 @@ class OpenAIEmbeddingService(BaseEmbeddingService):
|
|||
async def embed_texts(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Embed a batch of texts using OpenAI API.
|
||||
|
||||
Args:
|
||||
texts: List of text strings to embed
|
||||
|
||||
Returns:
|
||||
List of embedding vectors (each vector is a list of floats)
|
||||
|
||||
Raises:
|
||||
EmbeddingAPIKeyNotConfiguredError: If API key is not configured
|
||||
EmbeddingAPIKeyNotConfiguredError: If API key is not configured.
|
||||
"""
|
||||
self._ensure_api_key_configured()
|
||||
|
||||
try:
|
||||
# OpenAI API call
|
||||
response = await self.client.embeddings.create(
|
||||
input=texts,
|
||||
model=self.model_id,
|
||||
)
|
||||
|
||||
# Extract embeddings from response
|
||||
embeddings = [item.embedding for item in response.data]
|
||||
return embeddings
|
||||
|
||||
return [item.embedding for item in response.data]
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating OpenAI embeddings: {e}")
|
||||
raise
|
||||
|
|
@ -155,14 +98,8 @@ class OpenAIEmbeddingService(BaseEmbeddingService):
|
|||
async def embed_query(self, query: str) -> List[float]:
|
||||
"""Embed a single query text using OpenAI API.
|
||||
|
||||
Args:
|
||||
query: Query text to embed
|
||||
|
||||
Returns:
|
||||
Embedding vector as list of floats
|
||||
|
||||
Raises:
|
||||
EmbeddingAPIKeyNotConfiguredError: If API key is not configured
|
||||
EmbeddingAPIKeyNotConfiguredError: If API key is not configured.
|
||||
"""
|
||||
self._ensure_api_key_configured()
|
||||
embeddings = await self.embed_texts([query])
|
||||
|
|
@ -177,201 +114,17 @@ class OpenAIEmbeddingService(BaseEmbeddingService):
|
|||
) -> List[Dict[str, Any]]:
|
||||
"""Search for similar chunks using vector similarity.
|
||||
|
||||
Args:
|
||||
query: Search query text
|
||||
organization_id: Organization ID for scoping
|
||||
limit: Maximum number of results to return
|
||||
document_uuids: Optional list of document UUIDs to filter by
|
||||
|
||||
Returns:
|
||||
List of dictionaries with chunk data and similarity scores
|
||||
|
||||
Raises:
|
||||
EmbeddingAPIKeyNotConfiguredError: If API key is not configured
|
||||
EmbeddingAPIKeyNotConfiguredError: If API key is not configured.
|
||||
"""
|
||||
self._ensure_api_key_configured()
|
||||
|
||||
# Generate query embedding
|
||||
query_embedding = await self.embed_query(query)
|
||||
|
||||
# Perform vector similarity search
|
||||
results = await self.db.search_similar_chunks(
|
||||
return await self.db.search_similar_chunks(
|
||||
query_embedding=query_embedding,
|
||||
organization_id=organization_id,
|
||||
limit=limit,
|
||||
document_uuids=document_uuids,
|
||||
embedding_model=self.model_id,
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
async def process_document(
|
||||
self,
|
||||
file_path: str,
|
||||
organization_id: int,
|
||||
created_by: int,
|
||||
custom_metadata: dict = None,
|
||||
):
|
||||
"""Process a document: convert, chunk, embed, and store in database.
|
||||
|
||||
Args:
|
||||
file_path: Path to the document file
|
||||
organization_id: Organization ID for scoping
|
||||
created_by: User ID who uploaded the document
|
||||
custom_metadata: Optional custom metadata dictionary
|
||||
|
||||
Returns:
|
||||
The created document record
|
||||
"""
|
||||
try:
|
||||
# Extract file metadata
|
||||
filename = Path(file_path).name
|
||||
file_hash = self.db.compute_file_hash(file_path)
|
||||
file_size = os.path.getsize(file_path)
|
||||
mime_type = self.db.get_mime_type(file_path)
|
||||
|
||||
# Check if document already exists
|
||||
existing_doc = await self.db.get_document_by_hash(
|
||||
file_hash, organization_id
|
||||
)
|
||||
if existing_doc:
|
||||
logger.info(f"Document already exists: {filename} (hash: {file_hash})")
|
||||
return existing_doc
|
||||
|
||||
# Create document record
|
||||
doc_record = await self.db.create_document(
|
||||
organization_id=organization_id,
|
||||
created_by=created_by,
|
||||
filename=filename,
|
||||
file_size_bytes=file_size,
|
||||
file_hash=file_hash,
|
||||
mime_type=mime_type,
|
||||
custom_metadata=custom_metadata or {},
|
||||
)
|
||||
|
||||
logger.info(f"Processing document with OpenAI embeddings: {filename}")
|
||||
|
||||
# Update status to processing
|
||||
await self.db.update_document_status(doc_record.id, "processing")
|
||||
|
||||
# Step 1: Convert document using docling
|
||||
logger.info("Converting document with docling...")
|
||||
conversion_result = self.converter.convert(file_path)
|
||||
doc = conversion_result.document
|
||||
|
||||
# Store docling metadata
|
||||
docling_metadata = {
|
||||
"num_pages": len(doc.pages) if hasattr(doc, "pages") else None,
|
||||
"document_type": type(doc).__name__,
|
||||
}
|
||||
|
||||
# Step 2: Chunk the document
|
||||
logger.info(f"Chunking document with max_tokens={self.max_tokens}...")
|
||||
chunks = list(self.chunker.chunk(dl_doc=doc))
|
||||
total_chunks = len(chunks)
|
||||
|
||||
logger.info(f"Generated {total_chunks} chunks")
|
||||
|
||||
# Step 3: Process each chunk
|
||||
chunk_texts = []
|
||||
chunk_records = []
|
||||
token_counts = []
|
||||
|
||||
for i, chunk in enumerate(chunks):
|
||||
# Get chunk text
|
||||
chunk_text = chunk.text
|
||||
|
||||
# Get contextualized text
|
||||
contextualized_text = self.chunker.contextualize(chunk=chunk)
|
||||
|
||||
# Calculate token count
|
||||
text_to_tokenize = (
|
||||
contextualized_text if contextualized_text else chunk_text
|
||||
)
|
||||
token_count = len(
|
||||
self.tokenizer.tokenizer.encode(
|
||||
text_to_tokenize, add_special_tokens=False
|
||||
)
|
||||
)
|
||||
token_counts.append(token_count)
|
||||
|
||||
# Prepare chunk metadata
|
||||
chunk_metadata = {}
|
||||
if hasattr(chunk, "meta") and chunk.meta:
|
||||
chunk_metadata = {
|
||||
"doc_items": (
|
||||
[str(item) for item in chunk.meta.doc_items]
|
||||
if hasattr(chunk.meta, "doc_items")
|
||||
else []
|
||||
),
|
||||
"headings": (
|
||||
chunk.meta.headings
|
||||
if hasattr(chunk.meta, "headings")
|
||||
else []
|
||||
),
|
||||
}
|
||||
|
||||
# Create chunk record (without embedding yet)
|
||||
chunk_record = KnowledgeBaseChunkModel(
|
||||
document_id=doc_record.id,
|
||||
organization_id=organization_id,
|
||||
chunk_text=chunk_text,
|
||||
contextualized_text=contextualized_text,
|
||||
chunk_index=i,
|
||||
chunk_metadata=chunk_metadata,
|
||||
embedding_model=self.model_id,
|
||||
embedding_dimension=EMBEDDING_DIMENSION,
|
||||
token_count=token_count,
|
||||
)
|
||||
|
||||
chunk_records.append(chunk_record)
|
||||
chunk_texts.append(text_to_tokenize)
|
||||
|
||||
# Log chunk statistics
|
||||
if token_counts:
|
||||
avg_tokens = sum(token_counts) / len(token_counts)
|
||||
min_tokens = min(token_counts)
|
||||
max_tokens = max(token_counts)
|
||||
logger.info("Chunk token statistics:")
|
||||
logger.info(f" - Average: {avg_tokens:.1f} tokens")
|
||||
logger.info(f" - Min: {min_tokens} tokens")
|
||||
logger.info(f" - Max: {max_tokens} tokens")
|
||||
|
||||
# Step 4: Generate embeddings using OpenAI API
|
||||
logger.info(f"Generating embeddings using OpenAI ({self.model_id})...")
|
||||
embeddings = await self.embed_texts(chunk_texts)
|
||||
|
||||
# Step 5: Attach embeddings to chunk records
|
||||
for chunk_record, embedding in zip(chunk_records, embeddings):
|
||||
chunk_record.embedding = embedding
|
||||
|
||||
# Step 6: Save all chunks in batch
|
||||
logger.info("Storing chunks in database...")
|
||||
await self.db.create_chunks_batch(chunk_records)
|
||||
|
||||
# Update document status to completed
|
||||
await self.db.update_document_status(
|
||||
doc_record.id,
|
||||
"completed",
|
||||
total_chunks=total_chunks,
|
||||
docling_metadata=docling_metadata,
|
||||
)
|
||||
|
||||
logger.info(f"Successfully processed document: {filename}")
|
||||
logger.info(f" - Total chunks: {total_chunks}")
|
||||
logger.info(f" - Embedding model: {self.model_id}")
|
||||
logger.info(f" - Document ID: {doc_record.id}")
|
||||
logger.info(f" - Document UUID: {doc_record.document_uuid}")
|
||||
|
||||
return doc_record
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing document with OpenAI: {e}")
|
||||
|
||||
# Update document status to failed if it exists
|
||||
if "doc_record" in locals():
|
||||
await self.db.update_document_status(
|
||||
doc_record.id, "failed", error_message=str(e)
|
||||
)
|
||||
|
||||
raise
|
||||
|
|
|
|||
|
|
@ -487,6 +487,71 @@ class MPSServiceKeyClient:
|
|||
response=response,
|
||||
)
|
||||
|
||||
async def process_document(
|
||||
self,
|
||||
file_path: str,
|
||||
filename: str,
|
||||
content_type: str,
|
||||
retrieval_mode: str = "chunked",
|
||||
max_tokens: int = 128,
|
||||
chunk_overlap_tokens: int = 0,
|
||||
merge_peers: bool = True,
|
||||
tokenizer_model: Optional[str] = None,
|
||||
correlation_id: Optional[str] = None,
|
||||
organization_id: Optional[int] = None,
|
||||
created_by: Optional[str] = None,
|
||||
) -> dict:
|
||||
"""Convert + chunk a document via MPS /document/process.
|
||||
|
||||
Returns a dict matching DocumentProcessResponse in MPS:
|
||||
{
|
||||
"mode": "chunked" | "full_document",
|
||||
"docling_metadata": {...},
|
||||
"full_text": str | None, # populated only in full_document mode
|
||||
"chunks": [...], # populated only in chunked mode
|
||||
}
|
||||
|
||||
Timeout is 300s to match the ALB idle_timeout configured in
|
||||
infrastructure/mps/main.tf. Raises on non-2xx responses.
|
||||
"""
|
||||
data = {
|
||||
"retrieval_mode": retrieval_mode,
|
||||
"max_tokens": str(max_tokens),
|
||||
"chunk_overlap_tokens": str(chunk_overlap_tokens),
|
||||
"merge_peers": str(merge_peers).lower(),
|
||||
}
|
||||
if tokenizer_model is not None:
|
||||
data["tokenizer_model"] = tokenizer_model
|
||||
if correlation_id:
|
||||
data["correlation_id"] = correlation_id
|
||||
|
||||
headers = self._get_headers(organization_id, created_by)
|
||||
# Remove JSON content-type so httpx sets the correct multipart boundary.
|
||||
headers.pop("Content-Type", None)
|
||||
|
||||
async with httpx.AsyncClient(timeout=httpx.Timeout(300.0)) as client:
|
||||
with open(file_path, "rb") as fh:
|
||||
files = {"file": (filename, fh.read(), content_type)}
|
||||
|
||||
response = await client.post(
|
||||
f"{self.base_url}/api/v1/document/process",
|
||||
files=files,
|
||||
data=data,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
return response.json()
|
||||
|
||||
logger.error(
|
||||
f"Failed to process document: {response.status_code} - {response.text}"
|
||||
)
|
||||
raise httpx.HTTPStatusError(
|
||||
f"Failed to process document: {response.text}",
|
||||
request=response.request,
|
||||
response=response,
|
||||
)
|
||||
|
||||
async def call_workflow_api(
|
||||
self,
|
||||
call_type: str,
|
||||
|
|
|
|||
|
|
@ -97,6 +97,7 @@ async def play_audio(
|
|||
queue_frame: Callable[[Frame], Awaitable[None]],
|
||||
transcript: Optional[str] = None,
|
||||
append_to_context: bool = False,
|
||||
persist_to_logs: bool = False,
|
||||
) -> None:
|
||||
"""Play raw PCM-16 audio once.
|
||||
|
||||
|
|
@ -115,6 +116,8 @@ async def play_audio(
|
|||
transcript: Optional transcript of the recording.
|
||||
append_to_context: Whether the transcript should be appended to
|
||||
the LLM assistant context. Defaults to False.
|
||||
persist_to_logs: Whether the transcript should be written to the
|
||||
app-level logs buffer by observers. Defaults to False.
|
||||
"""
|
||||
context_id = str(uuid.uuid4())
|
||||
await queue_frame(TTSStartedFrame(context_id=context_id))
|
||||
|
|
@ -123,6 +126,7 @@ async def play_audio(
|
|||
text=transcript, aggregated_by="recording", context_id=context_id
|
||||
)
|
||||
tts_text.append_to_context = append_to_context
|
||||
tts_text.persist_to_logs = persist_to_logs
|
||||
await queue_frame(tts_text)
|
||||
await queue_frame(
|
||||
TTSAudioRawFrame(
|
||||
|
|
|
|||
|
|
@ -42,6 +42,7 @@ from pipecat.frames.frames import (
|
|||
MetricsFrame,
|
||||
StopFrame,
|
||||
TranscriptionFrame,
|
||||
TTSSpeakFrame,
|
||||
TTSTextFrame,
|
||||
UserMuteStartedFrame,
|
||||
UserMuteStoppedFrame,
|
||||
|
|
@ -230,8 +231,22 @@ class RealtimeFeedbackObserver(BaseObserver):
|
|||
},
|
||||
}
|
||||
)
|
||||
# Handle engine-queued speech (transition/tool messages) marked for
|
||||
# log persistence. The downstream TTSTextFrame(s) from the TTS service
|
||||
# still stream to WS as normal; we persist the full utterance once here
|
||||
# to avoid word-level log entries from word-timestamp providers.
|
||||
elif isinstance(frame, TTSSpeakFrame):
|
||||
if getattr(frame, "persist_to_logs", False):
|
||||
await self._append_to_buffer(
|
||||
{
|
||||
"type": RealtimeFeedbackType.BOT_TEXT.value,
|
||||
"payload": {"text": frame.text},
|
||||
}
|
||||
)
|
||||
# Handle bot TTS text - respect pts timing, WebSocket only
|
||||
# Complete turn text is persisted via register_turn_handlers
|
||||
# Complete turn text is persisted via register_turn_handlers,
|
||||
# except for frames explicitly flagged persist_to_logs (e.g. recording
|
||||
# transcripts from play_audio) which bypass the aggregator path.
|
||||
elif isinstance(frame, TTSTextFrame):
|
||||
message = {
|
||||
"type": RealtimeFeedbackType.BOT_TEXT.value,
|
||||
|
|
@ -249,6 +264,9 @@ class RealtimeFeedbackObserver(BaseObserver):
|
|||
|
||||
await self._ensure_clock_task()
|
||||
await self._clock_queue.put((frame.pts, frame.id, message))
|
||||
elif getattr(frame, "persist_to_logs", False):
|
||||
# No pts + explicit persistence request (recording transcript).
|
||||
await self._send_message(message)
|
||||
else:
|
||||
# No pts, send immediately
|
||||
await self._send_ws(message)
|
||||
|
|
|
|||
|
|
@ -94,6 +94,14 @@ class _OrgRoutingExporter(SpanExporter):
|
|||
org_buckets = {}
|
||||
|
||||
for span in spans:
|
||||
# Drop fastmcp's built-in auto-instrumentation spans
|
||||
# (`tools/call <name>`, etc.) — our `@traced_tool` decorator
|
||||
# in `api/mcp_server/tracing.py` produces the spans we want. Keeping
|
||||
# both would just double every trace.
|
||||
scope = getattr(span, "instrumentation_scope", None)
|
||||
if scope is not None and scope.name == "fastmcp":
|
||||
continue
|
||||
|
||||
org_id = span.attributes.get("dograh.org_id") if span.attributes else None
|
||||
if org_id and str(org_id) in self._org_exporters:
|
||||
org_buckets.setdefault(str(org_id), []).append(span)
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from enum import Enum
|
||||
from typing import List, Optional
|
||||
from typing import Annotated, List, Literal, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, Field, ValidationError, model_validator
|
||||
|
||||
|
|
@ -42,17 +42,48 @@ class RetryConfigDTO(BaseModel):
|
|||
retry_delay_seconds: int = 5
|
||||
|
||||
|
||||
class NodeDataDTO(BaseModel):
|
||||
# ─────────────────────────────────────────────────────────────────────────
|
||||
# Per-type node data classes.
|
||||
#
|
||||
# Shared fields are factored out as Pydantic mixins; per-type classes
|
||||
# inherit only the mixins they need so mistyped fields raise at validation
|
||||
# time and downstream consumers get accurate types. `is_start` / `is_end`
|
||||
# live on every variant so the WorkflowGraph can identify boundary nodes
|
||||
# without dispatching on type.
|
||||
# ─────────────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class _NodeDataBase(BaseModel):
|
||||
name: str = Field(..., min_length=1)
|
||||
prompt: Optional[str] = Field(default=None)
|
||||
is_static: bool = False
|
||||
is_start: bool = False
|
||||
is_end: bool = False
|
||||
|
||||
|
||||
class _PromptedNodeDataMixin(BaseModel):
|
||||
prompt: Optional[str] = Field(default=None)
|
||||
is_static: bool = False
|
||||
allow_interrupt: bool = False
|
||||
add_global_prompt: bool = True
|
||||
|
||||
|
||||
class _ExtractionNodeDataMixin(BaseModel):
|
||||
extraction_enabled: bool = False
|
||||
extraction_prompt: Optional[str] = None
|
||||
extraction_variables: Optional[list[ExtractionVariableDTO]] = None
|
||||
add_global_prompt: bool = True
|
||||
|
||||
|
||||
class _ToolDocumentRefsMixin(BaseModel):
|
||||
tool_uuids: Optional[List[str]] = None
|
||||
document_uuids: Optional[List[str]] = None
|
||||
|
||||
|
||||
class StartCallNodeData(
|
||||
_NodeDataBase,
|
||||
_PromptedNodeDataMixin,
|
||||
_ExtractionNodeDataMixin,
|
||||
_ToolDocumentRefsMixin,
|
||||
):
|
||||
is_start: bool = True
|
||||
greeting: Optional[str] = None
|
||||
greeting_type: Optional[str] = None # 'text' or 'audio'
|
||||
greeting_recording_id: Optional[str] = None
|
||||
|
|
@ -61,14 +92,38 @@ class NodeDataDTO(BaseModel):
|
|||
detect_voicemail: bool = False
|
||||
delayed_start: bool = False
|
||||
delayed_start_duration: Optional[float] = None
|
||||
# Pre-call fetch (start node only)
|
||||
pre_call_fetch_enabled: bool = False
|
||||
pre_call_fetch_url: Optional[str] = None
|
||||
pre_call_fetch_credential_uuid: Optional[str] = None
|
||||
tool_uuids: Optional[List[str]] = None
|
||||
document_uuids: Optional[List[str]] = None
|
||||
|
||||
|
||||
class AgentNodeData(
|
||||
_NodeDataBase,
|
||||
_PromptedNodeDataMixin,
|
||||
_ExtractionNodeDataMixin,
|
||||
_ToolDocumentRefsMixin,
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
class EndCallNodeData(
|
||||
_NodeDataBase,
|
||||
_PromptedNodeDataMixin,
|
||||
_ExtractionNodeDataMixin,
|
||||
):
|
||||
is_end: bool = True
|
||||
|
||||
|
||||
class GlobalNodeData(_NodeDataBase, _PromptedNodeDataMixin):
|
||||
pass
|
||||
|
||||
|
||||
class TriggerNodeData(_NodeDataBase):
|
||||
trigger_path: Optional[str] = None
|
||||
# Webhook node specific fields
|
||||
enabled: bool = True
|
||||
|
||||
|
||||
class WebhookNodeData(_NodeDataBase):
|
||||
enabled: bool = True
|
||||
http_method: Optional[str] = None
|
||||
endpoint_url: Optional[str] = None
|
||||
|
|
@ -76,30 +131,129 @@ class NodeDataDTO(BaseModel):
|
|||
custom_headers: Optional[list[CustomHeaderDTO]] = None
|
||||
payload_template: Optional[dict] = None
|
||||
retry_config: Optional[RetryConfigDTO] = None
|
||||
# QA node specific fields
|
||||
|
||||
|
||||
class QANodeData(_NodeDataBase):
|
||||
qa_enabled: bool = True
|
||||
qa_system_prompt: Optional[str] = None
|
||||
qa_use_workflow_llm: bool = True
|
||||
qa_provider: Optional[str] = None
|
||||
qa_model: Optional[str] = None
|
||||
qa_api_key: Optional[str] = None
|
||||
qa_endpoint: Optional[str] = None
|
||||
qa_system_prompt: Optional[str] = None
|
||||
qa_min_call_duration: int = 15
|
||||
qa_voicemail_calls: bool = False
|
||||
qa_sample_rate: int = 100
|
||||
|
||||
|
||||
class RFNodeDTO(BaseModel):
|
||||
# Union of every per-type data class — useful as a type annotation on
|
||||
# consumers that handle any node data without dispatching on type. Cannot
|
||||
# be called as a constructor; use the per-type class directly.
|
||||
NodeDataDTO = Union[
|
||||
StartCallNodeData,
|
||||
AgentNodeData,
|
||||
EndCallNodeData,
|
||||
GlobalNodeData,
|
||||
TriggerNodeData,
|
||||
WebhookNodeData,
|
||||
QANodeData,
|
||||
]
|
||||
|
||||
|
||||
# ─────────────────────────────────────────────────────────────────────────
|
||||
# Per-type RF nodes.
|
||||
#
|
||||
# RFNodeDTO is a discriminated Union over `type`. Pydantic dispatches to
|
||||
# the right variant when validating wire JSON. Direct instantiation must
|
||||
# use the concrete per-type class (StartCallRFNode, AgentRFNode, ...).
|
||||
# ─────────────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class _RFNodeBase(BaseModel):
|
||||
id: str
|
||||
type: NodeType = Field(default=NodeType.agentNode)
|
||||
position: Position
|
||||
data: NodeDataDTO
|
||||
|
||||
|
||||
def _require_prompt(data, type_label: str) -> None:
|
||||
prompt = getattr(data, "prompt", None)
|
||||
if not prompt or len(prompt.strip()) == 0:
|
||||
raise ValueError(f"Prompt is required for {type_label} nodes")
|
||||
|
||||
|
||||
class StartCallRFNode(_RFNodeBase):
|
||||
type: Literal["startCall"] = "startCall"
|
||||
data: StartCallNodeData
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _validate_prompt_required(self):
|
||||
"""Require prompt for all node types except trigger, webhook, and qa."""
|
||||
if self.type not in (NodeType.trigger, NodeType.webhook, NodeType.qa):
|
||||
if not self.data.prompt or len(self.data.prompt.strip()) == 0:
|
||||
raise ValueError("Prompt is required for non-trigger nodes")
|
||||
def _validate(self):
|
||||
_require_prompt(self.data, "start")
|
||||
return self
|
||||
|
||||
|
||||
class AgentRFNode(_RFNodeBase):
|
||||
type: Literal["agentNode"] = "agentNode"
|
||||
data: AgentNodeData
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _validate(self):
|
||||
_require_prompt(self.data, "agent")
|
||||
return self
|
||||
|
||||
|
||||
class EndCallRFNode(_RFNodeBase):
|
||||
type: Literal["endCall"] = "endCall"
|
||||
data: EndCallNodeData
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _validate(self):
|
||||
_require_prompt(self.data, "end")
|
||||
return self
|
||||
|
||||
|
||||
class GlobalRFNode(_RFNodeBase):
|
||||
type: Literal["globalNode"] = "globalNode"
|
||||
data: GlobalNodeData
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _validate(self):
|
||||
_require_prompt(self.data, "global")
|
||||
return self
|
||||
|
||||
|
||||
class TriggerRFNode(_RFNodeBase):
|
||||
type: Literal["trigger"] = "trigger"
|
||||
data: TriggerNodeData
|
||||
|
||||
|
||||
class WebhookRFNode(_RFNodeBase):
|
||||
type: Literal["webhook"] = "webhook"
|
||||
data: WebhookNodeData
|
||||
|
||||
|
||||
class QARFNode(_RFNodeBase):
|
||||
type: Literal["qa"] = "qa"
|
||||
data: QANodeData
|
||||
|
||||
|
||||
RFNodeDTO = Annotated[
|
||||
Union[
|
||||
StartCallRFNode,
|
||||
AgentRFNode,
|
||||
EndCallRFNode,
|
||||
GlobalRFNode,
|
||||
TriggerRFNode,
|
||||
WebhookRFNode,
|
||||
QARFNode,
|
||||
],
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
|
||||
|
||||
# ─────────────────────────────────────────────────────────────────────────
|
||||
# Edges
|
||||
# ─────────────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class EdgeDataDTO(BaseModel):
|
||||
label: str = Field(..., min_length=1)
|
||||
condition: str = Field(..., min_length=1)
|
||||
|
|
@ -144,3 +298,60 @@ class ReactFlowDTO(BaseModel):
|
|||
)
|
||||
|
||||
return self
|
||||
|
||||
|
||||
# Node type → per-type data class. Keeps sanitize_workflow_definition in
|
||||
# step with RFNodeDTO's discriminated union.
|
||||
_NODE_DATA_CLASSES: dict[str, type[BaseModel]] = {
|
||||
NodeType.startNode.value: StartCallNodeData,
|
||||
NodeType.agentNode.value: AgentNodeData,
|
||||
NodeType.endNode.value: EndCallNodeData,
|
||||
NodeType.globalNode.value: GlobalNodeData,
|
||||
NodeType.trigger.value: TriggerNodeData,
|
||||
NodeType.webhook.value: WebhookNodeData,
|
||||
NodeType.qa.value: QANodeData,
|
||||
}
|
||||
|
||||
|
||||
def sanitize_workflow_definition(definition: dict | None) -> dict | None:
|
||||
"""Strip unknown fields from each node.data and edge.data so UI-only
|
||||
runtime state (`invalid`, `validationMessage`, etc.) doesn't leak into
|
||||
persisted workflow JSON.
|
||||
|
||||
Only `.data` is filtered — top-level keys on nodes/edges/definition
|
||||
(viewport, ReactFlow-computed width/height, etc.) are preserved as-is.
|
||||
This is a stripper, not a validator: it doesn't enforce required fields
|
||||
or run model_validators, so partial drafts save cleanly.
|
||||
"""
|
||||
if not definition:
|
||||
return definition
|
||||
|
||||
out = dict(definition)
|
||||
raw_nodes = out.get("nodes")
|
||||
if isinstance(raw_nodes, list):
|
||||
out["nodes"] = [_sanitize_node(n) for n in raw_nodes]
|
||||
raw_edges = out.get("edges")
|
||||
if isinstance(raw_edges, list):
|
||||
out["edges"] = [_sanitize_edge(e) for e in raw_edges]
|
||||
return out
|
||||
|
||||
|
||||
def _sanitize_node(node):
|
||||
if not isinstance(node, dict):
|
||||
return node
|
||||
data_cls = _NODE_DATA_CLASSES.get(node.get("type"))
|
||||
raw_data = node.get("data")
|
||||
if not data_cls or not isinstance(raw_data, dict):
|
||||
return node
|
||||
allowed = data_cls.model_fields.keys()
|
||||
return {**node, "data": {k: v for k, v in raw_data.items() if k in allowed}}
|
||||
|
||||
|
||||
def _sanitize_edge(edge):
|
||||
if not isinstance(edge, dict):
|
||||
return edge
|
||||
raw_data = edge.get("data")
|
||||
if not isinstance(raw_data, dict):
|
||||
return edge
|
||||
allowed = EdgeDataDTO.model_fields.keys()
|
||||
return {**edge, "data": {k: v for k, v in raw_data.items() if k in allowed}}
|
||||
|
|
|
|||
105
api/services/workflow/layout.py
Normal file
105
api/services/workflow/layout.py
Normal file
|
|
@ -0,0 +1,105 @@
|
|||
"""Position reconciliation for LLM-edited workflows.
|
||||
|
||||
`save_workflow` re-parses LLM-authored TypeScript into workflow JSON,
|
||||
but the parser deliberately ignores positions (LLMs place nodes
|
||||
poorly, and the authoring surface stays tighter without coordinates).
|
||||
This module fills them back in by matching the newly-parsed nodes
|
||||
against the previously-stored workflow:
|
||||
|
||||
1. Named match: (type, data.name) — most reliable
|
||||
2. Unnamed match: (type, nth-occurrence-in-order) — best effort
|
||||
3. New nodes: placed adjacent to their first incoming neighbor
|
||||
(src.x + 400, src.y + 200), or (0, 0) if orphan
|
||||
|
||||
The UI has a proper dagre-based re-layout button
|
||||
(`ui/src/app/workflow/[workflowId]/utils/layoutNodes.ts`) users can
|
||||
invoke when they want a clean pass. This module only aims to avoid
|
||||
all-nodes-at-origin after a save.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
_DEFAULT_POSITION: dict[str, float] = {"x": 0.0, "y": 0.0}
|
||||
# Horizontal / vertical offset for newly-introduced nodes relative to
|
||||
# their first incoming neighbor. Chosen to roughly match the UI layout's
|
||||
# node spacing without overlapping the neighbor's card.
|
||||
_NEW_NODE_DX: float = 400.0
|
||||
_NEW_NODE_DY: float = 200.0
|
||||
|
||||
|
||||
def reconcile_positions(
|
||||
new_wf: dict[str, Any],
|
||||
previous_wf: dict[str, Any] | None,
|
||||
) -> dict[str, Any]:
|
||||
"""Return `new_wf` with positions filled from `previous_wf` where
|
||||
node identity matches, and approximate positions for genuinely new
|
||||
nodes. Mutates and returns the same dict (callers typically want
|
||||
the mutation)."""
|
||||
if not previous_wf:
|
||||
_place_new_nodes(new_wf)
|
||||
return new_wf
|
||||
|
||||
prev_nodes = previous_wf.get("nodes") or []
|
||||
named_positions: dict[tuple[str, str], dict[str, float]] = {}
|
||||
unnamed_positions: dict[str, list[dict[str, float]]] = {}
|
||||
|
||||
for n in prev_nodes:
|
||||
t = n.get("type") or ""
|
||||
name = ((n.get("data") or {}).get("name") or "").strip()
|
||||
pos = n.get("position") or dict(_DEFAULT_POSITION)
|
||||
if name:
|
||||
named_positions[(t, name)] = pos
|
||||
else:
|
||||
unnamed_positions.setdefault(t, []).append(pos)
|
||||
|
||||
unnamed_cursor: dict[str, int] = {}
|
||||
|
||||
for node in new_wf.get("nodes") or []:
|
||||
t = node.get("type") or ""
|
||||
name = ((node.get("data") or {}).get("name") or "").strip()
|
||||
|
||||
pos: dict[str, float] | None = None
|
||||
if name:
|
||||
pos = named_positions.get((t, name))
|
||||
if pos is None:
|
||||
idx = unnamed_cursor.get(t, 0)
|
||||
positions = unnamed_positions.get(t, [])
|
||||
if idx < len(positions):
|
||||
pos = positions[idx]
|
||||
unnamed_cursor[t] = idx + 1
|
||||
if pos is not None:
|
||||
node["position"] = dict(pos)
|
||||
|
||||
_place_new_nodes(new_wf)
|
||||
return new_wf
|
||||
|
||||
|
||||
def _place_new_nodes(wf: dict[str, Any]) -> None:
|
||||
"""For nodes still at (0, 0) — i.e. unmatched by any previous
|
||||
node — pick a position adjacent to the first incoming neighbor.
|
||||
Runs after named/unnamed matching so only genuinely-new nodes are
|
||||
affected."""
|
||||
nodes = wf.get("nodes") or []
|
||||
if not nodes:
|
||||
return
|
||||
id_to_node = {n["id"]: n for n in nodes}
|
||||
edges = wf.get("edges") or []
|
||||
|
||||
for node in nodes:
|
||||
pos = node.get("position") or {}
|
||||
if pos.get("x") or pos.get("y"):
|
||||
continue # already has a non-origin position
|
||||
src_id = next(
|
||||
(e["source"] for e in edges if e.get("target") == node["id"]),
|
||||
None,
|
||||
)
|
||||
if src_id and src_id in id_to_node:
|
||||
src_pos = id_to_node[src_id].get("position") or dict(_DEFAULT_POSITION)
|
||||
node["position"] = {
|
||||
"x": float(src_pos.get("x", 0.0)) + _NEW_NODE_DX,
|
||||
"y": float(src_pos.get("y", 0.0)) + _NEW_NODE_DY,
|
||||
}
|
||||
# Leaves truly orphan new nodes at (0, 0). The UI's re-layout
|
||||
# pass will pull them into the graph on next edit.
|
||||
82
api/services/workflow/node_specs/__init__.py
Normal file
82
api/services/workflow/node_specs/__init__.py
Normal file
|
|
@ -0,0 +1,82 @@
|
|||
"""Node specification registry.
|
||||
|
||||
Adding a new node type:
|
||||
1. Create a new module under this package, define a `SPEC: NodeSpec`.
|
||||
2. Add it to the imports + REGISTRY below.
|
||||
3. The Pydantic discriminated-union variant in dto.py must use the same
|
||||
`name` value as `SPEC.name`.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from api.services.workflow.node_specs._base import (
|
||||
SPEC_VERSION,
|
||||
DisplayOptions,
|
||||
GraphConstraints,
|
||||
NodeCategory,
|
||||
NodeExample,
|
||||
NodeSpec,
|
||||
PropertyOption,
|
||||
PropertySpec,
|
||||
PropertyType,
|
||||
evaluate_display_options,
|
||||
)
|
||||
|
||||
REGISTRY: dict[str, NodeSpec] = {}
|
||||
|
||||
|
||||
def register(spec: NodeSpec) -> NodeSpec:
|
||||
"""Register a NodeSpec in the global registry. Returns the spec for
|
||||
chaining at module top-level: `SPEC = register(NodeSpec(...))`."""
|
||||
if spec.name in REGISTRY:
|
||||
raise ValueError(
|
||||
f"Duplicate NodeSpec registration for {spec.name!r}. "
|
||||
f"Each node type must have exactly one spec."
|
||||
)
|
||||
REGISTRY[spec.name] = spec
|
||||
return spec
|
||||
|
||||
|
||||
def get_spec(name: str) -> NodeSpec | None:
|
||||
return REGISTRY.get(name)
|
||||
|
||||
|
||||
def all_specs() -> list[NodeSpec]:
|
||||
"""All registered specs, sorted by name for stable output."""
|
||||
return [REGISTRY[name] for name in sorted(REGISTRY)]
|
||||
|
||||
|
||||
__all__ = [
|
||||
"SPEC_VERSION",
|
||||
"REGISTRY",
|
||||
"DisplayOptions",
|
||||
"GraphConstraints",
|
||||
"NodeCategory",
|
||||
"NodeExample",
|
||||
"NodeSpec",
|
||||
"PropertyOption",
|
||||
"PropertySpec",
|
||||
"PropertyType",
|
||||
"all_specs",
|
||||
"evaluate_display_options",
|
||||
"get_spec",
|
||||
"register",
|
||||
]
|
||||
|
||||
|
||||
# Side-effect imports — each module's `register(SPEC)` call populates REGISTRY.
|
||||
# Keep at module bottom so the registry helpers are defined first.
|
||||
from api.services.workflow.node_specs import ( # noqa: E402, F401
|
||||
agent,
|
||||
end_call,
|
||||
global_node,
|
||||
qa,
|
||||
start_call,
|
||||
trigger,
|
||||
webhook,
|
||||
)
|
||||
|
||||
# Wire up registrations from the SPEC constants in each module.
|
||||
for _module in (start_call, agent, end_call, global_node, trigger, webhook, qa):
|
||||
register(_module.SPEC)
|
||||
del _module
|
||||
28
api/services/workflow/node_specs/__main__.py
Normal file
28
api/services/workflow/node_specs/__main__.py
Normal file
|
|
@ -0,0 +1,28 @@
|
|||
"""Dump the registered NodeSpecs to stdout as JSON.
|
||||
|
||||
Used by `scripts/generate_sdk.sh` to feed both SDK codegens without
|
||||
requiring a running backend. Shape matches the `/api/v1/node-types`
|
||||
HTTP response so either source is interchangeable.
|
||||
|
||||
python -m api.services.workflow.node_specs > specs.json
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import sys
|
||||
|
||||
from api.services.workflow.node_specs import SPEC_VERSION, all_specs
|
||||
|
||||
|
||||
def main() -> None:
|
||||
payload = {
|
||||
"spec_version": SPEC_VERSION,
|
||||
"node_types": [s.model_dump(mode="json") for s in all_specs()],
|
||||
}
|
||||
json.dump(payload, sys.stdout, indent=2, ensure_ascii=False)
|
||||
sys.stdout.write("\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
224
api/services/workflow/node_specs/_base.py
Normal file
224
api/services/workflow/node_specs/_base.py
Normal file
|
|
@ -0,0 +1,224 @@
|
|||
"""Spec schema for node definitions.
|
||||
|
||||
A `NodeSpec` is the single source of truth for a node type. It drives:
|
||||
- Pydantic validation (the per-type DTOs in dto.py mirror these property types)
|
||||
- The generic UI renderer (frontend reads specs via /api/v1/node-types)
|
||||
- The LLM SDK (constructors and JSON-Schema derived from these specs)
|
||||
|
||||
Every property's `description` is LLM-readable copy — treat it as production
|
||||
documentation, not internal notes. Spec lint enforces non-empty descriptions
|
||||
and example coverage.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import Enum
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
# Spec contract version. Bump when adding new PropertyType values or making
|
||||
# breaking changes to the NodeSpec wire shape. SDK clients warn on mismatch.
|
||||
SPEC_VERSION = "1.0.0"
|
||||
|
||||
|
||||
class PropertyType(str, Enum):
|
||||
"""Bounded vocabulary of property types the renderer dispatches on.
|
||||
|
||||
Adding a value here requires a matching arm in the frontend
|
||||
`<PropertyInput>` switch and (where relevant) the SDK codegen template.
|
||||
"""
|
||||
|
||||
string = "string"
|
||||
number = "number"
|
||||
boolean = "boolean"
|
||||
options = "options" # single-select dropdown
|
||||
multi_options = "multi_options" # multi-select
|
||||
fixed_collection = "fixed_collection" # repeating rows of sub-properties
|
||||
json = "json" # arbitrary JSON object editor
|
||||
|
||||
# Domain-specific reference types — values are UUIDs/keys looked up against
|
||||
# a reference catalog (list_tools, list_documents, list_recordings,
|
||||
# list_credentials).
|
||||
tool_refs = "tool_refs"
|
||||
document_refs = "document_refs"
|
||||
recording_ref = "recording_ref"
|
||||
credential_ref = "credential_ref"
|
||||
|
||||
# Domain-specific input widgets
|
||||
mention_textarea = "mention_textarea" # textarea with {{var}} mentions
|
||||
url = "url"
|
||||
|
||||
|
||||
class NodeCategory(str, Enum):
|
||||
"""Drives grouping in the AddNodePanel UI."""
|
||||
|
||||
call_node = "call_node"
|
||||
global_node = "global_node"
|
||||
trigger = "trigger"
|
||||
integration = "integration"
|
||||
|
||||
|
||||
class DisplayOptions(BaseModel):
|
||||
"""Conditional visibility rules.
|
||||
|
||||
`show` keys are AND-combined: this property is visible only when EVERY
|
||||
referenced field's value matches one of the listed values.
|
||||
|
||||
`hide` keys are OR-combined: this property is hidden when ANY referenced
|
||||
field's value matches one of the listed values.
|
||||
|
||||
Example:
|
||||
DisplayOptions(show={"extraction_enabled": [True]})
|
||||
DisplayOptions(show={"greeting_type": ["audio"]})
|
||||
"""
|
||||
|
||||
show: Optional[dict[str, list[Any]]] = None
|
||||
hide: Optional[dict[str, list[Any]]] = None
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
|
||||
def evaluate_display_options(
|
||||
rules: Optional[DisplayOptions | dict[str, Any]],
|
||||
values: dict[str, Any],
|
||||
) -> bool:
|
||||
"""Reference implementation of the display_options visibility check.
|
||||
|
||||
Mirrored 1:1 in the TypeScript renderer
|
||||
(`ui/src/components/flow/renderer/displayOptions.ts`). The golden
|
||||
fixtures in `display_options_fixtures.json` lock the two
|
||||
implementations together — update both whenever the semantics change.
|
||||
"""
|
||||
if rules is None:
|
||||
return True
|
||||
|
||||
if isinstance(rules, DisplayOptions):
|
||||
show = rules.show
|
||||
hide = rules.hide
|
||||
else:
|
||||
show = rules.get("show")
|
||||
hide = rules.get("hide")
|
||||
|
||||
if show:
|
||||
for field, allowed in show.items():
|
||||
if values.get(field) not in allowed:
|
||||
return False
|
||||
|
||||
if hide:
|
||||
for field, hidden in hide.items():
|
||||
if values.get(field) in hidden:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
class PropertyOption(BaseModel):
|
||||
"""An option in an `options` or `multi_options` dropdown."""
|
||||
|
||||
value: str | int | bool | float
|
||||
label: str
|
||||
description: Optional[str] = None
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
|
||||
class PropertySpec(BaseModel):
|
||||
"""Single field on a node.
|
||||
|
||||
`description` is HUMAN-FACING — shown under the field in the edit
|
||||
dialog. Keep it concise and explain what the field does.
|
||||
|
||||
`llm_hint` is LLM-FACING — appears only in the `get_node_type` MCP
|
||||
response and in SDK schema output. Use it for catalog tool references
|
||||
(e.g., "Use `list_recordings`"), array shape, expected value idioms,
|
||||
or anything that would be noise in the UI. Optional; omit when the
|
||||
`description` already suffices for both audiences.
|
||||
"""
|
||||
|
||||
name: str
|
||||
type: PropertyType
|
||||
display_name: str
|
||||
description: str = Field(
|
||||
...,
|
||||
min_length=1,
|
||||
description="Human-facing explanation shown in the UI.",
|
||||
)
|
||||
llm_hint: Optional[str] = Field(
|
||||
default=None,
|
||||
description="LLM-only guidance; omitted from the UI.",
|
||||
)
|
||||
default: Any = None
|
||||
required: bool = False
|
||||
placeholder: Optional[str] = None
|
||||
|
||||
display_options: Optional[DisplayOptions] = None
|
||||
|
||||
# For `options` / `multi_options`
|
||||
options: Optional[list[PropertyOption]] = None
|
||||
|
||||
# For `fixed_collection` — sub-properties of each row
|
||||
properties: Optional[list["PropertySpec"]] = None
|
||||
|
||||
# Validation hints. Enforced by Pydantic where possible.
|
||||
min_value: Optional[float] = None
|
||||
max_value: Optional[float] = None
|
||||
min_length: Optional[int] = None
|
||||
max_length: Optional[int] = None
|
||||
pattern: Optional[str] = None
|
||||
|
||||
# Renderer hint, e.g. "textarea" vs single-line for `string`.
|
||||
editor: Optional[str] = None
|
||||
|
||||
# Free-form metadata for renderer-specific behavior. Use sparingly.
|
||||
extra: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
|
||||
PropertySpec.model_rebuild()
|
||||
|
||||
|
||||
class NodeExample(BaseModel):
|
||||
"""A worked example LLMs can pattern-match. Keep small and realistic."""
|
||||
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
data: dict[str, Any]
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
|
||||
class GraphConstraints(BaseModel):
|
||||
"""Per-node-type graph rules. WorkflowGraph enforces these at validation."""
|
||||
|
||||
min_incoming: Optional[int] = None
|
||||
max_incoming: Optional[int] = None
|
||||
min_outgoing: Optional[int] = None
|
||||
max_outgoing: Optional[int] = None
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
|
||||
class NodeSpec(BaseModel):
|
||||
"""Single source of truth for a node type."""
|
||||
|
||||
name: str # machine name; matches the Pydantic discriminator value
|
||||
display_name: str
|
||||
description: str = Field(
|
||||
...,
|
||||
min_length=1,
|
||||
description="Human-facing explanation shown in AddNodePanel.",
|
||||
)
|
||||
llm_hint: Optional[str] = Field(
|
||||
default=None,
|
||||
description="LLM-only guidance; omitted from the UI.",
|
||||
)
|
||||
category: NodeCategory
|
||||
icon: str # lucide-react icon name (e.g., "Play")
|
||||
version: str = "1.0.0"
|
||||
properties: list[PropertySpec]
|
||||
examples: list[NodeExample] = Field(default_factory=list)
|
||||
graph_constraints: Optional[GraphConstraints] = None
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
168
api/services/workflow/node_specs/agent.py
Normal file
168
api/services/workflow/node_specs/agent.py
Normal file
|
|
@ -0,0 +1,168 @@
|
|||
"""Spec for the Agent node — the workhorse mid-call node where the LLM
|
||||
executes a focused conversational step with optional tools and documents."""
|
||||
|
||||
from api.services.workflow.node_specs._base import (
|
||||
DisplayOptions,
|
||||
GraphConstraints,
|
||||
NodeCategory,
|
||||
NodeExample,
|
||||
NodeSpec,
|
||||
PropertyOption,
|
||||
PropertySpec,
|
||||
PropertyType,
|
||||
)
|
||||
|
||||
SPEC = NodeSpec(
|
||||
name="agentNode",
|
||||
display_name="Agent Node",
|
||||
description="Conversational step — the LLM runs one focused exchange.",
|
||||
llm_hint=(
|
||||
"Mid-call step executed by the LLM. Most workflows are a chain of "
|
||||
"agent nodes connected by edges that describe transition conditions. "
|
||||
"Each agent node can invoke tools and reference documents."
|
||||
),
|
||||
category=NodeCategory.call_node,
|
||||
icon="Headset",
|
||||
properties=[
|
||||
PropertySpec(
|
||||
name="name",
|
||||
type=PropertyType.string,
|
||||
display_name="Name",
|
||||
description=(
|
||||
"Short identifier for this step (e.g., 'Qualify Budget'). "
|
||||
"Appears in call logs and edge transition tools."
|
||||
),
|
||||
required=True,
|
||||
min_length=1,
|
||||
default="Agent",
|
||||
),
|
||||
PropertySpec(
|
||||
name="prompt",
|
||||
type=PropertyType.mention_textarea,
|
||||
display_name="Prompt",
|
||||
description=(
|
||||
"Agent system prompt for this step. Supports "
|
||||
"{{template_variables}} from extraction or pre-call fetch."
|
||||
),
|
||||
required=True,
|
||||
min_length=1,
|
||||
placeholder="Ask the caller about their budget and timeline.",
|
||||
),
|
||||
PropertySpec(
|
||||
name="allow_interrupt",
|
||||
type=PropertyType.boolean,
|
||||
display_name="Allow Interruption",
|
||||
description=(
|
||||
"When true, the user can interrupt the agent mid-utterance. "
|
||||
"Set false for non-interruptible disclosures."
|
||||
),
|
||||
default=True,
|
||||
),
|
||||
PropertySpec(
|
||||
name="add_global_prompt",
|
||||
type=PropertyType.boolean,
|
||||
display_name="Add Global Prompt",
|
||||
description=(
|
||||
"When true and a Global node exists, prepends the global "
|
||||
"prompt to this node's prompt at runtime."
|
||||
),
|
||||
default=True,
|
||||
),
|
||||
PropertySpec(
|
||||
name="extraction_enabled",
|
||||
type=PropertyType.boolean,
|
||||
display_name="Enable Variable Extraction",
|
||||
description=(
|
||||
"When true, runs an LLM extraction pass on transition out of "
|
||||
"this node to capture variables from the conversation."
|
||||
),
|
||||
default=False,
|
||||
),
|
||||
PropertySpec(
|
||||
name="extraction_prompt",
|
||||
type=PropertyType.string,
|
||||
display_name="Extraction Prompt",
|
||||
description="Overall instructions guiding variable extraction.",
|
||||
display_options=DisplayOptions(show={"extraction_enabled": [True]}),
|
||||
editor="textarea",
|
||||
),
|
||||
PropertySpec(
|
||||
name="extraction_variables",
|
||||
type=PropertyType.fixed_collection,
|
||||
display_name="Variables to Extract",
|
||||
description=(
|
||||
"Each entry declares one variable to capture from the "
|
||||
"conversation, with its name, type, and per-variable hint."
|
||||
),
|
||||
display_options=DisplayOptions(show={"extraction_enabled": [True]}),
|
||||
properties=[
|
||||
PropertySpec(
|
||||
name="name",
|
||||
type=PropertyType.string,
|
||||
display_name="Variable Name",
|
||||
description="snake_case identifier used downstream.",
|
||||
required=True,
|
||||
),
|
||||
PropertySpec(
|
||||
name="type",
|
||||
type=PropertyType.options,
|
||||
display_name="Type",
|
||||
description="Data type of the extracted value.",
|
||||
required=True,
|
||||
default="string",
|
||||
options=[
|
||||
PropertyOption(value="string", label="String"),
|
||||
PropertyOption(value="number", label="Number"),
|
||||
PropertyOption(value="boolean", label="Boolean"),
|
||||
],
|
||||
),
|
||||
PropertySpec(
|
||||
name="prompt",
|
||||
type=PropertyType.string,
|
||||
display_name="Extraction Hint",
|
||||
description="Per-variable hint describing what to look for.",
|
||||
editor="textarea",
|
||||
),
|
||||
],
|
||||
),
|
||||
PropertySpec(
|
||||
name="tool_uuids",
|
||||
type=PropertyType.tool_refs,
|
||||
display_name="Tools",
|
||||
description="Tools the agent can invoke during this step.",
|
||||
llm_hint="List of tool UUIDs from `list_tools`.",
|
||||
),
|
||||
PropertySpec(
|
||||
name="document_uuids",
|
||||
type=PropertyType.document_refs,
|
||||
display_name="Knowledge Base Documents",
|
||||
description="Documents the agent can reference during this step.",
|
||||
llm_hint="List of document UUIDs from `list_documents`.",
|
||||
),
|
||||
],
|
||||
examples=[
|
||||
NodeExample(
|
||||
name="qualify_lead",
|
||||
data={
|
||||
"name": "Qualify Budget",
|
||||
"prompt": "Ask about budget and timeline. Capture both before transitioning.",
|
||||
"allow_interrupt": True,
|
||||
"extraction_enabled": True,
|
||||
"extraction_prompt": "Extract budget amount and rough timeline.",
|
||||
"extraction_variables": [
|
||||
{
|
||||
"name": "budget_usd",
|
||||
"type": "number",
|
||||
"prompt": "Stated budget in USD",
|
||||
},
|
||||
{
|
||||
"name": "timeline",
|
||||
"type": "string",
|
||||
"prompt": "When they want to start",
|
||||
},
|
||||
],
|
||||
},
|
||||
),
|
||||
],
|
||||
graph_constraints=GraphConstraints(min_incoming=1),
|
||||
)
|
||||
123
api/services/workflow/node_specs/display_options_fixtures.json
Normal file
123
api/services/workflow/node_specs/display_options_fixtures.json
Normal file
|
|
@ -0,0 +1,123 @@
|
|||
{
|
||||
"_doc": "Golden fixtures for the display_options evaluator. Both the Python evaluator (api/services/workflow/node_specs/_base.py:evaluate_display_options) and the TypeScript evaluator (ui/src/components/flow/renderer/displayOptions.ts:evaluateDisplayOptions) must agree on every case here. Fixtures double as documentation for the show/hide semantics.",
|
||||
"cases": [
|
||||
{
|
||||
"name": "no_rules_visible",
|
||||
"rules": null,
|
||||
"values": {"a": 1},
|
||||
"expected": true
|
||||
},
|
||||
{
|
||||
"name": "empty_rules_visible",
|
||||
"rules": {"show": null, "hide": null},
|
||||
"values": {},
|
||||
"expected": true
|
||||
},
|
||||
{
|
||||
"name": "show_match_visible",
|
||||
"rules": {"show": {"extraction_enabled": [true]}},
|
||||
"values": {"extraction_enabled": true},
|
||||
"expected": true
|
||||
},
|
||||
{
|
||||
"name": "show_mismatch_hidden",
|
||||
"rules": {"show": {"extraction_enabled": [true]}},
|
||||
"values": {"extraction_enabled": false},
|
||||
"expected": false
|
||||
},
|
||||
{
|
||||
"name": "show_missing_field_hidden",
|
||||
"rules": {"show": {"extraction_enabled": [true]}},
|
||||
"values": {},
|
||||
"expected": false
|
||||
},
|
||||
{
|
||||
"name": "show_multiple_allowed_values",
|
||||
"rules": {"show": {"greeting_type": ["text", "audio"]}},
|
||||
"values": {"greeting_type": "audio"},
|
||||
"expected": true
|
||||
},
|
||||
{
|
||||
"name": "show_multiple_keys_all_match",
|
||||
"rules": {
|
||||
"show": {
|
||||
"qa_use_workflow_llm": [false],
|
||||
"qa_provider": ["azure"]
|
||||
}
|
||||
},
|
||||
"values": {"qa_use_workflow_llm": false, "qa_provider": "azure"},
|
||||
"expected": true
|
||||
},
|
||||
{
|
||||
"name": "show_multiple_keys_one_mismatch_hides",
|
||||
"rules": {
|
||||
"show": {
|
||||
"qa_use_workflow_llm": [false],
|
||||
"qa_provider": ["azure"]
|
||||
}
|
||||
},
|
||||
"values": {"qa_use_workflow_llm": false, "qa_provider": "openai"},
|
||||
"expected": false
|
||||
},
|
||||
{
|
||||
"name": "hide_match_hides",
|
||||
"rules": {"hide": {"locked": [true]}},
|
||||
"values": {"locked": true},
|
||||
"expected": false
|
||||
},
|
||||
{
|
||||
"name": "hide_mismatch_visible",
|
||||
"rules": {"hide": {"locked": [true]}},
|
||||
"values": {"locked": false},
|
||||
"expected": true
|
||||
},
|
||||
{
|
||||
"name": "hide_missing_field_visible",
|
||||
"rules": {"hide": {"locked": [true]}},
|
||||
"values": {},
|
||||
"expected": true
|
||||
},
|
||||
{
|
||||
"name": "hide_or_combined_either_hides",
|
||||
"rules": {"hide": {"a": [1], "b": [2]}},
|
||||
"values": {"a": 0, "b": 2},
|
||||
"expected": false
|
||||
},
|
||||
{
|
||||
"name": "show_and_hide_both_required",
|
||||
"rules": {"show": {"enabled": [true]}, "hide": {"locked": [true]}},
|
||||
"values": {"enabled": true, "locked": false},
|
||||
"expected": true
|
||||
},
|
||||
{
|
||||
"name": "show_and_hide_show_passes_hide_blocks",
|
||||
"rules": {"show": {"enabled": [true]}, "hide": {"locked": [true]}},
|
||||
"values": {"enabled": true, "locked": true},
|
||||
"expected": false
|
||||
},
|
||||
{
|
||||
"name": "show_and_hide_show_fails_hide_irrelevant",
|
||||
"rules": {"show": {"enabled": [true]}, "hide": {"locked": [true]}},
|
||||
"values": {"enabled": false, "locked": false},
|
||||
"expected": false
|
||||
},
|
||||
{
|
||||
"name": "scalar_int_strict",
|
||||
"rules": {"show": {"sample_rate": [100]}},
|
||||
"values": {"sample_rate": 100},
|
||||
"expected": true
|
||||
},
|
||||
{
|
||||
"name": "scalar_int_mismatch",
|
||||
"rules": {"show": {"sample_rate": [100]}},
|
||||
"values": {"sample_rate": 99},
|
||||
"expected": false
|
||||
},
|
||||
{
|
||||
"name": "scalar_string_strict",
|
||||
"rules": {"show": {"http_method": ["POST", "PUT"]}},
|
||||
"values": {"http_method": "GET"},
|
||||
"expected": false
|
||||
}
|
||||
]
|
||||
}
|
||||
141
api/services/workflow/node_specs/end_call.py
Normal file
141
api/services/workflow/node_specs/end_call.py
Normal file
|
|
@ -0,0 +1,141 @@
|
|||
"""Spec for the End Call node — terminal node that wraps up a conversation
|
||||
and optionally extracts variables before hangup."""
|
||||
|
||||
from api.services.workflow.node_specs._base import (
|
||||
DisplayOptions,
|
||||
GraphConstraints,
|
||||
NodeCategory,
|
||||
NodeExample,
|
||||
NodeSpec,
|
||||
PropertyOption,
|
||||
PropertySpec,
|
||||
PropertyType,
|
||||
)
|
||||
|
||||
SPEC = NodeSpec(
|
||||
name="endCall",
|
||||
display_name="End Call",
|
||||
description="Closes the conversation and hangs up.",
|
||||
llm_hint=(
|
||||
"Terminal node that politely closes the conversation. Variable "
|
||||
"extraction can run before hangup. A workflow can have multiple "
|
||||
"endCall nodes reached via different edge conditions."
|
||||
),
|
||||
category=NodeCategory.call_node,
|
||||
icon="OctagonX",
|
||||
properties=[
|
||||
PropertySpec(
|
||||
name="name",
|
||||
type=PropertyType.string,
|
||||
display_name="Name",
|
||||
description=(
|
||||
"Short identifier shown in call logs. Should describe the "
|
||||
"ending context (e.g., 'Successful close', 'Polite decline')."
|
||||
),
|
||||
required=True,
|
||||
min_length=1,
|
||||
default="End Call",
|
||||
),
|
||||
PropertySpec(
|
||||
name="prompt",
|
||||
type=PropertyType.mention_textarea,
|
||||
display_name="Prompt",
|
||||
description=(
|
||||
"Agent system prompt for the closing exchange. Supports "
|
||||
"{{template_variables}} from extraction or pre-call fetch."
|
||||
),
|
||||
required=True,
|
||||
min_length=1,
|
||||
placeholder="Thank the caller and confirm next steps before ending the call.",
|
||||
),
|
||||
PropertySpec(
|
||||
name="add_global_prompt",
|
||||
type=PropertyType.boolean,
|
||||
display_name="Add Global Prompt",
|
||||
description=(
|
||||
"When true and a Global node exists, prepends the global "
|
||||
"prompt to this node's prompt at runtime."
|
||||
),
|
||||
default=False,
|
||||
),
|
||||
PropertySpec(
|
||||
name="extraction_enabled",
|
||||
type=PropertyType.boolean,
|
||||
display_name="Enable Variable Extraction",
|
||||
description=(
|
||||
"When true, runs an LLM extraction pass before hangup to "
|
||||
"capture variables from the conversation."
|
||||
),
|
||||
default=False,
|
||||
),
|
||||
PropertySpec(
|
||||
name="extraction_prompt",
|
||||
type=PropertyType.string,
|
||||
display_name="Extraction Prompt",
|
||||
description=(
|
||||
"Overall instructions guiding how variables should be "
|
||||
"extracted from the conversation."
|
||||
),
|
||||
display_options=DisplayOptions(show={"extraction_enabled": [True]}),
|
||||
editor="textarea",
|
||||
),
|
||||
PropertySpec(
|
||||
name="extraction_variables",
|
||||
type=PropertyType.fixed_collection,
|
||||
display_name="Variables to Extract",
|
||||
description=(
|
||||
"Each entry declares one variable to capture from the "
|
||||
"conversation, with its name, data type, and a per-variable "
|
||||
"extraction hint."
|
||||
),
|
||||
display_options=DisplayOptions(show={"extraction_enabled": [True]}),
|
||||
properties=[
|
||||
PropertySpec(
|
||||
name="name",
|
||||
type=PropertyType.string,
|
||||
display_name="Variable Name",
|
||||
description="snake_case identifier used downstream.",
|
||||
required=True,
|
||||
),
|
||||
PropertySpec(
|
||||
name="type",
|
||||
type=PropertyType.options,
|
||||
display_name="Type",
|
||||
description="The data type of the extracted value.",
|
||||
required=True,
|
||||
default="string",
|
||||
options=[
|
||||
PropertyOption(value="string", label="String"),
|
||||
PropertyOption(value="number", label="Number"),
|
||||
PropertyOption(value="boolean", label="Boolean"),
|
||||
],
|
||||
),
|
||||
PropertySpec(
|
||||
name="prompt",
|
||||
type=PropertyType.string,
|
||||
display_name="Extraction Hint",
|
||||
description=(
|
||||
"Per-variable hint describing what to look for in "
|
||||
"the conversation."
|
||||
),
|
||||
editor="textarea",
|
||||
),
|
||||
],
|
||||
),
|
||||
],
|
||||
examples=[
|
||||
NodeExample(
|
||||
name="successful_close",
|
||||
data={
|
||||
"name": "Successful Close",
|
||||
"prompt": "Confirm the appointment time, thank the caller, and end the call.",
|
||||
"add_global_prompt": False,
|
||||
},
|
||||
),
|
||||
],
|
||||
graph_constraints=GraphConstraints(
|
||||
min_incoming=1,
|
||||
min_outgoing=0,
|
||||
max_outgoing=0,
|
||||
),
|
||||
)
|
||||
77
api/services/workflow/node_specs/global_node.py
Normal file
77
api/services/workflow/node_specs/global_node.py
Normal file
|
|
@ -0,0 +1,77 @@
|
|||
"""Spec for the Global node — system-level instructions appended to every
|
||||
agent node that opts in via `add_global_prompt`."""
|
||||
|
||||
from api.services.workflow.node_specs._base import (
|
||||
GraphConstraints,
|
||||
NodeCategory,
|
||||
NodeExample,
|
||||
NodeSpec,
|
||||
PropertySpec,
|
||||
PropertyType,
|
||||
)
|
||||
|
||||
SPEC = NodeSpec(
|
||||
name="globalNode",
|
||||
display_name="Global Node",
|
||||
description="Persona/tone appended to every agent node's prompt.",
|
||||
llm_hint=(
|
||||
"System-level prompt appended to every prompted node whose "
|
||||
"`add_global_prompt` is true. Use it for persona, tone, and shared "
|
||||
"rules that apply across the entire conversation. At most one "
|
||||
"global node per workflow."
|
||||
),
|
||||
category=NodeCategory.global_node,
|
||||
icon="Globe",
|
||||
properties=[
|
||||
PropertySpec(
|
||||
name="name",
|
||||
type=PropertyType.string,
|
||||
display_name="Name",
|
||||
description=(
|
||||
"Short identifier shown in the canvas and call logs. Has no "
|
||||
"runtime effect."
|
||||
),
|
||||
required=True,
|
||||
min_length=1,
|
||||
default="Global Node",
|
||||
),
|
||||
PropertySpec(
|
||||
name="prompt",
|
||||
type=PropertyType.mention_textarea,
|
||||
display_name="Global Prompt",
|
||||
description=(
|
||||
"Text appended to every prompted node's system prompt when "
|
||||
"that node has `add_global_prompt=true`. Supports "
|
||||
"{{template_variables}}."
|
||||
),
|
||||
required=True,
|
||||
min_length=1,
|
||||
placeholder="You are a friendly assistant calling on behalf of {{company_name}}.",
|
||||
default=(
|
||||
"You are a helpful assistant whose mode of interaction with "
|
||||
"the user is voice. So don't use any special characters which "
|
||||
"can not be pronounced. Use short sentences and simple language."
|
||||
),
|
||||
),
|
||||
],
|
||||
examples=[
|
||||
NodeExample(
|
||||
name="basic_persona",
|
||||
description="Establishes a consistent persona across the call.",
|
||||
data={
|
||||
"name": "Persona",
|
||||
"prompt": (
|
||||
"You are Sarah, a polite and warm representative from "
|
||||
"Acme Corp. Always thank the caller for their time and "
|
||||
"speak in short conversational sentences."
|
||||
),
|
||||
},
|
||||
),
|
||||
],
|
||||
graph_constraints=GraphConstraints(
|
||||
min_incoming=0,
|
||||
max_incoming=0,
|
||||
min_outgoing=0,
|
||||
max_outgoing=0,
|
||||
),
|
||||
)
|
||||
196
api/services/workflow/node_specs/qa.py
Normal file
196
api/services/workflow/node_specs/qa.py
Normal file
|
|
@ -0,0 +1,196 @@
|
|||
"""Spec for the QA Analysis node — runs an LLM quality review on the call
|
||||
transcript after completion."""
|
||||
|
||||
from api.services.workflow.node_specs._base import (
|
||||
DisplayOptions,
|
||||
NodeCategory,
|
||||
NodeExample,
|
||||
NodeSpec,
|
||||
PropertyOption,
|
||||
PropertySpec,
|
||||
PropertyType,
|
||||
)
|
||||
|
||||
DEFAULT_QA_SYSTEM_PROMPT = """You are a QA analyst evaluating a specific segment of a voice AI conversation.
|
||||
|
||||
## Node Purpose
|
||||
{{node_summary}}
|
||||
|
||||
## Previous Conversation Context (For start of conversation, previous conversation summary can be empty.)
|
||||
{{previous_conversation_summary}}
|
||||
|
||||
## Tags to evaluate
|
||||
|
||||
Examine the conversation carefully and identify which of the following tags apply:
|
||||
|
||||
- UNCLEAR_CONVERSATION - The conversation is not coherent or clear, messages don't connect logically
|
||||
- ASSISTANT_IN_LOOP - The assistant asks the same question multiple times or gets stuck repeating itself
|
||||
- ASSISTANT_REPLY_IMPROPER - The assistant did not reply properly to the user's question/query or seems confused by what the user said
|
||||
- USER_FRUSTRATED - The user seems angry, frustrated, or is complaining about something in the call
|
||||
- USER_NOT_UNDERSTANDING - The user explicitly says they don't understand or repeatedly asks for clarification
|
||||
- HEARING_ISSUES - Either party can't hear the other ("hello?", "are you there?", "can you hear me?")
|
||||
- DEAD_AIR - Unusually long silences in the conversation (use the timestamps to judge)
|
||||
- USER_REQUESTING_FEATURE - The user asks for something the assistant can't fulfill
|
||||
- ASSISTANT_LACKS_EMPATHY - The assistant ignores the user's personal situation or emotional state and continues pitching or pushing the agenda.
|
||||
- USER_DETECTS_AI - The user suspects or identifies that they are talking to an AI/robot/bot rather than a real human.
|
||||
|
||||
## Call metrics (pre-computed)
|
||||
|
||||
Use these alongside the transcript for your analysis:
|
||||
{{metrics}}
|
||||
|
||||
## Output format
|
||||
|
||||
Return ONLY a valid JSON object (no markdown):
|
||||
{
|
||||
"tags": [
|
||||
{
|
||||
"tag": "TAG_NAME",
|
||||
"reason": "Short reason with evidence from the transcript"
|
||||
}
|
||||
],
|
||||
"overall_sentiment": "positive|neutral|negative",
|
||||
"call_quality_score": <1-10>,
|
||||
"summary": "1-2 sentence summary of this segment"
|
||||
}
|
||||
|
||||
If no tags apply, return an empty tags list. Always provide sentiment, score, and summary."""
|
||||
|
||||
|
||||
SPEC = NodeSpec(
|
||||
name="qa",
|
||||
display_name="QA Analysis",
|
||||
description="Run LLM quality analysis on the call transcript.",
|
||||
llm_hint=(
|
||||
"Runs an LLM quality review on the call transcript after completion. "
|
||||
"Per-node analysis splits the conversation by node and evaluates each "
|
||||
"segment against the configured system prompt. Sampling, minimum "
|
||||
"duration, and voicemail filters are supported."
|
||||
),
|
||||
category=NodeCategory.integration,
|
||||
icon="ClipboardCheck",
|
||||
properties=[
|
||||
PropertySpec(
|
||||
name="name",
|
||||
type=PropertyType.string,
|
||||
display_name="Name",
|
||||
description="Short identifier for this QA configuration.",
|
||||
required=True,
|
||||
min_length=1,
|
||||
default="QA Analysis",
|
||||
),
|
||||
PropertySpec(
|
||||
name="qa_enabled",
|
||||
type=PropertyType.boolean,
|
||||
display_name="Enabled",
|
||||
description="When false, the QA run is skipped.",
|
||||
default=True,
|
||||
),
|
||||
PropertySpec(
|
||||
name="qa_system_prompt",
|
||||
type=PropertyType.string,
|
||||
display_name="System Prompt",
|
||||
description=(
|
||||
"Instructions to the QA reviewer LLM. Supports placeholders: "
|
||||
"`{node_summary}`, `{previous_conversation_summary}`, "
|
||||
"`{transcript}`, `{metrics}`."
|
||||
),
|
||||
editor="textarea",
|
||||
default=DEFAULT_QA_SYSTEM_PROMPT,
|
||||
),
|
||||
PropertySpec(
|
||||
name="qa_min_call_duration",
|
||||
type=PropertyType.number,
|
||||
display_name="Minimum Call Duration (seconds)",
|
||||
description="Calls shorter than this are skipped.",
|
||||
default=15,
|
||||
min_value=0,
|
||||
),
|
||||
PropertySpec(
|
||||
name="qa_voicemail_calls",
|
||||
type=PropertyType.boolean,
|
||||
display_name="Include Voicemail Calls",
|
||||
description="When false, calls flagged as voicemail are skipped.",
|
||||
default=False,
|
||||
),
|
||||
PropertySpec(
|
||||
name="qa_sample_rate",
|
||||
type=PropertyType.number,
|
||||
display_name="Sample Rate (%)",
|
||||
description=(
|
||||
"Percent of eligible calls QA'd. 100 means every call; lower "
|
||||
"values use random sampling."
|
||||
),
|
||||
default=100,
|
||||
min_value=1,
|
||||
max_value=100,
|
||||
),
|
||||
# ---- LLM configuration ----
|
||||
PropertySpec(
|
||||
name="qa_use_workflow_llm",
|
||||
type=PropertyType.boolean,
|
||||
display_name="Use Workflow's LLM",
|
||||
description=(
|
||||
"When true, the QA pass uses the same LLM the workflow runs "
|
||||
"with. Set false to specify a separate provider/model."
|
||||
),
|
||||
default=True,
|
||||
),
|
||||
PropertySpec(
|
||||
name="qa_provider",
|
||||
type=PropertyType.options,
|
||||
display_name="QA LLM Provider",
|
||||
description="LLM provider used for the QA pass.",
|
||||
display_options=DisplayOptions(show={"qa_use_workflow_llm": [False]}),
|
||||
options=[
|
||||
PropertyOption(value="openai", label="OpenAI"),
|
||||
PropertyOption(value="azure", label="Azure OpenAI"),
|
||||
PropertyOption(value="openrouter", label="OpenRouter"),
|
||||
PropertyOption(value="anthropic", label="Anthropic"),
|
||||
],
|
||||
),
|
||||
PropertySpec(
|
||||
name="qa_model",
|
||||
type=PropertyType.string,
|
||||
display_name="QA Model",
|
||||
description=(
|
||||
"Model identifier (e.g., 'gpt-4o', 'claude-sonnet-4-6'). "
|
||||
"Provider-specific."
|
||||
),
|
||||
display_options=DisplayOptions(show={"qa_use_workflow_llm": [False]}),
|
||||
default="default",
|
||||
),
|
||||
PropertySpec(
|
||||
name="qa_api_key",
|
||||
type=PropertyType.string,
|
||||
display_name="API Key",
|
||||
description="API key for the chosen provider.",
|
||||
display_options=DisplayOptions(show={"qa_use_workflow_llm": [False]}),
|
||||
),
|
||||
PropertySpec(
|
||||
name="qa_endpoint",
|
||||
type=PropertyType.url,
|
||||
display_name="Azure Endpoint",
|
||||
description="Required for the Azure provider.",
|
||||
display_options=DisplayOptions(
|
||||
show={"qa_use_workflow_llm": [False], "qa_provider": ["azure"]}
|
||||
),
|
||||
),
|
||||
],
|
||||
examples=[
|
||||
NodeExample(
|
||||
name="basic_qa",
|
||||
data={
|
||||
"name": "Compliance Check",
|
||||
"qa_enabled": True,
|
||||
"qa_system_prompt": (
|
||||
"You are a compliance reviewer. Review the transcript and "
|
||||
"produce a JSON object with `tags`, `summary`, "
|
||||
"`call_quality_score`, and `overall_sentiment`."
|
||||
),
|
||||
"qa_min_call_duration": 30,
|
||||
"qa_sample_rate": 100,
|
||||
},
|
||||
),
|
||||
],
|
||||
)
|
||||
248
api/services/workflow/node_specs/start_call.py
Normal file
248
api/services/workflow/node_specs/start_call.py
Normal file
|
|
@ -0,0 +1,248 @@
|
|||
"""Spec for the Start Call node — the single entry point of every workflow.
|
||||
Carries greeting, pre-call data fetch, and the same prompt/extraction/tools
|
||||
fields as agent nodes."""
|
||||
|
||||
from api.services.workflow.node_specs._base import (
|
||||
DisplayOptions,
|
||||
GraphConstraints,
|
||||
NodeCategory,
|
||||
NodeExample,
|
||||
NodeSpec,
|
||||
PropertyOption,
|
||||
PropertySpec,
|
||||
PropertyType,
|
||||
)
|
||||
|
||||
SPEC = NodeSpec(
|
||||
name="startCall",
|
||||
display_name="Start Call",
|
||||
description="Entry point of the workflow — plays a greeting and opens the conversation.",
|
||||
llm_hint=(
|
||||
"The entry point of every workflow (exactly one required). Plays an "
|
||||
"optional greeting, can fetch context from an external API before "
|
||||
"the call begins, and executes the first conversational turn."
|
||||
),
|
||||
category=NodeCategory.call_node,
|
||||
icon="Play",
|
||||
properties=[
|
||||
PropertySpec(
|
||||
name="name",
|
||||
type=PropertyType.string,
|
||||
display_name="Name",
|
||||
description="Short identifier shown in the canvas and call logs.",
|
||||
required=True,
|
||||
min_length=1,
|
||||
default="Start Call",
|
||||
),
|
||||
# ---- Greeting (variant via greeting_type) ----
|
||||
PropertySpec(
|
||||
name="greeting_type",
|
||||
type=PropertyType.options,
|
||||
display_name="Greeting Type",
|
||||
description=(
|
||||
"Whether the optional greeting is spoken via TTS from text "
|
||||
"or played from a pre-recorded audio file."
|
||||
),
|
||||
default="text",
|
||||
options=[
|
||||
PropertyOption(value="text", label="Text (TTS)"),
|
||||
PropertyOption(value="audio", label="Pre-recorded Audio"),
|
||||
],
|
||||
),
|
||||
PropertySpec(
|
||||
name="greeting",
|
||||
type=PropertyType.string,
|
||||
display_name="Greeting Text",
|
||||
description=(
|
||||
"Text spoken via TTS at the start of the call. Supports "
|
||||
"{{template_variables}}. Leave empty to skip the greeting."
|
||||
),
|
||||
display_options=DisplayOptions(show={"greeting_type": ["text"]}),
|
||||
editor="textarea",
|
||||
placeholder="Hi {{first_name}}, this is Sarah from Acme.",
|
||||
),
|
||||
PropertySpec(
|
||||
name="greeting_recording_id",
|
||||
type=PropertyType.recording_ref,
|
||||
display_name="Greeting Recording",
|
||||
description="Pre-recorded audio file played at the start of the call.",
|
||||
llm_hint=(
|
||||
"Value is the `recording_id` string. Use the `list_recordings` "
|
||||
"MCP tool to discover available recordings."
|
||||
),
|
||||
display_options=DisplayOptions(show={"greeting_type": ["audio"]}),
|
||||
),
|
||||
PropertySpec(
|
||||
name="prompt",
|
||||
type=PropertyType.mention_textarea,
|
||||
display_name="Prompt",
|
||||
description=(
|
||||
"Agent system prompt for the opening turn. Supports "
|
||||
"{{template_variables}} from pre-call fetch and the initial context."
|
||||
),
|
||||
required=True,
|
||||
min_length=1,
|
||||
placeholder="Greet the caller warmly and ask how you can help today.",
|
||||
),
|
||||
# ---- Behavior toggles ----
|
||||
PropertySpec(
|
||||
name="allow_interrupt",
|
||||
type=PropertyType.boolean,
|
||||
display_name="Allow Interruption",
|
||||
description=("When true, the user can interrupt the agent mid-utterance."),
|
||||
default=False,
|
||||
),
|
||||
PropertySpec(
|
||||
name="add_global_prompt",
|
||||
type=PropertyType.boolean,
|
||||
display_name="Add Global Prompt",
|
||||
description=(
|
||||
"When true and a Global node exists, prepends the global "
|
||||
"prompt to this node's prompt at runtime."
|
||||
),
|
||||
default=True,
|
||||
),
|
||||
PropertySpec(
|
||||
name="delayed_start",
|
||||
type=PropertyType.boolean,
|
||||
display_name="Delayed Start",
|
||||
description=(
|
||||
"When true, the agent waits before speaking after pickup. "
|
||||
"Useful for outbound calls where the called party needs a "
|
||||
"moment to settle."
|
||||
),
|
||||
default=False,
|
||||
),
|
||||
PropertySpec(
|
||||
name="delayed_start_duration",
|
||||
type=PropertyType.number,
|
||||
display_name="Delay Duration (seconds)",
|
||||
description="Seconds to wait before the agent speaks. 0.1–10.",
|
||||
default=2.0,
|
||||
min_value=0.1,
|
||||
max_value=10.0,
|
||||
display_options=DisplayOptions(show={"delayed_start": [True]}),
|
||||
),
|
||||
# ---- Variable extraction ----
|
||||
PropertySpec(
|
||||
name="extraction_enabled",
|
||||
type=PropertyType.boolean,
|
||||
display_name="Enable Variable Extraction",
|
||||
description=(
|
||||
"When true, runs an LLM extraction pass on transition out of "
|
||||
"this node to capture variables from the opening turn."
|
||||
),
|
||||
default=False,
|
||||
),
|
||||
PropertySpec(
|
||||
name="extraction_prompt",
|
||||
type=PropertyType.string,
|
||||
display_name="Extraction Prompt",
|
||||
description="Overall instructions guiding variable extraction.",
|
||||
display_options=DisplayOptions(show={"extraction_enabled": [True]}),
|
||||
editor="textarea",
|
||||
),
|
||||
PropertySpec(
|
||||
name="extraction_variables",
|
||||
type=PropertyType.fixed_collection,
|
||||
display_name="Variables to Extract",
|
||||
description=(
|
||||
"Each entry declares one variable to capture, with its name, "
|
||||
"data type, and per-variable extraction hint."
|
||||
),
|
||||
display_options=DisplayOptions(show={"extraction_enabled": [True]}),
|
||||
properties=[
|
||||
PropertySpec(
|
||||
name="name",
|
||||
type=PropertyType.string,
|
||||
display_name="Variable Name",
|
||||
description="snake_case identifier used downstream.",
|
||||
required=True,
|
||||
),
|
||||
PropertySpec(
|
||||
name="type",
|
||||
type=PropertyType.options,
|
||||
display_name="Type",
|
||||
description="Data type of the extracted value.",
|
||||
required=True,
|
||||
default="string",
|
||||
options=[
|
||||
PropertyOption(value="string", label="String"),
|
||||
PropertyOption(value="number", label="Number"),
|
||||
PropertyOption(value="boolean", label="Boolean"),
|
||||
],
|
||||
),
|
||||
PropertySpec(
|
||||
name="prompt",
|
||||
type=PropertyType.string,
|
||||
display_name="Extraction Hint",
|
||||
description="Per-variable hint describing what to look for.",
|
||||
editor="textarea",
|
||||
),
|
||||
],
|
||||
),
|
||||
# ---- Tools / documents ----
|
||||
PropertySpec(
|
||||
name="tool_uuids",
|
||||
type=PropertyType.tool_refs,
|
||||
display_name="Tools",
|
||||
description="Tools the agent can invoke during the opening turn.",
|
||||
llm_hint="List of tool UUIDs from `list_tools`.",
|
||||
),
|
||||
PropertySpec(
|
||||
name="document_uuids",
|
||||
type=PropertyType.document_refs,
|
||||
display_name="Knowledge Base Documents",
|
||||
description="Documents the agent can reference.",
|
||||
llm_hint="List of document UUIDs from `list_documents`.",
|
||||
),
|
||||
# ---- Pre-call data fetch (advanced) ----
|
||||
PropertySpec(
|
||||
name="pre_call_fetch_enabled",
|
||||
type=PropertyType.boolean,
|
||||
display_name="Pre-Call Data Fetch",
|
||||
description=(
|
||||
"When true, makes a POST request to an external API before "
|
||||
"the call starts and merges the JSON response into the call "
|
||||
"context as template variables."
|
||||
),
|
||||
default=False,
|
||||
),
|
||||
PropertySpec(
|
||||
name="pre_call_fetch_url",
|
||||
type=PropertyType.url,
|
||||
display_name="Endpoint URL",
|
||||
description=(
|
||||
"URL the pre-call POST request is sent to. The request body "
|
||||
"includes caller and called numbers."
|
||||
),
|
||||
display_options=DisplayOptions(show={"pre_call_fetch_enabled": [True]}),
|
||||
placeholder="https://api.example.com/customer-lookup",
|
||||
),
|
||||
PropertySpec(
|
||||
name="pre_call_fetch_credential_uuid",
|
||||
type=PropertyType.credential_ref,
|
||||
display_name="Authentication",
|
||||
description="Optional credential attached to the pre-call request.",
|
||||
llm_hint="Credential UUID from `list_credentials`.",
|
||||
display_options=DisplayOptions(show={"pre_call_fetch_enabled": [True]}),
|
||||
),
|
||||
],
|
||||
examples=[
|
||||
NodeExample(
|
||||
name="warm_greeting",
|
||||
data={
|
||||
"name": "Greeting",
|
||||
"prompt": "Greet warmly and ask the caller's reason for calling.",
|
||||
"greeting_type": "text",
|
||||
"greeting": "Hi {{first_name}}, this is Sarah from Acme.",
|
||||
"allow_interrupt": True,
|
||||
},
|
||||
),
|
||||
],
|
||||
graph_constraints=GraphConstraints(
|
||||
min_incoming=0,
|
||||
max_incoming=0,
|
||||
min_outgoing=1,
|
||||
),
|
||||
)
|
||||
61
api/services/workflow/node_specs/trigger.py
Normal file
61
api/services/workflow/node_specs/trigger.py
Normal file
|
|
@ -0,0 +1,61 @@
|
|||
"""Spec for the API Trigger node — exposes a public webhook URL that
|
||||
external systems can hit to launch the workflow."""
|
||||
|
||||
from api.services.workflow.node_specs._base import (
|
||||
GraphConstraints,
|
||||
NodeCategory,
|
||||
NodeExample,
|
||||
NodeSpec,
|
||||
PropertySpec,
|
||||
PropertyType,
|
||||
)
|
||||
|
||||
SPEC = NodeSpec(
|
||||
name="trigger",
|
||||
display_name="API Trigger",
|
||||
description="Public HTTP endpoint that launches the workflow.",
|
||||
llm_hint=(
|
||||
"Exposes a public HTTP POST endpoint. External systems call the URL "
|
||||
"(derived from the auto-generated `trigger_path`) to launch this "
|
||||
"workflow. Requires an API key in the `X-API-Key` header."
|
||||
),
|
||||
category=NodeCategory.trigger,
|
||||
icon="Webhook",
|
||||
properties=[
|
||||
PropertySpec(
|
||||
name="name",
|
||||
type=PropertyType.string,
|
||||
display_name="Name",
|
||||
description="Short identifier shown in the canvas. No runtime effect.",
|
||||
required=True,
|
||||
min_length=1,
|
||||
default="API Trigger",
|
||||
),
|
||||
PropertySpec(
|
||||
name="enabled",
|
||||
type=PropertyType.boolean,
|
||||
display_name="Enabled",
|
||||
description="When false, the trigger URL returns 404.",
|
||||
default=True,
|
||||
),
|
||||
PropertySpec(
|
||||
name="trigger_path",
|
||||
type=PropertyType.string,
|
||||
display_name="Trigger Path",
|
||||
description=(
|
||||
"Auto-generated UUID-style path segment that uniquely "
|
||||
"identifies this trigger. Do not edit manually."
|
||||
),
|
||||
),
|
||||
],
|
||||
examples=[
|
||||
NodeExample(
|
||||
name="default",
|
||||
data={"name": "Inbound Trigger", "enabled": True},
|
||||
),
|
||||
],
|
||||
graph_constraints=GraphConstraints(
|
||||
min_incoming=0,
|
||||
max_incoming=0,
|
||||
),
|
||||
)
|
||||
135
api/services/workflow/node_specs/webhook.py
Normal file
135
api/services/workflow/node_specs/webhook.py
Normal file
|
|
@ -0,0 +1,135 @@
|
|||
"""Spec for the Webhook node — sends an HTTP request to an external system
|
||||
after the workflow completes."""
|
||||
|
||||
from api.services.workflow.node_specs._base import (
|
||||
NodeCategory,
|
||||
NodeExample,
|
||||
NodeSpec,
|
||||
PropertyOption,
|
||||
PropertySpec,
|
||||
PropertyType,
|
||||
)
|
||||
|
||||
SPEC = NodeSpec(
|
||||
name="webhook",
|
||||
display_name="Webhook",
|
||||
description="Send HTTP request after the workflow completes.",
|
||||
llm_hint=(
|
||||
"Sends an HTTP request to an external system after the workflow "
|
||||
"completes. The payload is a Jinja-templated JSON body with access "
|
||||
"to `workflow_run_id`, `initial_context`, `gathered_context`, "
|
||||
"`annotations`, and call metadata."
|
||||
),
|
||||
category=NodeCategory.integration,
|
||||
icon="Link2",
|
||||
properties=[
|
||||
PropertySpec(
|
||||
name="name",
|
||||
type=PropertyType.string,
|
||||
display_name="Name",
|
||||
description="Short identifier shown in the canvas and run logs.",
|
||||
required=True,
|
||||
min_length=1,
|
||||
default="Webhook",
|
||||
),
|
||||
PropertySpec(
|
||||
name="enabled",
|
||||
type=PropertyType.boolean,
|
||||
display_name="Enabled",
|
||||
description="When false, the webhook is skipped at run time.",
|
||||
default=True,
|
||||
),
|
||||
PropertySpec(
|
||||
name="http_method",
|
||||
type=PropertyType.options,
|
||||
display_name="HTTP Method",
|
||||
description="HTTP verb used for the outbound request.",
|
||||
default="POST",
|
||||
options=[
|
||||
PropertyOption(value="GET", label="GET"),
|
||||
PropertyOption(value="POST", label="POST"),
|
||||
PropertyOption(value="PUT", label="PUT"),
|
||||
PropertyOption(value="PATCH", label="PATCH"),
|
||||
PropertyOption(value="DELETE", label="DELETE"),
|
||||
],
|
||||
),
|
||||
PropertySpec(
|
||||
name="endpoint_url",
|
||||
type=PropertyType.url,
|
||||
display_name="Endpoint URL",
|
||||
description="URL the request is sent to.",
|
||||
placeholder="https://api.example.com/webhook",
|
||||
),
|
||||
PropertySpec(
|
||||
name="credential_uuid",
|
||||
type=PropertyType.credential_ref,
|
||||
display_name="Authentication",
|
||||
description="Optional credential applied as the Authorization header.",
|
||||
llm_hint="Credential UUID from `list_credentials`.",
|
||||
),
|
||||
PropertySpec(
|
||||
name="custom_headers",
|
||||
type=PropertyType.fixed_collection,
|
||||
display_name="Custom Headers",
|
||||
description="Additional HTTP headers to include with the request.",
|
||||
properties=[
|
||||
PropertySpec(
|
||||
name="key",
|
||||
type=PropertyType.string,
|
||||
display_name="Header Name",
|
||||
description="HTTP header name (e.g., 'X-Source').",
|
||||
required=True,
|
||||
),
|
||||
PropertySpec(
|
||||
name="value",
|
||||
type=PropertyType.string,
|
||||
display_name="Header Value",
|
||||
description="Header value (supports {{template_variables}}).",
|
||||
required=True,
|
||||
),
|
||||
],
|
||||
),
|
||||
PropertySpec(
|
||||
name="payload_template",
|
||||
type=PropertyType.json,
|
||||
display_name="Payload Template",
|
||||
description=(
|
||||
"JSON body of the request. Values are Jinja-rendered against "
|
||||
"the run context — `{{workflow_run_id}}`, "
|
||||
"`{{gathered_context.foo}}`, `{{annotations.qa_xxx}}`, etc."
|
||||
),
|
||||
default={
|
||||
"call_id": "{{workflow_run_id}}",
|
||||
"first_name": "{{initial_context.first_name}}",
|
||||
"rsvp": "{{gathered_context.rsvp}}",
|
||||
"duration": "{{cost_info.call_duration_seconds}}",
|
||||
"recording_url": "{{recording_url}}",
|
||||
"transcript_url": "{{transcript_url}}",
|
||||
},
|
||||
),
|
||||
PropertySpec(
|
||||
name="retry_config",
|
||||
type=PropertyType.json,
|
||||
display_name="Retry Configuration",
|
||||
description=(
|
||||
"Optional retry settings: `enabled` (bool), `max_retries` "
|
||||
"(int), `retry_delay_seconds` (int)."
|
||||
),
|
||||
),
|
||||
],
|
||||
examples=[
|
||||
NodeExample(
|
||||
name="post_to_crm",
|
||||
data={
|
||||
"name": "Notify CRM",
|
||||
"enabled": True,
|
||||
"http_method": "POST",
|
||||
"endpoint_url": "https://crm.example.com/calls",
|
||||
"payload_template": {
|
||||
"run_id": "{{workflow_run_id}}",
|
||||
"outcome": "{{gathered_context.call_disposition}}",
|
||||
},
|
||||
},
|
||||
),
|
||||
],
|
||||
)
|
||||
|
|
@ -243,6 +243,7 @@ class PipecatEngine:
|
|||
else 16000,
|
||||
queue_frame=self._transport_output.queue_frame,
|
||||
transcript=result.transcript,
|
||||
persist_to_logs=True,
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
|
|
@ -252,7 +253,11 @@ class PipecatEngine:
|
|||
logger.info(f"Playing transition speech: {transition_speech}")
|
||||
self._queued_speech_mute_state = "waiting"
|
||||
await self.task.queue_frame(
|
||||
TTSSpeakFrame(transition_speech, append_to_context=False)
|
||||
TTSSpeakFrame(
|
||||
transition_speech,
|
||||
append_to_context=False,
|
||||
persist_to_logs=True,
|
||||
)
|
||||
)
|
||||
|
||||
# Set context for the new node, so that when the function call result
|
||||
|
|
|
|||
|
|
@ -100,6 +100,7 @@ class CustomToolManager:
|
|||
else 16000,
|
||||
queue_frame=self._engine._transport_output.queue_frame,
|
||||
transcript=result.transcript,
|
||||
persist_to_logs=True,
|
||||
)
|
||||
return True
|
||||
else:
|
||||
|
|
@ -110,7 +111,11 @@ class CustomToolManager:
|
|||
custom_message = config.get("customMessage", "")
|
||||
if custom_message:
|
||||
await self._engine.task.queue_frame(
|
||||
TTSSpeakFrame(custom_message, append_to_context=append_to_context)
|
||||
TTSSpeakFrame(
|
||||
custom_message,
|
||||
append_to_context=append_to_context,
|
||||
persist_to_logs=True,
|
||||
)
|
||||
)
|
||||
return True
|
||||
|
||||
|
|
@ -311,6 +316,7 @@ class CustomToolManager:
|
|||
else 16000,
|
||||
queue_frame=self._engine._transport_output.queue_frame,
|
||||
transcript=result.transcript,
|
||||
persist_to_logs=True,
|
||||
)
|
||||
elif custom_message:
|
||||
logger.info(
|
||||
|
|
@ -318,7 +324,11 @@ class CustomToolManager:
|
|||
)
|
||||
self._engine._queued_speech_mute_state = "waiting"
|
||||
await self._engine.task.queue_frame(
|
||||
TTSSpeakFrame(custom_message, append_to_context=False)
|
||||
TTSSpeakFrame(
|
||||
custom_message,
|
||||
append_to_context=False,
|
||||
persist_to_logs=True,
|
||||
)
|
||||
)
|
||||
|
||||
result = await execute_http_tool(
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ from loguru import logger
|
|||
from api.db.models import WorkflowRunModel
|
||||
from api.services.gen_ai.json_parser import parse_llm_json
|
||||
from api.services.pipecat.service_factory import create_llm_service_from_provider
|
||||
from api.services.workflow.dto import QANodeData
|
||||
from api.services.workflow.qa.conversation import (
|
||||
build_conversation_structure,
|
||||
format_transcript,
|
||||
|
|
@ -77,7 +78,7 @@ async def _generate_conversation_summary(
|
|||
|
||||
|
||||
async def run_per_node_qa_analysis(
|
||||
qa_node_data: dict[str, Any],
|
||||
qa_data: QANodeData,
|
||||
workflow_run: WorkflowRunModel,
|
||||
workflow_run_id: int,
|
||||
workflow_definition: dict,
|
||||
|
|
@ -106,18 +107,16 @@ async def run_per_node_qa_analysis(
|
|||
logger.info(
|
||||
f"Events lack node_id for run {workflow_run_id}, falling back to whole-call QA"
|
||||
)
|
||||
return await _run_whole_call_qa_analysis(
|
||||
qa_node_data, workflow_run, workflow_run_id
|
||||
)
|
||||
return await _run_whole_call_qa_analysis(qa_data, workflow_run, workflow_run_id)
|
||||
|
||||
system_prompt = qa_node_data.get("qa_system_prompt", "")
|
||||
system_prompt = qa_data.qa_system_prompt or ""
|
||||
if not system_prompt:
|
||||
logger.warning("No system prompt defined for QA Node")
|
||||
return {"error": "no_system_prompt", "node_results": {}}
|
||||
|
||||
# Resolve LLM config
|
||||
provider, model, api_key, service_kwargs = await resolve_llm_config(
|
||||
qa_node_data, workflow_run
|
||||
qa_data, workflow_run
|
||||
)
|
||||
if not api_key:
|
||||
logger.warning(
|
||||
|
|
@ -127,7 +126,7 @@ async def run_per_node_qa_analysis(
|
|||
|
||||
# Ensure node summaries
|
||||
node_summaries = await ensure_node_summaries(
|
||||
workflow_definition, definition_id, workflow_run, qa_node_data
|
||||
workflow_definition, definition_id, workflow_run, qa_data
|
||||
)
|
||||
|
||||
# Set up Langfuse tracing
|
||||
|
|
@ -228,7 +227,7 @@ async def run_per_node_qa_analysis(
|
|||
|
||||
|
||||
async def _run_whole_call_qa_analysis(
|
||||
qa_node_data: dict[str, Any],
|
||||
qa_data: QANodeData,
|
||||
workflow_run: WorkflowRunModel,
|
||||
workflow_run_id: int,
|
||||
) -> dict[str, Any]:
|
||||
|
|
@ -254,13 +253,13 @@ async def _run_whole_call_qa_analysis(
|
|||
metrics = compute_call_metrics(rtf_events, call_duration)
|
||||
|
||||
# Resolve LLM config
|
||||
system_prompt = qa_node_data.get("qa_system_prompt", "")
|
||||
system_prompt = qa_data.qa_system_prompt or ""
|
||||
if not system_prompt:
|
||||
logger.warning("No system prompt defined for QA Node")
|
||||
return {"error": "no_system_prompt", "node_results": {}}
|
||||
|
||||
provider, model, api_key, service_kwargs = await resolve_llm_config(
|
||||
qa_node_data, workflow_run
|
||||
qa_data, workflow_run
|
||||
)
|
||||
|
||||
if not api_key:
|
||||
|
|
|
|||
|
|
@ -4,10 +4,11 @@ import random
|
|||
|
||||
from api.db import db_client
|
||||
from api.db.models import WorkflowRunModel
|
||||
from api.services.workflow.dto import QANodeData
|
||||
|
||||
|
||||
async def resolve_llm_config(
|
||||
qa_node_data: dict, workflow_run: WorkflowRunModel
|
||||
qa_data: QANodeData, workflow_run: WorkflowRunModel
|
||||
) -> tuple[str, str, str, dict]:
|
||||
"""Resolve the LLM provider, model, API key, and extra kwargs for QA analysis.
|
||||
|
||||
|
|
@ -18,24 +19,23 @@ async def resolve_llm_config(
|
|||
(provider, model, api_key, service_kwargs) tuple — service_kwargs can be
|
||||
passed directly to create_llm_service_from_provider as keyword arguments.
|
||||
"""
|
||||
if not qa_node_data.get("qa_use_workflow_llm", True):
|
||||
provider = qa_node_data.get("qa_provider", "openai")
|
||||
if not qa_data.qa_use_workflow_llm:
|
||||
provider = qa_data.qa_provider or "openai"
|
||||
kwargs = {}
|
||||
if provider == "azure":
|
||||
kwargs["endpoint"] = qa_node_data.get("qa_endpoint", "")
|
||||
kwargs["endpoint"] = qa_data.qa_endpoint or ""
|
||||
return (
|
||||
provider,
|
||||
qa_node_data.get("qa_model"),
|
||||
qa_node_data.get("qa_api_key"),
|
||||
qa_data.qa_model,
|
||||
qa_data.qa_api_key,
|
||||
kwargs,
|
||||
)
|
||||
|
||||
# Fall back to user's configured LLM
|
||||
provider, model, api_key, kwargs = await resolve_user_llm_config(workflow_run)
|
||||
|
||||
qa_model = qa_node_data.get("qa_model", "default")
|
||||
if qa_model and qa_model != "default":
|
||||
model = qa_model
|
||||
if qa_data.qa_model and qa_data.qa_model != "default":
|
||||
model = qa_data.qa_model
|
||||
|
||||
return provider, model, api_key, kwargs
|
||||
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ from loguru import logger
|
|||
from api.db import db_client
|
||||
from api.db.models import WorkflowRunModel
|
||||
from api.services.pipecat.service_factory import create_llm_service_from_provider
|
||||
from api.services.workflow.dto import NodeType
|
||||
from api.services.workflow.dto import NodeType, QANodeData
|
||||
from api.services.workflow.qa.llm_config import resolve_llm_config
|
||||
from api.services.workflow.qa.tracing import create_node_summary_trace
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
|
|
@ -48,7 +48,7 @@ async def ensure_node_summaries(
|
|||
workflow_definition: dict,
|
||||
definition_id: int | None,
|
||||
workflow_run: WorkflowRunModel,
|
||||
qa_node_data: dict,
|
||||
qa_data: QANodeData,
|
||||
) -> dict[str, Any]:
|
||||
"""Ensure every agentNode/startCall node has a summary in the definition.
|
||||
|
||||
|
|
@ -69,7 +69,7 @@ async def ensure_node_summaries(
|
|||
return existing_summaries
|
||||
|
||||
provider, model, api_key, service_kwargs = await resolve_llm_config(
|
||||
qa_node_data, workflow_run
|
||||
qa_data, workflow_run
|
||||
)
|
||||
if not api_key:
|
||||
logger.warning("No API key for node summary generation, skipping")
|
||||
|
|
|
|||
|
|
@ -242,7 +242,6 @@ async def _perform_retrieval(
|
|||
|
||||
embedding_service = OpenAIEmbeddingService(
|
||||
db_client=db_client,
|
||||
max_tokens=128,
|
||||
api_key=embeddings_api_key,
|
||||
model_id=embeddings_model or "text-embedding-3-small",
|
||||
base_url=embeddings_base_url,
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ import re
|
|||
from collections import Counter
|
||||
from typing import Dict, List, Set
|
||||
|
||||
from api.services.workflow.dto import EdgeDataDTO, NodeDataDTO, NodeType, ReactFlowDTO
|
||||
from api.services.workflow.dto import EdgeDataDTO, NodeType, ReactFlowDTO
|
||||
from api.services.workflow.errors import ItemKind, WorkflowError
|
||||
|
||||
# Regex for matching {{ variable }} template placeholders.
|
||||
|
|
@ -61,32 +61,38 @@ class Edge:
|
|||
|
||||
|
||||
class Node:
|
||||
def __init__(self, id: str, node_type: NodeType, data: NodeDataDTO):
|
||||
def __init__(self, id: str, node_type: NodeType, data):
|
||||
self.id, self.node_type, self.data = id, node_type, data
|
||||
self.out: Dict[str, "Node"] = {} # forward nodes
|
||||
self.out_edges: List[Edge] = [] # forward edges with properties
|
||||
|
||||
# name/is_start/is_end live on every per-type data class (base).
|
||||
self.name = data.name
|
||||
self.prompt = data.prompt
|
||||
self.is_static = data.is_static
|
||||
self.is_start = data.is_start
|
||||
self.is_end = data.is_end
|
||||
self.allow_interrupt = data.allow_interrupt
|
||||
self.extraction_enabled = data.extraction_enabled
|
||||
self.extraction_prompt = data.extraction_prompt
|
||||
self.extraction_variables = data.extraction_variables
|
||||
self.add_global_prompt = data.add_global_prompt
|
||||
self.greeting = data.greeting
|
||||
self.greeting_type = data.greeting_type
|
||||
self.greeting_recording_id = data.greeting_recording_id
|
||||
self.detect_voicemail = data.detect_voicemail
|
||||
self.delayed_start = data.delayed_start
|
||||
self.delayed_start_duration = data.delayed_start_duration
|
||||
self.tool_uuids = data.tool_uuids
|
||||
self.document_uuids = data.document_uuids
|
||||
self.pre_call_fetch_enabled = data.pre_call_fetch_enabled
|
||||
self.pre_call_fetch_url = data.pre_call_fetch_url
|
||||
self.pre_call_fetch_credential_uuid = data.pre_call_fetch_credential_uuid
|
||||
|
||||
# Type-specific fields — read with getattr so this works for every
|
||||
# node variant in the discriminated union.
|
||||
self.prompt = getattr(data, "prompt", None)
|
||||
self.is_static = getattr(data, "is_static", False)
|
||||
self.allow_interrupt = getattr(data, "allow_interrupt", False)
|
||||
self.extraction_enabled = getattr(data, "extraction_enabled", False)
|
||||
self.extraction_prompt = getattr(data, "extraction_prompt", None)
|
||||
self.extraction_variables = getattr(data, "extraction_variables", None)
|
||||
self.add_global_prompt = getattr(data, "add_global_prompt", True)
|
||||
self.greeting = getattr(data, "greeting", None)
|
||||
self.greeting_type = getattr(data, "greeting_type", None)
|
||||
self.greeting_recording_id = getattr(data, "greeting_recording_id", None)
|
||||
self.detect_voicemail = getattr(data, "detect_voicemail", False)
|
||||
self.delayed_start = getattr(data, "delayed_start", False)
|
||||
self.delayed_start_duration = getattr(data, "delayed_start_duration", None)
|
||||
self.tool_uuids = getattr(data, "tool_uuids", None)
|
||||
self.document_uuids = getattr(data, "document_uuids", None)
|
||||
self.pre_call_fetch_enabled = getattr(data, "pre_call_fetch_enabled", False)
|
||||
self.pre_call_fetch_url = getattr(data, "pre_call_fetch_url", None)
|
||||
self.pre_call_fetch_credential_uuid = getattr(
|
||||
data, "pre_call_fetch_credential_uuid", None
|
||||
)
|
||||
|
||||
self.data = data
|
||||
|
||||
|
|
@ -98,9 +104,11 @@ class WorkflowGraph:
|
|||
"""
|
||||
|
||||
def __init__(self, dto: ReactFlowDTO):
|
||||
# build adjacency list
|
||||
# build adjacency list. n.type comes off the discriminated-union
|
||||
# variant as a literal string; coerce to NodeType for downstream
|
||||
# comparisons.
|
||||
self.nodes: Dict[str, Node] = {
|
||||
n.id: Node(n.id, n.type, n.data) for n in dto.nodes
|
||||
n.id: Node(n.id, NodeType(n.type), n.data) for n in dto.nodes
|
||||
}
|
||||
|
||||
# Store all edges
|
||||
|
|
|
|||
|
|
@ -1,22 +1,22 @@
|
|||
"""ARQ background task for processing knowledge base documents."""
|
||||
"""ARQ background task for processing knowledge base documents.
|
||||
|
||||
Document conversion and chunking live in the Model Proxy Service (MPS);
|
||||
this task downloads the file from S3, calls MPS, then handles the embedding
|
||||
and DB writes locally.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
from docling.chunking import HybridChunker
|
||||
from docling.document_converter import DocumentConverter
|
||||
from docling_core.transforms.chunker.tokenizer.huggingface import HuggingFaceTokenizer
|
||||
from loguru import logger
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from api.db import db_client
|
||||
from api.db.models import KnowledgeBaseChunkModel
|
||||
from api.services.gen_ai import OpenAIEmbeddingService
|
||||
from api.services.mps_service_key_client import mps_service_key_client
|
||||
from api.services.storage import storage_fs
|
||||
|
||||
# For tokenization/chunking
|
||||
TOKENIZER_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
|
||||
MAX_FILE_SIZE_BYTES = 5 * 1024 * 1024
|
||||
|
||||
|
||||
async def process_knowledge_base_document(
|
||||
|
|
@ -24,93 +24,84 @@ async def process_knowledge_base_document(
|
|||
document_id: int,
|
||||
s3_key: str,
|
||||
organization_id: int,
|
||||
created_by_provider_id: str,
|
||||
max_tokens: int = 128,
|
||||
retrieval_mode: str = "chunked",
|
||||
):
|
||||
"""Process a knowledge base document: download, chunk, embed, and store.
|
||||
"""Process a knowledge base document via MPS: download, call MPS, embed, store.
|
||||
|
||||
Args:
|
||||
ctx: ARQ context
|
||||
document_id: Database ID of the document
|
||||
s3_key: S3 key where the file is stored
|
||||
organization_id: Organization ID
|
||||
created_by_provider_id: Uploading user's provider ID (for OSS-mode auth to MPS)
|
||||
max_tokens: Maximum number of tokens per chunk (default: 128)
|
||||
retrieval_mode: "chunked" for vector search or "full_document" for full text
|
||||
"""
|
||||
logger.info(
|
||||
f"Starting knowledge base document processing for document_id={document_id}, "
|
||||
f"s3_key={s3_key}, organization_id={organization_id}"
|
||||
f"Processing knowledge base document: document_id={document_id}, "
|
||||
f"s3_key={s3_key}, org={organization_id}, mode={retrieval_mode}"
|
||||
)
|
||||
|
||||
temp_file_path = None
|
||||
|
||||
try:
|
||||
# Update status to processing
|
||||
await db_client.update_document_status(document_id, "processing")
|
||||
|
||||
# Extract file extension from S3 key
|
||||
filename = s3_key.split("/")[-1]
|
||||
file_extension = (
|
||||
os.path.splitext(filename)[1] or ".bin"
|
||||
) # Default to .bin if no extension
|
||||
file_extension = os.path.splitext(filename)[1] or ".bin"
|
||||
|
||||
# Create temp file for download with correct extension
|
||||
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=file_extension)
|
||||
temp_file_path = temp_file.name
|
||||
temp_file.close()
|
||||
|
||||
# Download file from S3
|
||||
logger.info(f"Downloading file from S3: {s3_key}")
|
||||
download_success = await storage_fs.adownload_file(s3_key, temp_file_path)
|
||||
|
||||
if not download_success:
|
||||
raise Exception(f"Failed to download file from S3: {s3_key}")
|
||||
|
||||
if not os.path.exists(temp_file_path):
|
||||
raise FileNotFoundError(f"Downloaded file not found: {temp_file_path}")
|
||||
|
||||
file_size = os.path.getsize(temp_file_path)
|
||||
logger.info(f"Downloaded file size: {file_size} bytes")
|
||||
|
||||
# Validate file size (max 5MB)
|
||||
max_file_size = 5 * 1024 * 1024
|
||||
if file_size > max_file_size:
|
||||
error_message = f"File size ({file_size / (1024 * 1024):.1f}MB) exceeds the maximum allowed size of 5MB."
|
||||
if file_size > MAX_FILE_SIZE_BYTES:
|
||||
error_message = (
|
||||
f"File size ({file_size / (1024 * 1024):.1f}MB) exceeds the "
|
||||
f"maximum allowed size of {MAX_FILE_SIZE_BYTES // (1024 * 1024)}MB."
|
||||
)
|
||||
logger.warning(f"Document {document_id}: {error_message}")
|
||||
await db_client.update_document_status(
|
||||
document_id, "failed", error_message=error_message
|
||||
)
|
||||
return
|
||||
|
||||
# Compute file hash and get mime type
|
||||
file_hash = db_client.compute_file_hash(temp_file_path)
|
||||
mime_type = db_client.get_mime_type(temp_file_path)
|
||||
filename = s3_key.split("/")[-1]
|
||||
|
||||
# Get document record
|
||||
document = await db_client.get_document_by_id(document_id)
|
||||
if not document:
|
||||
raise Exception(f"Document {document_id} not found")
|
||||
|
||||
# Check if a document with this hash already exists (reject duplicates)
|
||||
# Reject duplicates (same hash already ingested for this org).
|
||||
existing_doc = await db_client.get_document_by_hash(file_hash, organization_id)
|
||||
if existing_doc and existing_doc.id != document_id:
|
||||
error_message = (
|
||||
f"This file is a duplicate of '{existing_doc.filename}'. "
|
||||
f"Please delete the duplicate files and consolidate them into a single unique file before uploading."
|
||||
f"Please delete the duplicate files and consolidate them into a "
|
||||
f"single unique file before uploading."
|
||||
)
|
||||
logger.warning(
|
||||
f"Duplicate document detected: {document_id} is duplicate of {existing_doc.id} "
|
||||
f"({existing_doc.filename})"
|
||||
f"Duplicate document detected: {document_id} is duplicate of "
|
||||
f"{existing_doc.id} ({existing_doc.filename})"
|
||||
)
|
||||
# Update file metadata
|
||||
await db_client.update_document_metadata(
|
||||
document_id,
|
||||
file_size_bytes=file_size,
|
||||
file_hash=file_hash,
|
||||
mime_type=mime_type,
|
||||
)
|
||||
# Mark as failed with duplicate error message
|
||||
await db_client.update_document_status(
|
||||
document_id,
|
||||
"failed",
|
||||
|
|
@ -122,7 +113,6 @@ async def process_knowledge_base_document(
|
|||
)
|
||||
return
|
||||
|
||||
# Update document with file metadata
|
||||
await db_client.update_document_metadata(
|
||||
document_id,
|
||||
file_size_bytes=file_size,
|
||||
|
|
@ -130,52 +120,35 @@ async def process_knowledge_base_document(
|
|||
mime_type=mime_type,
|
||||
)
|
||||
|
||||
# Full document mode: extract text and store it, skip chunking/embedding
|
||||
logger.info(f"Delegating document processing to MPS (mode={retrieval_mode})")
|
||||
mps_response = await mps_service_key_client.process_document(
|
||||
file_path=temp_file_path,
|
||||
filename=filename,
|
||||
content_type=mime_type or "application/octet-stream",
|
||||
retrieval_mode=retrieval_mode,
|
||||
max_tokens=max_tokens,
|
||||
organization_id=organization_id,
|
||||
created_by=created_by_provider_id,
|
||||
)
|
||||
|
||||
docling_metadata = mps_response.get("docling_metadata", {})
|
||||
|
||||
if retrieval_mode == "full_document":
|
||||
logger.info(f"Document {document_id}: full_document mode, extracting text")
|
||||
|
||||
plain_text_extensions = {".txt", ".json"}
|
||||
if file_extension.lower() in plain_text_extensions:
|
||||
with open(temp_file_path, "r", encoding="utf-8") as f:
|
||||
full_text = f.read()
|
||||
if file_extension.lower() == ".json":
|
||||
try:
|
||||
parsed = json.loads(full_text)
|
||||
full_text = json.dumps(parsed, indent=2, ensure_ascii=False)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
docling_metadata = {"document_type": "PlainText"}
|
||||
else:
|
||||
converter = DocumentConverter()
|
||||
conversion_result = converter.convert(temp_file_path)
|
||||
doc = conversion_result.document
|
||||
full_text = doc.export_to_text()
|
||||
docling_metadata = {
|
||||
"num_pages": len(doc.pages) if hasattr(doc, "pages") else None,
|
||||
"document_type": type(doc).__name__,
|
||||
}
|
||||
|
||||
# Store full text on the document record
|
||||
full_text = mps_response.get("full_text") or ""
|
||||
await db_client.update_document_full_text(document_id, full_text)
|
||||
|
||||
await db_client.update_document_status(
|
||||
document_id,
|
||||
"completed",
|
||||
total_chunks=0,
|
||||
docling_metadata=docling_metadata,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Successfully processed full_document {document_id}. "
|
||||
f"Text length: {len(full_text)} chars"
|
||||
)
|
||||
return
|
||||
|
||||
# Initialize the OpenAI embedding service
|
||||
logger.info(
|
||||
f"Initializing OpenAI embedding service with max_tokens={max_tokens}"
|
||||
)
|
||||
# Try to get user's embeddings configuration
|
||||
# Chunked mode: fetch user embedding config, embed via OpenAI, persist chunks.
|
||||
embeddings_api_key = None
|
||||
embeddings_model = None
|
||||
embeddings_base_url = None
|
||||
|
|
@ -187,7 +160,6 @@ async def process_knowledge_base_document(
|
|||
embeddings_base_url = getattr(user_config.embeddings, "base_url", None)
|
||||
logger.info(f"Using user embeddings config: model={embeddings_model}")
|
||||
|
||||
# Check if API key is configured
|
||||
if not embeddings_api_key:
|
||||
error_message = (
|
||||
"OpenAI API key not configured. Please set your API key in "
|
||||
|
|
@ -199,190 +171,57 @@ async def process_knowledge_base_document(
|
|||
)
|
||||
return
|
||||
|
||||
service = OpenAIEmbeddingService(
|
||||
embedding_service = OpenAIEmbeddingService(
|
||||
db_client=db_client,
|
||||
max_tokens=max_tokens,
|
||||
api_key=embeddings_api_key,
|
||||
model_id=embeddings_model or "text-embedding-3-small",
|
||||
base_url=embeddings_base_url,
|
||||
)
|
||||
|
||||
# Step 1: Initialize tokenizer for chunking
|
||||
logger.info(
|
||||
f"Loading tokenizer: {TOKENIZER_MODEL} with max_tokens={max_tokens}"
|
||||
)
|
||||
hf_tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_MODEL)
|
||||
tokenizer = HuggingFaceTokenizer(
|
||||
tokenizer=hf_tokenizer,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
mps_chunks = mps_response.get("chunks", [])
|
||||
if not mps_chunks:
|
||||
logger.warning(f"Document {document_id}: MPS returned zero chunks")
|
||||
|
||||
chunk_texts = []
|
||||
chunk_records = []
|
||||
token_counts = []
|
||||
|
||||
# Check if file is a plain text format that docling doesn't support
|
||||
plain_text_extensions = {".txt", ".json"}
|
||||
if file_extension.lower() in plain_text_extensions:
|
||||
# Read text content directly
|
||||
logger.info(f"Reading {file_extension} file directly (bypassing docling)")
|
||||
with open(temp_file_path, "r", encoding="utf-8") as f:
|
||||
raw_content = f.read()
|
||||
|
||||
# For JSON files, pretty-print for better readability
|
||||
if file_extension.lower() == ".json":
|
||||
try:
|
||||
parsed = json.loads(raw_content)
|
||||
raw_content = json.dumps(parsed, indent=2, ensure_ascii=False)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(
|
||||
"JSON file is not valid JSON, treating as plain text"
|
||||
)
|
||||
|
||||
docling_metadata = {
|
||||
"num_pages": None,
|
||||
"document_type": "PlainText",
|
||||
}
|
||||
|
||||
# Token-based chunking for plain text
|
||||
tokens = hf_tokenizer.encode(raw_content, add_special_tokens=False)
|
||||
total_tokens = len(tokens)
|
||||
logger.info(
|
||||
f"Total tokens in file: {total_tokens}, chunking with max_tokens={max_tokens}"
|
||||
chunk_texts = []
|
||||
for chunk in mps_chunks:
|
||||
contextualized = chunk.get("contextualized_text") or chunk["chunk_text"]
|
||||
chunk_records.append(
|
||||
KnowledgeBaseChunkModel(
|
||||
document_id=document_id,
|
||||
organization_id=organization_id,
|
||||
chunk_text=chunk["chunk_text"],
|
||||
contextualized_text=contextualized,
|
||||
chunk_index=chunk["chunk_index"],
|
||||
chunk_metadata=chunk.get("chunk_metadata") or {},
|
||||
embedding_model=embedding_service.get_model_id(),
|
||||
embedding_dimension=embedding_service.get_embedding_dimension(),
|
||||
token_count=chunk.get("token_count", 0),
|
||||
)
|
||||
)
|
||||
chunk_texts.append(contextualized)
|
||||
|
||||
start = 0
|
||||
chunk_index = 0
|
||||
while start < total_tokens:
|
||||
end = min(start + max_tokens, total_tokens)
|
||||
chunk_token_ids = tokens[start:end]
|
||||
chunk_text = hf_tokenizer.decode(
|
||||
chunk_token_ids, skip_special_tokens=True
|
||||
)
|
||||
|
||||
token_count = len(chunk_token_ids)
|
||||
token_counts.append(token_count)
|
||||
|
||||
chunk_record = KnowledgeBaseChunkModel(
|
||||
document_id=document_id,
|
||||
organization_id=organization_id,
|
||||
chunk_text=chunk_text,
|
||||
contextualized_text=chunk_text,
|
||||
chunk_index=chunk_index,
|
||||
chunk_metadata={},
|
||||
embedding_model=service.get_model_id(),
|
||||
embedding_dimension=service.get_embedding_dimension(),
|
||||
token_count=token_count,
|
||||
)
|
||||
|
||||
chunk_records.append(chunk_record)
|
||||
chunk_texts.append(chunk_text)
|
||||
chunk_index += 1
|
||||
start = end
|
||||
|
||||
total_chunks = len(chunk_records)
|
||||
logger.info(f"Generated {total_chunks} chunks from plain text")
|
||||
|
||||
else:
|
||||
# Use docling for structured formats (PDF, DOCX, etc.)
|
||||
logger.info("Converting document with docling")
|
||||
converter = DocumentConverter()
|
||||
conversion_result = converter.convert(temp_file_path)
|
||||
doc = conversion_result.document
|
||||
|
||||
docling_metadata = {
|
||||
"num_pages": len(doc.pages) if hasattr(doc, "pages") else None,
|
||||
"document_type": type(doc).__name__,
|
||||
}
|
||||
|
||||
# Initialize chunker
|
||||
logger.info(f"Initializing HybridChunker with max_tokens={max_tokens}")
|
||||
chunker = HybridChunker(tokenizer=tokenizer)
|
||||
|
||||
# Chunk the document
|
||||
logger.info(f"Chunking document with max_tokens={max_tokens}")
|
||||
chunks = list(chunker.chunk(dl_doc=doc))
|
||||
total_chunks = len(chunks)
|
||||
logger.info(f"Generated {total_chunks} chunks")
|
||||
|
||||
# Process each chunk
|
||||
for i, chunk in enumerate(chunks):
|
||||
chunk_text = chunk.text
|
||||
contextualized_text = chunker.contextualize(chunk=chunk)
|
||||
|
||||
text_to_tokenize = (
|
||||
contextualized_text if contextualized_text else chunk_text
|
||||
)
|
||||
token_count = len(
|
||||
tokenizer.tokenizer.encode(
|
||||
text_to_tokenize, add_special_tokens=False
|
||||
)
|
||||
)
|
||||
token_counts.append(token_count)
|
||||
|
||||
chunk_metadata = {}
|
||||
if hasattr(chunk, "meta") and chunk.meta:
|
||||
chunk_metadata = {
|
||||
"doc_items": (
|
||||
[str(item) for item in chunk.meta.doc_items]
|
||||
if hasattr(chunk.meta, "doc_items")
|
||||
else []
|
||||
),
|
||||
"headings": (
|
||||
chunk.meta.headings
|
||||
if hasattr(chunk.meta, "headings")
|
||||
else []
|
||||
),
|
||||
}
|
||||
|
||||
chunk_record = KnowledgeBaseChunkModel(
|
||||
document_id=document_id,
|
||||
organization_id=organization_id,
|
||||
chunk_text=chunk_text,
|
||||
contextualized_text=contextualized_text,
|
||||
chunk_index=i,
|
||||
chunk_metadata=chunk_metadata,
|
||||
embedding_model=service.get_model_id(),
|
||||
embedding_dimension=service.get_embedding_dimension(),
|
||||
token_count=token_count,
|
||||
)
|
||||
|
||||
chunk_records.append(chunk_record)
|
||||
chunk_texts.append(text_to_tokenize)
|
||||
|
||||
# Log chunk statistics
|
||||
if token_counts:
|
||||
avg_tokens = sum(token_counts) / len(token_counts)
|
||||
min_tokens = min(token_counts)
|
||||
max_tokens_actual = max(token_counts)
|
||||
logger.info("Chunk token statistics:")
|
||||
logger.info(f" - Average: {avg_tokens:.1f} tokens")
|
||||
logger.info(f" - Min: {min_tokens} tokens")
|
||||
logger.info(f" - Max: {max_tokens_actual} tokens")
|
||||
|
||||
# Step 6: Generate embeddings using OpenAI
|
||||
logger.info(f"Generating embeddings using {service.get_model_id()}")
|
||||
embeddings = await service.embed_texts(chunk_texts)
|
||||
|
||||
# Step 7: Attach embeddings to chunk records
|
||||
logger.info(
|
||||
f"Generating embeddings for {len(chunk_texts)} chunks "
|
||||
f"using {embedding_service.get_model_id()}"
|
||||
)
|
||||
embeddings = await embedding_service.embed_texts(chunk_texts)
|
||||
for chunk_record, embedding in zip(chunk_records, embeddings):
|
||||
chunk_record.embedding = embedding
|
||||
|
||||
# Step 8: Save chunks in database
|
||||
logger.info("Storing chunks in database")
|
||||
await db_client.create_chunks_batch(chunk_records)
|
||||
|
||||
# Step 9: Update document status to completed
|
||||
await db_client.update_document_status(
|
||||
document_id,
|
||||
"completed",
|
||||
total_chunks=total_chunks,
|
||||
total_chunks=len(chunk_records),
|
||||
docling_metadata=docling_metadata,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Successfully processed knowledge base document {document_id}. "
|
||||
f"Total chunks: {total_chunks}"
|
||||
f"Total chunks: {len(chunk_records)}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
|
|
@ -390,14 +229,12 @@ async def process_knowledge_base_document(
|
|||
f"Error processing knowledge base document {document_id}: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
# Update document status to failed
|
||||
await db_client.update_document_status(
|
||||
document_id, "failed", error_message=str(e)
|
||||
)
|
||||
raise
|
||||
|
||||
finally:
|
||||
# Clean up temp file
|
||||
if temp_file_path and os.path.exists(temp_file_path):
|
||||
try:
|
||||
os.remove(temp_file_path)
|
||||
|
|
|
|||
|
|
@ -5,12 +5,19 @@ from typing import Any, Dict, Optional
|
|||
|
||||
import httpx
|
||||
from loguru import logger
|
||||
from pydantic import ValidationError
|
||||
|
||||
from api.constants import BACKEND_API_ENDPOINT
|
||||
from api.db import db_client
|
||||
from api.db.models import WorkflowRunModel
|
||||
from api.enums import OrganizationConfigurationKey
|
||||
from api.services.pipecat.tracing_config import register_org_langfuse_credentials
|
||||
from api.services.workflow.dto import (
|
||||
QANodeData,
|
||||
QARFNode,
|
||||
WebhookNodeData,
|
||||
WebhookRFNode,
|
||||
)
|
||||
from api.services.workflow.qa import run_per_node_qa_analysis
|
||||
from api.utils.credential_auth import build_auth_header
|
||||
from api.utils.template_renderer import render_template
|
||||
|
|
@ -19,34 +26,34 @@ from pipecat.utils.run_context import set_current_org_id, set_current_run_id
|
|||
|
||||
|
||||
def _should_skip_qa(
|
||||
node_data: dict,
|
||||
qa_data: QANodeData,
|
||||
workflow_run: WorkflowRunModel,
|
||||
) -> str | None:
|
||||
"""Check whether QA analysis should be skipped for this call.
|
||||
|
||||
Returns a reason string if the call should be skipped, or None if it should proceed.
|
||||
"""
|
||||
# Check minimum call duration
|
||||
min_duration = node_data.get("qa_min_call_duration", 15)
|
||||
usage_info = workflow_run.usage_info or {}
|
||||
call_duration = usage_info.get("call_duration_seconds")
|
||||
if call_duration is not None and call_duration < min_duration:
|
||||
return f"call duration ({call_duration:.1f}s) below minimum ({min_duration}s)"
|
||||
if call_duration is not None and call_duration < qa_data.qa_min_call_duration:
|
||||
return (
|
||||
f"call duration ({call_duration:.1f}s) below minimum "
|
||||
f"({qa_data.qa_min_call_duration}s)"
|
||||
)
|
||||
|
||||
# Check voicemail calls
|
||||
qa_voicemail_calls = node_data.get("qa_voicemail_calls", False)
|
||||
if not qa_voicemail_calls:
|
||||
if not qa_data.qa_voicemail_calls:
|
||||
gathered_context = workflow_run.gathered_context or {}
|
||||
call_disposition = gathered_context.get("call_disposition", "")
|
||||
if call_disposition == EndTaskReason.VOICEMAIL_DETECTED.value:
|
||||
return "voicemail call and QA voicemail calls is disabled"
|
||||
|
||||
# Check sample rate
|
||||
sample_rate = node_data.get("qa_sample_rate", 100)
|
||||
if sample_rate < 100:
|
||||
if qa_data.qa_sample_rate < 100:
|
||||
roll = random.randint(1, 100)
|
||||
if roll > sample_rate:
|
||||
return f"excluded by sampling ({sample_rate}% sample rate, rolled {roll})"
|
||||
if roll > qa_data.qa_sample_rate:
|
||||
return (
|
||||
f"excluded by sampling ({qa_data.qa_sample_rate}% sample rate, "
|
||||
f"rolled {roll})"
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
|
|
@ -66,15 +73,22 @@ async def _run_qa_nodes(
|
|||
results: Dict[str, Any] = {}
|
||||
|
||||
for node in qa_nodes:
|
||||
node_data = node.get("data", {})
|
||||
node_id = node.get("id", "unknown")
|
||||
node_name = node_data.get("name", "QA Analysis")
|
||||
try:
|
||||
qa_node = QARFNode.model_validate(node)
|
||||
except ValidationError as e:
|
||||
logger.warning(f"QA node #{node_id} failed validation, skipping: {e}")
|
||||
results[f"qa_{node_id}"] = {"error": "validation_failed"}
|
||||
continue
|
||||
|
||||
if not node_data.get("qa_enabled", True):
|
||||
qa_data = qa_node.data
|
||||
node_name = qa_data.name
|
||||
|
||||
if not qa_data.qa_enabled:
|
||||
logger.debug(f"QA node '{node_name}' is disabled, skipping")
|
||||
continue
|
||||
|
||||
skip_reason = _should_skip_qa(node_data, workflow_run)
|
||||
skip_reason = _should_skip_qa(qa_data, workflow_run)
|
||||
if skip_reason:
|
||||
logger.info(f"Skipping QA node '{node_name}' (#{node_id}): {skip_reason}")
|
||||
results[f"qa_{node_id}"] = {"skipped": True, "reason": skip_reason}
|
||||
|
|
@ -83,7 +97,7 @@ async def _run_qa_nodes(
|
|||
try:
|
||||
logger.info(f"Running QA analysis for node '{node_name}' (#{node_id})")
|
||||
result = await run_per_node_qa_analysis(
|
||||
node_data,
|
||||
qa_data,
|
||||
workflow_run,
|
||||
workflow_run_id,
|
||||
workflow_definition,
|
||||
|
|
@ -260,7 +274,16 @@ async def run_integrations_post_workflow_run(_ctx, workflow_run_id: int):
|
|||
|
||||
# Step 8: Execute each webhook node
|
||||
for node in webhook_nodes:
|
||||
webhook_data = node.get("data", {})
|
||||
node_id = node.get("id", "unknown")
|
||||
try:
|
||||
webhook_node = WebhookRFNode.model_validate(node)
|
||||
except ValidationError as e:
|
||||
logger.warning(
|
||||
f"Webhook node #{node_id} failed validation, skipping: {e}"
|
||||
)
|
||||
continue
|
||||
|
||||
webhook_data = webhook_node.data
|
||||
try:
|
||||
await _execute_webhook_node(
|
||||
webhook_data=webhook_data,
|
||||
|
|
@ -268,10 +291,7 @@ async def run_integrations_post_workflow_run(_ctx, workflow_run_id: int):
|
|||
organization_id=organization_id,
|
||||
)
|
||||
except Exception as e:
|
||||
# Log error but continue with other webhooks
|
||||
logger.warning(
|
||||
f"Failed to execute webhook '{webhook_data.get('name', 'unknown')}': {e}"
|
||||
)
|
||||
logger.warning(f"Failed to execute webhook '{webhook_data.name}': {e}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error running integrations: {e}", exc_info=True)
|
||||
|
|
@ -323,7 +343,7 @@ def _build_render_context(
|
|||
|
||||
|
||||
async def _execute_webhook_node(
|
||||
webhook_data: Dict[str, Any],
|
||||
webhook_data: WebhookNodeData,
|
||||
render_context: Dict[str, Any],
|
||||
organization_id: int,
|
||||
) -> bool:
|
||||
|
|
@ -331,31 +351,27 @@ async def _execute_webhook_node(
|
|||
Execute a single webhook node.
|
||||
|
||||
Args:
|
||||
webhook_data: The webhook node's data dict from workflow definition
|
||||
webhook_data: The validated webhook node data
|
||||
render_context: Context for template rendering
|
||||
organization_id: For credential lookup
|
||||
|
||||
Returns:
|
||||
True if successful, False otherwise
|
||||
"""
|
||||
webhook_name = webhook_data.get("name", "Unnamed Webhook")
|
||||
webhook_name = webhook_data.name
|
||||
|
||||
# 1. Check if enabled
|
||||
if not webhook_data.get("enabled", True):
|
||||
if not webhook_data.enabled:
|
||||
logger.debug(f"Webhook '{webhook_name}' is disabled, skipping")
|
||||
return True
|
||||
|
||||
# 2. Validate endpoint URL
|
||||
url = webhook_data.get("endpoint_url")
|
||||
url = webhook_data.endpoint_url
|
||||
if not url:
|
||||
logger.warning(f"Webhook '{webhook_name}' has no endpoint URL")
|
||||
return False
|
||||
|
||||
# 3. Build headers
|
||||
headers = {"Content-Type": "application/json"}
|
||||
|
||||
# 4. Add auth header if credential configured
|
||||
credential_uuid = webhook_data.get("credential_uuid")
|
||||
credential_uuid = webhook_data.credential_uuid
|
||||
if credential_uuid:
|
||||
credential = await db_client.get_credential_by_uuid(
|
||||
credential_uuid, organization_id
|
||||
|
|
@ -369,18 +385,13 @@ async def _execute_webhook_node(
|
|||
f"Credential {credential_uuid} not found for webhook '{webhook_name}'"
|
||||
)
|
||||
|
||||
# 5. Add custom headers
|
||||
custom_headers = webhook_data.get("custom_headers", [])
|
||||
for h in custom_headers:
|
||||
if h.get("key") and h.get("value"):
|
||||
headers[h["key"]] = h["value"]
|
||||
for h in webhook_data.custom_headers or []:
|
||||
if h.key and h.value:
|
||||
headers[h.key] = h.value
|
||||
|
||||
# 6. Render payload template
|
||||
payload_template = webhook_data.get("payload_template", {})
|
||||
payload = render_template(payload_template, render_context)
|
||||
payload = render_template(webhook_data.payload_template or {}, render_context)
|
||||
|
||||
# 7. Make HTTP request
|
||||
method = webhook_data.get("http_method", "POST").upper()
|
||||
method = (webhook_data.http_method or "POST").upper()
|
||||
|
||||
logger.info(f"Executing webhook '{webhook_name}': {method}")
|
||||
|
||||
|
|
|
|||
|
|
@ -14,14 +14,17 @@ from unittest.mock import Mock
|
|||
import pytest
|
||||
|
||||
from api.services.workflow.dto import (
|
||||
AgentNodeData,
|
||||
AgentRFNode,
|
||||
EdgeDataDTO,
|
||||
EndCallNodeData,
|
||||
EndCallRFNode,
|
||||
ExtractionVariableDTO,
|
||||
NodeDataDTO,
|
||||
NodeType,
|
||||
Position,
|
||||
ReactFlowDTO,
|
||||
RFEdgeDTO,
|
||||
RFNodeDTO,
|
||||
StartCallNodeData,
|
||||
StartCallRFNode,
|
||||
VariableType,
|
||||
)
|
||||
from api.services.workflow.workflow import WorkflowGraph
|
||||
|
|
@ -252,11 +255,10 @@ def simple_workflow() -> WorkflowGraph:
|
|||
"""
|
||||
dto = ReactFlowDTO(
|
||||
nodes=[
|
||||
RFNodeDTO(
|
||||
StartCallRFNode(
|
||||
id="start",
|
||||
type=NodeType.startNode,
|
||||
position=Position(x=0, y=0),
|
||||
data=NodeDataDTO(
|
||||
data=StartCallNodeData(
|
||||
name="Start Call",
|
||||
prompt=START_CALL_SYSTEM_PROMPT,
|
||||
is_start=True,
|
||||
|
|
@ -273,11 +275,10 @@ def simple_workflow() -> WorkflowGraph:
|
|||
],
|
||||
),
|
||||
),
|
||||
RFNodeDTO(
|
||||
EndCallRFNode(
|
||||
id="end",
|
||||
type=NodeType.endNode,
|
||||
position=Position(x=0, y=200),
|
||||
data=NodeDataDTO(
|
||||
data=EndCallNodeData(
|
||||
name="End Call",
|
||||
prompt=END_CALL_SYSTEM_PROMPT,
|
||||
is_end=True,
|
||||
|
|
@ -317,11 +318,10 @@ def three_node_workflow() -> WorkflowGraph:
|
|||
"""
|
||||
dto = ReactFlowDTO(
|
||||
nodes=[
|
||||
RFNodeDTO(
|
||||
StartCallRFNode(
|
||||
id="start",
|
||||
type=NodeType.startNode,
|
||||
position=Position(x=0, y=0),
|
||||
data=NodeDataDTO(
|
||||
data=StartCallNodeData(
|
||||
name="Start Call",
|
||||
prompt=START_CALL_SYSTEM_PROMPT,
|
||||
is_start=True,
|
||||
|
|
@ -338,11 +338,10 @@ def three_node_workflow() -> WorkflowGraph:
|
|||
],
|
||||
),
|
||||
),
|
||||
RFNodeDTO(
|
||||
AgentRFNode(
|
||||
id="agent",
|
||||
type=NodeType.agentNode,
|
||||
position=Position(x=0, y=200),
|
||||
data=NodeDataDTO(
|
||||
data=AgentNodeData(
|
||||
name="Collect Info",
|
||||
prompt=AGENT_SYSTEM_PROMPT,
|
||||
allow_interrupt=False,
|
||||
|
|
@ -358,11 +357,10 @@ def three_node_workflow() -> WorkflowGraph:
|
|||
],
|
||||
),
|
||||
),
|
||||
RFNodeDTO(
|
||||
EndCallRFNode(
|
||||
id="end",
|
||||
type=NodeType.endNode,
|
||||
position=Position(x=0, y=400),
|
||||
data=NodeDataDTO(
|
||||
data=EndCallNodeData(
|
||||
name="End Call",
|
||||
prompt=END_CALL_SYSTEM_PROMPT,
|
||||
is_end=True,
|
||||
|
|
@ -411,11 +409,10 @@ def three_node_workflow_extraction_start_only() -> WorkflowGraph:
|
|||
"""
|
||||
dto = ReactFlowDTO(
|
||||
nodes=[
|
||||
RFNodeDTO(
|
||||
StartCallRFNode(
|
||||
id="start",
|
||||
type=NodeType.startNode,
|
||||
position=Position(x=0, y=0),
|
||||
data=NodeDataDTO(
|
||||
data=StartCallNodeData(
|
||||
name="Start Call",
|
||||
prompt=START_CALL_SYSTEM_PROMPT,
|
||||
is_start=True,
|
||||
|
|
@ -432,11 +429,10 @@ def three_node_workflow_extraction_start_only() -> WorkflowGraph:
|
|||
],
|
||||
),
|
||||
),
|
||||
RFNodeDTO(
|
||||
AgentRFNode(
|
||||
id="agent",
|
||||
type=NodeType.agentNode,
|
||||
position=Position(x=0, y=200),
|
||||
data=NodeDataDTO(
|
||||
data=AgentNodeData(
|
||||
name="Collect Info",
|
||||
prompt=AGENT_SYSTEM_PROMPT,
|
||||
allow_interrupt=False,
|
||||
|
|
@ -444,11 +440,10 @@ def three_node_workflow_extraction_start_only() -> WorkflowGraph:
|
|||
extraction_enabled=False, # Explicitly disabled for testing
|
||||
),
|
||||
),
|
||||
RFNodeDTO(
|
||||
EndCallRFNode(
|
||||
id="end",
|
||||
type=NodeType.endNode,
|
||||
position=Position(x=0, y=400),
|
||||
data=NodeDataDTO(
|
||||
data=EndCallNodeData(
|
||||
name="End Call",
|
||||
prompt=END_CALL_SYSTEM_PROMPT,
|
||||
is_end=True,
|
||||
|
|
@ -493,11 +488,10 @@ def three_node_workflow_no_variable_extraction() -> WorkflowGraph:
|
|||
"""
|
||||
dto = ReactFlowDTO(
|
||||
nodes=[
|
||||
RFNodeDTO(
|
||||
StartCallRFNode(
|
||||
id="start",
|
||||
type=NodeType.startNode,
|
||||
position=Position(x=0, y=0),
|
||||
data=NodeDataDTO(
|
||||
data=StartCallNodeData(
|
||||
name="Start Call",
|
||||
prompt=START_CALL_SYSTEM_PROMPT,
|
||||
is_start=True,
|
||||
|
|
@ -506,11 +500,10 @@ def three_node_workflow_no_variable_extraction() -> WorkflowGraph:
|
|||
extraction_enabled=False,
|
||||
),
|
||||
),
|
||||
RFNodeDTO(
|
||||
AgentRFNode(
|
||||
id="agent",
|
||||
type=NodeType.agentNode,
|
||||
position=Position(x=0, y=200),
|
||||
data=NodeDataDTO(
|
||||
data=AgentNodeData(
|
||||
name="Collect Info",
|
||||
prompt=AGENT_SYSTEM_PROMPT,
|
||||
allow_interrupt=False,
|
||||
|
|
@ -518,11 +511,10 @@ def three_node_workflow_no_variable_extraction() -> WorkflowGraph:
|
|||
extraction_enabled=False, # Explicitly disabled for testing
|
||||
),
|
||||
),
|
||||
RFNodeDTO(
|
||||
EndCallRFNode(
|
||||
id="end",
|
||||
type=NodeType.endNode,
|
||||
position=Position(x=0, y=400),
|
||||
data=NodeDataDTO(
|
||||
data=EndCallNodeData(
|
||||
name="End Call",
|
||||
prompt=END_CALL_SYSTEM_PROMPT,
|
||||
is_end=True,
|
||||
|
|
|
|||
39
api/tests/test_display_options_evaluator.py
Normal file
39
api/tests/test_display_options_evaluator.py
Normal file
|
|
@ -0,0 +1,39 @@
|
|||
"""Golden-test parity for the display_options evaluator.
|
||||
|
||||
Both the Python `evaluate_display_options` and the TypeScript
|
||||
`evaluateDisplayOptions` (in `ui/src/components/flow/renderer/displayOptions.ts`)
|
||||
must agree on every fixture in `display_options_fixtures.json`. The TS
|
||||
side is verified by `ui/scripts/test-display-options.mjs`.
|
||||
"""
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from api.services.workflow.node_specs import evaluate_display_options
|
||||
|
||||
FIXTURES_PATH = (
|
||||
Path(__file__).parent.parent
|
||||
/ "services"
|
||||
/ "workflow"
|
||||
/ "node_specs"
|
||||
/ "display_options_fixtures.json"
|
||||
)
|
||||
|
||||
|
||||
def load_cases():
|
||||
with open(FIXTURES_PATH) as f:
|
||||
return json.load(f)["cases"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("case", load_cases(), ids=lambda c: c["name"])
|
||||
def test_python_evaluator_matches_fixture(case):
|
||||
rules = case["rules"]
|
||||
values = case["values"]
|
||||
expected = case["expected"]
|
||||
actual = evaluate_display_options(rules, values)
|
||||
assert actual is expected, (
|
||||
f"{case['name']}: expected {expected}, got {actual} "
|
||||
f"for rules={rules!r} values={values!r}"
|
||||
)
|
||||
235
api/tests/test_dograh_sdk.py
Normal file
235
api/tests/test_dograh_sdk.py
Normal file
|
|
@ -0,0 +1,235 @@
|
|||
"""Tests for the Python runtime SDK (`dograh_sdk`).
|
||||
|
||||
Uses a stub client backed by the in-process spec registry rather than
|
||||
exercising the HTTP layer — the HTTP client is a thin wrapper that's
|
||||
easier to test manually against a live server.
|
||||
|
||||
Covers:
|
||||
- Workflow builder round-trips through ReactFlowDTO validation
|
||||
- Validation errors fail at the `add()` call site
|
||||
- from_json preserves node IDs and subsequent add() doesn't collide
|
||||
- Edge labels / conditions are required
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from dograh_sdk import Workflow
|
||||
from dograh_sdk._generated_models import NodeSpec
|
||||
from dograh_sdk.errors import ValidationError
|
||||
|
||||
from api.services.workflow.dto import ReactFlowDTO
|
||||
from api.services.workflow.node_specs import all_specs, get_spec
|
||||
|
||||
|
||||
class _StubClient:
|
||||
"""Stand-in for DograhClient backed by the in-process spec registry.
|
||||
Matches the real client's contract: `get_node_type(name)` returns a
|
||||
`NodeSpec` Pydantic model."""
|
||||
|
||||
def get_node_type(self, name: str) -> NodeSpec:
|
||||
spec = get_spec(name)
|
||||
if spec is None:
|
||||
raise ValueError(f"Unknown spec: {name}")
|
||||
return NodeSpec.model_validate(spec.model_dump(mode="json"))
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client() -> _StubClient:
|
||||
return _StubClient()
|
||||
|
||||
|
||||
# ─── Builder + to_json round-trip ────────────────────────────────────────
|
||||
|
||||
|
||||
def test_builds_minimal_workflow_and_roundtrips_through_dto(client: _StubClient):
|
||||
wf = Workflow(client=client, name="minimal")
|
||||
start = wf.add(
|
||||
type="startCall",
|
||||
name="greeting",
|
||||
prompt="Say hi to the caller.",
|
||||
)
|
||||
end = wf.add(
|
||||
type="endCall",
|
||||
name="close",
|
||||
prompt="Thank the caller and hang up.",
|
||||
)
|
||||
wf.edge(start, end, label="done", condition="When the greeting is complete")
|
||||
|
||||
payload = wf.to_json()
|
||||
# Wire format must validate through the backend Pydantic union — if
|
||||
# it doesn't, the SDK has silently drifted from the spec schema.
|
||||
dto = ReactFlowDTO.model_validate(payload)
|
||||
assert len(dto.nodes) == 2
|
||||
assert {n.type for n in dto.nodes} == {"startCall", "endCall"}
|
||||
assert len(dto.edges) == 1
|
||||
|
||||
|
||||
def test_defaults_applied_from_spec(client: _StubClient):
|
||||
"""Spec defaults (e.g., `allow_interrupt=False` on startCall) fill in
|
||||
when the user doesn't pass them."""
|
||||
wf = Workflow(client=client, name="defaults")
|
||||
start = wf.add(type="startCall", name="greeting", prompt="hello")
|
||||
payload = wf.to_json()
|
||||
data = payload["nodes"][0]["data"]
|
||||
assert data["allow_interrupt"] is False # spec default
|
||||
assert data["add_global_prompt"] is True # spec default
|
||||
_ = start # used implicitly; silence unused
|
||||
|
||||
|
||||
def test_webhook_complex_fields_validate(client: _StubClient):
|
||||
"""Webhook's json + fixed_collection (custom_headers) round-trip."""
|
||||
wf = Workflow(client=client, name="wh")
|
||||
wh = wf.add(
|
||||
type="webhook",
|
||||
name="notify",
|
||||
enabled=True,
|
||||
http_method="POST",
|
||||
endpoint_url="https://api.example.com/hook",
|
||||
custom_headers=[{"key": "X-Source", "value": "dograh"}],
|
||||
payload_template={"run": "{{workflow_run_id}}"},
|
||||
)
|
||||
payload = wf.to_json()
|
||||
# Webhook has no incoming/outgoing graph requirements — render as a
|
||||
# standalone node in the graph for the DTO round-trip.
|
||||
ReactFlowDTO.model_validate(payload)
|
||||
_ = wh
|
||||
|
||||
|
||||
# ─── Validation errors at call site ──────────────────────────────────────
|
||||
|
||||
|
||||
def test_unknown_field_raises_at_add(client: _StubClient):
|
||||
wf = Workflow(client=client, name="typo")
|
||||
with pytest.raises(ValidationError, match="unknown field"):
|
||||
wf.add(
|
||||
type="startCall",
|
||||
name="greeting",
|
||||
prompt="hi",
|
||||
promt="typo", # extra misspelled field
|
||||
)
|
||||
|
||||
|
||||
def test_missing_required_raises_at_add(client: _StubClient):
|
||||
wf = Workflow(client=client, name="missing")
|
||||
with pytest.raises(ValidationError, match="required field missing"):
|
||||
wf.add(type="startCall", name="greeting") # no prompt
|
||||
|
||||
|
||||
def test_wrong_scalar_type_raises(client: _StubClient):
|
||||
wf = Workflow(client=client, name="wrongtype")
|
||||
with pytest.raises(ValidationError, match="expected boolean"):
|
||||
wf.add(
|
||||
type="agentNode",
|
||||
name="x",
|
||||
prompt="y",
|
||||
allow_interrupt="yes",
|
||||
)
|
||||
|
||||
|
||||
def test_invalid_options_value_raises(client: _StubClient):
|
||||
wf = Workflow(client=client, name="wrongenum")
|
||||
with pytest.raises(ValidationError, match="not in allowed"):
|
||||
wf.add(
|
||||
type="startCall",
|
||||
name="greeting",
|
||||
prompt="hi",
|
||||
greeting_type="video", # only text|audio allowed
|
||||
)
|
||||
|
||||
|
||||
def test_unknown_node_type_raises(client: _StubClient):
|
||||
wf = Workflow(client=client, name="x")
|
||||
with pytest.raises(ValueError, match="Unknown spec"):
|
||||
wf.add(type="nonExistentType", name="x")
|
||||
|
||||
|
||||
def test_validation_error_surfaces_llm_hint(client: _StubClient):
|
||||
"""When a property carries `llm_hint`, it appears in the error message
|
||||
so LLMs can self-correct on retry. `tool_uuids` on agentNode has the
|
||||
hint 'List of tool UUIDs from `list_tools`.'"""
|
||||
wf = Workflow(client=client, name="hint")
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
wf.add(
|
||||
type="agentNode",
|
||||
name="x",
|
||||
prompt="y",
|
||||
tool_uuids="single-uuid-not-a-list", # wrong shape: str, not list
|
||||
)
|
||||
msg = str(exc_info.value)
|
||||
assert "tool_uuids" in msg
|
||||
assert "Hint:" in msg
|
||||
assert "list_tools" in msg
|
||||
|
||||
|
||||
def test_no_hint_message_when_spec_has_none(client: _StubClient):
|
||||
"""Properties without `llm_hint` produce a plain error (no dangling
|
||||
'Hint:' line)."""
|
||||
wf = Workflow(client=client, name="no-hint")
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
wf.add(type="agentNode", name="x", prompt="y", allow_interrupt="yes")
|
||||
assert "Hint:" not in str(exc_info.value)
|
||||
|
||||
|
||||
def test_edge_requires_label_and_condition(client: _StubClient):
|
||||
wf = Workflow(client=client, name="edge")
|
||||
a = wf.add(type="startCall", name="a", prompt="hi")
|
||||
b = wf.add(type="endCall", name="b", prompt="bye")
|
||||
with pytest.raises(ValidationError, match="label is required"):
|
||||
wf.edge(a, b, label="", condition="condition")
|
||||
with pytest.raises(ValidationError, match="condition is required"):
|
||||
wf.edge(a, b, label="label", condition="")
|
||||
|
||||
|
||||
# ─── Round-trip from_json → edit → to_json ────────────────────────────────
|
||||
|
||||
|
||||
def test_from_json_preserves_ids_and_next_id_doesnt_collide(client: _StubClient):
|
||||
wf0 = Workflow(client=client, name="w0")
|
||||
start = wf0.add(type="startCall", name="g", prompt="hi")
|
||||
end = wf0.add(type="endCall", name="e", prompt="bye")
|
||||
wf0.edge(start, end, label="done", condition="done")
|
||||
|
||||
payload = wf0.to_json()
|
||||
wf1 = Workflow.from_json(payload, client=client, name="w0-reload")
|
||||
|
||||
# IDs are preserved
|
||||
assert [n.id for n in wf1._nodes] == [start.id, end.id]
|
||||
# Next add() gets a fresh ID, not colliding with the existing ones
|
||||
new_ref = wf1.add(type="agentNode", name="qualify", prompt="ask stuff")
|
||||
assert new_ref.id != start.id
|
||||
assert new_ref.id != end.id
|
||||
assert int(new_ref.id) > max(int(start.id), int(end.id))
|
||||
|
||||
|
||||
def test_from_json_validates_data(client: _StubClient):
|
||||
"""Loading a JSON payload with a misnamed field raises — we don't
|
||||
silently accept drift."""
|
||||
bad = {
|
||||
"nodes": [
|
||||
{
|
||||
"id": "1",
|
||||
"type": "startCall",
|
||||
"position": {"x": 0, "y": 0},
|
||||
"data": {"name": "g", "prompt": "hi", "bogus": 1},
|
||||
}
|
||||
],
|
||||
"edges": [],
|
||||
}
|
||||
with pytest.raises(ValidationError, match="unknown field"):
|
||||
Workflow.from_json(bad, client=client)
|
||||
|
||||
|
||||
# ─── Sanity: all registered specs are reachable by name ───────────────────
|
||||
|
||||
|
||||
def test_every_registered_spec_is_reachable_by_sdk(client: _StubClient):
|
||||
wf = Workflow(client=client, name="probe")
|
||||
for spec in all_specs():
|
||||
# Just fetch the spec via the client; doesn't add anything. This
|
||||
# ensures the `_StubClient` wiring works for all types.
|
||||
probe = client.get_node_type(spec.name)
|
||||
assert probe.name == spec.name
|
||||
_ = wf
|
||||
128
api/tests/test_dograh_sdk_typed.py
Normal file
128
api/tests/test_dograh_sdk_typed.py
Normal file
|
|
@ -0,0 +1,128 @@
|
|||
"""Tests for the typed SDK (`dograh_sdk.typed`).
|
||||
|
||||
Covers:
|
||||
- Generated classes import cleanly and declare the correct spec name
|
||||
- `Workflow.add_typed(node)` produces the same wire format as
|
||||
`Workflow.add(type=..., **kwargs)`
|
||||
- Typed-class construction respects required/optional field defaults
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from dograh_sdk import Workflow
|
||||
from dograh_sdk._generated_models import NodeSpec
|
||||
from dograh_sdk.typed import (
|
||||
AgentNode,
|
||||
EndCall,
|
||||
GlobalNode,
|
||||
Qa,
|
||||
StartCall,
|
||||
Trigger,
|
||||
TypedNode,
|
||||
Webhook,
|
||||
)
|
||||
|
||||
from api.services.workflow.dto import ReactFlowDTO
|
||||
from api.services.workflow.node_specs import get_spec
|
||||
|
||||
|
||||
class _StubClient:
|
||||
def get_node_type(self, name: str) -> NodeSpec:
|
||||
return NodeSpec.model_validate(get_spec(name).model_dump(mode="json"))
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client() -> _StubClient:
|
||||
return _StubClient()
|
||||
|
||||
|
||||
# ─── Generated classes declare the correct discriminator ──────────────────
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"cls,expected_type",
|
||||
[
|
||||
(StartCall, "startCall"),
|
||||
(AgentNode, "agentNode"),
|
||||
(EndCall, "endCall"),
|
||||
(GlobalNode, "globalNode"),
|
||||
(Trigger, "trigger"),
|
||||
(Webhook, "webhook"),
|
||||
(Qa, "qa"),
|
||||
],
|
||||
ids=lambda v: v.__name__ if isinstance(v, type) else v,
|
||||
)
|
||||
def test_typed_class_declares_spec_name(cls: type[TypedNode], expected_type: str):
|
||||
assert cls.type == expected_type
|
||||
# Instances inherit the ClassVar
|
||||
if cls is StartCall:
|
||||
inst = cls(name="g", prompt="hi")
|
||||
elif cls is AgentNode:
|
||||
inst = cls(name="a", prompt="hi")
|
||||
elif cls is EndCall:
|
||||
inst = cls(name="e", prompt="hi")
|
||||
elif cls is GlobalNode:
|
||||
inst = cls(name="g", prompt="hi")
|
||||
elif cls is Trigger:
|
||||
inst = cls(name="t")
|
||||
elif cls is Webhook:
|
||||
inst = cls(name="wh")
|
||||
else: # Qa
|
||||
inst = cls(name="qa")
|
||||
assert inst.type == expected_type
|
||||
|
||||
|
||||
# ─── add_typed integrates with Workflow and round-trips through DTO ──────
|
||||
|
||||
|
||||
def test_add_typed_builds_valid_workflow(client: _StubClient):
|
||||
wf = Workflow(client=client, name="typed-e2e")
|
||||
start = wf.add_typed(StartCall(name="greeting", prompt="Hi there!"))
|
||||
end = wf.add_typed(EndCall(name="done", prompt="Bye."))
|
||||
wf.edge(start, end, label="done", condition="conversation over")
|
||||
|
||||
payload = wf.to_json()
|
||||
dto = ReactFlowDTO.model_validate(payload)
|
||||
assert len(dto.nodes) == 2
|
||||
assert payload["nodes"][0]["type"] == "startCall"
|
||||
assert payload["nodes"][1]["type"] == "endCall"
|
||||
|
||||
|
||||
def test_add_typed_and_add_produce_identical_data(client: _StubClient):
|
||||
"""The typed path and the generic path should produce identical node
|
||||
data for equivalent inputs."""
|
||||
wf_typed = Workflow(client=client)
|
||||
wf_typed.add_typed(AgentNode(name="q", prompt="ask"))
|
||||
|
||||
wf_generic = Workflow(client=client)
|
||||
wf_generic.add(type="agentNode", name="q", prompt="ask")
|
||||
|
||||
typed_data = wf_typed.to_json()["nodes"][0]["data"]
|
||||
generic_data = wf_generic.to_json()["nodes"][0]["data"]
|
||||
assert typed_data == generic_data
|
||||
|
||||
|
||||
def test_webhook_mutable_defaults_dont_share_state(client: _StubClient):
|
||||
"""Dataclass default_factory ensures every Webhook() gets its own dict."""
|
||||
wf = Workflow(client=client)
|
||||
a = wf.add_typed(Webhook(name="a"))
|
||||
b = wf.add_typed(Webhook(name="b"))
|
||||
payload = wf.to_json()
|
||||
a_data = payload["nodes"][0]["data"]
|
||||
b_data = payload["nodes"][1]["data"]
|
||||
# Both instances must end up with payload_template populated from the
|
||||
# factory; mutating one must not affect the other.
|
||||
assert a_data["payload_template"] is not b_data["payload_template"]
|
||||
_ = a, b
|
||||
|
||||
|
||||
def test_typed_sdk_surfaces_spec_default_to_field(client: _StubClient):
|
||||
"""Spec defaults make it all the way through: StartCall().name defaults
|
||||
to the spec's `"Start Call"` literal."""
|
||||
s = StartCall(prompt="hi")
|
||||
assert s.name == "Start Call"
|
||||
assert s.allow_interrupt is False # matches spec default from earlier edits
|
||||
assert s.add_global_prompt is True
|
||||
|
|
@ -1,11 +1,98 @@
|
|||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from api.services.workflow.dto import ReactFlowDTO
|
||||
from api.services.workflow.dto import ReactFlowDTO, sanitize_workflow_definition
|
||||
|
||||
_FIXTURES_DIR = Path(__file__).parent / "definitions"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dto():
|
||||
# assert no exceptions are raised
|
||||
with open("tests/definitions/rf-1.json", "r") as f:
|
||||
# Path resolved relative to this test file so the test works regardless
|
||||
# of the cwd pytest is invoked from.
|
||||
with open(_FIXTURES_DIR / "rf-1.json", "r") as f:
|
||||
dto = ReactFlowDTO.model_validate_json(f.read())
|
||||
assert dto is not None
|
||||
|
||||
|
||||
def test_sanitize_strips_ui_runtime_fields():
|
||||
definition = {
|
||||
"viewport": {"x": 0, "y": 0, "zoom": 1},
|
||||
"nodes": [
|
||||
{
|
||||
"id": "n1",
|
||||
"type": "startCall",
|
||||
"position": {"x": 0, "y": 0},
|
||||
"width": 200, # ReactFlow-computed, preserved
|
||||
"selected": True, # ReactFlow runtime, preserved
|
||||
"data": {
|
||||
"name": "Start",
|
||||
"prompt": "hi",
|
||||
"greeting": "hello",
|
||||
"invalid": True, # UI-only, should be stripped
|
||||
"validationMessage": "oops", # UI-only, should be stripped
|
||||
"mystery_field": 42, # unknown, should be stripped
|
||||
},
|
||||
},
|
||||
{
|
||||
"id": "n2",
|
||||
"type": "agentNode",
|
||||
"position": {"x": 1, "y": 1},
|
||||
"data": {"name": "A", "prompt": "p", "invalid": False},
|
||||
},
|
||||
],
|
||||
"edges": [
|
||||
{
|
||||
"id": "e1",
|
||||
"source": "n1",
|
||||
"target": "n2",
|
||||
"data": {
|
||||
"label": "next",
|
||||
"condition": "true",
|
||||
"invalid": True, # UI-only, should be stripped
|
||||
},
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
out = sanitize_workflow_definition(definition)
|
||||
|
||||
# Top-level keys preserved
|
||||
assert out["viewport"] == {"x": 0, "y": 0, "zoom": 1}
|
||||
# ReactFlow runtime fields on the node itself preserved
|
||||
assert out["nodes"][0]["width"] == 200
|
||||
assert out["nodes"][0]["selected"] is True
|
||||
|
||||
# node.data stripped of unknowns, known fields kept
|
||||
n1_data = out["nodes"][0]["data"]
|
||||
assert n1_data == {"name": "Start", "prompt": "hi", "greeting": "hello"}
|
||||
assert "invalid" not in n1_data
|
||||
assert "validationMessage" not in n1_data
|
||||
assert "mystery_field" not in n1_data
|
||||
|
||||
n2_data = out["nodes"][1]["data"]
|
||||
assert n2_data == {"name": "A", "prompt": "p"}
|
||||
|
||||
# edge.data stripped
|
||||
assert out["edges"][0]["data"] == {"label": "next", "condition": "true"}
|
||||
|
||||
|
||||
def test_sanitize_noop_on_empty_and_unknown_types():
|
||||
assert sanitize_workflow_definition(None) is None
|
||||
assert sanitize_workflow_definition({}) == {}
|
||||
|
||||
# Unknown node type: pass through unchanged rather than wipe data
|
||||
definition = {
|
||||
"nodes": [
|
||||
{
|
||||
"id": "n1",
|
||||
"type": "unknownType",
|
||||
"position": {"x": 0, "y": 0},
|
||||
"data": {"anything": "goes"},
|
||||
}
|
||||
],
|
||||
"edges": [],
|
||||
}
|
||||
out = sanitize_workflow_definition(definition)
|
||||
assert out["nodes"][0]["data"] == {"anything": "goes"}
|
||||
|
|
|
|||
124
api/tests/test_layout.py
Normal file
124
api/tests/test_layout.py
Normal file
|
|
@ -0,0 +1,124 @@
|
|||
"""Tests for position reconciliation after the LLM save round-trip."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from api.services.workflow.layout import reconcile_positions
|
||||
|
||||
|
||||
def _node(
|
||||
id: str,
|
||||
type: str,
|
||||
*,
|
||||
name: str | None = None,
|
||||
x: float = 0.0,
|
||||
y: float = 0.0,
|
||||
) -> dict:
|
||||
data: dict = {}
|
||||
if name is not None:
|
||||
data["name"] = name
|
||||
return {"id": id, "type": type, "position": {"x": x, "y": y}, "data": data}
|
||||
|
||||
|
||||
def _edge(src: str, tgt: str) -> dict:
|
||||
return {
|
||||
"id": f"{src}-{tgt}",
|
||||
"source": src,
|
||||
"target": tgt,
|
||||
"data": {"label": "x", "condition": "y"},
|
||||
}
|
||||
|
||||
|
||||
def test_named_match_preserves_position():
|
||||
previous = {
|
||||
"nodes": [_node("99", "startCall", name="greeting", x=100, y=200)],
|
||||
"edges": [],
|
||||
}
|
||||
new = {
|
||||
"nodes": [_node("1", "startCall", name="greeting")],
|
||||
"edges": [],
|
||||
}
|
||||
out = reconcile_positions(new, previous)
|
||||
assert out["nodes"][0]["position"] == {"x": 100, "y": 200}
|
||||
|
||||
|
||||
def test_unnamed_match_by_type_ordering():
|
||||
previous = {
|
||||
"nodes": [
|
||||
_node("7", "agentNode", x=-648, y=-158),
|
||||
_node("8", "agentNode", x=500, y=-100),
|
||||
],
|
||||
"edges": [],
|
||||
}
|
||||
new = {
|
||||
"nodes": [
|
||||
_node("1", "agentNode"),
|
||||
_node("2", "agentNode"),
|
||||
],
|
||||
"edges": [],
|
||||
}
|
||||
out = reconcile_positions(new, previous)
|
||||
assert out["nodes"][0]["position"] == {"x": -648, "y": -158}
|
||||
assert out["nodes"][1]["position"] == {"x": 500, "y": -100}
|
||||
|
||||
|
||||
def test_new_node_placed_relative_to_incoming_neighbor():
|
||||
previous = {
|
||||
"nodes": [_node("99", "startCall", name="greeting", x=100, y=200)],
|
||||
"edges": [],
|
||||
}
|
||||
new = {
|
||||
"nodes": [
|
||||
_node("1", "startCall", name="greeting"),
|
||||
_node("2", "agentNode", name="new_node"),
|
||||
],
|
||||
"edges": [_edge("1", "2")],
|
||||
}
|
||||
out = reconcile_positions(new, previous)
|
||||
# Start call keeps its previous position.
|
||||
assert out["nodes"][0]["position"] == {"x": 100, "y": 200}
|
||||
# New node offset from its incoming neighbor.
|
||||
assert out["nodes"][1]["position"] == {"x": 500, "y": 400}
|
||||
|
||||
|
||||
def test_orphan_new_node_stays_at_origin():
|
||||
new = {
|
||||
"nodes": [_node("1", "agentNode", name="orphan")],
|
||||
"edges": [],
|
||||
}
|
||||
out = reconcile_positions(new, None)
|
||||
assert out["nodes"][0]["position"] == {"x": 0.0, "y": 0.0}
|
||||
|
||||
|
||||
def test_named_wins_over_unnamed_ordering():
|
||||
previous = {
|
||||
"nodes": [
|
||||
_node("7", "agentNode", x=-648, y=-158), # unnamed
|
||||
_node("8", "agentNode", name="qualify", x=900, y=900),
|
||||
],
|
||||
"edges": [],
|
||||
}
|
||||
new = {
|
||||
"nodes": [
|
||||
_node("1", "agentNode", name="qualify"), # matches named
|
||||
_node("2", "agentNode"), # falls to unnamed queue
|
||||
],
|
||||
"edges": [],
|
||||
}
|
||||
out = reconcile_positions(new, previous)
|
||||
assert out["nodes"][0]["position"] == {"x": 900, "y": 900}
|
||||
assert out["nodes"][1]["position"] == {"x": -648, "y": -158}
|
||||
|
||||
|
||||
def test_no_previous_keeps_origin_for_all_matched_positions():
|
||||
new = {
|
||||
"nodes": [
|
||||
_node("1", "startCall", name="greeting"),
|
||||
_node("2", "agentNode", name="reply"),
|
||||
],
|
||||
"edges": [_edge("1", "2")],
|
||||
}
|
||||
out = reconcile_positions(new, None)
|
||||
# No previous → first node stays at origin (no incoming), second
|
||||
# node placed relative to its incoming neighbor at origin.
|
||||
assert out["nodes"][0]["position"] == {"x": 0.0, "y": 0.0}
|
||||
assert out["nodes"][1]["position"] == {"x": 400.0, "y": 200.0}
|
||||
225
api/tests/test_mcp_save_workflow.py
Normal file
225
api/tests/test_mcp_save_workflow.py
Normal file
|
|
@ -0,0 +1,225 @@
|
|||
"""Integration tests for the `save_workflow` MCP tool.
|
||||
|
||||
Mocks `authenticate_mcp_request` and the db_client so tests don't need
|
||||
a live DB, but exercises the real TS validator subprocess end-to-end —
|
||||
parse is part of the contract the LLM relies on.
|
||||
|
||||
Round-trip and pure-parser tests live in `test_ts_bridge.py`; this file
|
||||
focuses on the MCP tool's error-routing, version tagging, and DB-call
|
||||
shape.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import shutil
|
||||
from dataclasses import dataclass
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
|
||||
from api.mcp_server.tools.save_workflow import save_workflow
|
||||
|
||||
pytestmark = pytest.mark.skipif(
|
||||
shutil.which("node") is None, reason="node binary not available"
|
||||
)
|
||||
|
||||
|
||||
# ─── Fixtures & helpers ──────────────────────────────────────────────────
|
||||
|
||||
|
||||
@dataclass
|
||||
class _FakeDraft:
|
||||
version_number: int = 2
|
||||
status: str = "draft"
|
||||
|
||||
|
||||
class _FakeWorkflowModel:
|
||||
id = 1
|
||||
organization_id = 1
|
||||
name = "test"
|
||||
# reconcile_positions reads whichever of these holds the previous
|
||||
# stored workflow JSON; None on all three is fine for a greenfield
|
||||
# test and causes reconcile_positions to fall back to the placement
|
||||
# heuristic for any new node.
|
||||
current_definition = None
|
||||
released_definition = None
|
||||
workflow_definition = None
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def authed_user() -> MagicMock:
|
||||
user = MagicMock()
|
||||
user.selected_organization_id = 1
|
||||
user.id = 1
|
||||
return user
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_backends(authed_user: MagicMock):
|
||||
save_mock = AsyncMock(return_value=_FakeDraft())
|
||||
update_mock = AsyncMock(return_value=_FakeWorkflowModel())
|
||||
with (
|
||||
patch(
|
||||
"api.mcp_server.tools.save_workflow.authenticate_mcp_request",
|
||||
AsyncMock(return_value=authed_user),
|
||||
),
|
||||
patch(
|
||||
"api.mcp_server.tools.save_workflow.db_client.get_workflow",
|
||||
AsyncMock(return_value=_FakeWorkflowModel()),
|
||||
),
|
||||
patch(
|
||||
"api.mcp_server.tools.save_workflow.db_client.save_workflow_draft",
|
||||
save_mock,
|
||||
),
|
||||
patch(
|
||||
"api.mcp_server.tools.save_workflow.db_client.update_workflow",
|
||||
update_mock,
|
||||
),
|
||||
patch(
|
||||
"api.mcp_server.tools.save_workflow.db_client.get_draft_version",
|
||||
AsyncMock(return_value=None),
|
||||
),
|
||||
):
|
||||
yield save_mock, update_mock
|
||||
|
||||
|
||||
def _valid_code(name: str = "tool-test") -> str:
|
||||
return f'''import {{ Workflow }} from "@dograh/sdk";
|
||||
import {{ startCall, endCall }} from "@dograh/sdk/typed";
|
||||
|
||||
const wf = new Workflow({{ name: "{name}" }});
|
||||
|
||||
const greeting = wf.addTyped(startCall({{ name: "greeting", prompt: "Hi!" }}));
|
||||
const done = wf.addTyped(endCall({{ name: "done", prompt: "Bye." }}));
|
||||
|
||||
wf.edge(greeting, done, {{ label: "done", condition: "conversation complete" }});
|
||||
'''
|
||||
|
||||
|
||||
# ─── Happy path ──────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_happy_path_saves_draft(mock_backends):
|
||||
save_mock, update_mock = mock_backends
|
||||
# Match the stored name so the rename branch stays dormant here.
|
||||
result = await save_workflow(
|
||||
workflow_id=1, code=_valid_code(name=_FakeWorkflowModel.name)
|
||||
)
|
||||
assert result["saved"] is True
|
||||
assert result["workflow_id"] == 1
|
||||
assert result["version_number"] == 2
|
||||
assert result["status"] == "draft"
|
||||
assert result["node_count"] == 2
|
||||
assert result["edge_count"] == 1
|
||||
assert result["renamed"] is False
|
||||
assert result["name"] == _FakeWorkflowModel.name
|
||||
save_mock.assert_awaited_once()
|
||||
update_mock.assert_not_awaited()
|
||||
payload = save_mock.call_args.kwargs["workflow_definition"]
|
||||
assert len(payload["nodes"]) == 2
|
||||
assert len(payload["edges"]) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rename_propagates_to_update_workflow(mock_backends):
|
||||
save_mock, update_mock = mock_backends
|
||||
result = await save_workflow(workflow_id=1, code=_valid_code(name="renamed"))
|
||||
assert result["saved"] is True
|
||||
assert result["renamed"] is True
|
||||
assert result["name"] == "renamed"
|
||||
update_mock.assert_awaited_once()
|
||||
kwargs = update_mock.call_args.kwargs
|
||||
assert kwargs["workflow_id"] == 1
|
||||
assert kwargs["name"] == "renamed"
|
||||
assert kwargs["workflow_definition"] is None
|
||||
assert kwargs["organization_id"] == 1
|
||||
save_mock.assert_awaited_once()
|
||||
|
||||
|
||||
# ─── Parse-stage rejections ──────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_parser_rejects_disallowed_top_level(mock_backends):
|
||||
save_mock, update_mock = mock_backends
|
||||
code = _valid_code() + "function evil() { return 1; }\n"
|
||||
result = await save_workflow(workflow_id=1, code=code)
|
||||
assert result["saved"] is False
|
||||
assert result["error_code"] == "parse_error"
|
||||
save_mock.assert_not_awaited()
|
||||
update_mock.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_parser_rejects_unknown_factory(mock_backends):
|
||||
save_mock, update_mock = mock_backends
|
||||
code = """import { Workflow } from "@dograh/sdk";
|
||||
const wf = new Workflow({ name: "x" });
|
||||
const n = wf.addTyped(fakeNode({ name: "x", prompt: "y" }));
|
||||
"""
|
||||
result = await save_workflow(workflow_id=1, code=code)
|
||||
assert result["saved"] is False
|
||||
assert result["error_code"] == "parse_error"
|
||||
assert "Unknown node type" in result["error"]
|
||||
save_mock.assert_not_awaited()
|
||||
update_mock.assert_not_awaited()
|
||||
|
||||
|
||||
# ─── Validation-stage rejections ─────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unknown_field_surfaces_validation_error(mock_backends):
|
||||
save_mock, update_mock = mock_backends
|
||||
code = """import { Workflow } from "@dograh/sdk";
|
||||
import { startCall } from "@dograh/sdk/typed";
|
||||
const wf = new Workflow({ name: "x" });
|
||||
const n = wf.addTyped(startCall({ name: "g", prompt: "hi", promt: "typo" }));
|
||||
"""
|
||||
result = await save_workflow(workflow_id=1, code=code)
|
||||
assert result["saved"] is False
|
||||
assert result["error_code"] == "validation_error"
|
||||
assert "Unknown field" in result["error"]
|
||||
save_mock.assert_not_awaited()
|
||||
update_mock.assert_not_awaited()
|
||||
|
||||
|
||||
# ─── Graph-stage rejections ──────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_graph_validation_catches_missing_start_node(mock_backends):
|
||||
save_mock, update_mock = mock_backends
|
||||
# Only an end node — WorkflowGraph requires exactly one start node.
|
||||
code = """import { Workflow } from "@dograh/sdk";
|
||||
import { endCall } from "@dograh/sdk/typed";
|
||||
const wf = new Workflow({ name: "orphan" });
|
||||
const only = wf.addTyped(endCall({ name: "only", prompt: "bye" }));
|
||||
"""
|
||||
result = await save_workflow(workflow_id=1, code=code)
|
||||
assert result["saved"] is False
|
||||
assert result["error_code"] == "graph_validation"
|
||||
save_mock.assert_not_awaited()
|
||||
update_mock.assert_not_awaited()
|
||||
|
||||
|
||||
# ─── Workflow not found / unauthorized ───────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unknown_workflow_raises_404(authed_user: MagicMock):
|
||||
with (
|
||||
patch(
|
||||
"api.mcp_server.tools.save_workflow.authenticate_mcp_request",
|
||||
AsyncMock(return_value=authed_user),
|
||||
),
|
||||
patch(
|
||||
"api.mcp_server.tools.save_workflow.db_client.get_workflow",
|
||||
AsyncMock(return_value=None),
|
||||
),
|
||||
):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await save_workflow(workflow_id=999, code=_valid_code())
|
||||
assert exc_info.value.status_code == 404
|
||||
196
api/tests/test_node_specs.py
Normal file
196
api/tests/test_node_specs.py
Normal file
|
|
@ -0,0 +1,196 @@
|
|||
"""Spec-quality lint.
|
||||
|
||||
Catches drift between NodeSpecs and the rest of the system before it lands:
|
||||
- Placeholder/empty descriptions
|
||||
- Missing examples
|
||||
- display_options referencing fields that don't exist
|
||||
- Examples that don't validate against the per-type Pydantic DTO
|
||||
- Spec name not matching a discriminator value in dto.py
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
|
||||
import pytest
|
||||
|
||||
from api.services.workflow.dto import NodeType, ReactFlowDTO
|
||||
from api.services.workflow.node_specs import (
|
||||
NodeSpec,
|
||||
PropertySpec,
|
||||
PropertyType,
|
||||
all_specs,
|
||||
)
|
||||
|
||||
PLACEHOLDER_DESCRIPTION_PATTERN = re.compile(
|
||||
r"^\s*(todo|fixme|tbd|xxx|\.\.\.|placeholder|description|n/?a|\?)\s*\.?\s*$",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
|
||||
def _walk_properties(props: list[PropertySpec], path: str = ""):
|
||||
"""Yield (full_path, property) for every property and nested sub-property."""
|
||||
for prop in props:
|
||||
full_path = f"{path}.{prop.name}" if path else prop.name
|
||||
yield full_path, prop
|
||||
if prop.properties:
|
||||
yield from _walk_properties(prop.properties, full_path)
|
||||
|
||||
|
||||
# ─────────────────────────────────────────────────────────────────────────
|
||||
# Lint
|
||||
# ─────────────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.parametrize("spec", all_specs(), ids=lambda s: s.name)
|
||||
def test_node_spec_has_non_placeholder_description(spec: NodeSpec):
|
||||
assert spec.description.strip(), f"{spec.name}: empty description"
|
||||
assert not PLACEHOLDER_DESCRIPTION_PATTERN.match(spec.description), (
|
||||
f"{spec.name}: description looks like a placeholder: {spec.description!r}"
|
||||
)
|
||||
assert len(spec.description) >= 20, (
|
||||
f"{spec.name}: description too short to be useful for an LLM "
|
||||
f"({len(spec.description)} chars)"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("spec", all_specs(), ids=lambda s: s.name)
|
||||
def test_node_spec_has_at_least_one_example(spec: NodeSpec):
|
||||
assert spec.examples, (
|
||||
f"{spec.name}: must have at least one NodeExample so LLMs have a "
|
||||
f"realistic shape to pattern-match."
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("spec", all_specs(), ids=lambda s: s.name)
|
||||
def test_property_descriptions_non_placeholder(spec: NodeSpec):
|
||||
for path, prop in _walk_properties(spec.properties):
|
||||
assert prop.description.strip(), f"{spec.name}.{path}: empty description"
|
||||
assert not PLACEHOLDER_DESCRIPTION_PATTERN.match(prop.description), (
|
||||
f"{spec.name}.{path}: description looks like a placeholder: "
|
||||
f"{prop.description!r}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("spec", all_specs(), ids=lambda s: s.name)
|
||||
def test_display_options_reference_real_fields(spec: NodeSpec):
|
||||
"""A property's display_options must only reference sibling property
|
||||
names. Nested properties are scoped to their parent's siblings."""
|
||||
|
||||
def _check(scope_props: list[PropertySpec], scope_path: str = ""):
|
||||
names_in_scope = {p.name for p in scope_props}
|
||||
for prop in scope_props:
|
||||
current_path = f"{scope_path}.{prop.name}" if scope_path else prop.name
|
||||
if prop.display_options:
|
||||
refs = set((prop.display_options.show or {}).keys()) | set(
|
||||
(prop.display_options.hide or {}).keys()
|
||||
)
|
||||
missing = refs - names_in_scope
|
||||
assert not missing, (
|
||||
f"{spec.name}.{current_path}: display_options references "
|
||||
f"unknown sibling fields: {sorted(missing)}"
|
||||
)
|
||||
if prop.properties:
|
||||
_check(prop.properties, current_path)
|
||||
|
||||
_check(spec.properties)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("spec", all_specs(), ids=lambda s: s.name)
|
||||
def test_options_properties_have_options(spec: NodeSpec):
|
||||
for path, prop in _walk_properties(spec.properties):
|
||||
if prop.type in (PropertyType.options, PropertyType.multi_options):
|
||||
assert prop.options, (
|
||||
f"{spec.name}.{path}: type={prop.type.value} requires at "
|
||||
f"least one PropertyOption."
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("spec", all_specs(), ids=lambda s: s.name)
|
||||
def test_fixed_collection_has_sub_properties(spec: NodeSpec):
|
||||
for path, prop in _walk_properties(spec.properties):
|
||||
if prop.type == PropertyType.fixed_collection:
|
||||
assert prop.properties, (
|
||||
f"{spec.name}.{path}: fixed_collection requires nested "
|
||||
f"`properties` describing each row."
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("spec", all_specs(), ids=lambda s: s.name)
|
||||
def test_spec_name_matches_dto_discriminator(spec: NodeSpec):
|
||||
valid_names = {t.value for t in NodeType}
|
||||
assert spec.name in valid_names, (
|
||||
f"NodeSpec {spec.name!r} doesn't match any NodeType discriminator. "
|
||||
f"Valid: {sorted(valid_names)}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("spec", all_specs(), ids=lambda s: s.name)
|
||||
def test_examples_validate_against_dto(spec: NodeSpec):
|
||||
"""Each NodeExample.data must pass per-type DTO validation. This stops
|
||||
examples from drifting away from the actual wire schema."""
|
||||
for ex in spec.examples:
|
||||
wire_node = {
|
||||
"id": "example",
|
||||
"type": spec.name,
|
||||
"position": {"x": 0, "y": 0},
|
||||
"data": ex.data,
|
||||
}
|
||||
# Build a minimal valid graph: example node plus a synthetic peer if
|
||||
# graph_constraints require an incoming or outgoing edge.
|
||||
nodes = [wire_node]
|
||||
edges: list[dict] = []
|
||||
constraints = spec.graph_constraints
|
||||
|
||||
if constraints and (constraints.min_outgoing or 0) > 0:
|
||||
nodes.append(
|
||||
{
|
||||
"id": "downstream",
|
||||
"type": "endCall",
|
||||
"position": {"x": 0, "y": 0},
|
||||
"data": {"name": "End", "prompt": "End", "is_end": True},
|
||||
}
|
||||
)
|
||||
edges.append(
|
||||
{
|
||||
"id": "e_out",
|
||||
"source": "example",
|
||||
"target": "downstream",
|
||||
"data": {"label": "next", "condition": "next"},
|
||||
}
|
||||
)
|
||||
|
||||
if constraints and (constraints.min_incoming or 0) > 0:
|
||||
nodes.append(
|
||||
{
|
||||
"id": "upstream",
|
||||
"type": "startCall",
|
||||
"position": {"x": 0, "y": 0},
|
||||
"data": {
|
||||
"name": "Start",
|
||||
"prompt": "Hello",
|
||||
"is_start": True,
|
||||
},
|
||||
}
|
||||
)
|
||||
edges.append(
|
||||
{
|
||||
"id": "e_in",
|
||||
"source": "upstream",
|
||||
"target": "example",
|
||||
"data": {"label": "in", "condition": "in"},
|
||||
}
|
||||
)
|
||||
|
||||
# Validate. If this raises, the example is broken.
|
||||
ReactFlowDTO.model_validate({"nodes": nodes, "edges": edges})
|
||||
|
||||
|
||||
def test_all_dto_types_have_specs():
|
||||
"""Every NodeType discriminator value must have a registered NodeSpec —
|
||||
catches the case where someone adds a new node type to dto.py but
|
||||
forgets to author a spec."""
|
||||
spec_names = {s.name for s in all_specs()}
|
||||
type_values = {t.value for t in NodeType}
|
||||
missing = type_values - spec_names
|
||||
assert not missing, f"NodeType discriminators without specs: {sorted(missing)}"
|
||||
|
|
@ -27,12 +27,13 @@ import pytest
|
|||
from api.enums import ToolCategory
|
||||
from api.services.workflow.dto import (
|
||||
EdgeDataDTO,
|
||||
NodeDataDTO,
|
||||
NodeType,
|
||||
EndCallNodeData,
|
||||
EndCallRFNode,
|
||||
Position,
|
||||
ReactFlowDTO,
|
||||
RFEdgeDTO,
|
||||
RFNodeDTO,
|
||||
StartCallNodeData,
|
||||
StartCallRFNode,
|
||||
)
|
||||
from api.services.workflow.pipecat_engine import PipecatEngine
|
||||
from api.services.workflow.pipecat_engine_custom_tools import CustomToolManager
|
||||
|
|
@ -1013,11 +1014,10 @@ class TestEndCallExtractionBehavior:
|
|||
# Create a workflow where start node has NO extraction
|
||||
dto = ReactFlowDTO(
|
||||
nodes=[
|
||||
RFNodeDTO(
|
||||
StartCallRFNode(
|
||||
id="start",
|
||||
type=NodeType.startNode,
|
||||
position=Position(x=0, y=0),
|
||||
data=NodeDataDTO(
|
||||
data=StartCallNodeData(
|
||||
name="Start Call",
|
||||
prompt=START_CALL_SYSTEM_PROMPT,
|
||||
is_start=True,
|
||||
|
|
@ -1026,11 +1026,10 @@ class TestEndCallExtractionBehavior:
|
|||
extraction_enabled=False, # No extraction
|
||||
),
|
||||
),
|
||||
RFNodeDTO(
|
||||
EndCallRFNode(
|
||||
id="end",
|
||||
type=NodeType.endNode,
|
||||
position=Position(x=0, y=200),
|
||||
data=NodeDataDTO(
|
||||
data=EndCallNodeData(
|
||||
name="End Call",
|
||||
prompt=END_CALL_SYSTEM_PROMPT,
|
||||
is_end=True,
|
||||
|
|
|
|||
99
api/tests/test_sdk_sync.py
Normal file
99
api/tests/test_sdk_sync.py
Normal file
|
|
@ -0,0 +1,99 @@
|
|||
"""Drift guard: committed SDK typed files must match what codegen
|
||||
produces from the current `node_specs/` registry.
|
||||
|
||||
Fails loudly if a spec was edited without running
|
||||
`./scripts/generate_sdk.sh`. CI also runs the full script and asserts
|
||||
an empty `git diff` as the authoritative cross-language check; this
|
||||
test is the fast local feedback loop inside pytest.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
# Ensure the Python SDK package is importable without requiring a
|
||||
# `pip install -e sdk/python`. The codegen lives there because it ships
|
||||
# with the SDK wheel, but tests need to reach it directly.
|
||||
REPO_ROOT = Path(__file__).resolve().parents[2]
|
||||
SDK_PY_SRC = REPO_ROOT / "sdk" / "python" / "src"
|
||||
if str(SDK_PY_SRC) not in sys.path:
|
||||
sys.path.insert(0, str(SDK_PY_SRC))
|
||||
|
||||
from dograh_sdk.codegen import generate_all # noqa: E402
|
||||
|
||||
from api.services.workflow.node_specs import SPEC_VERSION, all_specs # noqa: E402
|
||||
|
||||
PY_OUT = REPO_ROOT / "sdk" / "python" / "src" / "dograh_sdk" / "typed"
|
||||
TS_OUT = REPO_ROOT / "sdk" / "typescript" / "src" / "typed"
|
||||
TS_CODEGEN = REPO_ROOT / "sdk" / "typescript" / "scripts" / "codegen.mts"
|
||||
REGEN_HINT = "Run ./scripts/generate_sdk.sh to regenerate."
|
||||
|
||||
|
||||
def _specs_payload() -> dict:
|
||||
return {
|
||||
"spec_version": SPEC_VERSION,
|
||||
"node_types": [s.model_dump(mode="json") for s in all_specs()],
|
||||
}
|
||||
|
||||
|
||||
def _compare_trees(expected_dir: Path, actual_dir: Path, *, skip: set[str]) -> None:
|
||||
def tree(d: Path) -> dict[str, str]:
|
||||
return {
|
||||
p.name: p.read_text()
|
||||
for p in d.iterdir()
|
||||
if p.is_file() and p.name not in skip
|
||||
}
|
||||
|
||||
expected = tree(expected_dir)
|
||||
actual = tree(actual_dir)
|
||||
|
||||
if expected.keys() != actual.keys():
|
||||
pytest.fail(
|
||||
f"File set differs in {expected_dir.name}/.\n"
|
||||
f" committed: {sorted(expected)}\n"
|
||||
f" generated: {sorted(actual)}\n"
|
||||
f"{REGEN_HINT}"
|
||||
)
|
||||
for name in sorted(expected):
|
||||
if expected[name] != actual[name]:
|
||||
pytest.fail(
|
||||
f"{expected_dir.name}/{name} is out of sync with node_specs. "
|
||||
f"{REGEN_HINT}"
|
||||
)
|
||||
|
||||
|
||||
def test_python_sdk_typed_in_sync(tmp_path: Path) -> None:
|
||||
specs = _specs_payload()["node_types"]
|
||||
generate_all(specs, tmp_path)
|
||||
# _base.py is hand-written and lives alongside generated files.
|
||||
_compare_trees(PY_OUT, tmp_path, skip={"_base.py", "__pycache__"})
|
||||
|
||||
|
||||
@pytest.mark.skipif(shutil.which("node") is None, reason="node binary not available")
|
||||
def test_typescript_sdk_typed_in_sync(tmp_path: Path) -> None:
|
||||
specs_file = tmp_path / "specs.json"
|
||||
specs_file.write_text(json.dumps(_specs_payload()))
|
||||
out = tmp_path / "ts_out"
|
||||
|
||||
result = subprocess.run(
|
||||
[
|
||||
"node",
|
||||
str(TS_CODEGEN),
|
||||
"--input",
|
||||
str(specs_file),
|
||||
"--out",
|
||||
str(out),
|
||||
],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
)
|
||||
assert result.returncode == 0, (
|
||||
f"TS codegen failed:\nstdout: {result.stdout}\nstderr: {result.stderr}"
|
||||
)
|
||||
_compare_trees(TS_OUT, out, skip=set())
|
||||
|
|
@ -15,12 +15,13 @@ import pytest
|
|||
from api.services.pipecat.recording_audio_cache import RecordingAudio
|
||||
from api.services.workflow.dto import (
|
||||
EdgeDataDTO,
|
||||
NodeDataDTO,
|
||||
NodeType,
|
||||
EndCallNodeData,
|
||||
EndCallRFNode,
|
||||
Position,
|
||||
ReactFlowDTO,
|
||||
RFEdgeDTO,
|
||||
RFNodeDTO,
|
||||
StartCallNodeData,
|
||||
StartCallRFNode,
|
||||
)
|
||||
from api.services.workflow.pipecat_engine import PipecatEngine
|
||||
from api.services.workflow.pipecat_engine_custom_tools import CustomToolManager
|
||||
|
|
@ -64,11 +65,10 @@ def text_workflow() -> WorkflowGraph:
|
|||
"""Start->End workflow with text greeting and text transition speech."""
|
||||
dto = ReactFlowDTO(
|
||||
nodes=[
|
||||
RFNodeDTO(
|
||||
StartCallRFNode(
|
||||
id="start",
|
||||
type=NodeType.startNode,
|
||||
position=Position(x=0, y=0),
|
||||
data=NodeDataDTO(
|
||||
data=StartCallNodeData(
|
||||
name="Start Call",
|
||||
prompt=START_PROMPT,
|
||||
is_start=True,
|
||||
|
|
@ -79,11 +79,10 @@ def text_workflow() -> WorkflowGraph:
|
|||
extraction_enabled=False,
|
||||
),
|
||||
),
|
||||
RFNodeDTO(
|
||||
EndCallRFNode(
|
||||
id="end",
|
||||
type=NodeType.endNode,
|
||||
position=Position(x=0, y=200),
|
||||
data=NodeDataDTO(
|
||||
data=EndCallNodeData(
|
||||
name="End Call",
|
||||
prompt=END_PROMPT,
|
||||
is_end=True,
|
||||
|
|
@ -115,11 +114,10 @@ def audio_workflow() -> WorkflowGraph:
|
|||
"""Start->End workflow with audio greeting and audio transition speech."""
|
||||
dto = ReactFlowDTO(
|
||||
nodes=[
|
||||
RFNodeDTO(
|
||||
StartCallRFNode(
|
||||
id="start",
|
||||
type=NodeType.startNode,
|
||||
position=Position(x=0, y=0),
|
||||
data=NodeDataDTO(
|
||||
data=StartCallNodeData(
|
||||
name="Start Call",
|
||||
prompt=START_PROMPT,
|
||||
is_start=True,
|
||||
|
|
@ -130,11 +128,10 @@ def audio_workflow() -> WorkflowGraph:
|
|||
extraction_enabled=False,
|
||||
),
|
||||
),
|
||||
RFNodeDTO(
|
||||
EndCallRFNode(
|
||||
id="end",
|
||||
type=NodeType.endNode,
|
||||
position=Position(x=0, y=200),
|
||||
data=NodeDataDTO(
|
||||
data=EndCallNodeData(
|
||||
name="End Call",
|
||||
prompt=END_PROMPT,
|
||||
is_end=True,
|
||||
|
|
@ -293,11 +290,10 @@ class TestStartGreeting:
|
|||
"""No greeting configured should return None."""
|
||||
dto = ReactFlowDTO(
|
||||
nodes=[
|
||||
RFNodeDTO(
|
||||
StartCallRFNode(
|
||||
id="start",
|
||||
type=NodeType.startNode,
|
||||
position=Position(x=0, y=0),
|
||||
data=NodeDataDTO(
|
||||
data=StartCallNodeData(
|
||||
name="Start",
|
||||
prompt="Prompt",
|
||||
is_start=True,
|
||||
|
|
@ -305,11 +301,10 @@ class TestStartGreeting:
|
|||
extraction_enabled=False,
|
||||
),
|
||||
),
|
||||
RFNodeDTO(
|
||||
EndCallRFNode(
|
||||
id="end",
|
||||
type=NodeType.endNode,
|
||||
position=Position(x=0, y=200),
|
||||
data=NodeDataDTO(
|
||||
data=EndCallNodeData(
|
||||
name="End",
|
||||
prompt="End",
|
||||
is_end=True,
|
||||
|
|
@ -338,11 +333,10 @@ class TestStartGreeting:
|
|||
"""Text greeting with {{variable}} placeholders should be rendered."""
|
||||
dto = ReactFlowDTO(
|
||||
nodes=[
|
||||
RFNodeDTO(
|
||||
StartCallRFNode(
|
||||
id="start",
|
||||
type=NodeType.startNode,
|
||||
position=Position(x=0, y=0),
|
||||
data=NodeDataDTO(
|
||||
data=StartCallNodeData(
|
||||
name="Start",
|
||||
prompt="Prompt",
|
||||
is_start=True,
|
||||
|
|
@ -352,11 +346,10 @@ class TestStartGreeting:
|
|||
extraction_enabled=False,
|
||||
),
|
||||
),
|
||||
RFNodeDTO(
|
||||
EndCallRFNode(
|
||||
id="end",
|
||||
type=NodeType.endNode,
|
||||
position=Position(x=0, y=200),
|
||||
data=NodeDataDTO(
|
||||
data=EndCallNodeData(
|
||||
name="End",
|
||||
prompt="End",
|
||||
is_end=True,
|
||||
|
|
|
|||
275
api/tests/test_ts_bridge.py
Normal file
275
api/tests/test_ts_bridge.py
Normal file
|
|
@ -0,0 +1,275 @@
|
|||
"""End-to-end tests for the Node TS validator bridge.
|
||||
|
||||
Exercises the real `node` subprocess — slow-ish but the whole point is
|
||||
that code → JSON and JSON → code round-trip losslessly.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import shutil
|
||||
|
||||
import pytest
|
||||
|
||||
from api.mcp_server.ts_bridge import TsBridgeError, generate_code, parse_code
|
||||
|
||||
pytestmark = pytest.mark.skipif(
|
||||
shutil.which("node") is None, reason="node binary not available"
|
||||
)
|
||||
|
||||
|
||||
def _minimal_workflow() -> dict:
|
||||
"""Start → End, one edge. Stored shape matches ReactFlowDTO."""
|
||||
return {
|
||||
"nodes": [
|
||||
{
|
||||
"id": "1",
|
||||
"type": "startCall",
|
||||
"position": {"x": 0, "y": 0},
|
||||
"data": {
|
||||
"name": "Greeting",
|
||||
"prompt": "Greet warmly.",
|
||||
"greeting_type": "text",
|
||||
"greeting": "Hi {{first_name}}!",
|
||||
"allow_interrupt": True,
|
||||
},
|
||||
},
|
||||
{
|
||||
"id": "2",
|
||||
"type": "endCall",
|
||||
"position": {"x": 200, "y": 0},
|
||||
"data": {"name": "Done", "prompt": "Say goodbye."},
|
||||
},
|
||||
],
|
||||
"edges": [
|
||||
{
|
||||
"id": "1-2",
|
||||
"source": "1",
|
||||
"target": "2",
|
||||
"data": {"label": "done", "condition": "conversation complete"},
|
||||
},
|
||||
],
|
||||
"viewport": {"x": 0, "y": 0, "zoom": 1},
|
||||
}
|
||||
|
||||
|
||||
def _normalize(wf: dict) -> dict:
|
||||
"""Strip cosmetics before comparing a round-tripped workflow.
|
||||
|
||||
Node IDs are regenerated deterministically by the parser
|
||||
(1, 2, 3, ...) so the inputs already match if constructed that way.
|
||||
Position is preserved. Edge ids follow `source-target`.
|
||||
"""
|
||||
return {
|
||||
"nodes": [
|
||||
{
|
||||
"id": n["id"],
|
||||
"type": n["type"],
|
||||
"position": n["position"],
|
||||
"data": n["data"],
|
||||
}
|
||||
for n in wf["nodes"]
|
||||
],
|
||||
"edges": [
|
||||
{
|
||||
"id": e["id"],
|
||||
"source": e["source"],
|
||||
"target": e["target"],
|
||||
"data": e["data"],
|
||||
}
|
||||
for e in wf["edges"]
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
# ─── generate_code ───────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_emits_imports_and_factories():
|
||||
code = await generate_code(_minimal_workflow(), workflow_name="test")
|
||||
assert 'import { Workflow } from "@dograh/sdk";' in code
|
||||
assert "startCall" in code
|
||||
assert "endCall" in code
|
||||
assert "wf.addTyped(startCall(" in code
|
||||
assert "wf.edge(" in code
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_strips_spec_defaults():
|
||||
wf = _minimal_workflow()
|
||||
code = await generate_code(wf)
|
||||
# `add_global_prompt=True` is a spec default for startCall; emitted
|
||||
# code should omit it. Keeps the LLM-facing projection tight.
|
||||
assert "add_global_prompt" not in code
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_omits_position():
|
||||
"""Positions are hidden from the LLM — auto-layout post-processing
|
||||
(future) reassigns them on save. Keeping them out of the edit
|
||||
surface avoids the LLM producing cramped/overlapping layouts."""
|
||||
wf = _minimal_workflow()
|
||||
code = await generate_code(wf)
|
||||
assert "position" not in code
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_strips_legacy_ui_state_fields():
|
||||
"""Stored workflows from before spec validation carry UI-state fields
|
||||
(`invalid`, `selected`, `is_start`, etc.). `get_workflow_code` hides
|
||||
those from the LLM so edits don't round-trip the noise."""
|
||||
wf = {
|
||||
"nodes": [
|
||||
{
|
||||
"id": "1",
|
||||
"type": "startCall",
|
||||
"position": {"x": 0, "y": 0},
|
||||
"data": {
|
||||
"name": "g",
|
||||
"prompt": "hi",
|
||||
"invalid": False,
|
||||
"validationMessage": None,
|
||||
"is_start": True,
|
||||
"selected": True,
|
||||
"dragging": False,
|
||||
},
|
||||
},
|
||||
],
|
||||
"edges": [],
|
||||
"viewport": {"x": 0, "y": 0, "zoom": 1},
|
||||
}
|
||||
code = await generate_code(wf)
|
||||
for dropped in ("invalid", "validationMessage", "is_start", "selected", "dragging"):
|
||||
assert dropped not in code, f"{dropped} should be stripped"
|
||||
assert 'prompt: "hi"' in code
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_strips_unknown_edge_fields():
|
||||
wf = _minimal_workflow()
|
||||
wf["edges"][0]["data"]["invalid"] = False
|
||||
wf["edges"][0]["data"]["validationMessage"] = None
|
||||
code = await generate_code(wf)
|
||||
assert "invalid" not in code
|
||||
assert "validationMessage" not in code
|
||||
|
||||
|
||||
# ─── parse_code ──────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_parse_accepts_minimal_code():
|
||||
code = """import { Workflow } from "@dograh/sdk";
|
||||
import { startCall, endCall } from "@dograh/sdk/typed";
|
||||
|
||||
const wf = new Workflow({ name: "min" });
|
||||
const a = wf.addTyped(startCall({ name: "g", prompt: "hi" }));
|
||||
const b = wf.addTyped(endCall({ name: "d", prompt: "bye" }));
|
||||
wf.edge(a, b, { label: "done", condition: "wrapped" });
|
||||
"""
|
||||
result = await parse_code(code)
|
||||
assert result["ok"] is True
|
||||
wf = result["workflow"]
|
||||
assert len(wf["nodes"]) == 2
|
||||
assert len(wf["edges"]) == 1
|
||||
assert wf["nodes"][0]["type"] == "startCall"
|
||||
assert wf["edges"][0]["source"] == wf["nodes"][0]["id"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_parse_rejects_function_declaration():
|
||||
code = """import { Workflow } from "@dograh/sdk";
|
||||
const wf = new Workflow({ name: "x" });
|
||||
function evil() { return 1; }
|
||||
"""
|
||||
result = await parse_code(code)
|
||||
assert result["ok"] is False
|
||||
assert result["stage"] == "parse"
|
||||
assert any("FunctionDeclaration" in e["message"] for e in result["errors"])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_parse_rejects_unknown_field():
|
||||
code = """import { Workflow } from "@dograh/sdk";
|
||||
import { startCall } from "@dograh/sdk/typed";
|
||||
const wf = new Workflow({ name: "x" });
|
||||
const a = wf.addTyped(startCall({ name: "g", prompt: "hi", promt: "typo" }));
|
||||
"""
|
||||
result = await parse_code(code)
|
||||
assert result["ok"] is False
|
||||
assert result["stage"] == "validate"
|
||||
assert any("Unknown field" in e["message"] for e in result["errors"])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_parse_rejects_unknown_variable_in_edge():
|
||||
code = """import { Workflow } from "@dograh/sdk";
|
||||
import { startCall, endCall } from "@dograh/sdk/typed";
|
||||
const wf = new Workflow({ name: "x" });
|
||||
const a = wf.addTyped(startCall({ name: "g", prompt: "hi" }));
|
||||
wf.edge(a, missing, { label: "done", condition: "c" });
|
||||
"""
|
||||
result = await parse_code(code)
|
||||
assert result["ok"] is False
|
||||
assert result["stage"] == "parse"
|
||||
assert any("Unknown node variable" in e["message"] for e in result["errors"])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_parse_requires_label_and_condition_on_edge():
|
||||
code = """import { Workflow } from "@dograh/sdk";
|
||||
import { startCall, endCall } from "@dograh/sdk/typed";
|
||||
const wf = new Workflow({ name: "x" });
|
||||
const a = wf.addTyped(startCall({ name: "g", prompt: "hi" }));
|
||||
const b = wf.addTyped(endCall({ name: "d", prompt: "bye" }));
|
||||
wf.edge(a, b, { label: "", condition: "c" });
|
||||
"""
|
||||
result = await parse_code(code)
|
||||
assert result["ok"] is False
|
||||
assert result["stage"] == "parse"
|
||||
|
||||
|
||||
# ─── Round-trip ──────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_round_trip_minimal():
|
||||
wf = _minimal_workflow()
|
||||
code = await generate_code(wf, workflow_name="rt")
|
||||
result = await parse_code(code)
|
||||
assert result["ok"] is True, result
|
||||
# Positions are intentionally not preserved — they'll be reassigned
|
||||
# by a downstream auto-layout pass. Parser defaults to {0, 0}.
|
||||
for in_node, out_node in zip(wf["nodes"], result["workflow"]["nodes"]):
|
||||
assert out_node["type"] == in_node["type"]
|
||||
assert out_node["position"] == {"x": 0, "y": 0}
|
||||
for k, v in in_node["data"].items():
|
||||
assert out_node["data"][k] == v, (
|
||||
f"{k}: {out_node['data'].get(k)!r} != {v!r}"
|
||||
)
|
||||
assert _normalize({"nodes": [], "edges": result["workflow"]["edges"]})["edges"] == [
|
||||
{
|
||||
"id": "1-2",
|
||||
"source": "1",
|
||||
"target": "2",
|
||||
"data": {"label": "done", "condition": "conversation complete"},
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_fails_on_unknown_type():
|
||||
bad = {
|
||||
"nodes": [
|
||||
{
|
||||
"id": "1",
|
||||
"type": "doesNotExist",
|
||||
"position": {"x": 0, "y": 0},
|
||||
"data": {},
|
||||
}
|
||||
],
|
||||
"edges": [],
|
||||
"viewport": {"x": 0, "y": 0, "zoom": 1},
|
||||
}
|
||||
with pytest.raises(TsBridgeError, match="Unknown node type"):
|
||||
await generate_code(bad)
|
||||
Loading…
Add table
Add a link
Reference in a new issue