mirror of
https://github.com/Kaelio/ktx.git
synced 2026-06-07 07:55:13 +02:00
* feat(context): add warehouse dialect dispatch * feat(context): read warehouse scan catalog * feat(context): add entity details verification tool * feat(context): add ingest SQL verification tool * feat(context): add raw warehouse discovery tool * feat(context): expose warehouse verification tools to ingest * docs(context): add ingest identifier verification protocol * test(context): guard ingest identifier verification prompts * chore(context): verify warehouse verification tools * docs: add warehouse verification tools plan and spec * fix(context): expose target warehouses to Notion ingest * fix(context): update ingest prompts for warehouse verification tools * fix(context): scope raw schema discovery to allowed connections * fix(context): verify warehouse column display targets * docs: add notion warehouse verification gap closure plan * fix(context): include raw discovery connection names * fix(context): expose warehouse targets for LookML and MetricFlow * fix(context): pass connection config to ingest query executors * fix(cli): enable read-only SQL probes for local ingest * docs: add warehouse verification final v1 closure plan * fix(context): align warehouse sql probe prompt shape * docs: add warehouse verification prompt shape closure plan * test(context): catch connectionless sql execution prompt examples * fix(context): include connection name in sl capture sql example * docs: add warehouse verification sql example closure plan * fix(context): report structured entity detail misses * docs: add warehouse verification structured target miss closure plan * fix: report untracked squash merge conflicts * feat: require ingest verification ledger * fix: stabilize ingest wiki references
1425 lines
58 KiB
Python
1425 lines
58 KiB
Python
from __future__ import annotations
|
|
|
|
import logging
|
|
from collections import Counter
|
|
|
|
import sqlglot
|
|
from sqlglot import exp
|
|
|
|
from semantic_layer.graph import RELATIONSHIP_INVERSE
|
|
from semantic_layer.models import (
|
|
MeasureGroup,
|
|
QueryDimension,
|
|
ResolvedJoin,
|
|
ResolvedMeasure,
|
|
ResolvedPlan,
|
|
SourceDefinition,
|
|
)
|
|
from semantic_layer.parser import ExpressionParser, quote_reserved_identifiers
|
|
|
|
# DIALECT CONVENTION:
|
|
# User-authored SQL fragments (measure `expr`, segment `expr`, filter,
|
|
# computed-column `expr`, `sql:` source bodies, join `on:` clauses) must
|
|
# be parsed with `read=self.dialect`. The `sl_capture` skill instructs
|
|
# authors to write in the connection's native dialect; parsing as postgres
|
|
# silently drops dialect-specific tokens (e.g. BigQuery `INTERVAL 30 DAY`).
|
|
# Source CTE bodies stay verbatim; the outer scaffold is written in
|
|
# postgres-compatible form with dialect-specific helpers where needed
|
|
# (see `_time_trunc`), and `_transpile()` round-trips it through
|
|
# `self.dialect` so embedded user exprs survive intact.
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def _qi(name: str) -> str:
|
|
"""Quote an identifier if it is a SQL reserved word."""
|
|
from semantic_layer.parser import _SQL_RESERVED
|
|
|
|
if name.lower() in _SQL_RESERVED:
|
|
return f'"{name}"'
|
|
return name
|
|
|
|
|
|
def _build_on_clause(
|
|
from_source: str, from_column: str, to_source: str, to_column: str
|
|
) -> str:
|
|
"""Build ON clause supporting composite keys (comma-separated columns)."""
|
|
from_cols = [c.strip() for c in from_column.split(",")]
|
|
to_cols = [c.strip() for c in to_column.split(",")]
|
|
conditions = [
|
|
f"{_qi(from_source)}.{_qi(fc)} = {_qi(to_source)}.{_qi(tc)}"
|
|
for fc, tc in zip(from_cols, to_cols)
|
|
]
|
|
return " AND ".join(conditions)
|
|
|
|
|
|
class SqlGenerator:
|
|
def __init__(
|
|
self, dialect: str = "postgres", alias_map: dict[str, str] | None = None
|
|
):
|
|
if dialect != "postgres":
|
|
from sqlglot import Dialect
|
|
|
|
try:
|
|
Dialect.get_or_raise(dialect)
|
|
except ValueError:
|
|
raise ValueError(
|
|
f"Unknown SQL dialect '{dialect}'. Use a dialect supported by sqlglot "
|
|
f"(e.g., postgres, bigquery, snowflake, mysql, duckdb)."
|
|
)
|
|
self.dialect = dialect
|
|
self._parser = ExpressionParser(dialect=dialect)
|
|
self._alias_map: dict[str, str] = alias_map or {}
|
|
|
|
def generate(self, plan: ResolvedPlan, sources: dict[str, SourceDefinition]) -> str:
|
|
native_source_ctes = self._build_source_ctes(plan, sources)
|
|
if plan.has_fan_out and plan.measure_groups:
|
|
outer_sql = self._generate_with_locality(plan, sources)
|
|
else:
|
|
outer_sql = self._generate_simple(plan, sources)
|
|
|
|
outer_transpiled = self._transpile(outer_sql)
|
|
|
|
if not native_source_ctes:
|
|
return outer_transpiled
|
|
|
|
source_header = ",\n".join(native_source_ctes)
|
|
stripped = outer_transpiled.lstrip()
|
|
if stripped[:5].upper() == "WITH ":
|
|
# Outer scaffold already has a WITH clause (e.g. locality CTEs).
|
|
# Merge the native source CTEs into the same WITH clause.
|
|
rest = stripped[5:].lstrip()
|
|
return "WITH " + source_header + ",\n" + rest
|
|
return "WITH " + source_header + "\n" + outer_transpiled
|
|
|
|
# ── Path A: Simple (no fan-out) ────────────────────────────────────
|
|
|
|
def _generate_simple(
|
|
self, plan: ResolvedPlan, sources: dict[str, SourceDefinition]
|
|
) -> str:
|
|
parts: list[str] = []
|
|
|
|
# SELECT — use DISTINCT when no measures (dimension-only query)
|
|
has_measures = any(not m.is_derived for m in plan.measures)
|
|
select_cols = self._build_select_columns(plan, sources)
|
|
if not has_measures and plan.dimensions:
|
|
parts.append("SELECT DISTINCT\n " + ",\n ".join(select_cols))
|
|
else:
|
|
parts.append("SELECT\n " + ",\n ".join(select_cols))
|
|
|
|
# FROM
|
|
anchor = plan.anchor_source
|
|
if anchor:
|
|
from_ref = self._source_ref(anchor, sources)
|
|
parts.append(f"FROM {from_ref}")
|
|
|
|
# JOINs
|
|
for join in plan.joins:
|
|
join_sql = self._build_join(join, sources, plan)
|
|
parts.append(join_sql)
|
|
|
|
# WHERE
|
|
if plan.where_filters:
|
|
where_clauses = [
|
|
self._qualify_filter(f, sources, plan) for f in plan.where_filters
|
|
]
|
|
parts.append("WHERE " + " AND ".join(where_clauses))
|
|
|
|
# GROUP BY (skip for dimension-only queries — DISTINCT handles dedup)
|
|
dim_exprs = self._build_group_by_exprs(plan, sources)
|
|
if dim_exprs and has_measures:
|
|
parts.append("GROUP BY " + ", ".join(dim_exprs))
|
|
|
|
# HAVING — expand predefined measure references to aggregate expressions
|
|
if plan.having_filters:
|
|
having_clauses = [
|
|
self._expand_having_filter(f, plan, sources)
|
|
for f in plan.having_filters
|
|
]
|
|
parts.append("HAVING " + " AND ".join(having_clauses))
|
|
|
|
# ORDER BY
|
|
if plan.order_by:
|
|
order_parts = []
|
|
for ob in plan.order_by:
|
|
field = self._resolve_order_field(ob.field, plan)
|
|
direction = (
|
|
ob.direction.upper() if ob.direction.lower() != "asc" else ""
|
|
)
|
|
order_parts.append(f"{field} {direction}".strip())
|
|
parts.append("ORDER BY " + ", ".join(order_parts))
|
|
elif dim_exprs:
|
|
parts.append(
|
|
"ORDER BY " + ", ".join(str(i) for i in range(1, len(dim_exprs) + 1))
|
|
)
|
|
|
|
# LIMIT
|
|
if plan.limit is not None:
|
|
parts.append(f"LIMIT {plan.limit}")
|
|
|
|
return "\n".join(parts)
|
|
|
|
# ── Path B: Aggregate locality ─────────────────────────────────────
|
|
|
|
def _generate_with_locality(
|
|
self, plan: ResolvedPlan, sources: dict[str, SourceDefinition]
|
|
) -> str:
|
|
parts: list[str] = []
|
|
# Only locality CTEs — source CTEs are concatenated by generate() so
|
|
# the native-dialect source body never reaches the postgres transpile.
|
|
cte_parts: list[str] = []
|
|
|
|
# All dimension key expressions
|
|
all_dim_keys = self._build_dim_key_exprs(plan, sources)
|
|
|
|
# Compute per-CTE reachable dimensions
|
|
safe_adj = self._build_safe_adjacency(plan)
|
|
per_cte_dims: dict[str, list[dict]] = {}
|
|
for group in plan.measure_groups:
|
|
reachable = []
|
|
for dk in all_dim_keys:
|
|
dim_sources = self._parser.extract_source_refs(dk["expr"])
|
|
can_reach = True
|
|
for ds in dim_sources:
|
|
if ds == group.source_name:
|
|
continue
|
|
path = self._find_join_path_steps(group.source_name, ds, safe_adj)
|
|
if not path and ds != group.source_name:
|
|
can_reach = False
|
|
break
|
|
if can_reach:
|
|
reachable.append(dk)
|
|
per_cte_dims[group.source_name] = reachable
|
|
|
|
# Validate: every dimension must be reachable from at least one CTE
|
|
for dk in all_dim_keys:
|
|
reachable_from_any = any(
|
|
any(d["alias"] == dk["alias"] for d in per_cte_dims[g.source_name])
|
|
for g in plan.measure_groups
|
|
)
|
|
if not reachable_from_any:
|
|
dim_sources = self._parser.extract_source_refs(dk["expr"])
|
|
source_names = [g.source_name for g in plan.measure_groups]
|
|
raise ValueError(
|
|
f"Aggregate locality cannot safely reach '{', '.join(dim_sources)}' from "
|
|
f"any measure source ({', '.join(source_names)}) without traversing one_to_many edges"
|
|
)
|
|
|
|
# Shared dimensions: reachable from ALL CTEs (used for JOIN condition)
|
|
shared_dim_aliases = None
|
|
for group in plan.measure_groups:
|
|
aliases = {dk["alias"] for dk in per_cte_dims[group.source_name]}
|
|
if shared_dim_aliases is None:
|
|
shared_dim_aliases = aliases
|
|
else:
|
|
shared_dim_aliases &= aliases
|
|
shared_dim_aliases = shared_dim_aliases or set()
|
|
shared_dims = [dk for dk in all_dim_keys if dk["alias"] in shared_dim_aliases]
|
|
|
|
# Validate grain consistency: asymmetric dims cause FULL JOIN fan-out
|
|
if len(plan.measure_groups) > 1:
|
|
for group in plan.measure_groups:
|
|
cte_dim_aliases = {
|
|
dk["alias"] for dk in per_cte_dims[group.source_name]
|
|
}
|
|
non_shared = cte_dim_aliases - shared_dim_aliases
|
|
if non_shared:
|
|
for other_group in plan.measure_groups:
|
|
if other_group.source_name == group.source_name:
|
|
continue
|
|
other_aliases = {
|
|
dk["alias"] for dk in per_cte_dims[other_group.source_name]
|
|
}
|
|
missing_from_other = non_shared - other_aliases
|
|
if missing_from_other:
|
|
raise ValueError(
|
|
f"Asymmetric dimension grain in chasm trap: "
|
|
f"'{group.source_name}' groups by {sorted(cte_dim_aliases)} "
|
|
f"but '{other_group.source_name}' cannot reach "
|
|
f"{sorted(missing_from_other)}. "
|
|
f"FULL JOIN on shared dimensions ({sorted(shared_dim_aliases)}) "
|
|
f"would fan out '{other_group.source_name}' measures across "
|
|
f"the extra dimensions, producing incorrect results. "
|
|
f"Remove the asymmetric dimensions or query each measure "
|
|
f"source separately."
|
|
)
|
|
|
|
# Collect all names that could collide with CTE aliases
|
|
reserved_names: set[str] = set(sources.keys())
|
|
for name in plan.sources_used:
|
|
src = sources.get(name)
|
|
if src and src.is_sql_source:
|
|
reserved_names.add(name)
|
|
assigned_aliases: set[str] = set()
|
|
|
|
# Pre-aggregation CTEs for each measure group
|
|
cte_aliases: list[str] = []
|
|
for group in plan.measure_groups:
|
|
alias = f"{group.source_name}_agg"
|
|
# Resolve collisions with existing source/CTE names
|
|
if alias in reserved_names or alias in assigned_aliases:
|
|
suffix = 1
|
|
while (
|
|
f"{group.source_name}_agg_{suffix}" in reserved_names
|
|
or f"{group.source_name}_agg_{suffix}" in assigned_aliases
|
|
):
|
|
suffix += 1
|
|
alias = f"{group.source_name}_agg_{suffix}"
|
|
cte_aliases.append(alias)
|
|
assigned_aliases.add(alias)
|
|
cte_dim_keys = per_cte_dims[group.source_name]
|
|
cte_sql = self._build_agg_cte(group, plan, sources, cte_dim_keys)
|
|
cte_parts.append(f"{alias} AS (\n{cte_sql}\n)")
|
|
|
|
if cte_parts:
|
|
parts.append("WITH " + ",\n".join(cte_parts))
|
|
|
|
# Final SELECT combining CTEs
|
|
select_cols, derived_inline_map = self._build_locality_select(
|
|
plan, cte_aliases, all_dim_keys, per_cte_dims
|
|
)
|
|
parts.append("SELECT\n " + ",\n ".join(select_cols))
|
|
|
|
# FROM + JOINs between CTEs
|
|
cte_join_type = "FULL JOIN" if plan.include_empty else "JOIN"
|
|
if cte_aliases:
|
|
parts.append(f"FROM {cte_aliases[0]}")
|
|
for i, alias in enumerate(cte_aliases[1:], 1):
|
|
join_conditions = []
|
|
for dk in shared_dims:
|
|
if i == 1:
|
|
lhs = f"{cte_aliases[0]}.{dk['alias']}"
|
|
else:
|
|
coalesce_args = ", ".join(
|
|
f"{cte_aliases[j]}.{dk['alias']}" for j in range(i)
|
|
)
|
|
lhs = f"COALESCE({coalesce_args})"
|
|
join_conditions.append(f"{lhs} = {alias}.{dk['alias']}")
|
|
if join_conditions:
|
|
parts.append(
|
|
f"{cte_join_type} {alias} ON " + " AND ".join(join_conditions)
|
|
)
|
|
else:
|
|
parts.append(f"CROSS JOIN {alias}")
|
|
|
|
# HAVING filters applied as WHERE on outer query (no GROUP BY at this level)
|
|
if plan.having_filters:
|
|
measure_cte_map: dict[str, str] = {}
|
|
for i, plan_group in enumerate(plan.measure_groups):
|
|
for m in plan_group.measures:
|
|
measure_cte_map[m.name] = cte_aliases[i]
|
|
|
|
having_clauses = []
|
|
for f in plan.having_filters:
|
|
resolved_f = self._resolve_having_for_locality(
|
|
f, plan, measure_cte_map, derived_inline_map
|
|
)
|
|
having_clauses.append(resolved_f)
|
|
parts.append("WHERE " + " AND ".join(having_clauses))
|
|
|
|
# ORDER BY
|
|
if plan.order_by:
|
|
order_parts = []
|
|
for ob in plan.order_by:
|
|
field = self._resolve_order_field(ob.field, plan)
|
|
direction = (
|
|
ob.direction.upper() if ob.direction.lower() != "asc" else ""
|
|
)
|
|
order_parts.append(f"{field} {direction}".strip())
|
|
parts.append("ORDER BY " + ", ".join(order_parts))
|
|
elif all_dim_keys:
|
|
parts.append(
|
|
"ORDER BY " + ", ".join(str(i) for i in range(1, len(all_dim_keys) + 1))
|
|
)
|
|
|
|
# LIMIT
|
|
if plan.limit is not None:
|
|
parts.append(f"LIMIT {plan.limit}")
|
|
|
|
return "\n".join(parts)
|
|
|
|
def _build_agg_cte(
|
|
self,
|
|
group: MeasureGroup,
|
|
plan: ResolvedPlan,
|
|
sources: dict[str, SourceDefinition],
|
|
dim_keys: list[dict],
|
|
) -> str:
|
|
"""Build a pre-aggregation CTE for one measure group."""
|
|
parts: list[str] = []
|
|
|
|
# SELECT: dimension keys + aggregated measures
|
|
select_cols: list[str] = []
|
|
for dk in dim_keys:
|
|
select_cols.append(f"{dk['expr']} AS {dk['alias']}")
|
|
|
|
for m in group.measures:
|
|
measure_expr = self._build_measure_expr(m, sources)
|
|
select_cols.append(f"{measure_expr} AS {m.name}")
|
|
|
|
parts.append(" SELECT\n " + ",\n ".join(select_cols))
|
|
|
|
# FROM the measure's source
|
|
from_ref = self._source_ref(group.source_name, sources)
|
|
parts.append(f" FROM {from_ref}")
|
|
|
|
joined_sources = {group.source_name}
|
|
target_sources = self._collect_cte_target_sources(group, plan, dim_keys)
|
|
join_steps = self._build_group_join_steps(
|
|
group.source_name, target_sources, plan
|
|
)
|
|
for join, next_source in join_steps:
|
|
if next_source in joined_sources:
|
|
continue
|
|
join_ref = self._source_ref(next_source, sources)
|
|
on_clause = _build_on_clause(
|
|
join.from_source, join.from_column, join.to_source, join.to_column
|
|
)
|
|
parts.append(f" JOIN {join_ref} ON {on_clause}")
|
|
joined_sources.add(next_source)
|
|
|
|
# WHERE filters — only push down filters whose sources are within this CTE
|
|
if plan.where_filters:
|
|
relevant_where = []
|
|
for f in plan.where_filters:
|
|
filter_sources = self._parser.extract_source_refs(f)
|
|
if not filter_sources or filter_sources <= joined_sources:
|
|
relevant_where.append(f)
|
|
if relevant_where:
|
|
where_clauses = [
|
|
self._qualify_filter(f, sources, plan) for f in relevant_where
|
|
]
|
|
parts.append(" WHERE " + " AND ".join(where_clauses))
|
|
|
|
# GROUP BY dimension keys
|
|
if dim_keys:
|
|
group_by = ", ".join(dk["expr"] for dk in dim_keys)
|
|
parts.append(f" GROUP BY {group_by}")
|
|
|
|
# HAVING filters are NOT placed here — they go on the outer query
|
|
# after the FULL JOIN to ensure correct semantics across CTEs
|
|
|
|
return "\n".join(parts)
|
|
|
|
def _build_dim_key_exprs(
|
|
self, plan: ResolvedPlan, sources: dict[str, SourceDefinition]
|
|
) -> list[dict]:
|
|
"""Build dimension key expressions for aggregate locality CTEs."""
|
|
colliding = self._colliding_dim_leaves(plan.dimensions)
|
|
result = []
|
|
for dim in plan.dimensions:
|
|
expr = self._dim_expr(dim, sources)
|
|
result.append(
|
|
{"expr": expr, "alias": self._dimension_alias(dim, colliding)}
|
|
)
|
|
return result
|
|
|
|
def _build_locality_select(
|
|
self,
|
|
plan: ResolvedPlan,
|
|
cte_aliases: list[str],
|
|
dim_keys: list[dict],
|
|
per_cte_dims: dict[str, list[dict]] | None = None,
|
|
) -> list[str]:
|
|
"""Build SELECT columns for the final query combining pre-aggregated CTEs."""
|
|
cols: list[str] = []
|
|
|
|
# Build mapping from CTE alias to source name
|
|
cte_source_names = [g.source_name for g in plan.measure_groups]
|
|
|
|
# Dimensions: COALESCE across CTEs that have the dim
|
|
for dk in dim_keys:
|
|
if per_cte_dims:
|
|
available_ctes = [
|
|
alias
|
|
for alias, src_name in zip(cte_aliases, cte_source_names)
|
|
if any(
|
|
d["alias"] == dk["alias"]
|
|
for d in per_cte_dims.get(src_name, [])
|
|
)
|
|
]
|
|
else:
|
|
available_ctes = cte_aliases
|
|
|
|
if len(available_ctes) > 1:
|
|
coalesce_args = ", ".join(f"{a}.{dk['alias']}" for a in available_ctes)
|
|
cols.append(f"COALESCE({coalesce_args}) AS {dk['alias']}")
|
|
elif len(available_ctes) == 1:
|
|
cols.append(f"{available_ctes[0]}.{dk['alias']} AS {dk['alias']}")
|
|
else:
|
|
cols.append(f"NULL AS {dk['alias']}")
|
|
|
|
# Non-derived measures from CTEs
|
|
measure_cte_map: dict[str, str] = {}
|
|
for i, plan_group in enumerate(plan.measure_groups):
|
|
alias = cte_aliases[i]
|
|
for m in plan_group.measures:
|
|
cols.append(f"{alias}.{m.name}")
|
|
measure_cte_map[m.name] = alias
|
|
|
|
# Derived measures — wrap cross-CTE refs in COALESCE for FULL JOIN NULL safety.
|
|
# Process in order (topological) so that derived-of-derived gets inlined.
|
|
derived_inline_map: dict[str, str] = {} # measure_name → fully inlined expr
|
|
for m in plan.measures:
|
|
if m.is_derived:
|
|
# Collect all transitive CTE aliases used by this derived measure
|
|
dep_ctes: set[str | None] = set()
|
|
for d in m.depends_on:
|
|
if d in measure_cte_map:
|
|
dep_ctes.add(measure_cte_map.get(d))
|
|
elif d in derived_inline_map:
|
|
# Derived dep — inherit its CTE references
|
|
# (cross-CTE if the inlined dep already has multiple CTEs)
|
|
dep_ctes.add("__derived__")
|
|
use_coalesce = len(dep_ctes - {None}) > 1
|
|
# Detect which deps are used as divisors
|
|
divisor_deps = self._find_divisor_deps(m.expr, m.depends_on)
|
|
replacements = {}
|
|
for dep_name in m.depends_on:
|
|
if dep_name in measure_cte_map:
|
|
ref = f"{measure_cte_map[dep_name]}.{dep_name}"
|
|
if use_coalesce:
|
|
if dep_name in divisor_deps:
|
|
ref = f"NULLIF(COALESCE({measure_cte_map[dep_name]}.{dep_name}, 0), 0)"
|
|
else:
|
|
ref = f"COALESCE({ref}, 0)"
|
|
replacements[dep_name] = ref
|
|
elif dep_name in derived_inline_map:
|
|
# Derived dependency — inline its already-resolved expression
|
|
ref = derived_inline_map[dep_name]
|
|
if dep_name in divisor_deps:
|
|
ref = f"NULLIF({ref}, 0)"
|
|
replacements[dep_name] = ref
|
|
expr = self._substitute_measure_refs(m.expr, replacements)
|
|
derived_inline_map[m.name] = expr
|
|
cols.append(f"{expr} AS {m.name}")
|
|
|
|
return cols, derived_inline_map
|
|
|
|
def _find_divisor_deps(self, expr: str, depends_on: list[str]) -> set[str]:
|
|
"""Find which dependency names appear as divisors in the expression."""
|
|
divisors: set[str] = set()
|
|
try:
|
|
tree = sqlglot.parse_one(f"SELECT {expr}", read=self.dialect)
|
|
for div_node in tree.find_all(exp.Div):
|
|
rhs = div_node.right
|
|
if isinstance(rhs, exp.Column) and not rhs.table:
|
|
if rhs.name in depends_on:
|
|
divisors.add(rhs.name)
|
|
except Exception:
|
|
logger.debug("Failed to parse expression for divisor detection: %s", expr)
|
|
return divisors
|
|
|
|
# ── Shared helpers ─────────────────────────────────────────────────
|
|
|
|
def _build_source_ctes(
|
|
self, plan: ResolvedPlan, sources: dict[str, SourceDefinition]
|
|
) -> list[str]:
|
|
"""Build CTEs for SQL-based sources, flattening inner WITH clauses."""
|
|
ctes = []
|
|
for name in plan.sources_used:
|
|
src = sources.get(name)
|
|
if src and src.is_sql_source and src.sql:
|
|
sql_text = src.sql.strip()
|
|
inner_ctes, final_select = self._extract_inner_ctes(sql_text)
|
|
if inner_ctes:
|
|
# Promote inner CTEs with prefixed names
|
|
renames: list[tuple[str, str]] = []
|
|
for inner_name, inner_body in inner_ctes:
|
|
prefixed = f"{name}__{inner_name}"
|
|
renames.append((inner_name, prefixed))
|
|
|
|
# Apply all renames to inner CTE bodies and final SELECT
|
|
promoted: list[tuple[str, str]] = []
|
|
for inner_name, inner_body in inner_ctes:
|
|
renamed_body = inner_body
|
|
for old, new in renames:
|
|
renamed_body = self._rename_table_ref(
|
|
renamed_body, old, new
|
|
)
|
|
prefixed = f"{name}__{inner_name}"
|
|
promoted.append((prefixed, renamed_body))
|
|
|
|
renamed_final = final_select
|
|
for old, new in renames:
|
|
renamed_final = self._rename_table_ref(renamed_final, old, new)
|
|
|
|
for prefixed, body in promoted:
|
|
ctes.append(f"{prefixed} AS (\n{body}\n)")
|
|
ctes.append(f"{name} AS (\n{renamed_final}\n)")
|
|
else:
|
|
ctes.append(f"{name} AS (\n{sql_text}\n)")
|
|
return ctes
|
|
|
|
def _extract_inner_ctes(self, sql_text: str) -> tuple[list[tuple[str, str]], str]:
|
|
"""Parse SQL and extract CTEs if present.
|
|
|
|
Source SQL is user-provided in the target dialect, so we parse and
|
|
serialize using ``self.dialect`` to avoid lossy cross-dialect
|
|
conversion (e.g. Snowflake DATEDIFF → postgres AGE).
|
|
"""
|
|
try:
|
|
tree = sqlglot.parse_one(sql_text, read=self.dialect)
|
|
with_clause = tree.find(exp.With)
|
|
if not with_clause:
|
|
return [], sql_text
|
|
cte_list = []
|
|
for cte in with_clause.expressions:
|
|
cte_name = cte.alias
|
|
cte_body = cte.this.sql(dialect=self.dialect)
|
|
cte_list.append((cte_name, cte_body))
|
|
# Get the main query without the WITH clause
|
|
tree_copy = tree.copy()
|
|
w = tree_copy.find(exp.With)
|
|
if w:
|
|
w.pop()
|
|
final_select = tree_copy.sql(dialect=self.dialect)
|
|
return cte_list, final_select
|
|
except Exception:
|
|
logger.debug(
|
|
"Failed to extract inner CTEs from SQL source, treating as raw SQL"
|
|
)
|
|
return [], sql_text
|
|
|
|
def _rename_table_ref(self, sql_text: str, old_name: str, new_name: str) -> str:
|
|
"""Rename table references in SQL text.
|
|
|
|
Uses ``self.dialect`` for parsing/serialization to preserve
|
|
dialect-specific constructs in user-provided source SQL.
|
|
"""
|
|
try:
|
|
tree = sqlglot.parse_one(sql_text, read=self.dialect)
|
|
|
|
def _rename(node):
|
|
if (
|
|
isinstance(node, exp.Table)
|
|
and node.name == old_name
|
|
and not node.db
|
|
):
|
|
alias = node.args.get("alias") or exp.TableAlias(
|
|
this=exp.to_identifier(old_name)
|
|
)
|
|
return exp.Table(this=exp.to_identifier(new_name), alias=alias)
|
|
return node
|
|
|
|
transformed = tree.transform(_rename)
|
|
return transformed.sql(dialect=self.dialect)
|
|
except Exception:
|
|
logger.debug(
|
|
"AST-based table rename failed for '%s' -> '%s', falling back to regex",
|
|
old_name,
|
|
new_name,
|
|
)
|
|
import re
|
|
|
|
return re.sub(rf"\b{re.escape(old_name)}\b", new_name, sql_text)
|
|
|
|
def _build_select_columns(
|
|
self, plan: ResolvedPlan, sources: dict[str, SourceDefinition]
|
|
) -> list[str]:
|
|
"""Build SELECT columns for simple (non-locality) path."""
|
|
colliding = self._colliding_dim_leaves(plan.dimensions)
|
|
cols: list[str] = []
|
|
|
|
# Dimensions
|
|
for dim in plan.dimensions:
|
|
expr = self._dim_expr(dim, sources)
|
|
cols.append(f"{expr} AS {self._dimension_alias(dim, colliding)}")
|
|
|
|
# Build map of measure names to their expressions (for derived measures)
|
|
measure_expr_map: dict[str, str] = {}
|
|
for m in plan.measures:
|
|
if not m.is_derived:
|
|
measure_expr = self._build_measure_expr(m, sources)
|
|
cols.append(f"{measure_expr} AS {m.name}")
|
|
measure_expr_map[m.name] = measure_expr
|
|
else:
|
|
# Derived: substitute dependencies with their expressions
|
|
expr = self._substitute_measure_refs(
|
|
m.expr,
|
|
{
|
|
dep: measure_expr_map[dep]
|
|
for dep in m.depends_on
|
|
if dep in measure_expr_map
|
|
},
|
|
)
|
|
cols.append(f"{expr} AS {m.name}")
|
|
measure_expr_map[m.name] = expr
|
|
|
|
return cols
|
|
|
|
def _build_measure_expr(
|
|
self, m: ResolvedMeasure, sources: dict[str, SourceDefinition]
|
|
) -> str:
|
|
"""Build the SQL expression for a single measure."""
|
|
expr = self._qualify_expr(m.expr, sources)
|
|
|
|
# Translate custom functions (median, percentile, count_distinct)
|
|
expr = self._translate_custom_funcs(expr)
|
|
|
|
if m.filter:
|
|
filter_sql = self._qualify_expr(m.filter, sources)
|
|
return self._apply_measure_filter(expr, filter_sql)
|
|
|
|
return expr
|
|
|
|
def _apply_measure_filter(self, expr: str, filter_sql: str) -> str:
|
|
"""Apply a measure-level filter by injecting CASE WHEN into each aggregate."""
|
|
try:
|
|
tree = sqlglot.parse_one(
|
|
f"SELECT {quote_reserved_identifiers(expr)}", read=self.dialect
|
|
)
|
|
select_expr = tree.expressions[0]
|
|
if isinstance(select_expr, exp.Alias):
|
|
select_expr = select_expr.this
|
|
|
|
filter_cond = sqlglot.parse_one(
|
|
f"SELECT {filter_sql}", read=self.dialect
|
|
).expressions[0]
|
|
|
|
def _make_case(inner_node):
|
|
return exp.Case(
|
|
ifs=[exp.If(this=filter_cond.copy(), true=inner_node.copy())]
|
|
)
|
|
|
|
def _inject_filter(node):
|
|
"""Walk the AST and inject CASE WHEN filter into each aggregate's argument."""
|
|
if isinstance(node, exp.AggFunc):
|
|
if isinstance(node, exp.Count):
|
|
count_arg = node.this
|
|
if isinstance(count_arg, exp.Star):
|
|
node.set(
|
|
"this",
|
|
_make_case(exp.Literal.number(1)),
|
|
)
|
|
return node
|
|
if (
|
|
isinstance(count_arg, exp.Distinct)
|
|
and count_arg.expressions
|
|
):
|
|
inner = count_arg.expressions[0]
|
|
node.set(
|
|
"this", exp.Distinct(expressions=[_make_case(inner)])
|
|
)
|
|
return node
|
|
if node.this is not None:
|
|
node.set("this", _make_case(node.this))
|
|
return node
|
|
return node
|
|
|
|
transformed = select_expr.transform(_inject_filter)
|
|
result = transformed.sql(dialect=self.dialect)
|
|
if result != expr:
|
|
return result
|
|
except Exception:
|
|
logger.debug(
|
|
"Failed to inject filter into aggregates for measure: %s", expr
|
|
)
|
|
|
|
return f"CASE WHEN {filter_sql} THEN {expr} END"
|
|
|
|
def _translate_custom_funcs(self, expr: str) -> str:
|
|
"""Translate custom functions: median(), percentile(), count_distinct()."""
|
|
tree = sqlglot.parse_one(
|
|
f"SELECT {quote_reserved_identifiers(expr)}", read=self.dialect
|
|
)
|
|
|
|
has_custom = False
|
|
has_custom = has_custom or any(True for _ in tree.find_all(exp.Median))
|
|
for node in tree.find_all(exp.Anonymous):
|
|
if node.name.lower() in ("percentile", "count_distinct"):
|
|
has_custom = True
|
|
break
|
|
if not has_custom:
|
|
return expr
|
|
|
|
def _replace(node):
|
|
if isinstance(node, exp.Median):
|
|
col_sql = node.this.sql(dialect=self.dialect)
|
|
return sqlglot.parse_one(
|
|
f"SELECT PERCENTILE_CONT(0.5) WITHIN GROUP (ORDER BY {col_sql})",
|
|
read=self.dialect,
|
|
).expressions[0]
|
|
if isinstance(node, exp.Anonymous) and node.name.lower() == "percentile":
|
|
if len(node.expressions) >= 2:
|
|
col_sql = node.expressions[0].sql(dialect=self.dialect)
|
|
p_sql = node.expressions[1].sql(dialect=self.dialect)
|
|
return sqlglot.parse_one(
|
|
f"SELECT PERCENTILE_CONT({p_sql}) WITHIN GROUP (ORDER BY {col_sql})",
|
|
read=self.dialect,
|
|
).expressions[0]
|
|
if (
|
|
isinstance(node, exp.Anonymous)
|
|
and node.name.lower() == "count_distinct"
|
|
):
|
|
if node.expressions:
|
|
col_sql = node.expressions[0].sql(dialect=self.dialect)
|
|
return sqlglot.parse_one(
|
|
f"SELECT COUNT(DISTINCT {col_sql})", read=self.dialect
|
|
).expressions[0]
|
|
return node
|
|
|
|
transformed = tree.transform(_replace)
|
|
return transformed.expressions[0].sql(dialect=self.dialect)
|
|
|
|
def _extract_outer_aggregate(self, expr: str) -> tuple[str | None, str | None]:
|
|
"""Use AST to extract the outer aggregate function name and inner expression."""
|
|
tree = sqlglot.parse_one(
|
|
f"SELECT {quote_reserved_identifiers(expr)}", read=self.dialect
|
|
)
|
|
select_expr = tree.expressions[0]
|
|
if isinstance(select_expr, exp.Alias):
|
|
select_expr = select_expr.this
|
|
if isinstance(select_expr, exp.AggFunc):
|
|
func_name = select_expr.sql_name()
|
|
inner = select_expr.this.sql(dialect=self.dialect)
|
|
return func_name, inner
|
|
if isinstance(select_expr, exp.Anonymous):
|
|
func_name = select_expr.name
|
|
if select_expr.expressions:
|
|
inner = select_expr.expressions[0].sql(dialect=self.dialect)
|
|
return func_name, inner
|
|
return None, None
|
|
|
|
def _substitute_measure_refs(self, expr: str, replacements: dict[str, str]) -> str:
|
|
"""Replace bare measure name references in an expression using AST transform."""
|
|
if not replacements:
|
|
return expr
|
|
tree = sqlglot.parse_one(
|
|
f"SELECT {quote_reserved_identifiers(expr)}", read=self.dialect
|
|
)
|
|
|
|
def _replace(node):
|
|
if (
|
|
isinstance(node, exp.Column)
|
|
and not node.table
|
|
and node.name in replacements
|
|
):
|
|
replacement_sql = replacements[node.name]
|
|
return sqlglot.parse_one(
|
|
f"SELECT {replacement_sql}", read=self.dialect
|
|
).expressions[0]
|
|
return node
|
|
|
|
transformed = tree.transform(_replace)
|
|
return transformed.expressions[0].sql(dialect=self.dialect)
|
|
|
|
def _build_join(
|
|
self,
|
|
join: ResolvedJoin,
|
|
sources: dict[str, SourceDefinition],
|
|
plan: ResolvedPlan,
|
|
) -> str:
|
|
join_type = "LEFT JOIN" if plan.include_empty else "JOIN"
|
|
# If to_source is an alias, resolve the actual source for the table ref
|
|
actual_source = self._alias_map.get(join.to_source, join.to_source)
|
|
if actual_source != join.to_source:
|
|
# This is an aliased join: JOIN actual_table AS alias
|
|
src = sources.get(actual_source)
|
|
if src and src.is_table_source and src.table:
|
|
join_ref = f"{src.table} AS {join.to_source}"
|
|
else:
|
|
join_ref = f"{actual_source} AS {join.to_source}"
|
|
else:
|
|
join_ref = self._source_ref(join.to_source, sources)
|
|
on_clause = _build_on_clause(
|
|
join.from_source, join.from_column, join.to_source, join.to_column
|
|
)
|
|
return f"{join_type} {join_ref} ON {on_clause}"
|
|
|
|
def _build_group_by_exprs(
|
|
self, plan: ResolvedPlan, sources: dict[str, SourceDefinition]
|
|
) -> list[str]:
|
|
exprs = []
|
|
for dim in plan.dimensions:
|
|
exprs.append(self._dim_expr(dim, sources))
|
|
return exprs
|
|
|
|
# SQLite strftime format strings for time truncation.
|
|
# None entries require special date arithmetic (handled in _sqlite_time_trunc).
|
|
_SQLITE_STRFTIME: dict[str, str | None] = {
|
|
"year": "%Y-01-01",
|
|
"month": "%Y-%m-01",
|
|
"day": "%Y-%m-%d",
|
|
"hour": "%Y-%m-%d %H:00:00",
|
|
"quarter": None,
|
|
"week": None,
|
|
}
|
|
|
|
def _dim_expr(
|
|
self, dim: QueryDimension, sources: dict[str, SourceDefinition]
|
|
) -> str:
|
|
"""Build dimension expression, including time truncation and computed column expansion."""
|
|
field = self._expand_computed_columns(dim.field, sources)
|
|
if dim.granularity:
|
|
return self._time_trunc(dim.granularity, field)
|
|
return field
|
|
|
|
def _time_trunc(self, granularity: str, field: str) -> str:
|
|
"""Generate dialect-appropriate time truncation expression."""
|
|
g = granularity.lower()
|
|
if self.dialect == "sqlite":
|
|
return self._sqlite_time_trunc(g, field)
|
|
if self.dialect == "bigquery":
|
|
return f"DATE_TRUNC({field}, {g.upper()})"
|
|
if self.dialect == "mysql":
|
|
return self._mysql_time_trunc(g, field)
|
|
return f"DATE_TRUNC('{g}', {field})"
|
|
|
|
def _sqlite_time_trunc(self, granularity: str, field: str) -> str:
|
|
"""SQLite time truncation using strftime / date arithmetic."""
|
|
fmt = self._SQLITE_STRFTIME.get(granularity)
|
|
if fmt is not None:
|
|
return f"DATE(STRFTIME('{fmt}', {field}))"
|
|
if granularity == "quarter":
|
|
return (
|
|
f"DATE(STRFTIME('%Y', {field}) || '-' || "
|
|
f"PRINTF('%02d', ((CAST(STRFTIME('%m', {field}) AS INTEGER) - 1) / 3) * 3 + 1) || '-01')"
|
|
)
|
|
if granularity == "week":
|
|
return f"DATE({field}, 'weekday 1', '-7 days')"
|
|
logger.warning(
|
|
"Unsupported SQLite granularity '%s', returning raw field", granularity
|
|
)
|
|
return field
|
|
|
|
_MYSQL_DATE_FORMAT: dict[str, str] = {
|
|
"year": "%Y-01-01",
|
|
"quarter": "%Y-01-01",
|
|
"month": "%Y-%m-01",
|
|
"week": "%Y-%m-%d",
|
|
"day": "%Y-%m-%d",
|
|
"hour": "%Y-%m-%d %H:00:00",
|
|
}
|
|
|
|
def _mysql_time_trunc(self, granularity: str, field: str) -> str:
|
|
"""MySQL time truncation using DATE_FORMAT / quarter arithmetic."""
|
|
if granularity == "quarter":
|
|
return (
|
|
f"DATE(CONCAT(YEAR({field}), '-', "
|
|
f"LPAD((QUARTER({field}) - 1) * 3 + 1, 2, '0'), '-01'))"
|
|
)
|
|
if granularity == "week":
|
|
return f"DATE(DATE_SUB({field}, INTERVAL WEEKDAY({field}) DAY))"
|
|
fmt = self._MYSQL_DATE_FORMAT.get(granularity)
|
|
if fmt is not None:
|
|
return f"DATE(DATE_FORMAT({field}, '{fmt}'))"
|
|
logger.warning(
|
|
"Unsupported MySQL granularity '%s', returning raw field", granularity
|
|
)
|
|
return field
|
|
|
|
def _colliding_dim_leaves(self, dims: list[QueryDimension]) -> set[str]:
|
|
leaves = [d.field.split(".")[-1] if "." in d.field else d.field for d in dims]
|
|
return {leaf for leaf, count in Counter(leaves).items() if count > 1}
|
|
|
|
def _dimension_alias(
|
|
self, dim: QueryDimension, colliding_leaves: set[str] | None = None
|
|
) -> str:
|
|
leaf = dim.field.split(".")[-1] if "." in dim.field else dim.field
|
|
if colliding_leaves and leaf in colliding_leaves:
|
|
alias = dim.field.replace(".", "_")
|
|
else:
|
|
alias = leaf
|
|
if dim.granularity:
|
|
alias = f"{alias}_{dim.granularity}"
|
|
return alias
|
|
|
|
def _resolve_order_field(self, field: str, plan: ResolvedPlan) -> str:
|
|
colliding = self._colliding_dim_leaves(plan.dimensions)
|
|
field_lower = field.lower()
|
|
for measure in plan.measures:
|
|
if field_lower == measure.name.lower():
|
|
return measure.name
|
|
if field_lower == measure.expr.lower():
|
|
return measure.name
|
|
if measure.qualified_ref and field_lower == measure.qualified_ref.lower():
|
|
return measure.name
|
|
if (
|
|
measure.source_name not in {"__derived__", ""}
|
|
and field_lower == f"{measure.source_name}.{measure.name}".lower()
|
|
):
|
|
return measure.name
|
|
|
|
for dim in plan.dimensions:
|
|
alias = self._dimension_alias(dim, colliding)
|
|
if field_lower == dim.field.lower() or field_lower == alias.lower():
|
|
return alias
|
|
|
|
raise ValueError(
|
|
f"ORDER BY field '{field}' is not a recognized measure or dimension in this query"
|
|
)
|
|
|
|
def _collect_cte_target_sources(
|
|
self,
|
|
group: MeasureGroup,
|
|
plan: ResolvedPlan,
|
|
dim_keys: list[dict] | None = None,
|
|
) -> set[str]:
|
|
"""Collect sources needed for this CTE — only safely reachable ones."""
|
|
safe_adj = self._build_safe_adjacency(plan)
|
|
target_sources: set[str] = set()
|
|
|
|
# Only include dimension sources reachable via safe edges
|
|
dims_to_check = (
|
|
dim_keys
|
|
if dim_keys is not None
|
|
else [{"expr": dim.field} for dim in plan.dimensions]
|
|
)
|
|
for dk in dims_to_check:
|
|
dim_sources = self._parser.extract_source_refs(dk["expr"])
|
|
for ds in dim_sources:
|
|
if ds == group.source_name:
|
|
target_sources.add(ds)
|
|
continue
|
|
path = self._find_join_path_steps(group.source_name, ds, safe_adj)
|
|
if path is not None:
|
|
target_sources.add(ds)
|
|
|
|
known_sources = set(target_sources)
|
|
known_sources.add(group.source_name)
|
|
for filter_expr in plan.where_filters:
|
|
filter_sources = self._parser.extract_source_refs(filter_expr)
|
|
if not filter_sources or filter_sources & known_sources:
|
|
for fs in filter_sources:
|
|
if fs == group.source_name:
|
|
target_sources.add(fs)
|
|
continue
|
|
path = self._find_join_path_steps(group.source_name, fs, safe_adj)
|
|
if path is not None:
|
|
target_sources.add(fs)
|
|
known_sources.add(fs)
|
|
|
|
# Include sources from measure-level filters
|
|
for m in group.measures:
|
|
if m.filter:
|
|
filter_sources = self._parser.extract_source_refs(m.filter)
|
|
for fs in filter_sources:
|
|
if fs == group.source_name:
|
|
target_sources.add(fs)
|
|
continue
|
|
path = self._find_join_path_steps(group.source_name, fs, safe_adj)
|
|
if path is not None:
|
|
target_sources.add(fs)
|
|
known_sources.add(fs)
|
|
|
|
# Include sources from measure expressions themselves
|
|
for m in group.measures:
|
|
measure_sources = self._parser.extract_source_refs(m.expr)
|
|
for ms in measure_sources:
|
|
if ms == group.source_name or ms in known_sources:
|
|
target_sources.add(ms)
|
|
continue
|
|
path = self._find_join_path_steps(group.source_name, ms, safe_adj)
|
|
if path is not None:
|
|
target_sources.add(ms)
|
|
known_sources.add(ms)
|
|
|
|
return target_sources
|
|
|
|
def _build_safe_adjacency(
|
|
self, plan: ResolvedPlan
|
|
) -> dict[str, list[tuple[str, ResolvedJoin]]]:
|
|
"""Build adjacency graph using only many_to_one and one_to_one edges."""
|
|
adjacency: dict[str, list[tuple[str, ResolvedJoin]]] = {}
|
|
for join in plan.joins:
|
|
if join.relationship in ("many_to_one", "one_to_one"):
|
|
adjacency.setdefault(join.from_source, []).append(
|
|
(join.to_source, join)
|
|
)
|
|
if RELATIONSHIP_INVERSE[join.relationship] in ("many_to_one", "one_to_one"):
|
|
adjacency.setdefault(join.to_source, []).append(
|
|
(join.from_source, join)
|
|
)
|
|
return adjacency
|
|
|
|
def _resolve_having_for_locality(
|
|
self,
|
|
filter_expr: str,
|
|
plan: ResolvedPlan,
|
|
measure_cte_map: dict[str, str],
|
|
derived_expr_map: dict[str, str] | None = None,
|
|
) -> str:
|
|
"""Rewrite HAVING filter to reference CTE output columns.
|
|
|
|
Handles: raw aggregates (sum(orders.amount)), predefined measure refs
|
|
(orders.revenue), bare measure names, derived measure names (inlined),
|
|
and case-insensitive matching.
|
|
"""
|
|
# Build comprehensive replacement map
|
|
replacement_map: dict[str, str] = {}
|
|
for m in plan.measures:
|
|
if m.is_derived:
|
|
# For derived measures, inline the full expression so the outer
|
|
# WHERE clause doesn't reference a SELECT alias (which is illegal).
|
|
if derived_expr_map and m.name in derived_expr_map:
|
|
replacement_map[m.name.lower()] = f"({derived_expr_map[m.name]})"
|
|
continue
|
|
cte_alias = measure_cte_map.get(m.name)
|
|
if not cte_alias:
|
|
continue
|
|
# In multi-CTE (FULL JOIN) mode, NULL from unmatched rows should
|
|
# be treated as 0 so that filters like "count(x) = 0" work.
|
|
if len(plan.measure_groups) > 1:
|
|
cte_ref = f"COALESCE({cte_alias}.{m.name}, 0)"
|
|
else:
|
|
cte_ref = f"{cte_alias}.{m.name}"
|
|
replacement_map[m.expr.lower()] = cte_ref
|
|
if m.qualified_ref:
|
|
replacement_map[m.qualified_ref.lower()] = cte_ref
|
|
elif m.source_name and m.source_name not in ("__derived__", ""):
|
|
replacement_map[f"{m.source_name}.{m.name}".lower()] = cte_ref
|
|
replacement_map[m.name.lower()] = cte_ref
|
|
|
|
# AST-based rewriting for robustness
|
|
try:
|
|
tree = sqlglot.parse_one(
|
|
f"SELECT {quote_reserved_identifiers(filter_expr)}",
|
|
read=self.dialect,
|
|
)
|
|
|
|
def _rewrite(node):
|
|
if isinstance(node, (exp.AggFunc, exp.Anonymous)):
|
|
node_sql = node.sql(dialect=self.dialect).lower()
|
|
if node_sql in replacement_map:
|
|
return sqlglot.parse_one(
|
|
f"SELECT {replacement_map[node_sql]}", read=self.dialect
|
|
).expressions[0]
|
|
if isinstance(node, exp.Column):
|
|
if node.table:
|
|
ref = f"{node.table}.{node.name}".lower()
|
|
if ref in replacement_map:
|
|
return sqlglot.parse_one(
|
|
f"SELECT {replacement_map[ref]}", read=self.dialect
|
|
).expressions[0]
|
|
if not node.table and node.name.lower() in replacement_map:
|
|
return sqlglot.parse_one(
|
|
f"SELECT {replacement_map[node.name.lower()]}",
|
|
read=self.dialect,
|
|
).expressions[0]
|
|
return node
|
|
|
|
transformed = tree.transform(_rewrite)
|
|
return transformed.expressions[0].sql(dialect=self.dialect)
|
|
except Exception:
|
|
logger.debug(
|
|
"AST-based HAVING rewrite failed for locality filter, falling back to regex: %s",
|
|
filter_expr,
|
|
)
|
|
import re as _re
|
|
|
|
result = filter_expr
|
|
for pattern, replacement in sorted(
|
|
replacement_map.items(), key=lambda x: -len(x[0])
|
|
):
|
|
result = _re.sub(
|
|
_re.escape(pattern), replacement, result, flags=_re.IGNORECASE
|
|
)
|
|
return result
|
|
|
|
def _build_group_join_steps(
|
|
self,
|
|
source_name: str,
|
|
target_sources: set[str],
|
|
plan: ResolvedPlan,
|
|
) -> list[tuple[ResolvedJoin, str]]:
|
|
if not target_sources:
|
|
return []
|
|
|
|
adjacency: dict[str, list[tuple[str, ResolvedJoin]]] = {}
|
|
for join in plan.joins:
|
|
if join.relationship in ("many_to_one", "one_to_one"):
|
|
adjacency.setdefault(join.from_source, []).append(
|
|
(join.to_source, join)
|
|
)
|
|
if RELATIONSHIP_INVERSE[join.relationship] in ("many_to_one", "one_to_one"):
|
|
adjacency.setdefault(join.to_source, []).append(
|
|
(join.from_source, join)
|
|
)
|
|
|
|
steps: list[tuple[ResolvedJoin, str]] = []
|
|
seen_edges: set[tuple[str, str, str, str]] = set()
|
|
|
|
for target in sorted(target_sources - {source_name}):
|
|
path_steps = self._find_join_path_steps(source_name, target, adjacency)
|
|
if not path_steps:
|
|
raise ValueError(
|
|
f"Aggregate locality cannot safely reach '{target}' from "
|
|
f"'{source_name}' without traversing one_to_many edges"
|
|
)
|
|
for join, next_source in path_steps:
|
|
edge_key = (
|
|
join.from_source,
|
|
join.to_source,
|
|
join.from_column,
|
|
join.to_column,
|
|
)
|
|
if edge_key in seen_edges:
|
|
continue
|
|
seen_edges.add(edge_key)
|
|
steps.append((join, next_source))
|
|
|
|
return steps
|
|
|
|
def _find_join_path_steps(
|
|
self,
|
|
start: str,
|
|
target: str,
|
|
adjacency: dict[str, list[tuple[str, ResolvedJoin]]],
|
|
) -> list[tuple[ResolvedJoin, str]]:
|
|
if start == target:
|
|
return []
|
|
|
|
queue = [start]
|
|
parents: dict[str, tuple[str | None, ResolvedJoin | None]] = {
|
|
start: (None, None)
|
|
}
|
|
|
|
while queue:
|
|
current = queue.pop(0)
|
|
if current == target:
|
|
break
|
|
|
|
for next_source, join in adjacency.get(current, []):
|
|
if next_source in parents:
|
|
continue
|
|
parents[next_source] = (current, join)
|
|
queue.append(next_source)
|
|
|
|
if target not in parents:
|
|
return []
|
|
|
|
steps: list[tuple[ResolvedJoin, str]] = []
|
|
current = target
|
|
while current != start:
|
|
parent, join = parents[current]
|
|
if parent is None or join is None:
|
|
break
|
|
steps.append((join, current))
|
|
current = parent
|
|
|
|
steps.reverse()
|
|
return steps
|
|
|
|
def _source_ref(self, name: str, sources: dict[str, SourceDefinition]) -> str:
|
|
"""Get the FROM reference for a source (table or CTE name)."""
|
|
qname = _qi(name)
|
|
src = sources.get(name)
|
|
if not src and name in self._alias_map:
|
|
actual_name = self._alias_map[name]
|
|
actual_src = sources.get(actual_name)
|
|
if actual_src is None:
|
|
raise ValueError(
|
|
f"Cannot generate SQL: alias '{name}' refers to source "
|
|
f"'{actual_name}', which is not defined"
|
|
)
|
|
if actual_src.is_table_source and actual_src.table:
|
|
return f"{actual_src.table} AS {qname}"
|
|
if actual_src.is_sql_source:
|
|
return f"{_qi(actual_name)} AS {qname}"
|
|
return f"{_qi(actual_name)} AS {qname}"
|
|
if not src:
|
|
raise ValueError(f"Cannot generate SQL: source '{name}' is not defined")
|
|
if src.is_sql_source:
|
|
return qname # references the CTE
|
|
return f"{src.table} AS {qname}" if src.table else qname
|
|
|
|
# ── Computed column expansion ─────────────────────────────────────
|
|
|
|
def _get_computed_col_map(
|
|
self, sources: dict[str, SourceDefinition]
|
|
) -> dict[str, str]:
|
|
"""Get or build the computed column map: {"source.col": "(qualified_expr)"}."""
|
|
cache_key = id(sources)
|
|
if getattr(self, "_computed_cache_key", None) != cache_key:
|
|
self._computed_col_map = self._build_computed_col_map(sources)
|
|
self._computed_cache_key = cache_key
|
|
return self._computed_col_map
|
|
|
|
def _build_computed_col_map(
|
|
self, sources: dict[str, SourceDefinition]
|
|
) -> dict[str, str]:
|
|
"""Build a lookup from 'source.column' to qualified expression for computed columns."""
|
|
result: dict[str, str] = {}
|
|
for src_name, src in sources.items():
|
|
col_names = {c.name for c in src.columns}
|
|
for col in src.columns:
|
|
if col.expr is None:
|
|
continue
|
|
qualified = self._qualify_bare_refs_in_expr(
|
|
col.expr, src_name, col_names
|
|
)
|
|
result[f"{src_name}.{col.name}"] = f"({qualified})"
|
|
return result
|
|
|
|
def _qualify_bare_refs_in_expr(
|
|
self, expr: str, source_name: str, col_names: set[str]
|
|
) -> str:
|
|
"""Qualify bare column references in a computed column expression with the source name."""
|
|
try:
|
|
tree = sqlglot.parse_one(
|
|
f"SELECT {quote_reserved_identifiers(expr)}", read=self.dialect
|
|
)
|
|
|
|
def _qualify(node: exp.Expression) -> exp.Expression:
|
|
if (
|
|
isinstance(node, exp.Column)
|
|
and not node.table
|
|
and node.name in col_names
|
|
):
|
|
return exp.Column(
|
|
this=node.this.copy(),
|
|
table=exp.to_identifier(source_name),
|
|
)
|
|
return node
|
|
|
|
transformed = tree.transform(_qualify)
|
|
return transformed.expressions[0].sql(dialect=self.dialect)
|
|
except Exception:
|
|
logger.debug(
|
|
"AST-based bare ref qualification failed for expr '%s' on source '%s'",
|
|
expr,
|
|
source_name,
|
|
)
|
|
return expr
|
|
|
|
def _expand_computed_columns(
|
|
self, expr: str, sources: dict[str, SourceDefinition]
|
|
) -> str:
|
|
"""Expand computed column references to their underlying expressions."""
|
|
computed_map = self._get_computed_col_map(sources)
|
|
if not computed_map:
|
|
return expr
|
|
|
|
try:
|
|
tree = sqlglot.parse_one(
|
|
f"SELECT {quote_reserved_identifiers(expr)}", read=self.dialect
|
|
)
|
|
|
|
changed = False
|
|
|
|
def _replace(node: exp.Expression) -> exp.Expression:
|
|
nonlocal changed
|
|
if isinstance(node, exp.Column) and node.table:
|
|
qualified = f"{node.table}.{node.name}"
|
|
if qualified in computed_map:
|
|
changed = True
|
|
return sqlglot.parse_one(
|
|
f"SELECT {computed_map[qualified]}", read=self.dialect
|
|
).expressions[0]
|
|
return node
|
|
|
|
transformed = tree.transform(_replace)
|
|
if changed:
|
|
return transformed.expressions[0].sql(dialect=self.dialect)
|
|
except Exception:
|
|
logger.debug("AST-based computed column expansion failed for: %s", expr)
|
|
|
|
return expr
|
|
|
|
def _qualify_expr(self, expr: str, sources: dict[str, SourceDefinition]) -> str:
|
|
"""Expand computed column references in expressions."""
|
|
return self._expand_computed_columns(expr, sources)
|
|
|
|
def _qualify_filter(
|
|
self, f: str, sources: dict[str, SourceDefinition], plan: ResolvedPlan
|
|
) -> str:
|
|
"""Expand computed column references in WHERE filters."""
|
|
return self._expand_computed_columns(f, sources)
|
|
|
|
def _expand_having_filter(
|
|
self, f: str, plan: ResolvedPlan, sources: dict[str, SourceDefinition]
|
|
) -> str:
|
|
"""Expand predefined measure references in HAVING filters to aggregate expressions.
|
|
|
|
e.g., 'orders.revenue > 1000' → 'SUM(orders.amount) > 1000'
|
|
when revenue is a predefined measure with expr='sum(amount)'.
|
|
"""
|
|
# Build a map of qualified measure ref → SQL aggregate expression
|
|
measure_expr_map: dict[str, str] = {}
|
|
for m in plan.measures:
|
|
if m.source_name and m.source_name != "__derived__":
|
|
if not m.is_derived:
|
|
if m.qualified_ref:
|
|
measure_expr_map[m.qualified_ref] = self._build_measure_expr(
|
|
m, sources
|
|
)
|
|
qualified_ref = f"{m.source_name}.{m.name}"
|
|
measure_expr_map[qualified_ref] = self._build_measure_expr(
|
|
m, sources
|
|
)
|
|
# Also map bare measure name for unqualified references
|
|
if not m.is_derived:
|
|
measure_expr_map[m.name] = self._build_measure_expr(m, sources)
|
|
|
|
if not measure_expr_map:
|
|
return f
|
|
|
|
# Use AST to find and replace column references matching measure names
|
|
try:
|
|
tree = sqlglot.parse_one(
|
|
f"SELECT * WHERE {quote_reserved_identifiers(f)}",
|
|
dialect=self.dialect,
|
|
)
|
|
where = tree.find(exp.Where)
|
|
if not where:
|
|
return f
|
|
|
|
changed = False
|
|
|
|
def _replace(node):
|
|
nonlocal changed
|
|
if isinstance(node, exp.Column):
|
|
table = node.table
|
|
col_name = node.name
|
|
if table:
|
|
qualified = f"{table}.{col_name}"
|
|
if qualified in measure_expr_map:
|
|
changed = True
|
|
return sqlglot.parse_one(
|
|
f"SELECT {measure_expr_map[qualified]}",
|
|
read=self.dialect,
|
|
).expressions[0]
|
|
elif col_name in measure_expr_map:
|
|
changed = True
|
|
return sqlglot.parse_one(
|
|
f"SELECT {measure_expr_map[col_name]}",
|
|
read=self.dialect,
|
|
).expressions[0]
|
|
return node
|
|
|
|
new_where = where.this.transform(_replace)
|
|
if changed:
|
|
return new_where.sql(dialect=self.dialect)
|
|
except Exception:
|
|
logger.debug(
|
|
"AST-based HAVING expansion failed, returning filter unchanged: %s", f
|
|
)
|
|
return f
|
|
|
|
def _transpile(self, outer_sql: str) -> str:
|
|
"""Normalize the outer scaffold for the target dialect.
|
|
|
|
Source CTEs are concatenated by generate() verbatim, so only the
|
|
engine-generated outer scaffold (which embeds user-authored expr:
|
|
fragments already in self.dialect) reaches this function. Reading and
|
|
writing in self.dialect preserves dialect-specific constructs
|
|
(TIMESTAMP_SUB, DATEADD, APPROX_COUNT_DISTINCT, etc.) that a
|
|
postgres-round-trip would mangle.
|
|
"""
|
|
if self.dialect == "postgres":
|
|
return outer_sql
|
|
try:
|
|
# Quote reserved-word identifiers so target dialect parsers do not
|
|
# confuse them with keywords (e.g. Snowflake's SAMPLE, QUALIFY).
|
|
quoted_outer = quote_reserved_identifiers(outer_sql)
|
|
results = sqlglot.transpile(
|
|
quoted_outer, read=self.dialect, write=self.dialect
|
|
)
|
|
return results[0] if results else outer_sql
|
|
except Exception:
|
|
logger.debug(
|
|
"Outer transpile in '%s' failed; returning un-normalized outer",
|
|
self.dialect,
|
|
)
|
|
return outer_sql
|