mirror of
https://github.com/Kaelio/ktx.git
synced 2026-06-07 07:55:13 +02:00
304 lines
8.6 KiB
Python
304 lines
8.6 KiB
Python
|
|
from __future__ import annotations
|
||
|
|
|
||
|
|
import functools
|
||
|
|
import re
|
||
|
|
from dataclasses import dataclass, field
|
||
|
|
|
||
|
|
import sqlglot
|
||
|
|
from sqlglot import exp
|
||
|
|
|
||
|
|
# DIALECT CONVENTION:
|
||
|
|
# `ExpressionParser` wraps read-only AST walks over user-authored
|
||
|
|
# expressions. Callers must construct it with the connection's native
|
||
|
|
# dialect (per sl_capture). The parse cache is keyed on (sql, dialect)
|
||
|
|
# so engines with different dialects do not share AST collisions.
|
||
|
|
|
||
|
|
AGGREGATE_FUNCTIONS = frozenset(
|
||
|
|
{
|
||
|
|
"sum",
|
||
|
|
"avg",
|
||
|
|
"count",
|
||
|
|
"count_distinct",
|
||
|
|
"min",
|
||
|
|
"max",
|
||
|
|
"median",
|
||
|
|
"percentile",
|
||
|
|
}
|
||
|
|
)
|
||
|
|
|
||
|
|
# Maps sqlglot AggFunc subclasses to our canonical names
|
||
|
|
_AGG_NODE_MAP: dict[type, str] = {
|
||
|
|
exp.Sum: "sum",
|
||
|
|
exp.Avg: "avg",
|
||
|
|
exp.Count: "count",
|
||
|
|
exp.Min: "min",
|
||
|
|
exp.Max: "max",
|
||
|
|
}
|
||
|
|
|
||
|
|
# Custom aggregates that sqlglot parses as Anonymous (not standard SQL)
|
||
|
|
_CUSTOM_AGG_NAMES = frozenset({"count_distinct", "percentile", "median"})
|
||
|
|
|
||
|
|
# SQL reserved words that cause parse failures when used as identifiers
|
||
|
|
_SQL_RESERVED = frozenset(
|
||
|
|
{
|
||
|
|
"select",
|
||
|
|
"from",
|
||
|
|
"where",
|
||
|
|
"group",
|
||
|
|
"order",
|
||
|
|
"by",
|
||
|
|
"having",
|
||
|
|
"limit",
|
||
|
|
"join",
|
||
|
|
"on",
|
||
|
|
"as",
|
||
|
|
"and",
|
||
|
|
"or",
|
||
|
|
"not",
|
||
|
|
"in",
|
||
|
|
"is",
|
||
|
|
"null",
|
||
|
|
"true",
|
||
|
|
"false",
|
||
|
|
"between",
|
||
|
|
"like",
|
||
|
|
"case",
|
||
|
|
"when",
|
||
|
|
"then",
|
||
|
|
"else",
|
||
|
|
"end",
|
||
|
|
"insert",
|
||
|
|
"update",
|
||
|
|
"delete",
|
||
|
|
"create",
|
||
|
|
"drop",
|
||
|
|
"alter",
|
||
|
|
"table",
|
||
|
|
"index",
|
||
|
|
"view",
|
||
|
|
"union",
|
||
|
|
"all",
|
||
|
|
"distinct",
|
||
|
|
"into",
|
||
|
|
"values",
|
||
|
|
"set",
|
||
|
|
"with",
|
||
|
|
"exists",
|
||
|
|
"any",
|
||
|
|
"some",
|
||
|
|
"offset",
|
||
|
|
"fetch",
|
||
|
|
"for",
|
||
|
|
"grant",
|
||
|
|
"revoke",
|
||
|
|
"primary",
|
||
|
|
"key",
|
||
|
|
"foreign",
|
||
|
|
"references",
|
||
|
|
"check",
|
||
|
|
"constraint",
|
||
|
|
"default",
|
||
|
|
"column",
|
||
|
|
"cross",
|
||
|
|
"full",
|
||
|
|
"inner",
|
||
|
|
"left",
|
||
|
|
"right",
|
||
|
|
"outer",
|
||
|
|
"natural",
|
||
|
|
"using",
|
||
|
|
"except",
|
||
|
|
"intersect",
|
||
|
|
# Snowflake / cross-dialect reserved words
|
||
|
|
"glob",
|
||
|
|
"ilike",
|
||
|
|
"lateral",
|
||
|
|
"match_recognize",
|
||
|
|
"notnull",
|
||
|
|
"out",
|
||
|
|
"qualify",
|
||
|
|
"regexp",
|
||
|
|
"returning",
|
||
|
|
"rlike",
|
||
|
|
"rollback",
|
||
|
|
"sample",
|
||
|
|
"tablesample",
|
||
|
|
"top",
|
||
|
|
"uncache",
|
||
|
|
"xor",
|
||
|
|
}
|
||
|
|
)
|
||
|
|
|
||
|
|
# Regex pattern for source.column references (word.word)
|
||
|
|
_DOTTED_IDENT_RE = re.compile(r"\b(\w+)\.(\w+)\b")
|
||
|
|
|
||
|
|
# Matches single-quoted SQL string literals (including escaped quotes '')
|
||
|
|
_STRING_LITERAL_RE = re.compile(r"'(?:[^']|'')*'")
|
||
|
|
|
||
|
|
|
||
|
|
@dataclass
|
||
|
|
class ParsedExpression:
|
||
|
|
original: str
|
||
|
|
source_refs: set[str] = field(default_factory=set)
|
||
|
|
column_refs: set[str] = field(default_factory=set) # "source.column" format
|
||
|
|
is_aggregate: bool = False
|
||
|
|
aggregate_function: str | None = None
|
||
|
|
has_window_function: bool = False
|
||
|
|
depends_on_measures: set[str] = field(default_factory=set)
|
||
|
|
|
||
|
|
|
||
|
|
def _strip_quotes(name: str) -> str:
|
||
|
|
"""Strip surrounding double quotes from an identifier."""
|
||
|
|
if name.startswith('"') and name.endswith('"'):
|
||
|
|
return name[1:-1]
|
||
|
|
return name
|
||
|
|
|
||
|
|
|
||
|
|
def quote_reserved_identifiers(expr: str) -> str:
|
||
|
|
"""Quote source.column references where either part is a SQL reserved word.
|
||
|
|
|
||
|
|
String literals are masked before processing to prevent matching
|
||
|
|
dotted identifiers inside quoted strings like 'group.value'.
|
||
|
|
"""
|
||
|
|
# Mask string literals to avoid matching inside them
|
||
|
|
literals: list[str] = []
|
||
|
|
|
||
|
|
def _mask_literal(m: re.Match) -> str:
|
||
|
|
literals.append(m.group(0))
|
||
|
|
return f"__SL_LIT_{len(literals) - 1}__"
|
||
|
|
|
||
|
|
masked = _STRING_LITERAL_RE.sub(_mask_literal, expr)
|
||
|
|
|
||
|
|
def _quote_match(m: re.Match) -> str:
|
||
|
|
source, col = m.group(1), m.group(2)
|
||
|
|
start = m.start()
|
||
|
|
if start > 0 and masked[start - 1] == '"':
|
||
|
|
return m.group(0)
|
||
|
|
needs_quote = False
|
||
|
|
source_q = source
|
||
|
|
col_q = col
|
||
|
|
if source.lower() in _SQL_RESERVED:
|
||
|
|
source_q = f'"{source}"'
|
||
|
|
needs_quote = True
|
||
|
|
if col.lower() in _SQL_RESERVED:
|
||
|
|
col_q = f'"{col}"'
|
||
|
|
needs_quote = True
|
||
|
|
if needs_quote:
|
||
|
|
return f"{source_q}.{col_q}"
|
||
|
|
return m.group(0)
|
||
|
|
|
||
|
|
result = _DOTTED_IDENT_RE.sub(_quote_match, masked)
|
||
|
|
|
||
|
|
# Restore string literals
|
||
|
|
for i, lit in enumerate(literals):
|
||
|
|
result = result.replace(f"__SL_LIT_{i}__", lit)
|
||
|
|
|
||
|
|
return result
|
||
|
|
|
||
|
|
|
||
|
|
@functools.lru_cache(maxsize=256)
|
||
|
|
def _cached_parse_select(sql: str, dialect: str) -> exp.Expression:
|
||
|
|
"""Cache parsed SELECT wrapper trees keyed by (sql, dialect).
|
||
|
|
|
||
|
|
Each (sql, dialect) pair gets its own entry, so engines using different
|
||
|
|
dialects don't share AST cache collisions.
|
||
|
|
"""
|
||
|
|
return sqlglot.parse_one(sql, read=dialect)
|
||
|
|
|
||
|
|
|
||
|
|
class ExpressionParser:
|
||
|
|
"""Parses user-authored SQL expressions for AST walks.
|
||
|
|
|
||
|
|
Must be constructed with the connection's native dialect. User-authored
|
||
|
|
`expr:`, `filter:`, and segment predicates from YAML are written in that
|
||
|
|
dialect (per the sl_capture skill contract) and parsing them as postgres
|
||
|
|
silently drops dialect-specific tokens (e.g. BigQuery `INTERVAL 30 DAY`).
|
||
|
|
"""
|
||
|
|
|
||
|
|
def __init__(self, dialect: str = "postgres") -> None:
|
||
|
|
self.dialect = dialect
|
||
|
|
|
||
|
|
def _quote_reserved_identifiers(self, expr: str) -> str:
|
||
|
|
return quote_reserved_identifiers(expr)
|
||
|
|
|
||
|
|
def _parse_as_select(self, quoted_expr: str) -> exp.Expression:
|
||
|
|
"""Parse expression wrapped in SELECT, using cache for repeated expressions."""
|
||
|
|
return _cached_parse_select(f"SELECT {quoted_expr}", self.dialect)
|
||
|
|
|
||
|
|
def parse(
|
||
|
|
self,
|
||
|
|
expr: str,
|
||
|
|
known_measure_names: set[str] | None = None,
|
||
|
|
) -> ParsedExpression:
|
||
|
|
known_measure_names = known_measure_names or set()
|
||
|
|
result = ParsedExpression(original=expr)
|
||
|
|
|
||
|
|
if not expr or not expr.strip():
|
||
|
|
return result
|
||
|
|
|
||
|
|
quoted_expr = self._quote_reserved_identifiers(expr)
|
||
|
|
tree = self._parse_as_select(quoted_expr)
|
||
|
|
|
||
|
|
# Extract source.column references
|
||
|
|
for col in tree.find_all(exp.Column):
|
||
|
|
if col.table:
|
||
|
|
source_name = _strip_quotes(col.table)
|
||
|
|
col_name = _strip_quotes(col.name)
|
||
|
|
result.source_refs.add(source_name)
|
||
|
|
result.column_refs.add(f"{source_name}.{col_name}")
|
||
|
|
|
||
|
|
# Detect aggregate functions (built-in AggFunc subclasses).
|
||
|
|
# Aggregates nested inside scalar/correlated subqueries do NOT make the
|
||
|
|
# outer expression aggregate — e.g. `col = (SELECT MAX(col) FROM t)` is a
|
||
|
|
# plain column predicate, not a HAVING candidate.
|
||
|
|
def _inside_subquery(node: exp.Expression) -> bool:
|
||
|
|
parent = node.parent
|
||
|
|
while parent is not None:
|
||
|
|
if isinstance(parent, exp.Subquery):
|
||
|
|
return True
|
||
|
|
parent = parent.parent
|
||
|
|
return False
|
||
|
|
|
||
|
|
agg_names: list[str] = []
|
||
|
|
for node in tree.find_all(exp.AggFunc):
|
||
|
|
if _inside_subquery(node):
|
||
|
|
continue
|
||
|
|
name = _AGG_NODE_MAP.get(type(node))
|
||
|
|
if name:
|
||
|
|
agg_names.append(name)
|
||
|
|
else:
|
||
|
|
agg_names.append(node.key.lower())
|
||
|
|
|
||
|
|
# Detect custom aggregates parsed as Anonymous (count_distinct, percentile, median)
|
||
|
|
for node in tree.find_all(exp.Anonymous):
|
||
|
|
if _inside_subquery(node):
|
||
|
|
continue
|
||
|
|
if node.name.lower() in _CUSTOM_AGG_NAMES:
|
||
|
|
agg_names.append(node.name.lower())
|
||
|
|
|
||
|
|
if agg_names:
|
||
|
|
result.is_aggregate = True
|
||
|
|
result.aggregate_function = agg_names[0]
|
||
|
|
|
||
|
|
# Detect window functions (OVER clause)
|
||
|
|
if tree.find(exp.Window):
|
||
|
|
result.has_window_function = True
|
||
|
|
|
||
|
|
# Detect dependencies on named measures (bare identifiers without table qualifier)
|
||
|
|
if known_measure_names:
|
||
|
|
for col in tree.find_all(exp.Column):
|
||
|
|
if not col.table and col.name in known_measure_names:
|
||
|
|
result.depends_on_measures.add(col.name)
|
||
|
|
|
||
|
|
return result
|
||
|
|
|
||
|
|
def extract_source_refs(self, expr: str) -> set[str]:
|
||
|
|
"""Quick extraction of source names from an expression."""
|
||
|
|
if not expr or not expr.strip():
|
||
|
|
return set()
|
||
|
|
quoted_expr = self._quote_reserved_identifiers(expr)
|
||
|
|
tree = self._parse_as_select(quoted_expr)
|
||
|
|
return {
|
||
|
|
_strip_quotes(col.table) for col in tree.find_all(exp.Column) if col.table
|
||
|
|
}
|