Implement cost calculator for Tuber (#471)

* Adding cost calculation in tuner for BYOK

* fix

* implement cost calculator for Tuner

* Update api/services/integrations/tuner/completion.py

Co-authored-by: cubic-dev-ai[bot] <191113872+cubic-dev-ai[bot]@users.noreply.github.com>

* feat: expose render_options in node spec

* Update api/services/integrations/registry.py

Co-authored-by: cubic-dev-ai[bot] <191113872+cubic-dev-ai[bot]@users.noreply.github.com>

---------

Co-authored-by: mohamed salem <259547077+mohamedsalem-bot@users.noreply.github.com>
Co-authored-by: cubic-dev-ai[bot] <191113872+cubic-dev-ai[bot]@users.noreply.github.com>
Co-authored-by: Abhishek Kumar <abhishek@a6k.me>
This commit is contained in:
Mohamed-Mamdouh 2026-07-02 08:21:14 +01:00 committed by GitHub
parent 97803b8121
commit 65d46bc313
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
16 changed files with 479 additions and 188 deletions

View file

@ -18,5 +18,5 @@ bcrypt==5.0.0
email-validator==2.3.0
posthog==7.19.1
fastmcp==3.2.4
tuner-pipecat-sdk==0.2.0
tuner-pipecat-sdk==0.2.4
PyNaCl==1.6.2

View file

@ -2,6 +2,8 @@ from __future__ import annotations
from typing import Any
from loguru import logger
from api.services.integrations.base import (
IntegrationCompletionContext,
IntegrationNodeRegistration,
@ -122,7 +124,15 @@ async def run_completion_handlers(
for package, nodes in iter_completion_packages(context.workflow_definition):
if package.run_completion is None:
continue
package_result = await package.run_completion(nodes, context)
try:
package_result = await package.run_completion(nodes, context)
except Exception as exc:
logger.exception(
f"Integration completion handler failed for package "
f"{package.name!r}: {exc}"
)
results[f"integration_{package.name}"] = {"error": "completion_handler_failed"}
continue
if package_result:
results.update(package_result)
return results

View file

@ -1,48 +1,21 @@
from __future__ import annotations
import time
from collections import deque
from dataclasses import dataclass
from typing import Any, Callable
from typing import Any
from loguru import logger
from pipecat.frames.frames import (
BotStartedSpeakingFrame,
BotStoppedSpeakingFrame,
CancelFrame,
EndFrame,
FunctionCallInProgressFrame,
FunctionCallResultFrame,
MetricsFrame,
StartFrame,
UserStartedSpeakingFrame,
UserStoppedSpeakingFrame,
VADUserStoppedSpeakingFrame,
)
from pipecat.observers.base_observer import BaseObserver, FramePushed
from pipecat.observers.turn_tracking_observer import TurnTrackingObserver
from pipecat.observers.user_bot_latency_observer import UserBotLatencyObserver
from pipecat.processors.frame_processor import FrameDirection
from pipecat.utils.context.message_sanitization import strip_thought_ids_from_messages
from tuner_pipecat_sdk.accumulator import CallAccumulator
from tuner_pipecat_sdk.payload_builder import build_payload
from tuner_pipecat_sdk import Observer
from api.enums import WorkflowRunMode
TUNER_RECORDING_PLACEHOLDER = "pipecat://no-recording"
@dataclass(frozen=True)
class _PayloadConfig:
call_id: str
call_type: str
recording_url: str
asr_model: str
llm_model: str
tts_model: str
sip_call_id: str | None = None
sip_headers: dict[str, str] | None = None
agent_version: int | None = None
# Placeholder credentials for the SDK Observer's TunerConfig. Real BYOK credentials
# (api_key / workspace_id / agent_id) are per tuner node and are applied later during
# the deferred delivery phase (completion.py), so they are not known here. TunerConfig
# validators require a non-empty api_key/agent_id and a positive workspace_id, hence
# these placeholders.
_DEFERRED_API_KEY = "deferred"
_DEFERRED_WORKSPACE_ID = 1
_DEFERRED_AGENT_ID = "deferred"
def mode_to_tuner_call_type(mode: str | None) -> str:
@ -54,8 +27,15 @@ def mode_to_tuner_call_type(mode: str | None) -> str:
return "phone_call"
class TunerCollector(BaseObserver):
"""Collect runtime call metadata and build a deferred Tuner payload."""
class DeferredTunerObserver(Observer):
"""SDK ``Observer`` that builds the Tuner payload from the live frame stream but
defers delivery to the completion phase instead of POSTing on call end.
The SDK ``Observer`` normally fire-and-forgets ``post_call`` when the call ends.
Dograh instead snapshots the payload into ``workflow_run.logs`` and delivers it
later (``completion.py``) once per tuner node with that node's BYOK credentials,
after injecting the real ``recording_url`` and a locally-computed ``call_cost``.
"""
def __init__(
self,
@ -66,126 +46,33 @@ class TunerCollector(BaseObserver):
llm_model: str = "",
tts_model: str = "",
agent_version: int | None = None,
max_frames: int = 500,
) -> None:
super().__init__()
self._call_id = str(workflow_run_id)
self._call_type = call_type
self._asr_model = asr_model
self._llm_model = llm_model
self._tts_model = tts_model
self._agent_version = agent_version
self._acc = CallAccumulator()
self._acc.call_start_abs_ns = time.time_ns()
self._pipeline_start_rel_ns: int | None = None
self._context_provider: Callable[[], list[dict[str, Any]]] | None = None
self._processed_frames: set[int] = set()
self._frame_history: deque[int] = deque(maxlen=max_frames)
super().__init__(
api_key=_DEFERRED_API_KEY,
workspace_id=_DEFERRED_WORKSPACE_ID,
agent_id=_DEFERRED_AGENT_ID,
call_id=str(workflow_run_id),
call_type=call_type,
recording_url=TUNER_RECORDING_PLACEHOLDER,
asr_model=asr_model,
llm_model=llm_model,
tts_model=tts_model,
agent_version=agent_version,
)
def attach_context(self, provider: Callable[[], list[dict[str, Any]]]) -> None:
self._context_provider = provider
async def _flush(self) -> None:
# Suppress the SDK's runtime post_call; delivery is deferred (see class docstring).
return None
def set_disconnection_reason(self, reason: str | None) -> None:
if reason:
self._acc.set_disconnection_reason(reason)
def attach_turn_tracking_observer(
self, turn_tracker: TurnTrackingObserver | None
) -> None:
if turn_tracker is None:
return
@turn_tracker.event_handler("on_turn_started")
async def _on_turn_started(_tracker: Any, turn_number: int) -> None:
self._acc.on_turn_started(turn_number, time.time_ns())
@turn_tracker.event_handler("on_turn_ended")
async def _on_turn_ended(
_tracker: Any, turn_number: int, _duration: float, was_interrupted: bool
) -> None:
self._acc.on_turn_ended(turn_number, was_interrupted)
def attach_latency_observer(
self, latency_observer: UserBotLatencyObserver | None
) -> None:
if latency_observer is None:
return
@latency_observer.event_handler("on_latency_measured")
async def _on_latency_measured(_observer: Any, latency: float) -> None:
self._acc.on_latency_measured(latency)
@latency_observer.event_handler("on_latency_breakdown")
async def _on_latency_breakdown(_observer: Any, breakdown: Any) -> None:
self._acc.on_latency_breakdown(breakdown)
async def on_push_frame(self, data: FramePushed):
if data.direction != FrameDirection.DOWNSTREAM:
return
if data.frame.id in self._processed_frames:
return
self._processed_frames.add(data.frame.id)
self._frame_history.append(data.frame.id)
if len(self._processed_frames) > len(self._frame_history):
self._processed_frames = set(self._frame_history)
frame = data.frame
# data.timestamp is a pipeline-relative clock (ns since pipeline start).
# Convert to absolute ns so the accumulator's _rel_ms() works correctly.
if self._pipeline_start_rel_ns is None:
self._pipeline_start_rel_ns = data.timestamp
timestamp_ns = self._acc.call_start_abs_ns + (
data.timestamp - self._pipeline_start_rel_ns
)
if isinstance(frame, StartFrame):
self._acc.on_start(timestamp_ns)
elif isinstance(frame, FunctionCallInProgressFrame):
self._acc.on_function_call_in_progress(frame, timestamp_ns)
elif isinstance(frame, FunctionCallResultFrame):
self._acc.on_function_call_result(frame.tool_call_id, timestamp_ns)
elif isinstance(frame, MetricsFrame):
self._acc.on_metrics_frame(frame)
elif isinstance(frame, UserStartedSpeakingFrame):
self._acc.on_user_started_speaking(timestamp_ns)
elif isinstance(frame, UserStoppedSpeakingFrame):
self._acc.on_user_stopped_speaking(timestamp_ns)
self._acc.on_user_turn_stopped(timestamp_ns)
elif isinstance(frame, BotStartedSpeakingFrame):
self._acc.on_bot_started_speaking(timestamp_ns)
elif isinstance(frame, BotStoppedSpeakingFrame):
self._acc.on_bot_stopped(timestamp_ns)
elif isinstance(frame, VADUserStoppedSpeakingFrame):
self._acc.on_vad_stopped(timestamp_ns)
elif isinstance(frame, (CancelFrame, EndFrame)):
self._acc.on_call_end(timestamp_ns)
def build_payload_snapshot(
self,
*,
recording_url: str = TUNER_RECORDING_PLACEHOLDER,
) -> dict[str, Any] | None:
if self._context_provider is None:
logger.warning(
"[tuner] no context provider attached; skipping payload snapshot"
)
return None
transcript = strip_thought_ids_from_messages(list(self._context_provider()))
payload = build_payload(
self._acc,
_PayloadConfig(
call_id=self._call_id,
call_type=self._call_type,
recording_url=recording_url,
asr_model=self._asr_model,
llm_model=self._llm_model,
tts_model=self._tts_model,
agent_version=self._agent_version,
),
transcript,
)
self._config.recording_url = recording_url
payload = self._acc.build_payload(self._config, None)
return payload.to_dict()

View file

@ -11,6 +11,7 @@ from api.services.integrations.base import IntegrationCompletionContext
from .client import TunerDeliveryConfig, post_call
from .collector import TUNER_RECORDING_PLACEHOLDER
from .cost import compute_call_cost_cents
from .node import TunerNodeData
@ -55,6 +56,14 @@ async def run_completion(
payload = copy.deepcopy(payload_snapshot)
payload["recording_url"] = recording_url
call_cost = compute_call_cost_cents(
tuner_data,
context.workflow_run.usage_info,
transcript_segments=payload.get("transcript_with_tool_calls"),
)
if call_cost is not None:
payload["call_cost"] = call_cost
try:
config = TunerDeliveryConfig(
base_url=TUNER_BASE_URL,
@ -67,6 +76,7 @@ async def run_completion(
**delivery,
"workspace_id": tuner_data.tuner_workspace_id,
"agent_id": tuner_data.tuner_agent_id,
**({"call_cost": call_cost} if call_cost is not None else {}),
"exported_at": datetime.now(UTC).isoformat(),
}
except Exception as exc:

View file

@ -0,0 +1,131 @@
"""Per-call cost computation for the Tuner export.
Dograh no longer rates calls locally, so when a user wants Tuner to show a
cost they provide their own per-unit prices on the Tuner node (the "bring your
own keys" model). This module turns those rates plus the call's measured usage
(`workflow_run.usage_info`) into a single `call_cost` value in cents, which is
what Tuner's public API stores.
Rates are optional: a blank rate contributes nothing. Usage metrics come from
the pipeline aggregator and are reliable for LLM tokens and TTS characters.
STT seconds are not measured, so the STT and telephony rates are applied
per-minute against the call's wall-clock duration (an approximation).
"""
from __future__ import annotations
from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
from .node import TunerNodeData
def _sum_llm_tokens(usage_info: dict[str, Any]) -> tuple[int, int, int]:
"""Sum prompt, completion, and cached-input tokens across all llm entries.
Cached-input tokens (``cache_read_input_tokens``) are reported as a discounted
subset of ``prompt_tokens`` (OpenAI convention), not in addition to it.
"""
prompt_tokens = 0
completion_tokens = 0
cached_tokens = 0
for entry in (usage_info.get("llm") or {}).values():
if isinstance(entry, dict):
prompt_tokens += entry.get("prompt_tokens") or 0
completion_tokens += entry.get("completion_tokens") or 0
cached_tokens += entry.get("cache_read_input_tokens") or 0
return prompt_tokens, completion_tokens, cached_tokens
def _sum_tts_characters(usage_info: dict[str, Any]) -> int:
"""Sum TTS characters across every tts processor/model entry."""
total = 0
for value in (usage_info.get("tts") or {}).values():
if isinstance(value, (int, float)):
total += value
return int(total)
# Transcript roles that represent bot-spoken text sent to TTS. Excludes
# "user" (STT input) and "agent_function"/"agent_result" (tool calls).
_SPOKEN_ROLES = {"agent", "assistant", "bot"}
def _count_transcript_tts_characters(
transcript_segments: list[dict[str, Any]] | None,
) -> int:
"""Count characters of bot-spoken transcript turns (TTS proxy).
Used when the pipeline did not measure TTS characters directly (e.g. the
Deepgram websocket TTS service does not emit usage metrics). The spoken
transcript text closely matches what was sent to the TTS engine.
"""
if not transcript_segments:
return 0
total = 0
for segment in transcript_segments:
if isinstance(segment, dict) and segment.get("role") in _SPOKEN_ROLES:
total += len(segment.get("text") or "")
return total
def compute_call_cost_cents(
tuner_data: "TunerNodeData",
usage_info: dict[str, Any] | None,
transcript_segments: list[dict[str, Any]] | None = None,
) -> float | None:
"""Compute the call cost in cents from node rates and measured usage.
Returns ``None`` when cost calculation is disabled or no rates are
configured, so the caller can omit ``call_cost`` from the payload entirely
rather than report a misleading zero.
"""
if not tuner_data.cost_calculation_enabled:
return None
raw_rates = (
tuner_data.cost_llm_input_rate,
tuner_data.cost_llm_cached_input_rate,
tuner_data.cost_llm_output_rate,
tuner_data.cost_tts_rate,
tuner_data.cost_stt_rate,
tuner_data.cost_telephony_rate,
)
if all(rate is None for rate in raw_rates):
return None
usage_info = usage_info or {}
prompt_tokens, completion_tokens, cached_tokens = _sum_llm_tokens(usage_info)
# Prefer the pipeline-measured TTS characters; fall back to the spoken
# transcript when the TTS service did not report usage (e.g. Deepgram websocket).
tts_characters = _sum_tts_characters(usage_info)
if tts_characters == 0:
tts_characters = _count_transcript_tts_characters(transcript_segments)
duration_minutes = (usage_info.get("call_duration_seconds") or 0) / 60.0
llm_input_rate = tuner_data.cost_llm_input_rate or 0.0
cached_input_rate = tuner_data.cost_llm_cached_input_rate
llm_output_rate = tuner_data.cost_llm_output_rate or 0.0
tts_rate = tuner_data.cost_tts_rate or 0.0
stt_rate = tuner_data.cost_stt_rate or 0.0
telephony_rate = tuner_data.cost_telephony_rate or 0.0
# Cached tokens are a discounted subset of prompt tokens. Only split them out
# when a cached rate is configured; otherwise bill all prompt tokens normally.
if cached_input_rate is not None:
uncached_prompt_tokens = max(prompt_tokens - cached_tokens, 0)
llm_input_usd = (
uncached_prompt_tokens * llm_input_rate + cached_tokens * cached_input_rate
) / 1_000_000
else:
llm_input_usd = prompt_tokens * llm_input_rate / 1_000_000
cost_usd = (
llm_input_usd
+ completion_tokens * llm_output_rate / 1_000_000
+ tts_characters * tts_rate / 1_000
+ duration_minutes * stt_rate
+ duration_minutes * telephony_rate
)
return round(cost_usd * 100, 4)

View file

@ -5,9 +5,13 @@ from pydantic import model_validator
from api.services.integrations.base import IntegrationNodeRegistration
from api.services.workflow.node_data import BaseNodeData
from api.services.workflow.node_specs._base import (
DisplayOptions,
GraphConstraints,
NodeCategory,
NodeExample,
NumberInputOptions,
PropertyLayoutOptions,
PropertyRendererOptions,
PropertyType,
)
from api.services.workflow.node_specs.model_spec import (
@ -16,6 +20,13 @@ from api.services.workflow.node_specs.model_spec import (
spec_field,
)
# Cost rate fields are only shown once the user turns on cost calculation.
_COST_FIELDS_VISIBLE = DisplayOptions(show={"cost_calculation_enabled": [True]})
_COST_RATE_RENDERER_OPTIONS = PropertyRendererOptions(
layout=PropertyLayoutOptions(column_span=6),
number_input=NumberInputOptions(fractional=True),
)
@node_spec(
name="tuner",
@ -48,6 +59,13 @@ from api.services.workflow.node_specs.model_spec import (
"tuner_agent_id",
"tuner_workspace_id",
"tuner_api_key",
"cost_calculation_enabled",
"cost_llm_input_rate",
"cost_llm_cached_input_rate",
"cost_llm_output_rate",
"cost_tts_rate",
"cost_stt_rate",
"cost_telephony_rate",
),
field_overrides={
"name": {
@ -103,6 +121,73 @@ class TunerNodeData(BaseNodeData):
description="Bearer token used when posting completed calls to Tuner.",
)
cost_calculation_enabled: bool = spec_field(
default=False,
ui_type=PropertyType.boolean,
display_name="Calculate cost",
description="Send a per-call cost to Tuner, computed from your own provider rates (BYOK). All rates below are optional.",
)
cost_llm_input_rate: float | None = spec_field(
default=None,
ge=0,
le=1000,
ui_type=PropertyType.number,
display_name="LLM input",
description="USD per 1M tokens",
display_options=_COST_FIELDS_VISIBLE,
renderer_options=_COST_RATE_RENDERER_OPTIONS,
)
cost_llm_cached_input_rate: float | None = spec_field(
default=None,
ge=0,
le=1000,
ui_type=PropertyType.number,
display_name="LLM cached input",
description="USD per 1M cached tokens",
display_options=_COST_FIELDS_VISIBLE,
renderer_options=_COST_RATE_RENDERER_OPTIONS,
)
cost_llm_output_rate: float | None = spec_field(
default=None,
ge=0,
le=1000,
ui_type=PropertyType.number,
display_name="LLM output",
description="USD per 1M tokens",
display_options=_COST_FIELDS_VISIBLE,
renderer_options=_COST_RATE_RENDERER_OPTIONS,
)
cost_tts_rate: float | None = spec_field(
default=None,
ge=0,
le=100,
ui_type=PropertyType.number,
display_name="TTS",
description="USD per 1K characters",
display_options=_COST_FIELDS_VISIBLE,
renderer_options=_COST_RATE_RENDERER_OPTIONS,
)
cost_stt_rate: float | None = spec_field(
default=None,
ge=0,
le=100,
ui_type=PropertyType.number,
display_name="STT",
description="USD per minute",
display_options=_COST_FIELDS_VISIBLE,
renderer_options=_COST_RATE_RENDERER_OPTIONS,
)
cost_telephony_rate: float | None = spec_field(
default=None,
ge=0,
le=100,
ui_type=PropertyType.number,
display_name="Telephony",
description="USD per minute",
display_options=_COST_FIELDS_VISIBLE,
renderer_options=_COST_RATE_RENDERER_OPTIONS,
)
@model_validator(mode="after")
def _validate_enabled_config(self):
if not self.tuner_enabled:

View file

@ -8,7 +8,7 @@ from api.services.integrations.base import (
IntegrationRuntimeSession,
)
from .collector import TunerCollector, mode_to_tuner_call_type
from .collector import DeferredTunerObserver, mode_to_tuner_call_type
def _format_model_label(provider: str | None, model: str | None) -> str:
@ -53,23 +53,25 @@ def _resolve_model_labels(context: IntegrationRuntimeContext) -> tuple[str, str,
class TunerRuntimeSession(IntegrationRuntimeSession):
name = "tuner"
def __init__(self, collector: TunerCollector) -> None:
self._collector = collector
def __init__(self, observer: DeferredTunerObserver) -> None:
self._observer = observer
def attach(self, task: Any) -> None:
self._collector.attach_turn_tracking_observer(task.turn_tracking_observer)
self._collector.attach_latency_observer(task.user_bot_latency_observer)
task.add_observer(self._collector)
self._observer.attach_turn_tracking_observer(task.turn_tracking_observer)
task.add_observer(self._observer)
# The SDK Observer wires latency into the accumulator via its own latency
# observer, which must itself be registered to receive frames.
task.add_observer(self._observer.latency_observer)
async def on_call_finished(
self,
*,
gathered_context: dict[str, Any],
) -> dict[str, Any] | None:
self._collector.set_disconnection_reason(
self._observer.set_disconnection_reason(
gathered_context.get("call_disposition")
)
payload = self._collector.build_payload_snapshot()
payload = self._observer.build_payload_snapshot()
if payload is None:
return None
return {"tuner_payload": payload}
@ -88,7 +90,7 @@ def create_runtime_sessions(
asr_model, llm_model, tts_model = _resolve_model_labels(context)
collector = TunerCollector(
observer = DeferredTunerObserver(
workflow_run_id=context.workflow_run_id,
call_type=mode_to_tuner_call_type(context.workflow_run.mode),
asr_model=asr_model,
@ -96,6 +98,5 @@ def create_runtime_sessions(
tts_model=tts_model,
agent_version=getattr(context.run_definition, "version_number", None),
)
collector.attach_context(context.context_messages_provider)
return [TunerRuntimeSession(collector)]
return [TunerRuntimeSession(observer)]

View file

@ -14,7 +14,10 @@ from api.services.workflow.node_specs._base import (
NodeCategory,
NodeExample,
NodeSpec,
NumberInputOptions,
PropertyLayoutOptions,
PropertyOption,
PropertyRendererOptions,
PropertySpec,
PropertyType,
evaluate_display_options,
@ -65,7 +68,10 @@ __all__ = [
"NodeCategory",
"NodeExample",
"NodeSpec",
"NumberInputOptions",
"PropertyLayoutOptions",
"PropertyOption",
"PropertyRendererOptions",
"PropertySpec",
"PropertyType",
"all_specs",

View file

@ -133,6 +133,42 @@ class PropertyOption(BaseModel):
return out
class PropertyLayoutOptions(BaseModel):
"""Renderer layout hints for a property in the node editor."""
column_span: Optional[int] = Field(
default=None,
ge=1,
le=12,
description="Number of columns to occupy in the editor's 12-column grid.",
)
model_config = ConfigDict(extra="forbid")
class NumberInputOptions(BaseModel):
"""Renderer hints for numeric inputs."""
fractional: bool = Field(
default=False,
description="Allow arbitrary fractional values via step='any'.",
)
model_config = ConfigDict(extra="forbid")
class PropertyRendererOptions(BaseModel):
"""Typed renderer metadata for node properties.
Add new renderer behavior here instead of using free-form property metadata.
"""
layout: Optional[PropertyLayoutOptions] = None
number_input: Optional[NumberInputOptions] = None
model_config = ConfigDict(extra="forbid")
class PropertySpec(BaseModel):
"""Single field on a node.
@ -180,8 +216,9 @@ class PropertySpec(BaseModel):
# Renderer hint, e.g. "textarea" vs single-line for `string`.
editor: Optional[str] = None
# Free-form metadata for renderer-specific behavior. Use sparingly.
extra: dict[str, Any] = Field(default_factory=dict)
# Typed metadata for renderer-specific behavior. Extend
# `PropertyRendererOptions` when the renderer needs a new hint.
renderer_options: Optional[PropertyRendererOptions] = None
model_config = ConfigDict(extra="forbid")
@ -192,7 +229,7 @@ class PropertySpec(BaseModel):
description, llm_hint, requiredness, default, enum options, nested
row properties, and validation bounds. UI-rendering concerns
(`display_name`, `placeholder`, `display_options`, `editor`,
`extra`) and null/empty fields are omitted they're noise in the
`renderer_options`) and null/empty fields are omitted they're noise in the
model's context and never appear in authored SDK code.
"""
out: dict[str, Any] = {

View file

@ -16,6 +16,7 @@ from api.services.workflow.node_specs._base import (
NodeExample,
NodeSpec,
PropertyOption,
PropertyRendererOptions,
PropertySpec,
PropertyType,
)
@ -50,7 +51,7 @@ def spec_field(
display_options: DisplayOptions | None = None,
options: list[PropertyOption] | None = None,
editor: str | None = None,
extra: dict[str, Any] | None = None,
renderer_options: PropertyRendererOptions | None = None,
spec_exclude: bool = False,
min_value: float | None = None,
max_value: float | None = None,
@ -69,7 +70,7 @@ def spec_field(
"display_options": display_options,
"options": options,
"editor": editor,
"extra": extra or {},
"renderer_options": renderer_options,
"spec_exclude": spec_exclude,
"min_value": min_value,
"max_value": max_value,
@ -206,7 +207,7 @@ def _build_property_spec(
max_length=max_length,
pattern=pattern,
editor=meta.get("editor"),
extra=meta.get("extra") or {},
renderer_options=meta.get("renderer_options"),
)

View file

@ -13,6 +13,7 @@ from __future__ import annotations
import re
import pytest
from pydantic import ValidationError
from api.services.workflow.dto import (
ReactFlowDTO,
@ -22,6 +23,7 @@ from api.services.workflow.dto import (
from api.services.workflow.node_data import BaseNodeData
from api.services.workflow.node_specs import (
NodeSpec,
PropertyRendererOptions,
PropertySpec,
PropertyType,
all_specs,
@ -296,6 +298,13 @@ def test_all_registered_node_models_inherit_base_node_data():
"tuner_agent_id",
"tuner_workspace_id",
"tuner_api_key",
"cost_calculation_enabled",
"cost_llm_input_rate",
"cost_llm_cached_input_rate",
"cost_llm_output_rate",
"cost_tts_rate",
"cost_stt_rate",
"cost_telephony_rate",
],
),
],
@ -305,6 +314,33 @@ def test_node_spec_property_order_stable(spec_name: str, expected_order: list[st
assert [prop.name for prop in spec.properties] == expected_order
def test_tuner_cost_rate_fields_use_typed_renderer_options():
spec = next(spec for spec in all_specs() if spec.name == "tuner")
cost_rate_props = [
prop
for prop in spec.properties
if prop.name.startswith("cost_") and prop.name.endswith("_rate")
]
assert len(cost_rate_props) == 6
assert all(prop.renderer_options is not None for prop in cost_rate_props)
assert all(
prop.renderer_options.layout is not None
and prop.renderer_options.layout.column_span == 6
for prop in cost_rate_props
)
assert all(
prop.renderer_options.number_input is not None
and prop.renderer_options.number_input.fractional is True
for prop in cost_rate_props
)
def test_property_renderer_options_reject_unknown_hints():
with pytest.raises(ValidationError):
PropertyRendererOptions.model_validate({"layout": {"width": "half"}})
# ─────────────────────────────────────────────────────────────────────────
# `to_mcp_dict` projection — the lean view served by the `get_node_type`
# MCP tool. UI-only metadata is dropped so it doesn't poison LLM context;
@ -322,7 +358,7 @@ _UI_ONLY_KEYS = frozenset(
"placeholder",
"display_options",
"editor",
"extra",
"renderer_options",
"label", # PropertyOption display string
}
)

File diff suppressed because one or more lines are too long

View file

@ -3711,6 +3711,20 @@ export type NodeTypesResponse = {
node_types: Array<NodeSpec>;
};
/**
* NumberInputOptions
*
* Renderer hints for numeric inputs.
*/
export type NumberInputOptions = {
/**
* Fractional
*
* Allow arbitrary fractional values via step='any'.
*/
fractional?: boolean;
};
/**
* OnboardingState
*
@ -4353,6 +4367,20 @@ export type ProcessDocumentRequestSchema = {
retrieval_mode?: string;
};
/**
* PropertyLayoutOptions
*
* Renderer layout hints for a property in the node editor.
*/
export type PropertyLayoutOptions = {
/**
* Column Span
*
* Number of columns to occupy in the editor's 12-column grid.
*/
column_span?: number | null;
};
/**
* PropertyOption
*
@ -4373,6 +4401,18 @@ export type PropertyOption = {
description?: string | null;
};
/**
* PropertyRendererOptions
*
* Typed renderer metadata for node properties.
*
* Add new renderer behavior here instead of using free-form property metadata.
*/
export type PropertyRendererOptions = {
layout?: PropertyLayoutOptions | null;
number_input?: NumberInputOptions | null;
};
/**
* PropertySpec
*
@ -4454,12 +4494,7 @@ export type PropertySpec = {
* Editor
*/
editor?: string | null;
/**
* Extra
*/
extra?: {
[key: string]: unknown;
};
renderer_options?: PropertyRendererOptions | null;
};
/**

View file

@ -4,6 +4,7 @@ import type { NodeSpec } from "@/client/types.gen";
import { evaluateDisplayOptions } from "./displayOptions";
import { PropertyInput, type RendererContext } from "./PropertyInput";
import { getPropertyColumnSpan } from "./propertyRendererOptions";
export interface NodeEditFormProps {
spec: NodeSpec;
@ -13,6 +14,21 @@ export interface NodeEditFormProps {
context: RendererContext;
}
const COLUMN_SPAN_CLASS: Record<number, string> = {
1: "sm:col-span-1",
2: "sm:col-span-2",
3: "sm:col-span-3",
4: "sm:col-span-4",
5: "sm:col-span-5",
6: "sm:col-span-6",
7: "sm:col-span-7",
8: "sm:col-span-8",
9: "sm:col-span-9",
10: "sm:col-span-10",
11: "sm:col-span-11",
12: "sm:col-span-12",
};
/**
* Generic node-edit form. Walks `spec.properties` once, evaluates each
* property's `display_options` against current values, and renders the
@ -31,18 +47,25 @@ export function NodeEditForm({ spec, values, onChange, context }: NodeEditFormPr
);
return (
<div className="grid gap-3">
<div className="grid grid-cols-12 gap-3">
{spec.properties
.filter((p) => evaluateDisplayOptions(p.display_options, values))
.map((p) => (
<PropertyInput
key={p.name}
spec={p}
value={values[p.name]}
onChange={(v) => setProp(p.name, v)}
context={context}
/>
))}
.map((p) => {
const columnSpan = getPropertyColumnSpan(p.renderer_options);
return (
<div
key={p.name}
className={`col-span-12 ${COLUMN_SPAN_CLASS[columnSpan]}`}
>
<PropertyInput
spec={p}
value={values[p.name]}
onChange={(v) => setProp(p.name, v)}
context={context}
/>
</div>
);
})}
</div>
);
}

View file

@ -18,6 +18,10 @@ import { Switch } from "@/components/ui/switch";
import { Textarea } from "@/components/ui/textarea";
import { evaluateDisplayOptions } from "./displayOptions";
import {
getPropertyColumnSpan,
isFractionalNumberInput,
} from "./propertyRendererOptions";
export interface RendererContext {
tools: ToolResponse[];
@ -175,13 +179,21 @@ function StringWidget({ spec, value, onChange }: WidgetProps) {
function NumberWidget({ spec, value, onChange }: WidgetProps) {
const v = (value as number | undefined) ?? "";
const isCompact = getPropertyColumnSpan(spec.renderer_options) < 12;
const isFractional = isFractionalNumberInput(spec.renderer_options);
return (
<div className="grid gap-2">
<StackedLabel spec={spec} />
<Input
type="number"
value={v as number | string}
step={spec.min_value && spec.min_value < 1 ? 0.1 : 1}
step={
isFractional
? "any"
: spec.min_value && spec.min_value < 1
? 0.1
: 1
}
min={spec.min_value ?? undefined}
max={spec.max_value ?? undefined}
onChange={(e) => {
@ -189,7 +201,7 @@ function NumberWidget({ spec, value, onChange }: WidgetProps) {
onChange(next === "" ? undefined : parseFloat(next));
}}
placeholder={spec.placeholder ?? undefined}
className="w-32"
className={isCompact ? "w-full" : "w-32"}
/>
</div>
);

View file

@ -0,0 +1,17 @@
import type { PropertyRendererOptions } from "@/client/types.gen";
export function getPropertyColumnSpan(
rendererOptions: PropertyRendererOptions | null | undefined,
): number {
const value = rendererOptions?.layout?.column_span;
if (typeof value !== "number" || !Number.isFinite(value)) {
return 12;
}
return Math.min(Math.max(Math.trunc(value), 1), 12);
}
export function isFractionalNumberInput(
rendererOptions: PropertyRendererOptions | null | undefined,
): boolean {
return rendererOptions?.number_input?.fractional === true;
}