From fb50c11d1678dbacc99b7f528c80392c981fb9f9 Mon Sep 17 00:00:00 2001 From: Luca Martial <48870843+luca-martial@users.noreply.github.com> Date: Fri, 19 Jun 2026 01:47:44 -0700 Subject: [PATCH] fix(sl): parse user filter expressions as predicates, not projections (#307) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix(sl): parse user filter expressions as predicates, not projections User-authored filters and segments were parsed in a projection context (`SELECT {expr}`). On T-SQL a top-level `col = 'value'` projection is the `alias = expression` aliasing syntax, so an equality filter parsed this way became `'value' AS col` — dropping the comparison entirely and silently skipping computed-column expansion (the column hid behind the alias). Parse user fragments as predicates (`SELECT * WHERE {expr}`) at every parse site — the parser cache, measure-filter CASE WHEN generation, computed-column expansion, and measure-filter/segment column qualification. For plain non-condition expressions the column set is identical, so this is a no-op everywhere except the T-SQL alias case it fixes. Add cross-dialect regression tests (tsql, postgres, snowflake, bigquery) locking equality filters/segments to comparison shape and confirming `= 'x'` now matches `IN ('x')` on T-SQL. Co-Authored-By: Claude Opus 4.8 * Shorten T-SQL predicate comments * docs(sl): tighten T-SQL predicate docstrings and AGENTS docstring rule Trim the parser and regression-test docstrings to the 1-3 line bar and extend the AGENTS.md comment guidance to cover docstrings explicitly. * refactor(sl): route all filter parsing through parse_predicate Consolidate the predicate-context parse into a single parse_predicate helper and route every filter-parsing call site through it: measure CASE-WHEN filters, segments, computed-column-in-filter, the aggregate-locality HAVING rewrite, and the planner OR-mixing / top-level-AND split. The locality and split paths still parsed user filters in projection context, so a named-measure equality filter compiled to `0 AS measure` on T-SQL. Add a locality regression test covering the HAVING rewrite path. --------- Co-authored-by: Claude Opus 4.8 Co-authored-by: Andrey Avtomonov --- AGENTS.md | 25 ++- python/ktx-sl/semantic_layer/generator.py | 37 ++-- python/ktx-sl/semantic_layer/parser.py | 32 +++- python/ktx-sl/semantic_layer/planner.py | 32 +--- .../test_tsql_filter_alias_regression.py | 175 ++++++++++++++++++ 5 files changed, 238 insertions(+), 63 deletions(-) create mode 100644 python/ktx-sl/tests/test_tsql_filter_alias_regression.py diff --git a/AGENTS.md b/AGENTS.md index 6f6dec86..4bdecf89 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -263,22 +263,31 @@ and route ingest, setup, memory, indexing, and docs through it. Do not add an `auto_commit`-style switch unless the user explicitly asks for staged-only runs and accepts the extra runtime path. -## Code Comments +## Code Comments and Docstrings -Code must be self-explanatory. A comment exists only to state a constraint the -code cannot show; everything else belongs in the PR description or nowhere. +Code must be self-explanatory. Clear names, types, and signatures do the +documenting; a comment or docstring exists only to state what the code cannot +show. Everything else belongs in the PR description or nowhere. - **MUST**: Keep each comment to 1-3 lines stating only what the code cannot show: a cross-file invariant ("error-severity issues never reach here — the doctor exits on them first"), a required ordering ("ktx.yaml is written before git init, so a crash cannot leave a bare `.git`"), or a library quirk ("zod reports unknown record keys as `invalid_key`"). +- **MUST**: Hold docstrings (Python `"""..."""`, JSDoc/TSDoc) to the same bar. + A docstring states a function's purpose or contract in 1-3 lines; when a real + quirk or invariant motivates the code, note it once and briefly. Let + self-explanatory code carry the rest — a well-named, well-typed function + often needs no docstring at all. - **MUST**: State each invariant once, at the public entry point. Do not repeat - the same guarantee across a helper, its wrapper, and the call site. -- **MUST NOT**: Write prose comment blocks — design rationale, alternatives - considered, change narration ("is now written before…"), caller enumerations - ("shared by X, Y, and Z"), or restatements of what the code already shows. - That is the author addressing the reviewer, and it rots once merged. + the same guarantee across a module docstring, a helper, its wrapper, and the + call site. +- **MUST NOT**: Write multi-paragraph docstrings or prose comment blocks — + design rationale, alternatives considered, change narration ("is now written + before…"), caller enumerations ("shared by X, Y, and Z"), worked examples + that restate the code, or the same explanation repeated in a module docstring + and the function it describes. That is the author addressing the reviewer; it + belongs in the PR description and rots once merged. - **MAY**: Open a regression test with a 1-3 line comment stating the scenario it guards when the test name cannot carry it. Omit design history and references to removed designs. diff --git a/python/ktx-sl/semantic_layer/generator.py b/python/ktx-sl/semantic_layer/generator.py index 5309018f..8207bac1 100644 --- a/python/ktx-sl/semantic_layer/generator.py +++ b/python/ktx-sl/semantic_layer/generator.py @@ -15,7 +15,11 @@ from semantic_layer.models import ( ResolvedPlan, SourceDefinition, ) -from semantic_layer.parser import ExpressionParser, quote_reserved_identifiers +from semantic_layer.parser import ( + ExpressionParser, + parse_predicate, + quote_reserved_identifiers, +) # DIALECT CONVENTION: # User-authored SQL fragments (measure `expr`, segment `expr`, filter, @@ -673,9 +677,7 @@ class SqlGenerator: if isinstance(select_expr, exp.Alias): select_expr = select_expr.this - filter_cond = sqlglot.parse_one( - f"SELECT {filter_sql}", read=self.dialect - ).expressions[0] + filter_cond = parse_predicate(filter_sql, self.dialect) def _make_case(inner_node): return exp.Case( @@ -1073,10 +1075,7 @@ class SqlGenerator: # AST-based rewriting for robustness try: - tree = sqlglot.parse_one( - f"SELECT {quote_reserved_identifiers(filter_expr)}", - read=self.dialect, - ) + condition = parse_predicate(filter_expr, self.dialect) def _rewrite(node): if isinstance(node, (exp.AggFunc, exp.Anonymous)): @@ -1099,8 +1098,8 @@ class SqlGenerator: ).expressions[0] return node - transformed = tree.transform(_rewrite) - return transformed.expressions[0].sql(dialect=self.dialect) + transformed = condition.transform(_rewrite) + return transformed.sql(dialect=self.dialect) except Exception: logger.debug( "AST-based HAVING rewrite failed for locality filter, falling back to regex: %s", @@ -1292,9 +1291,7 @@ class SqlGenerator: return expr try: - tree = sqlglot.parse_one( - f"SELECT {quote_reserved_identifiers(expr)}", read=self.dialect - ) + condition = parse_predicate(expr, self.dialect) changed = False @@ -1309,9 +1306,9 @@ class SqlGenerator: ).expressions[0] return node - transformed = tree.transform(_replace) + transformed = condition.transform(_replace) if changed: - return transformed.expressions[0].sql(dialect=self.dialect) + return transformed.sql(dialect=self.dialect) except Exception: logger.debug("AST-based computed column expansion failed for: %s", expr) @@ -1357,13 +1354,7 @@ class SqlGenerator: # Use AST to find and replace column references matching measure names try: - tree = sqlglot.parse_one( - f"SELECT * WHERE {quote_reserved_identifiers(f)}", - dialect=self.dialect, - ) - where = tree.find(exp.Where) - if not where: - return f + condition = parse_predicate(f, self.dialect) changed = False @@ -1388,7 +1379,7 @@ class SqlGenerator: ).expressions[0] return node - new_where = where.this.transform(_replace) + new_where = condition.transform(_replace) if changed: return new_where.sql(dialect=self.dialect) except Exception: diff --git a/python/ktx-sl/semantic_layer/parser.py b/python/ktx-sl/semantic_layer/parser.py index 39da6813..b6f2e3d7 100644 --- a/python/ktx-sl/semantic_layer/parser.py +++ b/python/ktx-sl/semantic_layer/parser.py @@ -196,6 +196,15 @@ def quote_reserved_identifiers(expr: str) -> str: return result +def _predicate_select(expr: str) -> str: + """Wrap a user expression as `SELECT * WHERE …`, quoting reserved identifiers. + + Predicate, not projection: T-SQL reads a top-level `col = 'value'` projection + as the `alias = expression` form and would compile the filter to `'value' AS col`. + """ + return f"SELECT * WHERE {quote_reserved_identifiers(expr)}" + + @functools.lru_cache(maxsize=256) def _cached_parse_select(sql: str, dialect: str) -> exp.Expression: """Cache parsed SELECT wrapper trees keyed by (sql, dialect). @@ -206,6 +215,14 @@ def _cached_parse_select(sql: str, dialect: str) -> exp.Expression: return sqlglot.parse_one(sql, read=dialect) +def parse_predicate(expr: str, dialect: str) -> exp.Expression: + """Parse a user filter into a fresh, mutable WHERE-condition node. + + Uncached, so the result is safe to `.transform()`; raises on unparseable input. + """ + return sqlglot.parse_one(_predicate_select(expr), read=dialect).find(exp.Where).this + + class ExpressionParser: """Parses user-authored SQL expressions for AST walks. @@ -218,12 +235,9 @@ class ExpressionParser: def __init__(self, dialect: str = "postgres") -> None: self.dialect = dialect - def _quote_reserved_identifiers(self, expr: str) -> str: - return quote_reserved_identifiers(expr) - - def _parse_as_select(self, quoted_expr: str) -> exp.Expression: - """Parse expression wrapped in SELECT, using cache for repeated expressions.""" - return _cached_parse_select(f"SELECT {quoted_expr}", self.dialect) + def _parse_as_select(self, expr: str) -> exp.Expression: + """Parse a user fragment for read-only AST walks, via the parse cache.""" + return _cached_parse_select(_predicate_select(expr), self.dialect) def parse( self, @@ -236,8 +250,7 @@ class ExpressionParser: if not expr or not expr.strip(): return result - quoted_expr = self._quote_reserved_identifiers(expr) - tree = self._parse_as_select(quoted_expr) + tree = self._parse_as_select(expr) # Extract source.column references for col in tree.find_all(exp.Column): @@ -296,8 +309,7 @@ class ExpressionParser: """Quick extraction of source names from an expression.""" if not expr or not expr.strip(): return set() - quoted_expr = self._quote_reserved_identifiers(expr) - tree = self._parse_as_select(quoted_expr) + tree = self._parse_as_select(expr) return { _strip_quotes(col.table) for col in tree.find_all(exp.Column) if col.table } diff --git a/python/ktx-sl/semantic_layer/planner.py b/python/ktx-sl/semantic_layer/planner.py index e5ebf02e..8f0c9cca 100644 --- a/python/ktx-sl/semantic_layer/planner.py +++ b/python/ktx-sl/semantic_layer/planner.py @@ -22,7 +22,11 @@ from semantic_layer.models import ( SemanticQuery, SourceDefinition, ) -from semantic_layer.parser import ExpressionParser, quote_reserved_identifiers +from semantic_layer.parser import ( + ExpressionParser, + parse_predicate, + quote_reserved_identifiers, +) # DIALECT CONVENTION: # User-authored measure `expr`, `filter`, and computed-column fragments must @@ -910,9 +914,7 @@ class QueryPlanner: for c in source.columns: col_to_source[c.name] = source_name - tree = sqlglot.parse_one( - f"SELECT {quote_reserved_identifiers(expr)}", read=self.dialect - ) + condition = parse_predicate(expr, self.dialect) def _qualify_column(node): if ( @@ -926,8 +928,8 @@ class QueryPlanner: ) return node - transformed = tree.transform(_qualify_column) - return transformed.expressions[0].sql(dialect=self.dialect) + transformed = condition.transform(_qualify_column) + return transformed.sql(dialect=self.dialect) def _detect_fan_out( self, @@ -1254,14 +1256,7 @@ class QueryPlanner: ) -> None: """Raise an error if an OR expression mixes WHERE and HAVING conditions.""" try: - tree = sqlglot.parse_one( - f"SELECT * WHERE {quote_reserved_identifiers(clause)}", - dialect=self.dialect, - ) - where = tree.find(exp.Where) - if not where: - return - inner = where.this + inner = parse_predicate(clause, self.dialect) # Only check if the top level contains OR or_parts: list[str] = [] @@ -1295,14 +1290,7 @@ class QueryPlanner: def _split_top_level_and(self, expr: str) -> list[str]: """Split a filter expression on top-level AND (not inside parentheses or strings).""" try: - tree = sqlglot.parse_one( - f"SELECT * WHERE {quote_reserved_identifiers(expr)}", - dialect=self.dialect, - ) - where = tree.find(exp.Where) - if not where: - return [expr] - inner = where.this + inner = parse_predicate(expr, self.dialect) parts: list[str] = [] def _collect_and(node): diff --git a/python/ktx-sl/tests/test_tsql_filter_alias_regression.py b/python/ktx-sl/tests/test_tsql_filter_alias_regression.py new file mode 100644 index 00000000..c219e514 --- /dev/null +++ b/python/ktx-sl/tests/test_tsql_filter_alias_regression.py @@ -0,0 +1,175 @@ +"""Regression tests for T-SQL `=` filters mis-parsed as column aliases. + +A top-level `col = 'value'` is T-SQL's `alias = expression` projection form, so +filters and segments must compile as predicates, not projections, on any dialect. +""" + +from __future__ import annotations + +import pytest +import sqlglot + +from .conftest import make_engine + + +def _jobs_source(**overrides): + base = { + "name": "jobs", + "table": "dbo.jobs", + "grain": ["id"], + "columns": [ + {"name": "id", "type": "number"}, + {"name": "amount", "type": "number"}, + {"name": "trade_name", "type": "string"}, + {"name": "base", "type": "number"}, + {"name": "doubled", "type": "number", "expr": "base * 2"}, + ], + "segments": [ + {"name": "is_roofing", "expr": "trade_name = 'Roofing'"}, + ], + "measures": [ + { + "name": "roofing_rev", + "expr": "sum(amount)", + "filter": "trade_name = 'Roofing'", + }, + ], + } + base.update(overrides) + return base + + +def _chasm_sources(): + """Two facts fanning out from a shared hub — triggers aggregate locality, the + path that rewrites HAVING filters against measure references.""" + fact = lambda name: { # noqa: E731 + "name": name, + "table": f"dbo.{name}", + "grain": ["id"], + "columns": [ + {"name": "id", "type": "number"}, + {"name": "hub_id", "type": "number"}, + {"name": "val", "type": "number"}, + ], + "joins": [ + {"to": "hub", "on": "hub_id = hub.id", "relationship": "many_to_one"} + ], + "measures": [{"name": f"{name}_total", "expr": "sum(val)"}], + } + return { + "hub": { + "name": "hub", + "table": "dbo.hub", + "grain": ["id"], + "columns": [ + {"name": "id", "type": "number"}, + {"name": "segment", "type": "string"}, + ], + }, + "fact_a": fact("fact_a"), + "fact_b": fact("fact_b"), + } + + +def _assert_valid(sql: str, dialect: str) -> None: + parsed = sqlglot.parse(sql, read=dialect) + assert parsed and all(stmt is not None for stmt in parsed), sql + + +# Dialects whose grammar contains the `alias = expression` projection form +# alongside ones that do not, to guard the cross-dialect contract. +DIALECTS = ["tsql", "postgres", "snowflake", "bigquery"] + + +@pytest.mark.parametrize("dialect", DIALECTS) +def test_measure_equality_filter_compiles_as_comparison(dialect): + engine = make_engine({"jobs": _jobs_source()}, dialect=dialect) + sql = engine.query({"measures": ["jobs.roofing_rev"]}).sql + + _assert_valid(sql, dialect) + assert "CASE WHEN" in sql.upper() + assert "'Roofing'" in sql + # The filter must remain an equality comparison, never an aliased literal. + assert "AS trade_name" not in sql + assert "'Roofing' AS" not in sql + + +@pytest.mark.parametrize("dialect", DIALECTS) +def test_segment_equality_filter_compiles_as_comparison(dialect): + engine = make_engine({"jobs": _jobs_source()}, dialect=dialect) + sql = engine.query( + {"measures": ["sum(jobs.amount)"], "segments": ["jobs.is_roofing"]} + ).sql + + _assert_valid(sql, dialect) + assert "CASE WHEN" in sql.upper() + assert "'Roofing' AS" not in sql + + +@pytest.mark.parametrize("dialect", DIALECTS) +def test_where_equality_filter_compiles_as_comparison(dialect): + engine = make_engine({"jobs": _jobs_source()}, dialect=dialect) + sql = engine.query( + {"measures": ["sum(jobs.amount)"], "filters": ["jobs.trade_name = 'Roofing'"]} + ).sql + + _assert_valid(sql, dialect) + assert "WHERE" in sql.upper() + assert "'Roofing' AS" not in sql + + +@pytest.mark.parametrize("dialect", DIALECTS) +def test_computed_column_expands_in_equality_where_filter(dialect): + engine = make_engine({"jobs": _jobs_source()}, dialect=dialect) + sql = engine.query( + {"measures": ["sum(jobs.amount)"], "filters": ["jobs.doubled = 10"]} + ).sql + + _assert_valid(sql, dialect) + # `doubled` is computed (base * 2); the filter must reference the underlying + # expression, not the (non-existent) computed column name. + assert "base" in sql + assert "doubled" not in sql + + +@pytest.mark.parametrize("dialect", DIALECTS) +def test_locality_named_measure_equality_filter_compiles_as_comparison(dialect): + """Aggregate-locality HAVING filter on a named measure must compile to a + comparison, not the `0 AS measure` alias form T-SQL emits.""" + engine = make_engine(_chasm_sources(), dialect=dialect) + sql = engine.query( + { + "measures": ["fact_a.fact_a_total", "fact_b.fact_b_total"], + "dimensions": ["hub.segment"], + "filters": ["fact_a.fact_a_total = 0"], + } + ).sql + + _assert_valid(sql, dialect) + assert "0 AS fact_a_total" not in sql + + +def test_tsql_equality_and_in_filters_are_equivalent_shape(): + """On T-SQL, `= 'x'` and `IN ('x')` filters both compile to predicates.""" + eq_engine = make_engine({"jobs": _jobs_source()}, dialect="tsql") + in_engine = make_engine( + { + "jobs": _jobs_source( + measures=[ + { + "name": "roofing_rev", + "expr": "sum(amount)", + "filter": "trade_name IN ('Roofing')", + } + ] + ) + }, + dialect="tsql", + ) + eq_sql = eq_engine.query({"measures": ["jobs.roofing_rev"]}).sql + in_sql = in_engine.query({"measures": ["jobs.roofing_rev"]}).sql + + _assert_valid(eq_sql, "tsql") + _assert_valid(in_sql, "tsql") + assert "jobs.trade_name = 'Roofing'" in eq_sql + assert "jobs.trade_name IN ('Roofing')" in in_sql