mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-07 07:55:16 +02:00
405 lines
12 KiB
Python
405 lines
12 KiB
Python
|
|
from __future__ import annotations
|
||
|
|
|
||
|
|
from dataclasses import dataclass
|
||
|
|
from dataclasses import field as dataclass_field
|
||
|
|
from enum import Enum
|
||
|
|
from types import NoneType
|
||
|
|
from typing import Any, Callable, Literal, get_args, get_origin
|
||
|
|
|
||
|
|
from pydantic import BaseModel, Field
|
||
|
|
from pydantic.fields import FieldInfo, PydanticUndefined
|
||
|
|
|
||
|
|
from api.services.workflow.node_specs._base import (
|
||
|
|
DisplayOptions,
|
||
|
|
GraphConstraints,
|
||
|
|
NodeCategory,
|
||
|
|
NodeExample,
|
||
|
|
NodeSpec,
|
||
|
|
PropertyOption,
|
||
|
|
PropertySpec,
|
||
|
|
PropertyType,
|
||
|
|
)
|
||
|
|
|
||
|
|
_SPEC_FIELD_META_KEY = "__dograh_spec_field__"
|
||
|
|
_UNSET = object()
|
||
|
|
|
||
|
|
|
||
|
|
@dataclass(frozen=True)
|
||
|
|
class NodeSpecMetadata:
|
||
|
|
name: str
|
||
|
|
display_name: str
|
||
|
|
description: str
|
||
|
|
category: NodeCategory
|
||
|
|
icon: str
|
||
|
|
llm_hint: str | None = None
|
||
|
|
version: str = "1.0.0"
|
||
|
|
examples: tuple[NodeExample, ...] = ()
|
||
|
|
graph_constraints: GraphConstraints | None = None
|
||
|
|
property_order: tuple[str, ...] = ()
|
||
|
|
field_overrides: dict[str, dict[str, Any]] = dataclass_field(default_factory=dict)
|
||
|
|
|
||
|
|
|
||
|
|
def spec_field(
|
||
|
|
*field_args: Any,
|
||
|
|
ui_type: PropertyType | str | None = None,
|
||
|
|
display_name: str | None = None,
|
||
|
|
llm_hint: str | None = None,
|
||
|
|
required: bool | None = None,
|
||
|
|
spec_default: Any = _UNSET,
|
||
|
|
placeholder: str | None = None,
|
||
|
|
display_options: DisplayOptions | None = None,
|
||
|
|
options: list[PropertyOption] | None = None,
|
||
|
|
editor: str | None = None,
|
||
|
|
extra: dict[str, Any] | None = None,
|
||
|
|
spec_exclude: bool = False,
|
||
|
|
min_value: float | None = None,
|
||
|
|
max_value: float | None = None,
|
||
|
|
min_length: int | None = None,
|
||
|
|
max_length: int | None = None,
|
||
|
|
pattern: str | None = None,
|
||
|
|
**field_kwargs: Any,
|
||
|
|
):
|
||
|
|
json_schema_extra = dict(field_kwargs.pop("json_schema_extra", {}) or {})
|
||
|
|
json_schema_extra[_SPEC_FIELD_META_KEY] = {
|
||
|
|
"ui_type": ui_type.value if isinstance(ui_type, PropertyType) else ui_type,
|
||
|
|
"display_name": display_name,
|
||
|
|
"llm_hint": llm_hint,
|
||
|
|
"required": required,
|
||
|
|
"placeholder": placeholder,
|
||
|
|
"display_options": display_options,
|
||
|
|
"options": options,
|
||
|
|
"editor": editor,
|
||
|
|
"extra": extra or {},
|
||
|
|
"spec_exclude": spec_exclude,
|
||
|
|
"min_value": min_value,
|
||
|
|
"max_value": max_value,
|
||
|
|
"min_length": min_length,
|
||
|
|
"max_length": max_length,
|
||
|
|
"pattern": pattern,
|
||
|
|
}
|
||
|
|
if spec_default is not _UNSET:
|
||
|
|
json_schema_extra[_SPEC_FIELD_META_KEY]["spec_default"] = spec_default
|
||
|
|
return Field(*field_args, json_schema_extra=json_schema_extra, **field_kwargs)
|
||
|
|
|
||
|
|
|
||
|
|
def node_spec(
|
||
|
|
*,
|
||
|
|
name: str,
|
||
|
|
display_name: str,
|
||
|
|
description: str,
|
||
|
|
category: NodeCategory,
|
||
|
|
icon: str,
|
||
|
|
llm_hint: str | None = None,
|
||
|
|
version: str = "1.0.0",
|
||
|
|
examples: list[NodeExample] | tuple[NodeExample, ...] = (),
|
||
|
|
graph_constraints: GraphConstraints | None = None,
|
||
|
|
property_order: list[str] | tuple[str, ...] = (),
|
||
|
|
field_overrides: dict[str, dict[str, Any]] | None = None,
|
||
|
|
) -> Callable[[type[BaseModel]], type[BaseModel]]:
|
||
|
|
metadata = NodeSpecMetadata(
|
||
|
|
name=name,
|
||
|
|
display_name=display_name,
|
||
|
|
description=description,
|
||
|
|
category=category,
|
||
|
|
icon=icon,
|
||
|
|
llm_hint=llm_hint,
|
||
|
|
version=version,
|
||
|
|
examples=tuple(examples),
|
||
|
|
graph_constraints=graph_constraints,
|
||
|
|
property_order=tuple(property_order),
|
||
|
|
field_overrides=field_overrides or {},
|
||
|
|
)
|
||
|
|
|
||
|
|
def decorator(model_cls: type[BaseModel]) -> type[BaseModel]:
|
||
|
|
setattr(model_cls, "__node_spec_metadata__", metadata)
|
||
|
|
return model_cls
|
||
|
|
|
||
|
|
return decorator
|
||
|
|
|
||
|
|
|
||
|
|
def build_spec(model_cls: type[BaseModel]) -> NodeSpec:
|
||
|
|
metadata: NodeSpecMetadata | None = getattr(
|
||
|
|
model_cls, "__node_spec_metadata__", None
|
||
|
|
)
|
||
|
|
if metadata is None:
|
||
|
|
raise ValueError(f"{model_cls.__name__} is missing __node_spec_metadata__")
|
||
|
|
|
||
|
|
properties: list[PropertySpec] = []
|
||
|
|
for name, field in model_cls.model_fields.items():
|
||
|
|
prop = _build_property_spec(model_cls, name, field)
|
||
|
|
if prop is not None:
|
||
|
|
properties.append(prop)
|
||
|
|
properties = _sort_properties(metadata.name, properties, metadata.property_order)
|
||
|
|
|
||
|
|
return NodeSpec(
|
||
|
|
name=metadata.name,
|
||
|
|
display_name=metadata.display_name,
|
||
|
|
description=metadata.description,
|
||
|
|
llm_hint=metadata.llm_hint,
|
||
|
|
category=metadata.category,
|
||
|
|
icon=metadata.icon,
|
||
|
|
version=metadata.version,
|
||
|
|
properties=properties,
|
||
|
|
examples=list(metadata.examples),
|
||
|
|
graph_constraints=metadata.graph_constraints,
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
def _sort_properties(
|
||
|
|
spec_name: str,
|
||
|
|
properties: list[PropertySpec],
|
||
|
|
property_order: tuple[str, ...],
|
||
|
|
) -> list[PropertySpec]:
|
||
|
|
if not property_order:
|
||
|
|
return properties
|
||
|
|
|
||
|
|
property_names = {prop.name for prop in properties}
|
||
|
|
missing = [name for name in property_order if name not in property_names]
|
||
|
|
if missing:
|
||
|
|
raise ValueError(
|
||
|
|
f"{spec_name}: property_order references unknown properties: {missing}"
|
||
|
|
)
|
||
|
|
|
||
|
|
order_map = {name: idx for idx, name in enumerate(property_order)}
|
||
|
|
ordered = sorted(
|
||
|
|
enumerate(properties),
|
||
|
|
key=lambda item: (order_map.get(item[1].name, len(order_map)), item[0]),
|
||
|
|
)
|
||
|
|
return [prop for _, prop in ordered]
|
||
|
|
|
||
|
|
|
||
|
|
def _build_property_spec(
|
||
|
|
owner_cls: type[BaseModel],
|
||
|
|
field_name: str,
|
||
|
|
field: FieldInfo,
|
||
|
|
) -> PropertySpec | None:
|
||
|
|
meta = _merged_field_meta(owner_cls, field_name, field)
|
||
|
|
if meta.get("spec_exclude"):
|
||
|
|
return None
|
||
|
|
|
||
|
|
prop_type = _resolve_property_type(field.annotation, meta)
|
||
|
|
nested_properties = _resolve_nested_properties(field.annotation, prop_type)
|
||
|
|
options = _resolve_options(field.annotation, meta, prop_type)
|
||
|
|
min_value, max_value, min_length, max_length, pattern = _resolve_constraints(
|
||
|
|
field, meta
|
||
|
|
)
|
||
|
|
|
||
|
|
description = meta.get("description") or field.description
|
||
|
|
if not description:
|
||
|
|
raise ValueError(f"{owner_cls.__name__}.{field_name} is missing a description")
|
||
|
|
|
||
|
|
return PropertySpec(
|
||
|
|
name=field_name,
|
||
|
|
type=prop_type,
|
||
|
|
display_name=meta.get("display_name") or _humanize_identifier(field_name),
|
||
|
|
description=description,
|
||
|
|
llm_hint=meta.get("llm_hint"),
|
||
|
|
default=_resolve_default(field, meta),
|
||
|
|
required=_resolve_required(field, meta),
|
||
|
|
placeholder=meta.get("placeholder"),
|
||
|
|
display_options=meta.get("display_options"),
|
||
|
|
options=options,
|
||
|
|
properties=nested_properties,
|
||
|
|
min_value=min_value,
|
||
|
|
max_value=max_value,
|
||
|
|
min_length=min_length,
|
||
|
|
max_length=max_length,
|
||
|
|
pattern=pattern,
|
||
|
|
editor=meta.get("editor"),
|
||
|
|
extra=meta.get("extra") or {},
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
def _merged_field_meta(
|
||
|
|
owner_cls: type[BaseModel],
|
||
|
|
field_name: str,
|
||
|
|
field: FieldInfo,
|
||
|
|
) -> dict[str, Any]:
|
||
|
|
field_meta = {}
|
||
|
|
if isinstance(field.json_schema_extra, dict):
|
||
|
|
field_meta = dict(field.json_schema_extra.get(_SPEC_FIELD_META_KEY, {}) or {})
|
||
|
|
metadata: NodeSpecMetadata | None = getattr(
|
||
|
|
owner_cls, "__node_spec_metadata__", None
|
||
|
|
)
|
||
|
|
override = (
|
||
|
|
dict(metadata.field_overrides.get(field_name, {}) or {})
|
||
|
|
if metadata is not None
|
||
|
|
else {}
|
||
|
|
)
|
||
|
|
merged = dict(field_meta)
|
||
|
|
merged.update(override)
|
||
|
|
return merged
|
||
|
|
|
||
|
|
|
||
|
|
def _resolve_property_type(annotation: Any, meta: dict[str, Any]) -> PropertyType:
|
||
|
|
ui_type = meta.get("ui_type")
|
||
|
|
if ui_type:
|
||
|
|
return PropertyType(ui_type)
|
||
|
|
|
||
|
|
inner = _strip_optional(annotation)
|
||
|
|
origin = get_origin(inner)
|
||
|
|
args = get_args(inner)
|
||
|
|
|
||
|
|
if origin is list:
|
||
|
|
item_type = _strip_optional(args[0]) if args else Any
|
||
|
|
if isinstance(item_type, type) and issubclass(item_type, BaseModel):
|
||
|
|
return PropertyType.fixed_collection
|
||
|
|
raise ValueError(
|
||
|
|
"List-valued fields must declare an explicit ui_type unless they wrap a "
|
||
|
|
f"BaseModel row type (field annotation: {annotation!r})."
|
||
|
|
)
|
||
|
|
|
||
|
|
if _is_enum(inner) or _is_literal(inner):
|
||
|
|
return PropertyType.options
|
||
|
|
|
||
|
|
if inner in (str,):
|
||
|
|
return PropertyType.string
|
||
|
|
if inner in (int, float):
|
||
|
|
return PropertyType.number
|
||
|
|
if inner is bool:
|
||
|
|
return PropertyType.boolean
|
||
|
|
if inner in (dict, Any) or origin is dict:
|
||
|
|
return PropertyType.json
|
||
|
|
|
||
|
|
raise ValueError(f"Unable to derive PropertyType for annotation {annotation!r}")
|
||
|
|
|
||
|
|
|
||
|
|
def _resolve_nested_properties(
|
||
|
|
annotation: Any,
|
||
|
|
prop_type: PropertyType,
|
||
|
|
) -> list[PropertySpec] | None:
|
||
|
|
if prop_type != PropertyType.fixed_collection:
|
||
|
|
return None
|
||
|
|
|
||
|
|
inner = _strip_optional(annotation)
|
||
|
|
args = get_args(inner)
|
||
|
|
if not args:
|
||
|
|
raise ValueError(
|
||
|
|
f"fixed_collection field annotation is missing row type: {annotation!r}"
|
||
|
|
)
|
||
|
|
row_type = _strip_optional(args[0])
|
||
|
|
if not isinstance(row_type, type) or not issubclass(row_type, BaseModel):
|
||
|
|
raise ValueError(
|
||
|
|
f"fixed_collection rows must be BaseModel subclasses: {annotation!r}"
|
||
|
|
)
|
||
|
|
|
||
|
|
properties: list[PropertySpec] = []
|
||
|
|
for field_name, field in row_type.model_fields.items():
|
||
|
|
prop = _build_property_spec(row_type, field_name, field)
|
||
|
|
if prop is not None:
|
||
|
|
properties.append(prop)
|
||
|
|
return properties
|
||
|
|
|
||
|
|
|
||
|
|
def _resolve_options(
|
||
|
|
annotation: Any,
|
||
|
|
meta: dict[str, Any],
|
||
|
|
prop_type: PropertyType,
|
||
|
|
) -> list[PropertyOption] | None:
|
||
|
|
if prop_type not in (PropertyType.options, PropertyType.multi_options):
|
||
|
|
return meta.get("options")
|
||
|
|
|
||
|
|
if meta.get("options"):
|
||
|
|
return meta["options"]
|
||
|
|
|
||
|
|
inner = _strip_optional(annotation)
|
||
|
|
if prop_type == PropertyType.multi_options:
|
||
|
|
inner = _strip_optional(get_args(inner)[0])
|
||
|
|
|
||
|
|
if _is_enum(inner):
|
||
|
|
return [
|
||
|
|
PropertyOption(
|
||
|
|
value=member.value, label=_humanize_option_label(member.value)
|
||
|
|
)
|
||
|
|
for member in inner
|
||
|
|
]
|
||
|
|
if _is_literal(inner):
|
||
|
|
return [
|
||
|
|
PropertyOption(value=value, label=_humanize_option_label(value))
|
||
|
|
for value in get_args(inner)
|
||
|
|
if value is not None
|
||
|
|
]
|
||
|
|
return None
|
||
|
|
|
||
|
|
|
||
|
|
def _resolve_constraints(
|
||
|
|
field: FieldInfo,
|
||
|
|
meta: dict[str, Any],
|
||
|
|
) -> tuple[float | None, float | None, int | None, int | None, str | None]:
|
||
|
|
min_value = meta.get("min_value")
|
||
|
|
max_value = meta.get("max_value")
|
||
|
|
min_length = meta.get("min_length")
|
||
|
|
max_length = meta.get("max_length")
|
||
|
|
pattern = meta.get("pattern")
|
||
|
|
|
||
|
|
for item in field.metadata:
|
||
|
|
if min_value is None:
|
||
|
|
if hasattr(item, "ge") and item.ge is not None:
|
||
|
|
min_value = item.ge
|
||
|
|
elif hasattr(item, "gt") and item.gt is not None:
|
||
|
|
min_value = item.gt
|
||
|
|
if max_value is None:
|
||
|
|
if hasattr(item, "le") and item.le is not None:
|
||
|
|
max_value = item.le
|
||
|
|
elif hasattr(item, "lt") and item.lt is not None:
|
||
|
|
max_value = item.lt
|
||
|
|
if (
|
||
|
|
min_length is None
|
||
|
|
and hasattr(item, "min_length")
|
||
|
|
and item.min_length is not None
|
||
|
|
):
|
||
|
|
min_length = item.min_length
|
||
|
|
if (
|
||
|
|
max_length is None
|
||
|
|
and hasattr(item, "max_length")
|
||
|
|
and item.max_length is not None
|
||
|
|
):
|
||
|
|
max_length = item.max_length
|
||
|
|
if pattern is None and hasattr(item, "pattern") and item.pattern is not None:
|
||
|
|
pattern = item.pattern
|
||
|
|
|
||
|
|
return min_value, max_value, min_length, max_length, pattern
|
||
|
|
|
||
|
|
|
||
|
|
def _resolve_default(field: FieldInfo, meta: dict[str, Any]) -> Any:
|
||
|
|
if "spec_default" in meta:
|
||
|
|
return meta["spec_default"]
|
||
|
|
if field.default is not PydanticUndefined:
|
||
|
|
return field.default
|
||
|
|
return None
|
||
|
|
|
||
|
|
|
||
|
|
def _resolve_required(field: FieldInfo, meta: dict[str, Any]) -> bool:
|
||
|
|
if meta.get("required") is not None:
|
||
|
|
return bool(meta["required"])
|
||
|
|
return bool(field.is_required())
|
||
|
|
|
||
|
|
|
||
|
|
def _strip_optional(annotation: Any) -> Any:
|
||
|
|
origin = get_origin(annotation)
|
||
|
|
if origin is None:
|
||
|
|
return annotation
|
||
|
|
|
||
|
|
args = [arg for arg in get_args(annotation) if arg is not NoneType]
|
||
|
|
if len(args) == 1 and len(args) != len(get_args(annotation)):
|
||
|
|
return args[0]
|
||
|
|
return annotation
|
||
|
|
|
||
|
|
|
||
|
|
def _is_enum(annotation: Any) -> bool:
|
||
|
|
return isinstance(annotation, type) and issubclass(annotation, Enum)
|
||
|
|
|
||
|
|
|
||
|
|
def _is_literal(annotation: Any) -> bool:
|
||
|
|
return get_origin(annotation) is Literal
|
||
|
|
|
||
|
|
|
||
|
|
def _humanize_identifier(name: str) -> str:
|
||
|
|
return name.replace("_", " ").strip().title()
|
||
|
|
|
||
|
|
|
||
|
|
def _humanize_option_label(value: Any) -> str:
|
||
|
|
if isinstance(value, str):
|
||
|
|
return value.replace("_", " ").replace("-", " ").strip().title()
|
||
|
|
return str(value)
|