mirror of
https://github.com/Kaelio/ktx.git
synced 2026-06-07 07:55:13 +02:00
1445 lines
58 KiB
Python
1445 lines
58 KiB
Python
from __future__ import annotations
|
|
|
|
import logging
|
|
import re
|
|
from collections import Counter
|
|
|
|
import sqlglot
|
|
from sqlglot import exp
|
|
|
|
from semantic_layer.graph import JoinGraph
|
|
from semantic_layer.models import (
|
|
ColumnVisibility,
|
|
MeasureDefinition,
|
|
MeasureGroup,
|
|
OrderByClause,
|
|
Provenance,
|
|
QueryDimension,
|
|
ResolvedColumn,
|
|
ResolvedJoin,
|
|
ResolvedMeasure,
|
|
ResolvedPlan,
|
|
SemanticQuery,
|
|
SourceDefinition,
|
|
)
|
|
from semantic_layer.parser import ExpressionParser, quote_reserved_identifiers
|
|
|
|
# DIALECT CONVENTION:
|
|
# User-authored measure `expr`, `filter`, and computed-column fragments must
|
|
# be parsed with `read=self.dialect`. Authors write in the connection's
|
|
# native dialect (per sl_capture); parsing as postgres silently drops
|
|
# dialect-specific tokens. When re-emitting ASTs as strings for later
|
|
# composition, use `sql(dialect=self.dialect)` so dialect-specific
|
|
# functions (e.g. BigQuery `TIMESTAMP_SUB`, Snowflake `DATEADD`) survive.
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class QueryPlanner:
|
|
def __init__(
|
|
self,
|
|
sources: dict[str, SourceDefinition],
|
|
graph: JoinGraph,
|
|
*,
|
|
dialect: str = "postgres",
|
|
):
|
|
self.sources = sources
|
|
self.graph = graph
|
|
self.dialect = dialect
|
|
self.parser = ExpressionParser(dialect=dialect)
|
|
|
|
def plan(self, query: SemanticQuery) -> ResolvedPlan:
|
|
# 0. Validate column visibility
|
|
self._validate_visibility(query)
|
|
|
|
# 1. Resolve dimensions
|
|
dimensions = self._resolve_dimensions(query.dimensions)
|
|
|
|
# 2. Resolve measures (parse, look up pre-defined, classify)
|
|
raw_measures = self._resolve_measures(query.measures)
|
|
|
|
# 3. Topological sort for derived measures
|
|
measures = self._topological_sort_measures(raw_measures)
|
|
|
|
# 3a. Apply query-time segments (AND each into matching measures' filter)
|
|
measures = self._apply_query_segments(measures, query.segments)
|
|
|
|
# 3b. Validate column references exist
|
|
self._validate_column_refs(measures, dimensions, query.filters)
|
|
|
|
# 4. Collect all referenced sources
|
|
source_refs: set[str] = set()
|
|
for m in measures:
|
|
if not m.is_derived:
|
|
source_refs.add(m.source_name)
|
|
source_refs.update(self.parser.extract_source_refs(m.expr))
|
|
if m.filter:
|
|
source_refs.update(self.parser.extract_source_refs(m.filter))
|
|
for d in dimensions:
|
|
refs = self.parser.extract_source_refs(d.field)
|
|
source_refs.update(refs)
|
|
for f in query.filters:
|
|
source_refs.update(self.parser.extract_source_refs(f))
|
|
|
|
if not source_refs:
|
|
raise ValueError("Query does not reference any sources")
|
|
|
|
# 5. Determine anchor source (must happen BEFORE resolve_join_tree)
|
|
anchor_source = self._pick_anchor(
|
|
measures,
|
|
dimensions,
|
|
source_refs,
|
|
include_empty=query.include_empty,
|
|
)
|
|
|
|
# 6. Resolve join tree, rooted at the anchor
|
|
tree = self.graph.resolve_join_tree(source_refs, root=anchor_source)
|
|
|
|
# 7. Build structured joins from tree edges
|
|
joins = [
|
|
ResolvedJoin(
|
|
from_source=e.from_source,
|
|
to_source=e.to_source,
|
|
from_column=e.from_column,
|
|
to_column=e.to_column,
|
|
relationship=e.relationship,
|
|
)
|
|
for e in tree.edges
|
|
]
|
|
|
|
# 8. Detect fanout / chasm trap
|
|
has_fan_out, measure_groups, fan_out_desc, locality_descs = (
|
|
self._detect_fan_out(measures, dimensions, tree, filters=query.filters)
|
|
)
|
|
|
|
# 9. Classify filters
|
|
where_filters, having_filters = self._classify_filters(query.filters, measures)
|
|
|
|
# 10. Compute anchor grain
|
|
dim_sources = set()
|
|
for d in dimensions:
|
|
refs = self.parser.extract_source_refs(d.field)
|
|
dim_sources.update(refs)
|
|
anchor_grain = []
|
|
for d in dimensions:
|
|
anchor_grain.append(d.field)
|
|
|
|
# 11. Build resolved columns
|
|
columns = self._build_columns(measures, dimensions)
|
|
|
|
# 12. Build join path descriptions
|
|
join_paths = []
|
|
for j in joins:
|
|
from_cols = [c.strip() for c in j.from_column.split(",")]
|
|
to_cols = [c.strip() for c in j.to_column.split(",")]
|
|
conditions = " AND ".join(
|
|
f"{j.from_source}.{fc} = {j.to_source}.{tc}"
|
|
for fc, tc in zip(from_cols, to_cols)
|
|
)
|
|
join_paths.append(f"{conditions} ({j.relationship})")
|
|
|
|
# 13. Resolve order_by
|
|
order_by_clauses = []
|
|
for ob in query.order_by:
|
|
if isinstance(ob, dict):
|
|
order_by_clauses.append(OrderByClause(**ob))
|
|
elif isinstance(ob, str):
|
|
order_by_clauses.append(OrderByClause(field=ob))
|
|
else:
|
|
order_by_clauses.append(ob)
|
|
|
|
return ResolvedPlan(
|
|
sources_used=sorted(tree.sources),
|
|
join_paths=join_paths,
|
|
joins=joins,
|
|
anchor_source=anchor_source,
|
|
anchor_grain=anchor_grain,
|
|
fan_out_description=fan_out_desc,
|
|
has_fan_out=has_fan_out,
|
|
measure_groups=measure_groups,
|
|
aggregate_locality=locality_descs,
|
|
where_filters=where_filters,
|
|
having_filters=having_filters,
|
|
columns=columns,
|
|
measures=measures,
|
|
dimensions=dimensions,
|
|
order_by=order_by_clauses,
|
|
limit=query.limit,
|
|
include_empty=query.include_empty,
|
|
)
|
|
|
|
def _resolve_dimensions(self, dims: list[str | dict]) -> list[QueryDimension]:
|
|
result = []
|
|
seen: set[tuple[str, str | None]] = set()
|
|
for d in dims:
|
|
if isinstance(d, str):
|
|
dim = QueryDimension(field=self._qualify_bare_column(d))
|
|
elif isinstance(d, dict):
|
|
field = d.get("field", "")
|
|
dim = QueryDimension(**{**d, "field": self._qualify_bare_column(field)})
|
|
else:
|
|
continue
|
|
key = (dim.field, dim.granularity)
|
|
if key not in seen:
|
|
seen.add(key)
|
|
result.append(dim)
|
|
return result
|
|
|
|
def _qualify_bare_column(self, field: str) -> str:
|
|
"""Qualify a bare column name to source.column if unambiguous."""
|
|
if "." in field or not field.strip().isidentifier():
|
|
return field
|
|
bare = field.strip()
|
|
matches: list[str] = []
|
|
for source_name, source in self.sources.items():
|
|
if any(c.name == bare for c in source.columns):
|
|
matches.append(source_name)
|
|
if len(matches) == 1:
|
|
return f"{matches[0]}.{bare}"
|
|
if len(matches) > 1:
|
|
raise ValueError(
|
|
f"Column '{bare}' is ambiguous: it exists in multiple sources "
|
|
f"({', '.join(sorted(matches))}). Use a qualified name like "
|
|
f"'{matches[0]}.{bare}' to disambiguate."
|
|
)
|
|
return field # not found — leave as-is, downstream will error
|
|
|
|
def _resolve_measures(self, raw: list[str | dict]) -> list[ResolvedMeasure]:
|
|
measures: list[ResolvedMeasure] = []
|
|
# Collect all named measures for dependency detection
|
|
named: set[str] = set()
|
|
for m in raw:
|
|
if isinstance(m, dict) and m.get("name"):
|
|
named.add(m["name"])
|
|
colliding_predefined_names = self._collect_colliding_predefined_names(raw)
|
|
|
|
for m in raw:
|
|
if isinstance(m, str):
|
|
measures.append(
|
|
self._resolve_measure_str(m, colliding_predefined_names)
|
|
)
|
|
elif isinstance(m, dict):
|
|
measures.append(
|
|
self._resolve_measure_dict(
|
|
m,
|
|
named,
|
|
colliding_predefined_names,
|
|
)
|
|
)
|
|
|
|
# Expand pre-defined measure chains (e.g., profit = revenue - total_cost)
|
|
measures = self._expand_predefined_chains(measures)
|
|
# Auto-add predefined measures referenced by derived measures
|
|
measures = self._auto_add_predefined_deps(measures)
|
|
# Qualify duplicate measure names across sources
|
|
measures = self._qualify_duplicate_names(measures)
|
|
return measures
|
|
|
|
def _collect_colliding_predefined_names(self, raw: list[str | dict]) -> set[str]:
|
|
counts: Counter[str] = Counter()
|
|
for item in raw:
|
|
if isinstance(item, str):
|
|
ref = self._match_predefined_ref(item)
|
|
if ref:
|
|
_, measure_name = ref
|
|
counts[measure_name] += 1
|
|
else:
|
|
bare = item.strip()
|
|
if bare.isidentifier():
|
|
try:
|
|
unq = self._resolve_unqualified_measure(bare)
|
|
if unq:
|
|
counts[unq[1]] += 1
|
|
except ValueError:
|
|
pass # ambiguous — caught later during resolution
|
|
elif isinstance(item, dict):
|
|
expr = item.get("expr", "")
|
|
for _, measure_name in self._extract_predefined_refs(expr):
|
|
counts[measure_name] += 1
|
|
return {name for name, count in counts.items() if count > 1}
|
|
|
|
def _extract_predefined_refs(self, expr: str) -> list[tuple[str, str]]:
|
|
refs: list[tuple[str, str]] = []
|
|
parsed = self.parser.parse(expr)
|
|
for ref in parsed.column_refs:
|
|
parts = ref.split(".", 1)
|
|
if len(parts) != 2:
|
|
continue
|
|
src_name, measure_name = parts
|
|
actual_src_name = self.graph.alias_map.get(src_name, src_name)
|
|
src = self.sources.get(actual_src_name)
|
|
if not src:
|
|
continue
|
|
if any(md.name == measure_name for md in src.measures) and not any(
|
|
c.name == measure_name for c in src.columns
|
|
):
|
|
refs.append((src_name, measure_name))
|
|
return refs
|
|
|
|
def _match_predefined_ref(self, expr: str) -> tuple[str, str] | None:
|
|
parsed = self.parser.parse(expr)
|
|
if parsed.is_aggregate or len(parsed.column_refs) != 1:
|
|
return None
|
|
ref = next(iter(parsed.column_refs))
|
|
parts = ref.split(".", 1)
|
|
if len(parts) != 2:
|
|
return None
|
|
source_name, measure_name = parts
|
|
actual_source_name = self.graph.alias_map.get(source_name, source_name)
|
|
source = self.sources.get(actual_source_name)
|
|
if not source:
|
|
return None
|
|
if any(md.name == measure_name for md in source.measures):
|
|
return source_name, measure_name
|
|
return None
|
|
|
|
def _resolve_unqualified_measure(self, bare_name: str) -> tuple[str, str] | None:
|
|
"""Find a unique predefined measure matching a bare (unqualified) name.
|
|
|
|
Returns (source_name, measure_name) if exactly one source defines it.
|
|
Raises ValueError if ambiguous (multiple sources).
|
|
"""
|
|
matches: list[str] = []
|
|
for source_name, source in self.sources.items():
|
|
if any(md.name == bare_name for md in source.measures):
|
|
matches.append(source_name)
|
|
if len(matches) == 0:
|
|
return None
|
|
if len(matches) == 1:
|
|
return matches[0], bare_name
|
|
raise ValueError(
|
|
f"Measure '{bare_name}' is ambiguous: it exists in multiple sources "
|
|
f"({', '.join(sorted(matches))}). Use a qualified name like "
|
|
f"'{matches[0]}.{bare_name}' to disambiguate."
|
|
)
|
|
|
|
@staticmethod
|
|
def _qualified_measure_name(source_name: str, measure_name: str) -> str:
|
|
return f"{source_name}_{measure_name}"
|
|
|
|
@staticmethod
|
|
def _auto_measure_name(expr: str) -> str:
|
|
normalized = expr.replace(".", "_").strip().lower()
|
|
normalized = re.sub(r"[^a-z0-9_]+", "_", normalized)
|
|
normalized = re.sub(r"_+", "_", normalized).strip("_")
|
|
if not normalized:
|
|
return "measure"
|
|
if normalized[0].isdigit():
|
|
return f"m_{normalized}"
|
|
return normalized
|
|
|
|
def _measure_definition_for_resolved(
|
|
self,
|
|
source: SourceDefinition,
|
|
source_name: str,
|
|
resolved_name: str,
|
|
original_name: str | None = None,
|
|
):
|
|
for candidate in (original_name, resolved_name):
|
|
if not candidate:
|
|
continue
|
|
mdef = next((md for md in source.measures if md.name == candidate), None)
|
|
if mdef:
|
|
return mdef
|
|
for mdef in source.measures:
|
|
if resolved_name == self._qualified_measure_name(source_name, mdef.name):
|
|
return mdef
|
|
return None
|
|
|
|
def _split_qualified_dep_token(self, token: str) -> tuple[str, str] | None:
|
|
for source_name, source in self.sources.items():
|
|
prefix = f"{source_name}_"
|
|
if not token.startswith(prefix):
|
|
continue
|
|
measure_name = token[len(prefix) :]
|
|
if any(md.name == measure_name for md in source.measures):
|
|
return source_name, measure_name
|
|
return None
|
|
|
|
def _auto_add_predefined_deps(
|
|
self, measures: list[ResolvedMeasure]
|
|
) -> list[ResolvedMeasure]:
|
|
"""Auto-add predefined measures that derived measures depend on but aren't in the list."""
|
|
existing_names = {m.name for m in measures}
|
|
extra: list[ResolvedMeasure] = []
|
|
for m in measures:
|
|
if not m.is_derived:
|
|
continue
|
|
for dep in m.depends_on:
|
|
if dep in existing_names:
|
|
continue
|
|
exact = self._split_qualified_dep_token(dep)
|
|
if exact:
|
|
src_name, measure_name = exact
|
|
resolved = self._resolve_measure_str(
|
|
f"{src_name}.{measure_name}",
|
|
set(),
|
|
)
|
|
if resolved.name != dep:
|
|
resolved = resolved.model_copy(update={"name": dep})
|
|
extra.append(resolved)
|
|
existing_names.add(dep)
|
|
continue
|
|
# Try to resolve as a predefined measure from any source
|
|
for src in self.sources.values():
|
|
mdef = next((md for md in src.measures if md.name == dep), None)
|
|
if mdef:
|
|
extra.append(
|
|
self._resolve_measure_str(f"{src.name}.{dep}", set())
|
|
)
|
|
existing_names.add(dep)
|
|
break
|
|
if extra:
|
|
# Prepend extras so dependencies come before derived measures
|
|
measures = extra + measures
|
|
return measures
|
|
|
|
def _qualify_duplicate_names(
|
|
self, measures: list[ResolvedMeasure]
|
|
) -> list[ResolvedMeasure]:
|
|
"""Qualify measure names that collide across different sources."""
|
|
name_counts = Counter(m.name for m in measures)
|
|
colliding = {name for name, count in name_counts.items() if count > 1}
|
|
if not colliding:
|
|
return measures
|
|
result = []
|
|
for m in measures:
|
|
if m.name in colliding and m.source_name != "__derived__":
|
|
result.append(
|
|
m.model_copy(update={"name": f"{m.source_name}_{m.name}"})
|
|
)
|
|
else:
|
|
result.append(m)
|
|
return result
|
|
|
|
def _expand_predefined_chains(
|
|
self, measures: list[ResolvedMeasure]
|
|
) -> list[ResolvedMeasure]:
|
|
"""Expand pre-defined measures that reference other pre-defined measures.
|
|
|
|
Fully recursive: handles chains of arbitrary depth (e.g.,
|
|
margin = net_profit / revenue, where net_profit = gross_profit - tax,
|
|
where gross_profit = revenue - cost).
|
|
"""
|
|
existing_names = {m.name for m in measures}
|
|
extra_measures: list[ResolvedMeasure] = []
|
|
updated: list[ResolvedMeasure] = []
|
|
# Track already-expanded deps to avoid duplicates
|
|
expanded: set[str] = set()
|
|
|
|
def _ensure_dep(
|
|
dep_name: str, source: SourceDefinition, source_name: str
|
|
) -> None:
|
|
"""Recursively ensure a dependency measure is added."""
|
|
if dep_name in existing_names or dep_name in expanded:
|
|
return
|
|
dep_mdef = next((md for md in source.measures if md.name == dep_name), None)
|
|
if not dep_mdef:
|
|
return
|
|
|
|
dep_other = {md.name for md in source.measures if md.name != dep_name}
|
|
dep_parsed = self.parser.parse(dep_mdef.expr, known_measure_names=dep_other)
|
|
|
|
if dep_parsed.depends_on_measures:
|
|
# Recursively add sub-dependencies first
|
|
for sub_dep in sorted(dep_parsed.depends_on_measures):
|
|
_ensure_dep(sub_dep, source, source_name)
|
|
# This dependency is itself derived
|
|
extra_measures.append(
|
|
ResolvedMeasure(
|
|
name=dep_name,
|
|
original_name=dep_name,
|
|
expr=dep_mdef.expr,
|
|
source_name="__derived__",
|
|
provenance=Provenance.VERIFIED,
|
|
is_derived=True,
|
|
depends_on=sorted(dep_parsed.depends_on_measures),
|
|
description=dep_mdef.description,
|
|
)
|
|
)
|
|
else:
|
|
# Leaf dependency: qualify and add as concrete measure
|
|
extra_measures.append(
|
|
ResolvedMeasure(
|
|
name=dep_name,
|
|
original_name=dep_name,
|
|
qualified_ref=f"{source_name}.{dep_name}",
|
|
expr=self._qualify_predefined_expr(dep_mdef.expr, source_name),
|
|
source_name=source_name,
|
|
filter=self._compose_measure_filter(dep_mdef, source_name),
|
|
provenance=Provenance.VERIFIED,
|
|
description=dep_mdef.description,
|
|
)
|
|
)
|
|
existing_names.add(dep_name)
|
|
expanded.add(dep_name)
|
|
|
|
for m in measures:
|
|
if m.provenance != Provenance.VERIFIED or m.is_derived:
|
|
updated.append(m)
|
|
continue
|
|
|
|
actual_source_name = self.graph.alias_map.get(m.source_name, m.source_name)
|
|
source = self.sources.get(actual_source_name)
|
|
if not source:
|
|
updated.append(m)
|
|
continue
|
|
|
|
mdef = self._measure_definition_for_resolved(
|
|
source, m.source_name, m.name, m.original_name
|
|
)
|
|
if not mdef:
|
|
updated.append(m)
|
|
continue
|
|
|
|
other_measure_names = {
|
|
md.name for md in source.measures if md.name != mdef.name
|
|
}
|
|
parsed = self.parser.parse(
|
|
mdef.expr, known_measure_names=other_measure_names
|
|
)
|
|
|
|
if not parsed.depends_on_measures:
|
|
updated.append(m)
|
|
continue
|
|
|
|
# Recursively add all dependencies
|
|
for dep_name in sorted(parsed.depends_on_measures):
|
|
_ensure_dep(dep_name, source, m.source_name)
|
|
|
|
# Convert this measure to derived
|
|
updated.append(
|
|
m.model_copy(
|
|
update={
|
|
"expr": mdef.expr,
|
|
"source_name": "__derived__",
|
|
"is_derived": True,
|
|
"depends_on": sorted(parsed.depends_on_measures),
|
|
"filter": None,
|
|
}
|
|
)
|
|
)
|
|
|
|
return extra_measures + updated
|
|
|
|
def _resolve_measure_str(
|
|
self,
|
|
s: str,
|
|
colliding_predefined_names: set[str],
|
|
) -> ResolvedMeasure:
|
|
"""
|
|
"orders.revenue" → pre-defined lookup
|
|
"sum(orders.amount)" → runtime expression
|
|
"""
|
|
parsed = self.parser.parse(s)
|
|
|
|
# Reject window functions in measures
|
|
if parsed.has_window_function:
|
|
raise ValueError(
|
|
f"Window functions (OVER clause) are not supported in measures: '{s}'. "
|
|
f"Window functions require row-level context and cannot be combined with "
|
|
f"GROUP BY aggregation."
|
|
)
|
|
|
|
predefined_ref = self._match_predefined_ref(s)
|
|
|
|
# Try unqualified resolution for bare identifiers (e.g. "revenue" → "orders.revenue")
|
|
if predefined_ref is None and not parsed.is_aggregate:
|
|
bare = s.strip()
|
|
if bare.isidentifier():
|
|
unqualified = self._resolve_unqualified_measure(bare)
|
|
if unqualified:
|
|
source_name, measure_name = unqualified
|
|
qualified = f"{source_name}.{measure_name}"
|
|
logger.info(
|
|
"Resolved unqualified measure '%s' to '%s'",
|
|
bare,
|
|
qualified,
|
|
)
|
|
return self._resolve_measure_str(
|
|
qualified, colliding_predefined_names
|
|
)
|
|
|
|
if predefined_ref:
|
|
source_name, measure_name = predefined_ref
|
|
actual_source_name = self.graph.alias_map.get(source_name, source_name)
|
|
source = self.sources[actual_source_name]
|
|
for mdef in source.measures:
|
|
if mdef.name == measure_name:
|
|
resolved_name = measure_name
|
|
if measure_name in colliding_predefined_names:
|
|
resolved_name = self._qualified_measure_name(
|
|
source_name, measure_name
|
|
)
|
|
return ResolvedMeasure(
|
|
name=resolved_name,
|
|
original_name=measure_name,
|
|
qualified_ref=f"{source_name}.{measure_name}",
|
|
expr=self._qualify_predefined_expr(mdef.expr, source_name),
|
|
source_name=source_name,
|
|
filter=self._compose_measure_filter(mdef, source_name),
|
|
provenance=Provenance.VERIFIED,
|
|
description=mdef.description,
|
|
)
|
|
|
|
# Bare column reference without aggregation — invalid as a measure
|
|
if not parsed.is_aggregate:
|
|
if parsed.column_refs:
|
|
ref = next(iter(parsed.column_refs))
|
|
src, col = ref.split(".", 1)
|
|
raise ValueError(
|
|
f"Measure '{s}' is not a pre-defined measure on source '{src}' "
|
|
f"and has no aggregate function. Use an aggregate like "
|
|
f"sum({s}), count({s}), avg({s}), etc."
|
|
)
|
|
raise ValueError(f"Measure '{s}' does not reference any source")
|
|
|
|
# Runtime expression
|
|
if not parsed.source_refs:
|
|
raise ValueError(f"Measure '{s}' does not reference any source")
|
|
|
|
# Reject nested aggregation (e.g., avg(sum(orders.amount)))
|
|
self._check_nested_aggregation(s)
|
|
|
|
source_name = sorted(parsed.source_refs)[0]
|
|
name = self._auto_measure_name(s)
|
|
return ResolvedMeasure(
|
|
name=name,
|
|
original_name=name,
|
|
expr=s,
|
|
source_name=source_name,
|
|
provenance=Provenance.COMPOSED,
|
|
)
|
|
|
|
def _resolve_measure_dict(
|
|
self,
|
|
d: dict,
|
|
named: set[str],
|
|
colliding_predefined_names: set[str],
|
|
) -> ResolvedMeasure:
|
|
expr = d.get("expr", "")
|
|
name = d.get("name", expr)
|
|
parsed = self.parser.parse(expr, known_measure_names=named)
|
|
|
|
# Reject window functions
|
|
if parsed.has_window_function:
|
|
raise ValueError(
|
|
f"Window functions (OVER clause) are not supported in measures: '{expr}'. "
|
|
f"Window functions require row-level context and cannot be combined with "
|
|
f"GROUP BY aggregation."
|
|
)
|
|
|
|
# Check if any column_refs match predefined measures (e.g., "orders.revenue")
|
|
predefined_deps: list[tuple[str, str, str]] = []
|
|
for src_name, measure_name in self._extract_predefined_refs(expr):
|
|
predefined_deps.append(
|
|
(f"{src_name}.{measure_name}", src_name, measure_name)
|
|
)
|
|
|
|
# Merge bare measure deps + qualified predefined deps
|
|
all_dep_names: set[str] = set(parsed.depends_on_measures)
|
|
rewritten_expr = expr
|
|
|
|
if predefined_deps:
|
|
replacement_map: dict[str, str] = {}
|
|
for ref, src_name, measure_name in predefined_deps:
|
|
dep_name = measure_name
|
|
if measure_name in colliding_predefined_names:
|
|
dep_name = self._qualified_measure_name(src_name, measure_name)
|
|
replacement_map[ref] = dep_name
|
|
all_dep_names.add(dep_name)
|
|
named.add(dep_name)
|
|
tree = sqlglot.parse_one(
|
|
f"SELECT {quote_reserved_identifiers(expr)}", dialect=self.dialect
|
|
)
|
|
|
|
def _replace(node):
|
|
if isinstance(node, exp.Column) and node.table:
|
|
ref = f"{node.table}.{node.name}"
|
|
if ref in replacement_map:
|
|
return exp.Column(this=exp.to_identifier(replacement_map[ref]))
|
|
return node
|
|
|
|
rewritten_expr = (
|
|
tree.transform(_replace).expressions[0].sql(dialect=self.dialect)
|
|
)
|
|
|
|
if all_dep_names:
|
|
return ResolvedMeasure(
|
|
name=name,
|
|
original_name=name,
|
|
expr=rewritten_expr,
|
|
source_name="__derived__",
|
|
provenance=Provenance.COMPOSED,
|
|
is_derived=True,
|
|
depends_on=sorted(all_dep_names),
|
|
)
|
|
|
|
if not parsed.source_refs:
|
|
raise ValueError(f"Measure expr '{expr}' does not reference any source")
|
|
|
|
# Reject nested aggregation (e.g., avg(sum(orders.amount)))
|
|
self._check_nested_aggregation(expr)
|
|
|
|
source_name = sorted(parsed.source_refs)[0]
|
|
return ResolvedMeasure(
|
|
name=name,
|
|
original_name=name,
|
|
expr=expr,
|
|
source_name=source_name,
|
|
provenance=Provenance.COMPOSED,
|
|
)
|
|
|
|
def _check_nested_aggregation(self, expr: str) -> None:
|
|
"""Reject expressions with nested aggregate functions (e.g., avg(sum(x)))."""
|
|
try:
|
|
tree = sqlglot.parse_one(
|
|
f"SELECT {quote_reserved_identifiers(expr)}", dialect=self.dialect
|
|
)
|
|
for agg_node in tree.find_all(exp.AggFunc):
|
|
# Check if this aggregate contains another aggregate inside
|
|
for inner in agg_node.find_all(exp.AggFunc):
|
|
if inner is not agg_node:
|
|
raise ValueError(
|
|
f"Nested aggregation is not supported: '{expr}'. "
|
|
f"Use a derived measure to combine aggregates "
|
|
f"(e.g., define sum_amount first, then avg it as a derived measure)."
|
|
)
|
|
except ValueError:
|
|
raise
|
|
except Exception:
|
|
logger.debug("Failed to check nested aggregation for: %s", expr)
|
|
|
|
def _topological_sort_measures(
|
|
self, measures: list[ResolvedMeasure]
|
|
) -> list[ResolvedMeasure]:
|
|
by_name = {m.name: m for m in measures}
|
|
visited: set[str] = set()
|
|
in_stack: set[str] = set()
|
|
result: list[ResolvedMeasure] = []
|
|
|
|
def visit(m: ResolvedMeasure) -> None:
|
|
if m.name in in_stack:
|
|
raise ValueError(f"Circular dependency detected: {m.name}")
|
|
if m.name in visited:
|
|
return
|
|
in_stack.add(m.name)
|
|
for dep_name in m.depends_on:
|
|
if dep_name in by_name:
|
|
visit(by_name[dep_name])
|
|
in_stack.discard(m.name)
|
|
visited.add(m.name)
|
|
result.append(m)
|
|
|
|
for m in measures:
|
|
visit(m)
|
|
return result
|
|
|
|
def _pick_anchor(
|
|
self,
|
|
measures: list[ResolvedMeasure],
|
|
dimensions: list[QueryDimension],
|
|
source_refs: set[str],
|
|
include_empty: bool,
|
|
) -> str:
|
|
if include_empty:
|
|
for d in dimensions:
|
|
refs = self.parser.extract_source_refs(d.field)
|
|
if refs:
|
|
return sorted(refs)[0]
|
|
# Prefer the first non-derived measure's source
|
|
for m in measures:
|
|
if not m.is_derived and m.source_name in self.sources:
|
|
return m.source_name
|
|
# Fallback to first dimension's source
|
|
for d in dimensions:
|
|
refs = self.parser.extract_source_refs(d.field)
|
|
if refs:
|
|
return sorted(refs)[0]
|
|
return sorted(source_refs)[0]
|
|
|
|
def _compose_measure_filter(
|
|
self, mdef: MeasureDefinition, source_name: str
|
|
) -> str | None:
|
|
"""Compose mdef.filter with mdef.segments[*].expr into a single AND-ed,
|
|
qualified predicate. Returns None if neither contributes.
|
|
|
|
Segments are bare names resolved against the measure's own source.
|
|
Unknown names raise at plan time.
|
|
"""
|
|
parts: list[str] = []
|
|
if mdef.filter:
|
|
parts.append(self._qualify_predefined_expr(mdef.filter, source_name))
|
|
if mdef.segments:
|
|
actual_source_name = self.graph.alias_map.get(source_name, source_name)
|
|
source = self.sources.get(actual_source_name)
|
|
seg_by_name = {s.name: s for s in (source.segments if source else [])}
|
|
for seg_name in mdef.segments:
|
|
seg = seg_by_name.get(seg_name)
|
|
if not seg:
|
|
available = ", ".join(sorted(seg_by_name)) or "(none)"
|
|
raise ValueError(
|
|
f"Measure '{mdef.name}' on source '{actual_source_name}' "
|
|
f"references unknown segment '{seg_name}'. "
|
|
f"Available segments: {available}."
|
|
)
|
|
parts.append(self._qualify_predefined_expr(seg.expr, source_name))
|
|
if not parts:
|
|
return None
|
|
if len(parts) == 1:
|
|
return parts[0]
|
|
return " AND ".join(f"({p})" for p in parts)
|
|
|
|
def _apply_query_segments(
|
|
self,
|
|
measures: list[ResolvedMeasure],
|
|
query_segments: list[str],
|
|
) -> list[ResolvedMeasure]:
|
|
"""AND each query-time segment into the filter of every measure whose
|
|
base source matches the segment's source.
|
|
|
|
Errors:
|
|
- Segment string isn't dotted source.name
|
|
- Source or segment doesn't exist
|
|
- No measure in the query has the segment's source as its base source
|
|
"""
|
|
if not query_segments:
|
|
return measures
|
|
|
|
segs_by_source: dict[str, list[str]] = {}
|
|
for raw in query_segments:
|
|
if "." not in raw:
|
|
raise ValueError(
|
|
f"Query-time segment '{raw}' must be a dotted "
|
|
f"'source.segment_name' reference."
|
|
)
|
|
src_name, seg_name = raw.split(".", 1)
|
|
actual = self.graph.alias_map.get(src_name, src_name)
|
|
source = self.sources.get(actual)
|
|
if not source:
|
|
raise ValueError(
|
|
f"Query-time segment '{raw}' references unknown source "
|
|
f"'{src_name}'."
|
|
)
|
|
seg = next((s for s in source.segments if s.name == seg_name), None)
|
|
if not seg:
|
|
avail = ", ".join(sorted(s.name for s in source.segments)) or "(none)"
|
|
raise ValueError(
|
|
f"Query-time segment '{raw}' references unknown segment "
|
|
f"'{seg_name}' on source '{src_name}'. Available: {avail}."
|
|
)
|
|
qualified = self._qualify_predefined_expr(seg.expr, src_name)
|
|
segs_by_source.setdefault(src_name, []).append(qualified)
|
|
|
|
updated: list[ResolvedMeasure] = []
|
|
matched_sources: set[str] = set()
|
|
for m in measures:
|
|
if m.is_derived or m.source_name not in segs_by_source:
|
|
updated.append(m)
|
|
continue
|
|
matched_sources.add(m.source_name)
|
|
new_parts: list[str] = []
|
|
if m.filter:
|
|
new_parts.append(m.filter)
|
|
new_parts.extend(segs_by_source[m.source_name])
|
|
composed = (
|
|
new_parts[0]
|
|
if len(new_parts) == 1
|
|
else " AND ".join(f"({p})" for p in new_parts)
|
|
)
|
|
updated.append(m.model_copy(update={"filter": composed}))
|
|
|
|
for src in segs_by_source:
|
|
if src not in matched_sources:
|
|
raise ValueError(
|
|
f"Query-time segment(s) on source '{src}' have no matching "
|
|
f"measure in the query. A query-time segment only applies to "
|
|
f"measures whose base source matches the segment's source."
|
|
)
|
|
|
|
return updated
|
|
|
|
def _qualify_predefined_expr(self, expr: str, source_name: str) -> str:
|
|
"""Qualify bare column references in predefined measure expressions using sqlglot AST.
|
|
|
|
BFS-traverses many_to_one/one_to_one joins from the measure's source to find
|
|
columns on transitively reachable sources. This handles measure filters that
|
|
reference joined-source columns (e.g., filter: "level = 'premium'" where
|
|
'level' is on a 'tiers' table reachable via orders → customers → tiers).
|
|
"""
|
|
actual_source_name = self.graph.alias_map.get(source_name, source_name)
|
|
source = self.sources.get(actual_source_name)
|
|
if not source:
|
|
return expr
|
|
|
|
# BFS through m2o/o2o joins to build column->source mapping
|
|
col_to_source: dict[str, str] = {}
|
|
visited: set[str] = set()
|
|
queue = [actual_source_name]
|
|
while queue:
|
|
current_name = queue.pop(0)
|
|
if current_name in visited:
|
|
continue
|
|
visited.add(current_name)
|
|
current_src = self.sources.get(current_name)
|
|
if not current_src:
|
|
continue
|
|
# Add columns from this source (first-discovered wins for ambiguity)
|
|
for c in current_src.columns:
|
|
if c.name not in col_to_source:
|
|
current_ref = (
|
|
source_name
|
|
if current_name == actual_source_name
|
|
else current_name
|
|
)
|
|
col_to_source[c.name] = current_ref
|
|
# Traverse m2o/o2o joins
|
|
for join_decl in current_src.joins:
|
|
if join_decl.relationship in ("many_to_one", "one_to_one"):
|
|
target = join_decl.alias or join_decl.to
|
|
actual = join_decl.to
|
|
if actual not in visited:
|
|
queue.append(actual)
|
|
# Map columns using alias if present
|
|
joined_src = self.sources.get(actual)
|
|
if joined_src:
|
|
for c in joined_src.columns:
|
|
if c.name not in col_to_source:
|
|
col_to_source[c.name] = target
|
|
# Own columns always take highest priority
|
|
for c in source.columns:
|
|
col_to_source[c.name] = source_name
|
|
|
|
tree = sqlglot.parse_one(
|
|
f"SELECT {quote_reserved_identifiers(expr)}", read=self.dialect
|
|
)
|
|
|
|
def _qualify_column(node):
|
|
if (
|
|
isinstance(node, exp.Column)
|
|
and not node.table
|
|
and node.name in col_to_source
|
|
):
|
|
target_source = col_to_source[node.name]
|
|
return exp.Column(
|
|
this=node.this.copy(), table=exp.to_identifier(target_source)
|
|
)
|
|
return node
|
|
|
|
transformed = tree.transform(_qualify_column)
|
|
return transformed.expressions[0].sql(dialect=self.dialect)
|
|
|
|
def _detect_fan_out(
|
|
self,
|
|
measures: list[ResolvedMeasure],
|
|
dimensions: list[QueryDimension],
|
|
tree,
|
|
filters: list[str] | None = None,
|
|
) -> tuple[bool, list[MeasureGroup], str, list[str]]:
|
|
"""
|
|
Detect fanout and chasm traps. Group measures by source.
|
|
If multiple measure sources exist, each needs its own pre-aggregation CTE.
|
|
Also checks filter sources — a filter forcing a one_to_many join from the
|
|
measure source is an error (cannot be safely pre-aggregated).
|
|
"""
|
|
# Group non-derived measures by source
|
|
groups: dict[str, list[ResolvedMeasure]] = {}
|
|
for m in measures:
|
|
if m.is_derived:
|
|
continue
|
|
groups.setdefault(m.source_name, []).append(m)
|
|
|
|
# Validate multi-source aggregate expressions: if a non-derived measure
|
|
# references sources from multiple groups, it can't be safely placed in
|
|
# a single CTE (the other source won't be available in the CTE scope).
|
|
if len(groups) > 1:
|
|
for m in measures:
|
|
if m.is_derived:
|
|
continue
|
|
measure_source_refs = self.parser.extract_source_refs(m.expr)
|
|
other_group_refs = measure_source_refs - {m.source_name}
|
|
for ref in other_group_refs:
|
|
ref_actual = self.graph.alias_map.get(ref, ref)
|
|
source_actual = self.graph.alias_map.get(
|
|
m.source_name, m.source_name
|
|
)
|
|
if ref_actual == source_actual:
|
|
continue
|
|
if ref in groups and ref != m.source_name:
|
|
raise ValueError(
|
|
f"Measure '{m.name}' references multiple independent "
|
|
f"sources ({m.source_name}, {ref}) that are in separate "
|
|
f"measure groups. In aggregate locality mode, each CTE "
|
|
f"can only access its own source's tables. Decompose "
|
|
f"the expression into separate named measures and combine "
|
|
f"as a derived measure: e.g., "
|
|
f'{{"expr": "part1", "name": "a"}}, '
|
|
f'{{"expr": "part2", "name": "b"}}, '
|
|
f'{{"expr": "a / b", "name": "{m.name}"}}'
|
|
)
|
|
|
|
# Collect dimension sources
|
|
dim_sources: set[str] = set()
|
|
for d in dimensions:
|
|
refs = self.parser.extract_source_refs(d.field)
|
|
dim_sources.update(refs)
|
|
|
|
# Collect filter sources
|
|
filter_sources: set[str] = set()
|
|
for f in filters or []:
|
|
filter_sources.update(self.parser.extract_source_refs(f))
|
|
|
|
if len(groups) <= 1:
|
|
# Single measure group: check the path FROM measure source TO dimension sources.
|
|
# Only flag fanout if those specific paths have one_to_many edges.
|
|
if groups:
|
|
source_name = next(iter(groups))
|
|
source_actual = self.graph.alias_map.get(source_name, source_name)
|
|
has_o2m = False
|
|
for dim_src in dim_sources:
|
|
if dim_src == source_name:
|
|
continue
|
|
# Skip alias siblings (same underlying source — no fanout)
|
|
dim_actual = self.graph.alias_map.get(dim_src, dim_src)
|
|
if dim_actual == source_actual:
|
|
continue
|
|
path = self.graph.find_path(source_name, dim_src)
|
|
if path and path.has_one_to_many:
|
|
has_o2m = True
|
|
break
|
|
|
|
# Also check filter sources for one_to_many fanout
|
|
if not has_o2m:
|
|
for filter_src in filter_sources - dim_sources - {source_name}:
|
|
filter_actual = self.graph.alias_map.get(filter_src, filter_src)
|
|
if filter_actual == source_actual:
|
|
continue
|
|
path = self.graph.find_path(source_name, filter_src)
|
|
if path and path.has_one_to_many:
|
|
raise ValueError(
|
|
f"Filter on '{filter_src}' requires a one_to_many join "
|
|
f"from measure source '{source_name}', which would cause "
|
|
f"incorrect aggregation (fanout). Consider rewriting the "
|
|
f"filter as a subquery or adding the filter source as a "
|
|
f"dimension source."
|
|
)
|
|
|
|
if has_o2m:
|
|
measure_groups = [
|
|
MeasureGroup(
|
|
source_name=source_name, measures=groups[source_name]
|
|
)
|
|
]
|
|
return (
|
|
True,
|
|
measure_groups,
|
|
f"Fanout detected: one_to_many edges from {source_name} to dimensions",
|
|
[f"Pre-aggregate {source_name} measures before joining"],
|
|
)
|
|
return False, [], "No fanout", []
|
|
|
|
# Multiple measure sources. Only merge groups that are provably row-safe
|
|
# (alias siblings or pure one_to_one chains). many_to_one chains are not
|
|
# safe to flatten because the "one" side measure is duplicated by the
|
|
# "many" side rows.
|
|
merged_groups = self._merge_safe_measure_groups(groups, dim_sources)
|
|
|
|
if len(merged_groups) <= 1:
|
|
# All measure sources are on the same safe join chain
|
|
if merged_groups:
|
|
mg_name, mg_measures = next(iter(merged_groups.items()))
|
|
# Still check if there's fanout to dimension sources
|
|
has_o2m = False
|
|
for dim_src in dim_sources:
|
|
if dim_src == mg_name:
|
|
continue
|
|
path = self.graph.find_path(mg_name, dim_src)
|
|
if path and path.has_one_to_many:
|
|
has_o2m = True
|
|
break
|
|
if has_o2m:
|
|
return (
|
|
True,
|
|
[MeasureGroup(source_name=mg_name, measures=mg_measures)],
|
|
f"Fanout detected: one_to_many edges from {mg_name} to dimensions",
|
|
[f"Pre-aggregate {mg_name} measures before joining"],
|
|
)
|
|
return False, [], "No fanout", []
|
|
|
|
# True chasm trap — independent measure sources that can't be safely merged.
|
|
# Before building groups, validate that all filter sources are reachable
|
|
# from at least one measure source without traversing one_to_many edges.
|
|
# If not, the filter would be silently dropped during CTE generation.
|
|
for filter_src in filter_sources - dim_sources:
|
|
reachable_from_any = False
|
|
for source_name in merged_groups:
|
|
if filter_src == source_name:
|
|
reachable_from_any = True
|
|
break
|
|
filter_actual = self.graph.alias_map.get(filter_src, filter_src)
|
|
source_actual = self.graph.alias_map.get(source_name, source_name)
|
|
if filter_actual == source_actual:
|
|
reachable_from_any = True
|
|
break
|
|
path = self.graph.find_path(source_name, filter_src)
|
|
if path and not path.has_one_to_many:
|
|
reachable_from_any = True
|
|
break
|
|
if not reachable_from_any:
|
|
raise ValueError(
|
|
f"Filter on '{filter_src}' is not reachable via many_to_one/one_to_one "
|
|
f"edges from any measure source ({', '.join(merged_groups.keys())}). "
|
|
f"The filter would be silently dropped in aggregate locality mode. "
|
|
f"Consider moving the filter condition into a SQL source or removing it."
|
|
)
|
|
|
|
measure_groups = []
|
|
locality_descs = []
|
|
for source_name, group_measures in merged_groups.items():
|
|
mg = MeasureGroup(source_name=source_name, measures=group_measures)
|
|
measure_groups.append(mg)
|
|
measure_names = ", ".join(m.name for m in group_measures)
|
|
locality_descs.append(
|
|
f"Pre-aggregate {source_name} ({measure_names}) by dimension keys"
|
|
)
|
|
|
|
return (
|
|
True,
|
|
measure_groups,
|
|
f"Chasm trap: {len(merged_groups)} independent measure sources ({', '.join(merged_groups.keys())})",
|
|
locality_descs,
|
|
)
|
|
|
|
def _merge_safe_measure_groups(
|
|
self,
|
|
groups: dict[str, list[ResolvedMeasure]],
|
|
dim_sources: set[str],
|
|
) -> dict[str, list[ResolvedMeasure]]:
|
|
"""Merge only row-safe measure groups.
|
|
|
|
Alias siblings are kept together to avoid false chasm detection for role-
|
|
based aliases, and pure one_to_one chains can be flattened safely.
|
|
many_to_one chains are intentionally not merged because measures from the
|
|
"one" side are duplicated by the "many" side rows.
|
|
"""
|
|
names = list(groups.keys())
|
|
|
|
# First pass: merge aliases of the same underlying source.
|
|
# Pick one representative per underlying source.
|
|
alias_groups: dict[str, list[str]] = {}
|
|
for name in names:
|
|
actual = self.graph.alias_map.get(name, name)
|
|
alias_groups.setdefault(actual, []).append(name)
|
|
|
|
merged: dict[str, list[ResolvedMeasure]] = {}
|
|
assigned: dict[str, str] = {} # source_name → merged anchor
|
|
|
|
# Merge alias siblings into the first alias name
|
|
for actual, siblings in alias_groups.items():
|
|
anchor = siblings[0]
|
|
merged[anchor] = []
|
|
for sib in siblings:
|
|
merged[anchor].extend(groups[sib])
|
|
assigned[sib] = anchor
|
|
|
|
def _edge_is_grain_safe(edge) -> bool:
|
|
if edge.relationship == "one_to_one":
|
|
return True
|
|
if edge.relationship != "many_to_one":
|
|
return False
|
|
actual_source = self.graph.alias_map.get(edge.from_source, edge.from_source)
|
|
source = self.sources.get(actual_source)
|
|
if not source:
|
|
return False
|
|
from_cols = {c.strip() for c in edge.from_column.split(",")}
|
|
grain_cols = {c.strip() for c in source.grain}
|
|
return from_cols == grain_cols
|
|
|
|
def _path_is_grain_safe(path) -> bool:
|
|
return bool(path) and all(_edge_is_grain_safe(edge) for edge in path.edges)
|
|
|
|
# Second pass: check pairwise one_to_one reachability between merged groups
|
|
merged_names = list(merged.keys())
|
|
final: dict[str, list[ResolvedMeasure]] = {}
|
|
final_assigned: dict[str, str] = {}
|
|
|
|
for name in merged_names:
|
|
if name in final_assigned:
|
|
continue
|
|
final[name] = list(merged[name])
|
|
final_assigned[name] = name
|
|
|
|
for other in merged_names:
|
|
if other == name or other in final_assigned:
|
|
continue
|
|
path_fwd = self.graph.find_path(name, other)
|
|
path_rev = self.graph.find_path(other, name)
|
|
if _path_is_grain_safe(path_fwd):
|
|
final[name].extend(merged[other])
|
|
final_assigned[other] = name
|
|
elif _path_is_grain_safe(path_rev):
|
|
final.setdefault(other, []).extend(final.pop(name, []))
|
|
final[other].extend(merged[other])
|
|
for k, v in final_assigned.items():
|
|
if v == name:
|
|
final_assigned[k] = other
|
|
final_assigned[name] = other
|
|
final_assigned[other] = other
|
|
break
|
|
|
|
return final
|
|
|
|
def _classify_filter_clause(
|
|
self,
|
|
clause: str,
|
|
measure_names: set[str],
|
|
predefined_refs: set[str],
|
|
) -> str:
|
|
"""Classify a single filter clause as 'where' or 'having'."""
|
|
parsed = self.parser.parse(clause, known_measure_names=measure_names)
|
|
if parsed.is_aggregate or parsed.depends_on_measures:
|
|
return "having"
|
|
if parsed.column_refs & predefined_refs:
|
|
matching_refs = parsed.column_refs & predefined_refs
|
|
all_are_columns = True
|
|
for ref in matching_refs:
|
|
src_name, col_name = ref.split(".", 1)
|
|
src = self.sources.get(src_name)
|
|
if not src or not any(c.name == col_name for c in src.columns):
|
|
all_are_columns = False
|
|
break
|
|
return "where" if all_are_columns else "having"
|
|
return "where"
|
|
|
|
def _classify_filters(
|
|
self, filters: list[str], measures: list[ResolvedMeasure]
|
|
) -> tuple[list[str], list[str]]:
|
|
measure_names = {m.name for m in measures}
|
|
where_filters = []
|
|
having_filters = []
|
|
|
|
# Build set of qualified pre-defined measure refs (e.g. "orders.revenue")
|
|
predefined_refs: set[str] = set()
|
|
for src in self.sources.values():
|
|
for mdef in src.measures:
|
|
predefined_refs.add(f"{src.name}.{mdef.name}")
|
|
|
|
for f in filters:
|
|
if not f or not f.strip():
|
|
continue
|
|
# Split compound AND expressions so each clause is classified independently.
|
|
# e.g. "sum(x) > 100 AND status = 'active'" → HAVING + WHERE
|
|
clauses = self._split_top_level_and(f)
|
|
for clause in clauses:
|
|
kind = self._classify_filter_clause(
|
|
clause, measure_names, predefined_refs
|
|
)
|
|
if kind == "having":
|
|
# Validate: if an OR expression mixes aggregate and non-aggregate
|
|
# sub-expressions, it cannot be split and would produce invalid SQL.
|
|
self._validate_or_filter_consistency(
|
|
clause, measure_names, predefined_refs
|
|
)
|
|
having_filters.append(clause)
|
|
else:
|
|
where_filters.append(clause)
|
|
|
|
return where_filters, having_filters
|
|
|
|
def _validate_or_filter_consistency(
|
|
self,
|
|
clause: str,
|
|
measure_names: set[str],
|
|
predefined_refs: set[str],
|
|
) -> None:
|
|
"""Raise an error if an OR expression mixes WHERE and HAVING conditions."""
|
|
try:
|
|
tree = sqlglot.parse_one(
|
|
f"SELECT * WHERE {quote_reserved_identifiers(clause)}",
|
|
dialect=self.dialect,
|
|
)
|
|
where = tree.find(exp.Where)
|
|
if not where:
|
|
return
|
|
inner = where.this
|
|
# Only check if the top level contains OR
|
|
or_parts: list[str] = []
|
|
|
|
def _collect_or(node):
|
|
if isinstance(node, exp.Or):
|
|
_collect_or(node.left)
|
|
_collect_or(node.right)
|
|
else:
|
|
or_parts.append(node.sql(dialect=self.dialect))
|
|
|
|
_collect_or(inner)
|
|
if len(or_parts) <= 1:
|
|
return
|
|
# Classify each OR branch independently
|
|
kinds = set()
|
|
for part in or_parts:
|
|
kinds.add(
|
|
self._classify_filter_clause(part, measure_names, predefined_refs)
|
|
)
|
|
if kinds == {"where", "having"}:
|
|
raise ValueError(
|
|
f"Filter '{clause}' mixes aggregate and non-aggregate conditions "
|
|
f"with OR, which cannot be split into WHERE and HAVING. "
|
|
f"Rewrite as separate filters or use a subquery."
|
|
)
|
|
except ValueError:
|
|
raise
|
|
except Exception:
|
|
logger.debug("Failed to validate OR filter consistency for: %s", clause)
|
|
|
|
def _split_top_level_and(self, expr: str) -> list[str]:
|
|
"""Split a filter expression on top-level AND (not inside parentheses or strings)."""
|
|
try:
|
|
tree = sqlglot.parse_one(
|
|
f"SELECT * WHERE {quote_reserved_identifiers(expr)}",
|
|
dialect=self.dialect,
|
|
)
|
|
where = tree.find(exp.Where)
|
|
if not where:
|
|
return [expr]
|
|
inner = where.this
|
|
parts: list[str] = []
|
|
|
|
def _collect_and(node):
|
|
if isinstance(node, exp.And):
|
|
_collect_and(node.left)
|
|
_collect_and(node.right)
|
|
else:
|
|
parts.append(node.sql(dialect=self.dialect))
|
|
|
|
_collect_and(inner)
|
|
if len(parts) > 1:
|
|
return parts
|
|
except Exception:
|
|
logger.debug("Failed to split top-level AND in filter: %s", expr)
|
|
return [expr]
|
|
|
|
def _validate_column_refs(
|
|
self,
|
|
measures: list[ResolvedMeasure],
|
|
dimensions: list[QueryDimension],
|
|
filters: list[str],
|
|
) -> None:
|
|
"""Validate that referenced columns exist in their source definitions."""
|
|
# Build separate column and measure name sets per source
|
|
valid_cols: dict[str, set[str]] = {}
|
|
valid_measure_names: dict[str, set[str]] = {}
|
|
for src in self.sources.values():
|
|
valid_cols[src.name] = {c.name for c in src.columns}
|
|
valid_measure_names[src.name] = {m.name for m in src.measures}
|
|
|
|
def _check_refs(expr: str, allow_measures: bool) -> None:
|
|
parsed = self.parser.parse(expr)
|
|
for col_ref in parsed.column_refs:
|
|
parts = col_ref.split(".", 1)
|
|
if len(parts) != 2:
|
|
continue
|
|
source_name, col_name = parts
|
|
resolved = self.graph.alias_map.get(source_name, source_name)
|
|
if resolved not in valid_cols:
|
|
continue # unknown source — handled elsewhere
|
|
if allow_measures:
|
|
all_valid = valid_cols[resolved] | valid_measure_names.get(
|
|
resolved, set()
|
|
)
|
|
else:
|
|
all_valid = valid_cols[resolved]
|
|
if col_name not in all_valid:
|
|
available = sorted(
|
|
valid_cols[resolved] | valid_measure_names.get(resolved, set())
|
|
)
|
|
raise ValueError(
|
|
f"Column '{col_name}' does not exist in source '{source_name}'. "
|
|
f"Available: {', '.join(available)}"
|
|
)
|
|
|
|
# Dimension refs: only columns allowed (not measure names)
|
|
for d in dimensions:
|
|
_check_refs(d.field, allow_measures=False)
|
|
|
|
# Measure/filter refs: columns + measure names allowed
|
|
for m in measures:
|
|
if not m.is_derived:
|
|
_check_refs(m.expr, allow_measures=True)
|
|
for f in filters:
|
|
if f and f.strip():
|
|
_check_refs(f, allow_measures=True)
|
|
|
|
def _validate_visibility(self, query: SemanticQuery) -> None:
|
|
"""Reject queries that reference hidden columns."""
|
|
# Build a set of hidden columns: {source_name: {col_name, ...}}
|
|
hidden: dict[str, set[str]] = {}
|
|
for source in self.sources.values():
|
|
for col in source.columns:
|
|
if col.visibility == ColumnVisibility.HIDDEN:
|
|
hidden.setdefault(source.name, set()).add(col.name)
|
|
|
|
if not hidden:
|
|
return
|
|
|
|
# Collect all source.column references from dimensions, measures, filters
|
|
all_exprs: list[str] = []
|
|
for d in query.dimensions:
|
|
if isinstance(d, str):
|
|
all_exprs.append(d)
|
|
elif isinstance(d, dict):
|
|
all_exprs.append(d.get("field", ""))
|
|
for m in query.measures:
|
|
if isinstance(m, str):
|
|
all_exprs.append(m)
|
|
elif isinstance(m, dict):
|
|
all_exprs.append(m.get("expr", ""))
|
|
all_exprs.extend(query.filters)
|
|
|
|
for expr in all_exprs:
|
|
parsed = self.parser.parse(expr)
|
|
for col_ref in parsed.column_refs:
|
|
source_name, col_name = col_ref.split(".", 1)
|
|
resolved = self.graph.alias_map.get(source_name, source_name)
|
|
if resolved in hidden and col_name in hidden[resolved]:
|
|
raise ValueError(
|
|
f"Column '{source_name}.{col_name}' is hidden and cannot be queried"
|
|
)
|
|
|
|
def _build_columns(
|
|
self,
|
|
measures: list[ResolvedMeasure],
|
|
dimensions: list[QueryDimension],
|
|
) -> list[ResolvedColumn]:
|
|
from collections import Counter
|
|
|
|
columns: list[ResolvedColumn] = []
|
|
|
|
leaves = [
|
|
d.field.split(".")[-1] if "." in d.field else d.field for d in dimensions
|
|
]
|
|
colliding = {leaf for leaf, count in Counter(leaves).items() if count > 1}
|
|
|
|
for d in dimensions:
|
|
leaf = d.field.split(".")[-1] if "." in d.field else d.field
|
|
col_name = d.field.replace(".", "_") if leaf in colliding else leaf
|
|
columns.append(
|
|
ResolvedColumn(
|
|
name=col_name,
|
|
provenance=Provenance.DIMENSION,
|
|
expr=d.field,
|
|
granularity=d.granularity,
|
|
)
|
|
)
|
|
|
|
for m in measures:
|
|
columns.append(
|
|
ResolvedColumn(
|
|
name=m.name,
|
|
provenance=m.provenance,
|
|
expr=m.expr,
|
|
description=getattr(m, "description", None),
|
|
)
|
|
)
|
|
|
|
return columns
|