fix(sl): parse user filter expressions as predicates, not projections (#307)

* fix(sl): parse user filter expressions as predicates, not projections

User-authored filters and segments were parsed in a projection context
(`SELECT {expr}`). On T-SQL a top-level `col = 'value'` projection is the
`alias = expression` aliasing syntax, so an equality filter parsed this way
became `'value' AS col` — dropping the comparison entirely and silently
skipping computed-column expansion (the column hid behind the alias).

Parse user fragments as predicates (`SELECT * WHERE {expr}`) at every parse
site — the parser cache, measure-filter CASE WHEN generation, computed-column
expansion, and measure-filter/segment column qualification. For plain
non-condition expressions the column set is identical, so this is a no-op
everywhere except the T-SQL alias case it fixes.

Add cross-dialect regression tests (tsql, postgres, snowflake, bigquery)
locking equality filters/segments to comparison shape and confirming `= 'x'`
now matches `IN ('x')` on T-SQL.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>

* Shorten T-SQL predicate comments

* docs(sl): tighten T-SQL predicate docstrings and AGENTS docstring rule

Trim the parser and regression-test docstrings to the 1-3 line bar and
extend the AGENTS.md comment guidance to cover docstrings explicitly.

* refactor(sl): route all filter parsing through parse_predicate

Consolidate the predicate-context parse into a single parse_predicate
helper and route every filter-parsing call site through it: measure
CASE-WHEN filters, segments, computed-column-in-filter, the
aggregate-locality HAVING rewrite, and the planner OR-mixing /
top-level-AND split. The locality and split paths still parsed user
filters in projection context, so a named-measure equality filter
compiled to `0 AS measure` on T-SQL. Add a locality regression test
covering the HAVING rewrite path.

---------

Co-authored-by: Claude Opus 4.8 <noreply@anthropic.com>
Co-authored-by: Andrey Avtomonov <andreybavt@gmail.com>
This commit is contained in:
Luca Martial 2026-06-19 01:47:44 -07:00 committed by GitHub
parent 4dae8c34dd
commit fb50c11d16
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 238 additions and 63 deletions

View file

@ -263,22 +263,31 @@ and route ingest, setup, memory, indexing, and docs through it. Do not add an
`auto_commit`-style switch unless the user explicitly asks for staged-only runs
and accepts the extra runtime path.
## Code Comments
## Code Comments and Docstrings
Code must be self-explanatory. A comment exists only to state a constraint the
code cannot show; everything else belongs in the PR description or nowhere.
Code must be self-explanatory. Clear names, types, and signatures do the
documenting; a comment or docstring exists only to state what the code cannot
show. Everything else belongs in the PR description or nowhere.
- **MUST**: Keep each comment to 1-3 lines stating only what the code cannot
show: a cross-file invariant ("error-severity issues never reach here — the
doctor exits on them first"), a required ordering ("ktx.yaml is written
before git init, so a crash cannot leave a bare `.git`"), or a library quirk
("zod reports unknown record keys as `invalid_key`").
- **MUST**: Hold docstrings (Python `"""..."""`, JSDoc/TSDoc) to the same bar.
A docstring states a function's purpose or contract in 1-3 lines; when a real
quirk or invariant motivates the code, note it once and briefly. Let
self-explanatory code carry the rest — a well-named, well-typed function
often needs no docstring at all.
- **MUST**: State each invariant once, at the public entry point. Do not repeat
the same guarantee across a helper, its wrapper, and the call site.
- **MUST NOT**: Write prose comment blocks — design rationale, alternatives
considered, change narration ("is now written before…"), caller enumerations
("shared by X, Y, and Z"), or restatements of what the code already shows.
That is the author addressing the reviewer, and it rots once merged.
the same guarantee across a module docstring, a helper, its wrapper, and the
call site.
- **MUST NOT**: Write multi-paragraph docstrings or prose comment blocks —
design rationale, alternatives considered, change narration ("is now written
before…"), caller enumerations ("shared by X, Y, and Z"), worked examples
that restate the code, or the same explanation repeated in a module docstring
and the function it describes. That is the author addressing the reviewer; it
belongs in the PR description and rots once merged.
- **MAY**: Open a regression test with a 1-3 line comment stating the scenario
it guards when the test name cannot carry it. Omit design history and
references to removed designs.

View file

@ -15,7 +15,11 @@ from semantic_layer.models import (
ResolvedPlan,
SourceDefinition,
)
from semantic_layer.parser import ExpressionParser, quote_reserved_identifiers
from semantic_layer.parser import (
ExpressionParser,
parse_predicate,
quote_reserved_identifiers,
)
# DIALECT CONVENTION:
# User-authored SQL fragments (measure `expr`, segment `expr`, filter,
@ -673,9 +677,7 @@ class SqlGenerator:
if isinstance(select_expr, exp.Alias):
select_expr = select_expr.this
filter_cond = sqlglot.parse_one(
f"SELECT {filter_sql}", read=self.dialect
).expressions[0]
filter_cond = parse_predicate(filter_sql, self.dialect)
def _make_case(inner_node):
return exp.Case(
@ -1073,10 +1075,7 @@ class SqlGenerator:
# AST-based rewriting for robustness
try:
tree = sqlglot.parse_one(
f"SELECT {quote_reserved_identifiers(filter_expr)}",
read=self.dialect,
)
condition = parse_predicate(filter_expr, self.dialect)
def _rewrite(node):
if isinstance(node, (exp.AggFunc, exp.Anonymous)):
@ -1099,8 +1098,8 @@ class SqlGenerator:
).expressions[0]
return node
transformed = tree.transform(_rewrite)
return transformed.expressions[0].sql(dialect=self.dialect)
transformed = condition.transform(_rewrite)
return transformed.sql(dialect=self.dialect)
except Exception:
logger.debug(
"AST-based HAVING rewrite failed for locality filter, falling back to regex: %s",
@ -1292,9 +1291,7 @@ class SqlGenerator:
return expr
try:
tree = sqlglot.parse_one(
f"SELECT {quote_reserved_identifiers(expr)}", read=self.dialect
)
condition = parse_predicate(expr, self.dialect)
changed = False
@ -1309,9 +1306,9 @@ class SqlGenerator:
).expressions[0]
return node
transformed = tree.transform(_replace)
transformed = condition.transform(_replace)
if changed:
return transformed.expressions[0].sql(dialect=self.dialect)
return transformed.sql(dialect=self.dialect)
except Exception:
logger.debug("AST-based computed column expansion failed for: %s", expr)
@ -1357,13 +1354,7 @@ class SqlGenerator:
# Use AST to find and replace column references matching measure names
try:
tree = sqlglot.parse_one(
f"SELECT * WHERE {quote_reserved_identifiers(f)}",
dialect=self.dialect,
)
where = tree.find(exp.Where)
if not where:
return f
condition = parse_predicate(f, self.dialect)
changed = False
@ -1388,7 +1379,7 @@ class SqlGenerator:
).expressions[0]
return node
new_where = where.this.transform(_replace)
new_where = condition.transform(_replace)
if changed:
return new_where.sql(dialect=self.dialect)
except Exception:

View file

@ -196,6 +196,15 @@ def quote_reserved_identifiers(expr: str) -> str:
return result
def _predicate_select(expr: str) -> str:
"""Wrap a user expression as `SELECT * WHERE …`, quoting reserved identifiers.
Predicate, not projection: T-SQL reads a top-level `col = 'value'` projection
as the `alias = expression` form and would compile the filter to `'value' AS col`.
"""
return f"SELECT * WHERE {quote_reserved_identifiers(expr)}"
@functools.lru_cache(maxsize=256)
def _cached_parse_select(sql: str, dialect: str) -> exp.Expression:
"""Cache parsed SELECT wrapper trees keyed by (sql, dialect).
@ -206,6 +215,14 @@ def _cached_parse_select(sql: str, dialect: str) -> exp.Expression:
return sqlglot.parse_one(sql, read=dialect)
def parse_predicate(expr: str, dialect: str) -> exp.Expression:
"""Parse a user filter into a fresh, mutable WHERE-condition node.
Uncached, so the result is safe to `.transform()`; raises on unparseable input.
"""
return sqlglot.parse_one(_predicate_select(expr), read=dialect).find(exp.Where).this
class ExpressionParser:
"""Parses user-authored SQL expressions for AST walks.
@ -218,12 +235,9 @@ class ExpressionParser:
def __init__(self, dialect: str = "postgres") -> None:
self.dialect = dialect
def _quote_reserved_identifiers(self, expr: str) -> str:
return quote_reserved_identifiers(expr)
def _parse_as_select(self, quoted_expr: str) -> exp.Expression:
"""Parse expression wrapped in SELECT, using cache for repeated expressions."""
return _cached_parse_select(f"SELECT {quoted_expr}", self.dialect)
def _parse_as_select(self, expr: str) -> exp.Expression:
"""Parse a user fragment for read-only AST walks, via the parse cache."""
return _cached_parse_select(_predicate_select(expr), self.dialect)
def parse(
self,
@ -236,8 +250,7 @@ class ExpressionParser:
if not expr or not expr.strip():
return result
quoted_expr = self._quote_reserved_identifiers(expr)
tree = self._parse_as_select(quoted_expr)
tree = self._parse_as_select(expr)
# Extract source.column references
for col in tree.find_all(exp.Column):
@ -296,8 +309,7 @@ class ExpressionParser:
"""Quick extraction of source names from an expression."""
if not expr or not expr.strip():
return set()
quoted_expr = self._quote_reserved_identifiers(expr)
tree = self._parse_as_select(quoted_expr)
tree = self._parse_as_select(expr)
return {
_strip_quotes(col.table) for col in tree.find_all(exp.Column) if col.table
}

View file

@ -22,7 +22,11 @@ from semantic_layer.models import (
SemanticQuery,
SourceDefinition,
)
from semantic_layer.parser import ExpressionParser, quote_reserved_identifiers
from semantic_layer.parser import (
ExpressionParser,
parse_predicate,
quote_reserved_identifiers,
)
# DIALECT CONVENTION:
# User-authored measure `expr`, `filter`, and computed-column fragments must
@ -910,9 +914,7 @@ class QueryPlanner:
for c in source.columns:
col_to_source[c.name] = source_name
tree = sqlglot.parse_one(
f"SELECT {quote_reserved_identifiers(expr)}", read=self.dialect
)
condition = parse_predicate(expr, self.dialect)
def _qualify_column(node):
if (
@ -926,8 +928,8 @@ class QueryPlanner:
)
return node
transformed = tree.transform(_qualify_column)
return transformed.expressions[0].sql(dialect=self.dialect)
transformed = condition.transform(_qualify_column)
return transformed.sql(dialect=self.dialect)
def _detect_fan_out(
self,
@ -1254,14 +1256,7 @@ class QueryPlanner:
) -> None:
"""Raise an error if an OR expression mixes WHERE and HAVING conditions."""
try:
tree = sqlglot.parse_one(
f"SELECT * WHERE {quote_reserved_identifiers(clause)}",
dialect=self.dialect,
)
where = tree.find(exp.Where)
if not where:
return
inner = where.this
inner = parse_predicate(clause, self.dialect)
# Only check if the top level contains OR
or_parts: list[str] = []
@ -1295,14 +1290,7 @@ class QueryPlanner:
def _split_top_level_and(self, expr: str) -> list[str]:
"""Split a filter expression on top-level AND (not inside parentheses or strings)."""
try:
tree = sqlglot.parse_one(
f"SELECT * WHERE {quote_reserved_identifiers(expr)}",
dialect=self.dialect,
)
where = tree.find(exp.Where)
if not where:
return [expr]
inner = where.this
inner = parse_predicate(expr, self.dialect)
parts: list[str] = []
def _collect_and(node):

View file

@ -0,0 +1,175 @@
"""Regression tests for T-SQL `=` filters mis-parsed as column aliases.
A top-level `col = 'value'` is T-SQL's `alias = expression` projection form, so
filters and segments must compile as predicates, not projections, on any dialect.
"""
from __future__ import annotations
import pytest
import sqlglot
from .conftest import make_engine
def _jobs_source(**overrides):
base = {
"name": "jobs",
"table": "dbo.jobs",
"grain": ["id"],
"columns": [
{"name": "id", "type": "number"},
{"name": "amount", "type": "number"},
{"name": "trade_name", "type": "string"},
{"name": "base", "type": "number"},
{"name": "doubled", "type": "number", "expr": "base * 2"},
],
"segments": [
{"name": "is_roofing", "expr": "trade_name = 'Roofing'"},
],
"measures": [
{
"name": "roofing_rev",
"expr": "sum(amount)",
"filter": "trade_name = 'Roofing'",
},
],
}
base.update(overrides)
return base
def _chasm_sources():
"""Two facts fanning out from a shared hub — triggers aggregate locality, the
path that rewrites HAVING filters against measure references."""
fact = lambda name: { # noqa: E731
"name": name,
"table": f"dbo.{name}",
"grain": ["id"],
"columns": [
{"name": "id", "type": "number"},
{"name": "hub_id", "type": "number"},
{"name": "val", "type": "number"},
],
"joins": [
{"to": "hub", "on": "hub_id = hub.id", "relationship": "many_to_one"}
],
"measures": [{"name": f"{name}_total", "expr": "sum(val)"}],
}
return {
"hub": {
"name": "hub",
"table": "dbo.hub",
"grain": ["id"],
"columns": [
{"name": "id", "type": "number"},
{"name": "segment", "type": "string"},
],
},
"fact_a": fact("fact_a"),
"fact_b": fact("fact_b"),
}
def _assert_valid(sql: str, dialect: str) -> None:
parsed = sqlglot.parse(sql, read=dialect)
assert parsed and all(stmt is not None for stmt in parsed), sql
# Dialects whose grammar contains the `alias = expression` projection form
# alongside ones that do not, to guard the cross-dialect contract.
DIALECTS = ["tsql", "postgres", "snowflake", "bigquery"]
@pytest.mark.parametrize("dialect", DIALECTS)
def test_measure_equality_filter_compiles_as_comparison(dialect):
engine = make_engine({"jobs": _jobs_source()}, dialect=dialect)
sql = engine.query({"measures": ["jobs.roofing_rev"]}).sql
_assert_valid(sql, dialect)
assert "CASE WHEN" in sql.upper()
assert "'Roofing'" in sql
# The filter must remain an equality comparison, never an aliased literal.
assert "AS trade_name" not in sql
assert "'Roofing' AS" not in sql
@pytest.mark.parametrize("dialect", DIALECTS)
def test_segment_equality_filter_compiles_as_comparison(dialect):
engine = make_engine({"jobs": _jobs_source()}, dialect=dialect)
sql = engine.query(
{"measures": ["sum(jobs.amount)"], "segments": ["jobs.is_roofing"]}
).sql
_assert_valid(sql, dialect)
assert "CASE WHEN" in sql.upper()
assert "'Roofing' AS" not in sql
@pytest.mark.parametrize("dialect", DIALECTS)
def test_where_equality_filter_compiles_as_comparison(dialect):
engine = make_engine({"jobs": _jobs_source()}, dialect=dialect)
sql = engine.query(
{"measures": ["sum(jobs.amount)"], "filters": ["jobs.trade_name = 'Roofing'"]}
).sql
_assert_valid(sql, dialect)
assert "WHERE" in sql.upper()
assert "'Roofing' AS" not in sql
@pytest.mark.parametrize("dialect", DIALECTS)
def test_computed_column_expands_in_equality_where_filter(dialect):
engine = make_engine({"jobs": _jobs_source()}, dialect=dialect)
sql = engine.query(
{"measures": ["sum(jobs.amount)"], "filters": ["jobs.doubled = 10"]}
).sql
_assert_valid(sql, dialect)
# `doubled` is computed (base * 2); the filter must reference the underlying
# expression, not the (non-existent) computed column name.
assert "base" in sql
assert "doubled" not in sql
@pytest.mark.parametrize("dialect", DIALECTS)
def test_locality_named_measure_equality_filter_compiles_as_comparison(dialect):
"""Aggregate-locality HAVING filter on a named measure must compile to a
comparison, not the `0 AS measure` alias form T-SQL emits."""
engine = make_engine(_chasm_sources(), dialect=dialect)
sql = engine.query(
{
"measures": ["fact_a.fact_a_total", "fact_b.fact_b_total"],
"dimensions": ["hub.segment"],
"filters": ["fact_a.fact_a_total = 0"],
}
).sql
_assert_valid(sql, dialect)
assert "0 AS fact_a_total" not in sql
def test_tsql_equality_and_in_filters_are_equivalent_shape():
"""On T-SQL, `= 'x'` and `IN ('x')` filters both compile to predicates."""
eq_engine = make_engine({"jobs": _jobs_source()}, dialect="tsql")
in_engine = make_engine(
{
"jobs": _jobs_source(
measures=[
{
"name": "roofing_rev",
"expr": "sum(amount)",
"filter": "trade_name IN ('Roofing')",
}
]
)
},
dialect="tsql",
)
eq_sql = eq_engine.query({"measures": ["jobs.roofing_rev"]}).sql
in_sql = in_engine.query({"measures": ["jobs.roofing_rev"]}).sql
_assert_valid(eq_sql, "tsql")
_assert_valid(in_sql, "tsql")
assert "jobs.trade_name = 'Roofing'" in eq_sql
assert "jobs.trade_name IN ('Roofing')" in in_sql