mirror of
https://github.com/Kaelio/ktx.git
synced 2026-06-10 08:05:14 +02:00
Initial open-source release
This commit is contained in:
commit
1a42152e6f
1199 changed files with 257054 additions and 0 deletions
303
python/klo-sl/semantic_layer/parser.py
Normal file
303
python/klo-sl/semantic_layer/parser.py
Normal file
|
|
@ -0,0 +1,303 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import sqlglot
|
||||
from sqlglot import exp
|
||||
|
||||
# DIALECT CONVENTION:
|
||||
# `ExpressionParser` wraps read-only AST walks over user-authored
|
||||
# expressions. Callers must construct it with the connection's native
|
||||
# dialect (per sl_capture). The parse cache is keyed on (sql, dialect)
|
||||
# so engines with different dialects do not share AST collisions.
|
||||
|
||||
AGGREGATE_FUNCTIONS = frozenset(
|
||||
{
|
||||
"sum",
|
||||
"avg",
|
||||
"count",
|
||||
"count_distinct",
|
||||
"min",
|
||||
"max",
|
||||
"median",
|
||||
"percentile",
|
||||
}
|
||||
)
|
||||
|
||||
# Maps sqlglot AggFunc subclasses to our canonical names
|
||||
_AGG_NODE_MAP: dict[type, str] = {
|
||||
exp.Sum: "sum",
|
||||
exp.Avg: "avg",
|
||||
exp.Count: "count",
|
||||
exp.Min: "min",
|
||||
exp.Max: "max",
|
||||
}
|
||||
|
||||
# Custom aggregates that sqlglot parses as Anonymous (not standard SQL)
|
||||
_CUSTOM_AGG_NAMES = frozenset({"count_distinct", "percentile", "median"})
|
||||
|
||||
# SQL reserved words that cause parse failures when used as identifiers
|
||||
_SQL_RESERVED = frozenset(
|
||||
{
|
||||
"select",
|
||||
"from",
|
||||
"where",
|
||||
"group",
|
||||
"order",
|
||||
"by",
|
||||
"having",
|
||||
"limit",
|
||||
"join",
|
||||
"on",
|
||||
"as",
|
||||
"and",
|
||||
"or",
|
||||
"not",
|
||||
"in",
|
||||
"is",
|
||||
"null",
|
||||
"true",
|
||||
"false",
|
||||
"between",
|
||||
"like",
|
||||
"case",
|
||||
"when",
|
||||
"then",
|
||||
"else",
|
||||
"end",
|
||||
"insert",
|
||||
"update",
|
||||
"delete",
|
||||
"create",
|
||||
"drop",
|
||||
"alter",
|
||||
"table",
|
||||
"index",
|
||||
"view",
|
||||
"union",
|
||||
"all",
|
||||
"distinct",
|
||||
"into",
|
||||
"values",
|
||||
"set",
|
||||
"with",
|
||||
"exists",
|
||||
"any",
|
||||
"some",
|
||||
"offset",
|
||||
"fetch",
|
||||
"for",
|
||||
"grant",
|
||||
"revoke",
|
||||
"primary",
|
||||
"key",
|
||||
"foreign",
|
||||
"references",
|
||||
"check",
|
||||
"constraint",
|
||||
"default",
|
||||
"column",
|
||||
"cross",
|
||||
"full",
|
||||
"inner",
|
||||
"left",
|
||||
"right",
|
||||
"outer",
|
||||
"natural",
|
||||
"using",
|
||||
"except",
|
||||
"intersect",
|
||||
# Snowflake / cross-dialect reserved words
|
||||
"glob",
|
||||
"ilike",
|
||||
"lateral",
|
||||
"match_recognize",
|
||||
"notnull",
|
||||
"out",
|
||||
"qualify",
|
||||
"regexp",
|
||||
"returning",
|
||||
"rlike",
|
||||
"rollback",
|
||||
"sample",
|
||||
"tablesample",
|
||||
"top",
|
||||
"uncache",
|
||||
"xor",
|
||||
}
|
||||
)
|
||||
|
||||
# Regex pattern for source.column references (word.word)
|
||||
_DOTTED_IDENT_RE = re.compile(r"\b(\w+)\.(\w+)\b")
|
||||
|
||||
# Matches single-quoted SQL string literals (including escaped quotes '')
|
||||
_STRING_LITERAL_RE = re.compile(r"'(?:[^']|'')*'")
|
||||
|
||||
|
||||
@dataclass
|
||||
class ParsedExpression:
|
||||
original: str
|
||||
source_refs: set[str] = field(default_factory=set)
|
||||
column_refs: set[str] = field(default_factory=set) # "source.column" format
|
||||
is_aggregate: bool = False
|
||||
aggregate_function: str | None = None
|
||||
has_window_function: bool = False
|
||||
depends_on_measures: set[str] = field(default_factory=set)
|
||||
|
||||
|
||||
def _strip_quotes(name: str) -> str:
|
||||
"""Strip surrounding double quotes from an identifier."""
|
||||
if name.startswith('"') and name.endswith('"'):
|
||||
return name[1:-1]
|
||||
return name
|
||||
|
||||
|
||||
def quote_reserved_identifiers(expr: str) -> str:
|
||||
"""Quote source.column references where either part is a SQL reserved word.
|
||||
|
||||
String literals are masked before processing to prevent matching
|
||||
dotted identifiers inside quoted strings like 'group.value'.
|
||||
"""
|
||||
# Mask string literals to avoid matching inside them
|
||||
literals: list[str] = []
|
||||
|
||||
def _mask_literal(m: re.Match) -> str:
|
||||
literals.append(m.group(0))
|
||||
return f"__SL_LIT_{len(literals) - 1}__"
|
||||
|
||||
masked = _STRING_LITERAL_RE.sub(_mask_literal, expr)
|
||||
|
||||
def _quote_match(m: re.Match) -> str:
|
||||
source, col = m.group(1), m.group(2)
|
||||
start = m.start()
|
||||
if start > 0 and masked[start - 1] == '"':
|
||||
return m.group(0)
|
||||
needs_quote = False
|
||||
source_q = source
|
||||
col_q = col
|
||||
if source.lower() in _SQL_RESERVED:
|
||||
source_q = f'"{source}"'
|
||||
needs_quote = True
|
||||
if col.lower() in _SQL_RESERVED:
|
||||
col_q = f'"{col}"'
|
||||
needs_quote = True
|
||||
if needs_quote:
|
||||
return f"{source_q}.{col_q}"
|
||||
return m.group(0)
|
||||
|
||||
result = _DOTTED_IDENT_RE.sub(_quote_match, masked)
|
||||
|
||||
# Restore string literals
|
||||
for i, lit in enumerate(literals):
|
||||
result = result.replace(f"__SL_LIT_{i}__", lit)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=256)
|
||||
def _cached_parse_select(sql: str, dialect: str) -> exp.Expression:
|
||||
"""Cache parsed SELECT wrapper trees keyed by (sql, dialect).
|
||||
|
||||
Each (sql, dialect) pair gets its own entry, so engines using different
|
||||
dialects don't share AST cache collisions.
|
||||
"""
|
||||
return sqlglot.parse_one(sql, read=dialect)
|
||||
|
||||
|
||||
class ExpressionParser:
|
||||
"""Parses user-authored SQL expressions for AST walks.
|
||||
|
||||
Must be constructed with the connection's native dialect. User-authored
|
||||
`expr:`, `filter:`, and segment predicates from YAML are written in that
|
||||
dialect (per the sl_capture skill contract) and parsing them as postgres
|
||||
silently drops dialect-specific tokens (e.g. BigQuery `INTERVAL 30 DAY`).
|
||||
"""
|
||||
|
||||
def __init__(self, dialect: str = "postgres") -> None:
|
||||
self.dialect = dialect
|
||||
|
||||
def _quote_reserved_identifiers(self, expr: str) -> str:
|
||||
return quote_reserved_identifiers(expr)
|
||||
|
||||
def _parse_as_select(self, quoted_expr: str) -> exp.Expression:
|
||||
"""Parse expression wrapped in SELECT, using cache for repeated expressions."""
|
||||
return _cached_parse_select(f"SELECT {quoted_expr}", self.dialect)
|
||||
|
||||
def parse(
|
||||
self,
|
||||
expr: str,
|
||||
known_measure_names: set[str] | None = None,
|
||||
) -> ParsedExpression:
|
||||
known_measure_names = known_measure_names or set()
|
||||
result = ParsedExpression(original=expr)
|
||||
|
||||
if not expr or not expr.strip():
|
||||
return result
|
||||
|
||||
quoted_expr = self._quote_reserved_identifiers(expr)
|
||||
tree = self._parse_as_select(quoted_expr)
|
||||
|
||||
# Extract source.column references
|
||||
for col in tree.find_all(exp.Column):
|
||||
if col.table:
|
||||
source_name = _strip_quotes(col.table)
|
||||
col_name = _strip_quotes(col.name)
|
||||
result.source_refs.add(source_name)
|
||||
result.column_refs.add(f"{source_name}.{col_name}")
|
||||
|
||||
# Detect aggregate functions (built-in AggFunc subclasses).
|
||||
# Aggregates nested inside scalar/correlated subqueries do NOT make the
|
||||
# outer expression aggregate — e.g. `col = (SELECT MAX(col) FROM t)` is a
|
||||
# plain column predicate, not a HAVING candidate.
|
||||
def _inside_subquery(node: exp.Expression) -> bool:
|
||||
parent = node.parent
|
||||
while parent is not None:
|
||||
if isinstance(parent, exp.Subquery):
|
||||
return True
|
||||
parent = parent.parent
|
||||
return False
|
||||
|
||||
agg_names: list[str] = []
|
||||
for node in tree.find_all(exp.AggFunc):
|
||||
if _inside_subquery(node):
|
||||
continue
|
||||
name = _AGG_NODE_MAP.get(type(node))
|
||||
if name:
|
||||
agg_names.append(name)
|
||||
else:
|
||||
agg_names.append(node.key.lower())
|
||||
|
||||
# Detect custom aggregates parsed as Anonymous (count_distinct, percentile, median)
|
||||
for node in tree.find_all(exp.Anonymous):
|
||||
if _inside_subquery(node):
|
||||
continue
|
||||
if node.name.lower() in _CUSTOM_AGG_NAMES:
|
||||
agg_names.append(node.name.lower())
|
||||
|
||||
if agg_names:
|
||||
result.is_aggregate = True
|
||||
result.aggregate_function = agg_names[0]
|
||||
|
||||
# Detect window functions (OVER clause)
|
||||
if tree.find(exp.Window):
|
||||
result.has_window_function = True
|
||||
|
||||
# Detect dependencies on named measures (bare identifiers without table qualifier)
|
||||
if known_measure_names:
|
||||
for col in tree.find_all(exp.Column):
|
||||
if not col.table and col.name in known_measure_names:
|
||||
result.depends_on_measures.add(col.name)
|
||||
|
||||
return result
|
||||
|
||||
def extract_source_refs(self, expr: str) -> set[str]:
|
||||
"""Quick extraction of source names from an expression."""
|
||||
if not expr or not expr.strip():
|
||||
return set()
|
||||
quoted_expr = self._quote_reserved_identifiers(expr)
|
||||
tree = self._parse_as_select(quoted_expr)
|
||||
return {
|
||||
_strip_quotes(col.table) for col in tree.find_all(exp.Column) if col.table
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue