fix(sl): parse user filter expressions as predicates, not projections (#307)

* fix(sl): parse user filter expressions as predicates, not projections

User-authored filters and segments were parsed in a projection context
(`SELECT {expr}`). On T-SQL a top-level `col = 'value'` projection is the
`alias = expression` aliasing syntax, so an equality filter parsed this way
became `'value' AS col` — dropping the comparison entirely and silently
skipping computed-column expansion (the column hid behind the alias).

Parse user fragments as predicates (`SELECT * WHERE {expr}`) at every parse
site — the parser cache, measure-filter CASE WHEN generation, computed-column
expansion, and measure-filter/segment column qualification. For plain
non-condition expressions the column set is identical, so this is a no-op
everywhere except the T-SQL alias case it fixes.

Add cross-dialect regression tests (tsql, postgres, snowflake, bigquery)
locking equality filters/segments to comparison shape and confirming `= 'x'`
now matches `IN ('x')` on T-SQL.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>

* Shorten T-SQL predicate comments

* docs(sl): tighten T-SQL predicate docstrings and AGENTS docstring rule

Trim the parser and regression-test docstrings to the 1-3 line bar and
extend the AGENTS.md comment guidance to cover docstrings explicitly.

* refactor(sl): route all filter parsing through parse_predicate

Consolidate the predicate-context parse into a single parse_predicate
helper and route every filter-parsing call site through it: measure
CASE-WHEN filters, segments, computed-column-in-filter, the
aggregate-locality HAVING rewrite, and the planner OR-mixing /
top-level-AND split. The locality and split paths still parsed user
filters in projection context, so a named-measure equality filter
compiled to `0 AS measure` on T-SQL. Add a locality regression test
covering the HAVING rewrite path.

---------

Co-authored-by: Claude Opus 4.8 <noreply@anthropic.com>
Co-authored-by: Andrey Avtomonov <andreybavt@gmail.com>
This commit is contained in:
Luca Martial 2026-06-19 01:47:44 -07:00 committed by GitHub
parent 4dae8c34dd
commit fb50c11d16
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 238 additions and 63 deletions

View file

@ -15,7 +15,11 @@ from semantic_layer.models import (
ResolvedPlan,
SourceDefinition,
)
from semantic_layer.parser import ExpressionParser, quote_reserved_identifiers
from semantic_layer.parser import (
ExpressionParser,
parse_predicate,
quote_reserved_identifiers,
)
# DIALECT CONVENTION:
# User-authored SQL fragments (measure `expr`, segment `expr`, filter,
@ -673,9 +677,7 @@ class SqlGenerator:
if isinstance(select_expr, exp.Alias):
select_expr = select_expr.this
filter_cond = sqlglot.parse_one(
f"SELECT {filter_sql}", read=self.dialect
).expressions[0]
filter_cond = parse_predicate(filter_sql, self.dialect)
def _make_case(inner_node):
return exp.Case(
@ -1073,10 +1075,7 @@ class SqlGenerator:
# AST-based rewriting for robustness
try:
tree = sqlglot.parse_one(
f"SELECT {quote_reserved_identifiers(filter_expr)}",
read=self.dialect,
)
condition = parse_predicate(filter_expr, self.dialect)
def _rewrite(node):
if isinstance(node, (exp.AggFunc, exp.Anonymous)):
@ -1099,8 +1098,8 @@ class SqlGenerator:
).expressions[0]
return node
transformed = tree.transform(_rewrite)
return transformed.expressions[0].sql(dialect=self.dialect)
transformed = condition.transform(_rewrite)
return transformed.sql(dialect=self.dialect)
except Exception:
logger.debug(
"AST-based HAVING rewrite failed for locality filter, falling back to regex: %s",
@ -1292,9 +1291,7 @@ class SqlGenerator:
return expr
try:
tree = sqlglot.parse_one(
f"SELECT {quote_reserved_identifiers(expr)}", read=self.dialect
)
condition = parse_predicate(expr, self.dialect)
changed = False
@ -1309,9 +1306,9 @@ class SqlGenerator:
).expressions[0]
return node
transformed = tree.transform(_replace)
transformed = condition.transform(_replace)
if changed:
return transformed.expressions[0].sql(dialect=self.dialect)
return transformed.sql(dialect=self.dialect)
except Exception:
logger.debug("AST-based computed column expansion failed for: %s", expr)
@ -1357,13 +1354,7 @@ class SqlGenerator:
# Use AST to find and replace column references matching measure names
try:
tree = sqlglot.parse_one(
f"SELECT * WHERE {quote_reserved_identifiers(f)}",
dialect=self.dialect,
)
where = tree.find(exp.Where)
if not where:
return f
condition = parse_predicate(f, self.dialect)
changed = False
@ -1388,7 +1379,7 @@ class SqlGenerator:
).expressions[0]
return node
new_where = where.this.transform(_replace)
new_where = condition.transform(_replace)
if changed:
return new_where.sql(dialect=self.dialect)
except Exception:

View file

@ -196,6 +196,15 @@ def quote_reserved_identifiers(expr: str) -> str:
return result
def _predicate_select(expr: str) -> str:
"""Wrap a user expression as `SELECT * WHERE …`, quoting reserved identifiers.
Predicate, not projection: T-SQL reads a top-level `col = 'value'` projection
as the `alias = expression` form and would compile the filter to `'value' AS col`.
"""
return f"SELECT * WHERE {quote_reserved_identifiers(expr)}"
@functools.lru_cache(maxsize=256)
def _cached_parse_select(sql: str, dialect: str) -> exp.Expression:
"""Cache parsed SELECT wrapper trees keyed by (sql, dialect).
@ -206,6 +215,14 @@ def _cached_parse_select(sql: str, dialect: str) -> exp.Expression:
return sqlglot.parse_one(sql, read=dialect)
def parse_predicate(expr: str, dialect: str) -> exp.Expression:
"""Parse a user filter into a fresh, mutable WHERE-condition node.
Uncached, so the result is safe to `.transform()`; raises on unparseable input.
"""
return sqlglot.parse_one(_predicate_select(expr), read=dialect).find(exp.Where).this
class ExpressionParser:
"""Parses user-authored SQL expressions for AST walks.
@ -218,12 +235,9 @@ class ExpressionParser:
def __init__(self, dialect: str = "postgres") -> None:
self.dialect = dialect
def _quote_reserved_identifiers(self, expr: str) -> str:
return quote_reserved_identifiers(expr)
def _parse_as_select(self, quoted_expr: str) -> exp.Expression:
"""Parse expression wrapped in SELECT, using cache for repeated expressions."""
return _cached_parse_select(f"SELECT {quoted_expr}", self.dialect)
def _parse_as_select(self, expr: str) -> exp.Expression:
"""Parse a user fragment for read-only AST walks, via the parse cache."""
return _cached_parse_select(_predicate_select(expr), self.dialect)
def parse(
self,
@ -236,8 +250,7 @@ class ExpressionParser:
if not expr or not expr.strip():
return result
quoted_expr = self._quote_reserved_identifiers(expr)
tree = self._parse_as_select(quoted_expr)
tree = self._parse_as_select(expr)
# Extract source.column references
for col in tree.find_all(exp.Column):
@ -296,8 +309,7 @@ class ExpressionParser:
"""Quick extraction of source names from an expression."""
if not expr or not expr.strip():
return set()
quoted_expr = self._quote_reserved_identifiers(expr)
tree = self._parse_as_select(quoted_expr)
tree = self._parse_as_select(expr)
return {
_strip_quotes(col.table) for col in tree.find_all(exp.Column) if col.table
}

View file

@ -22,7 +22,11 @@ from semantic_layer.models import (
SemanticQuery,
SourceDefinition,
)
from semantic_layer.parser import ExpressionParser, quote_reserved_identifiers
from semantic_layer.parser import (
ExpressionParser,
parse_predicate,
quote_reserved_identifiers,
)
# DIALECT CONVENTION:
# User-authored measure `expr`, `filter`, and computed-column fragments must
@ -910,9 +914,7 @@ class QueryPlanner:
for c in source.columns:
col_to_source[c.name] = source_name
tree = sqlglot.parse_one(
f"SELECT {quote_reserved_identifiers(expr)}", read=self.dialect
)
condition = parse_predicate(expr, self.dialect)
def _qualify_column(node):
if (
@ -926,8 +928,8 @@ class QueryPlanner:
)
return node
transformed = tree.transform(_qualify_column)
return transformed.expressions[0].sql(dialect=self.dialect)
transformed = condition.transform(_qualify_column)
return transformed.sql(dialect=self.dialect)
def _detect_fan_out(
self,
@ -1254,14 +1256,7 @@ class QueryPlanner:
) -> None:
"""Raise an error if an OR expression mixes WHERE and HAVING conditions."""
try:
tree = sqlglot.parse_one(
f"SELECT * WHERE {quote_reserved_identifiers(clause)}",
dialect=self.dialect,
)
where = tree.find(exp.Where)
if not where:
return
inner = where.this
inner = parse_predicate(clause, self.dialect)
# Only check if the top level contains OR
or_parts: list[str] = []
@ -1295,14 +1290,7 @@ class QueryPlanner:
def _split_top_level_and(self, expr: str) -> list[str]:
"""Split a filter expression on top-level AND (not inside parentheses or strings)."""
try:
tree = sqlglot.parse_one(
f"SELECT * WHERE {quote_reserved_identifiers(expr)}",
dialect=self.dialect,
)
where = tree.find(exp.Where)
if not where:
return [expr]
inner = where.this
inner = parse_predicate(expr, self.dialect)
parts: list[str] = []
def _collect_and(node):