Initial open-source release

This commit is contained in:
Andrey Avtomonov 2026-05-10 23:12:26 +02:00
commit 1a42152e6f
1199 changed files with 257054 additions and 0 deletions

View file

@ -0,0 +1,4 @@
from semantic_layer.engine import SemanticEngine
from semantic_layer.models import QueryResult, SemanticQuery
__all__ = ["SemanticEngine", "SemanticQuery", "QueryResult"]

View file

@ -0,0 +1,3 @@
from semantic_layer.cli import main
main()

View file

@ -0,0 +1,268 @@
"""CLI for the semantic layer engine.
Usage:
# Simple query
uv run python -m semantic_layer.cli \
--sources sources/ecommerce \
-q '{"measures": ["sum(orders.amount)"], "dimensions": ["orders.status"]}'
# Pre-defined measure with filter
uv run python -m semantic_layer.cli \
--sources sources/ecommerce \
-q '{"measures": ["orders.revenue"], "dimensions": ["orders.status"]}'
# Cross-source with time granularity
uv run python -m semantic_layer.cli \
--sources sources/ecommerce \
-q '{"measures": ["sum(orders.amount)"], "dimensions": ["regions.name", {"field": "orders.created_at", "granularity": "month"}], "filters": ["regions.name = '"'"'LATAM'"'"'"]}'
# Multiple dialects
uv run python -m semantic_layer.cli \
--sources sources/ecommerce \
-q '{"measures": ["sum(orders.amount)"], "dimensions": ["orders.status"]}' \
--dialect bigquery
# Plan only (no SQL generation)
uv run python -m semantic_layer.cli \
--sources sources/ecommerce \
-q '{"measures": ["sum(orders.amount)"], "dimensions": ["orders.status"]}' \
--plan-only
# JSON input from stdin
echo '{"measures":["sum(orders.amount)"],"dimensions":["orders.status"]}' | \
uv run python -m semantic_layer.cli --sources sources/ecommerce --json
# Custom ORDER BY
uv run python -m semantic_layer.cli \
--sources sources/ecommerce \
-q '{"measures": ["sum(orders.amount)"], "dimensions": ["orders.status"], "order_by": [{"field": "sum(orders.amount)", "direction": "desc"}]}'
# Validate query (suggest fixes on failure)
uv run python -m semantic_layer.cli \
--sources sources/ecommerce \
-q '{"measures": ["sum(orders.amount)"], "dimensions": ["orders.status"]}' \
--suggest
"""
from __future__ import annotations
import argparse
import json
import sys
import yaml
from semantic_layer.engine import SemanticEngine
from semantic_layer.models import SourceDefinition
def build_parser() -> argparse.ArgumentParser:
p = argparse.ArgumentParser(
prog="semantic-layer",
description="Query the semantic layer engine and generate SQL",
)
p.add_argument(
"--sources",
"-s",
help="Path to the sources directory (e.g. sources/ecommerce)",
)
p.add_argument(
"--model",
help="Path to a single YAML file containing all source definitions as a list",
)
p.add_argument(
"--dialect",
"-d",
default="postgres",
help="SQL dialect (postgres, bigquery, snowflake, etc.)",
)
# Query input
p.add_argument(
"--query",
"-q",
help='Raw JSON query string (e.g. \'{"measures": ["orders.revenue"], "dimensions": ["orders.status"]}\')',
)
# Output modes
p.add_argument(
"--json",
action="store_true",
dest="json_input",
help="Read query as JSON from stdin",
)
p.add_argument(
"--plan-only",
action="store_true",
help="Show the resolved plan instead of SQL",
)
p.add_argument(
"--plan",
action="store_true",
help="Show the resolved plan alongside SQL",
)
p.add_argument(
"--compact",
action="store_true",
help="Output SQL without formatting",
)
# Info commands
p.add_argument(
"--list-sources",
action="store_true",
help="List all available sources and exit",
)
p.add_argument(
"--suggest",
action="store_true",
help="Validate the query and suggest fixes if it fails",
)
return p
def list_sources(engine: SemanticEngine) -> None:
for name, src in sorted(engine.sources.items()):
print(f"\n{'' * 40}")
print(f" {name}")
src_type = "sql" if src.is_sql_source else "table"
print(f" type: {src_type}", end="")
if src.table:
print(f" table: {src.table}", end="")
print(f" grain: {src.grain}")
if src.description:
print(f" {src.description.strip()}")
if src.columns:
print(" columns:")
for col in src.columns:
role_tag = f" [{col.role.value}]" if col.role.value != "default" else ""
print(f" {col.name}: {col.type}{role_tag}")
if src.measures:
print(" measures:")
for m in src.measures:
filt = f" (filter: {m.filter})" if m.filter else ""
print(f" {m.name}: {m.expr}{filt}")
if src.joins:
print(" joins:")
for j in src.joins:
print(f"{j.to} ({j.relationship}) on {j.on}")
def print_plan(plan) -> None:
print("\n── Resolved Plan ──")
print(f" Sources: {', '.join(plan.sources_used)}")
print(f" Anchor: {plan.anchor_source}")
if plan.join_paths:
print(" Joins:")
for jp in plan.join_paths:
print(f" {jp}")
print(f" Fan-out: {plan.fan_out_description}")
if plan.aggregate_locality:
print(" Locality:")
for al in plan.aggregate_locality:
print(f" {al}")
if plan.where_filters:
print(f" WHERE: {' AND '.join(plan.where_filters)}")
if plan.having_filters:
print(f" HAVING: {' AND '.join(plan.having_filters)}")
print(" Columns:")
for col in plan.columns:
prov = col.provenance.value
gran = f" ({col.granularity})" if col.granularity else ""
print(f" {col.name} [{prov}]{gran}")
def _load_model_file(path: str) -> dict[str, SourceDefinition]:
"""Load a YAML file containing a list of source definitions."""
with open(path) as f:
data = yaml.safe_load(f)
if not isinstance(data, list):
raise ValueError("Model file must contain a YAML list of source definitions")
sources: dict[str, SourceDefinition] = {}
for item in data:
src = SourceDefinition(**item)
if src.name in sources:
raise ValueError(f"Duplicate source name: '{src.name}'")
sources[src.name] = src
return sources
def main(argv: list[str] | None = None) -> None:
parser = build_parser()
args = parser.parse_args(argv)
if args.model:
sources = _load_model_file(args.model)
engine = SemanticEngine.from_sources(sources, dialect=args.dialect)
elif args.sources:
engine = SemanticEngine(args.sources, dialect=args.dialect)
else:
parser.error("Provide --sources or --model")
# List sources mode
if args.list_sources:
list_sources(engine)
return
# Build query
if args.query:
query_dict = json.loads(args.query)
elif args.json_input:
raw = sys.stdin.read()
query_dict = json.loads(raw)
else:
parser.error("Provide --query or --json")
return
# Suggest mode
if args.suggest:
result = engine.suggest(query_dict)
if result["success"]:
print("Query is valid.")
print_plan(result["plan"])
else:
print(f"Query failed: {result['error']}")
if result.get("graph_errors"):
for err in result["graph_errors"]:
print(f" Graph error: {err}")
for s in result.get("suggestions", []):
if isinstance(s, dict):
print(f" Suggestion: {s.get('description', '')}")
for src in s.get("required_sources", []):
print(f" - Define source: {src}")
for j in s.get("required_joins", []):
print(
f" - Add join: {j['source']}.{j['on']} ({j['relationship']})"
)
for note in s.get("notes", []):
print(f" Note: {note}")
else:
print(f" Suggestion: {s}")
return
# Plan-only mode
if args.plan_only:
plan = engine.plan_only(query_dict)
print_plan(plan)
return
# Full query
result = engine.query(query_dict)
if args.plan:
print_plan(result.resolved_plan)
print()
if args.compact:
print(result.sql)
else:
print(f"-- dialect: {result.dialect}")
print(result.sql)
if __name__ == "__main__":
main()

View file

@ -0,0 +1,99 @@
"""Detect semantically-redundant measure definitions on the same source."""
from __future__ import annotations
import sqlglot
from sqlglot import exp
from semantic_layer.models import SourceDefinition
from semantic_layer.parser import quote_reserved_identifiers
# DIALECT CONVENTION:
# Measure `expr` values are compared structurally. They must be parsed with
# the connection's native dialect (per sl_capture); parsing as postgres
# would drop dialect-specific tokens and miss duplicates across BigQuery
# `SAFE_DIVIDE` / Snowflake `DIV0` etc.
def validate_measure_duplicates(
sources: dict[str, SourceDefinition],
*,
dialect: str = "postgres",
) -> list[str]:
"""
Flag pairs of measures on the same source whose `expr` is structurally
equivalent. Intended to prevent capture-time churn like:
- name: active_subscription_count
expr: count(*)
filter: is_active = true
- name: new_subscription_count
expr: count(*) # same base aggregation — should be query-time filter
Returns a list of human-readable error strings (empty list = no duplicates).
Compares every pair of measures within a single source; does not compare
across sources (measures on different sources are never redundant).
"""
errors: list[str] = []
for source_name, source in sources.items():
if len(source.measures) < 2:
continue
parsed: list[tuple[str, exp.Expression | None, str | None, frozenset[str]]] = []
for m in source.measures:
try:
quoted = quote_reserved_identifiers(m.expr)
tree = sqlglot.parse_one(f"SELECT {quoted}", read=dialect)
expr_node = tree.expressions[0] if tree.expressions else None
except Exception:
# Unparseable expressions are left for the caller's normal
# validation to surface; don't block on parse failure here.
expr_node = None
parsed.append((m.name, expr_node, m.filter, frozenset(m.segments)))
for i, (name_a, expr_a, filter_a, segments_a) in enumerate(parsed):
if expr_a is None:
continue
for name_b, expr_b, filter_b, segments_b in parsed[i + 1 :]:
if expr_b is None:
continue
if not _expressions_equivalent(expr_a, expr_b):
continue
# Segments are named, reusable filter predicates; two measures
# sharing an expr but applying different segments are by design
# distinct and must not be flagged.
if segments_a != segments_b:
continue
fa = (filter_a or "").strip()
fb = (filter_b or "").strip()
if fa == fb:
errors.append(
f"{source_name}: measures '{name_a}' and '{name_b}' have the same "
f"expression and filter — remove one or differentiate them."
)
else:
errors.append(
f"{source_name}: measure '{name_b}' has the same expression as "
f"'{name_a}' — differs only by `filter`. Use query-time filtering "
f"on '{name_a}' (via semantic_query filters), or, if the filter "
f"encodes a named business segment, add a segments[] entry on this "
f"source and reference it instead."
)
return errors
def _expressions_equivalent(a: exp.Expression, b: exp.Expression) -> bool:
"""
Structural equality on sqlglot ASTs.
Normalizes via sqlglot's .sql() canonical form (handles whitespace, case,
aliasing). Does NOT reorder operands `safe_divide(a, b)` is NOT equal to
`safe_divide(b, a)`, nor is `a - b` equal to `b - a`. This is deliberate:
the check's purpose is catching accidental redundancy, not proving
mathematical equivalence.
"""
if type(a) is not type(b):
return False
return a.sql(dialect="postgres") == b.sql(dialect="postgres")

View file

@ -0,0 +1,360 @@
from __future__ import annotations
from semantic_layer.generator import SqlGenerator
from semantic_layer.graph import JoinGraph
from semantic_layer.loader import SourceLoader
from semantic_layer.models import (
QueryResult,
ResolvedPlan,
SemanticQuery,
SourceDefinition,
ValidationReport,
)
from semantic_layer.planner import QueryPlanner
from semantic_layer.sql_table_extractor import (
extract_table_refs,
ref_matches_source_table,
)
class SemanticEngine:
def __init__(self, sources_dir: str, dialect: str = "postgres"):
self.loader = SourceLoader(sources_dir)
self.sources = self.loader.load_all()
self._init_engine(dialect)
@classmethod
def from_sources(
cls, sources: dict[str, SourceDefinition], dialect: str = "postgres"
) -> SemanticEngine:
"""Create engine from pre-loaded source definitions."""
obj = object.__new__(cls)
obj.loader = None
obj.sources = sources
obj._init_engine(dialect)
return obj
def _init_engine(self, dialect: str) -> None:
# Validate the dialect up-front with the user-facing "Unknown SQL
# dialect" error, before JoinGraph.build() hits sqlglot's parser.
SqlGenerator(dialect)
self.graph = JoinGraph(self.sources, dialect=dialect)
self.graph.build()
self.planner = QueryPlanner(self.sources, self.graph, dialect=dialect)
self.generator = SqlGenerator(dialect, alias_map=self.graph.alias_map)
def query(self, query: dict | SemanticQuery) -> QueryResult:
if isinstance(query, dict):
query = SemanticQuery(**query)
orphan_errors = self._collect_orphan_join_target_errors()
if orphan_errors:
raise ValueError("Cannot query semantic layer: " + "; ".join(orphan_errors))
plan = self.planner.plan(query)
sql = self.generator.generate(plan, self.sources)
return QueryResult(
resolved_plan=plan,
sql=sql,
dialect=self.generator.dialect,
columns=plan.columns,
)
def validate(self, recently_touched: set[str] | None = None) -> ValidationReport:
report = ValidationReport()
self._check_orphan_join_targets(report)
self._check_invalid_grain(report)
self._check_sql_join_coverage(report, recently_touched=recently_touched)
self._check_disconnected_components(report, recently_touched=recently_touched)
return report
def _collect_orphan_join_target_errors(self) -> list[str]:
known = set(self.sources.keys())
errors: list[str] = []
for source in self.sources.values():
for join in source.joins:
if join.to not in known:
errors.append(
f"Source '{source.name}' joins to '{join.to}', "
f"but '{join.to}' is not defined"
)
return errors
def _check_orphan_join_targets(self, report: ValidationReport) -> None:
report.errors.extend(self._collect_orphan_join_target_errors())
def _check_invalid_grain(self, report: ValidationReport) -> None:
for source in self.sources.values():
column_names = {c.name for c in source.columns}
for grain_col in source.grain:
if grain_col not in column_names:
report.errors.append(
f"Source '{source.name}' has grain column '{grain_col}' "
f"that is not in its columns list"
)
def _check_sql_join_coverage(
self,
report: ValidationReport,
recently_touched: set[str] | None = None,
) -> None:
"""Block writes whose SQL references a known source's base table
without declaring a join to that source.
Scoped to `recently_touched` so existing fragmentation isn't flagged
on every write. Only sources with `sql:` are checked. CTE
self-references are filtered by the extractor.
"""
if not recently_touched:
return
table_index: list[tuple[SourceDefinition, str]] = [
(src, src.table) for src in self.sources.values() if src.table is not None
]
if not table_index:
return
dialect = getattr(self.generator, "dialect", "postgres")
for source_name in sorted(recently_touched):
source = self.sources.get(source_name)
if source is None or not source.is_sql_source or not source.sql:
continue
declared = {j.to.lower() for j in source.joins}
refs = extract_table_refs(source.sql, dialect=dialect)
missing: list[str] = []
for ref in refs:
hit_name: str | None = None
for candidate, table_value in table_index:
if candidate.name == source.name:
continue
if ref_matches_source_table(ref, table_value):
hit_name = candidate.name
break
if hit_name is None:
continue
if hit_name.lower() in declared:
continue
if hit_name not in missing:
missing.append(hit_name)
if not missing:
continue
ref_list = ", ".join(missing)
example = missing[0]
grain_col = (
self.sources[example].grain[0] if self.sources[example].grain else "id"
)
msg = (
f"Source '{source.name}' SQL joins manifest table(s) [{ref_list}] "
f"that are not declared in joins[]. Add a join entry for each, "
f"e.g. {{to: {example}, on: '{source.name}.<your_fk> = "
f"{example}.{grain_col}', relationship: many_to_one}}. If a "
f"reference is intentionally absent, document it with a "
f"`unmapped-table-*` wiki note and remove the SQL reference."
)
report.errors.append(msg)
def _check_disconnected_components(
self,
report: ValidationReport,
recently_touched: set[str] | None = None,
) -> None:
components = self.graph.find_components()
if len(components) <= 1:
return
sorted_components = sorted(
components, key=lambda c: (-len(c), sorted(c)[0] if c else "")
)
lines = [
f"Model has {len(components)} disconnected components. "
f"Queries that span components will fail with 'No join path' errors:"
]
for i, component in enumerate(sorted_components, start=1):
names = sorted(component)
if len(names) > 3:
sample = ", ".join(names[:2])
lines.append(
f" - Component {i} ({len(names)} sources): {sample}, ... (+{len(names) - 2} more)"
)
else:
lines.append(
f" - Component {i} ({len(names)} sources): {', '.join(names)}"
)
report.warnings.append("\n".join(lines))
if recently_touched:
singleton_components = {next(iter(c)) for c in components if len(c) == 1}
for source_name in sorted(recently_touched & singleton_components):
report.per_source_warnings.setdefault(source_name, []).append(
f"Source '{source_name}' is now a singleton component (no joins to any "
f"other source). Queries that combine '{source_name}' with anything else "
f"will fail with 'No join path' errors. Run sl_discover for each table "
f"named in this source's SQL and add joins via sl_edit_source."
)
def plan_only(self, query: dict | SemanticQuery) -> ResolvedPlan:
if isinstance(query, dict):
query = SemanticQuery(**query)
return self.planner.plan(query)
def suggest(self, query: dict | SemanticQuery) -> dict:
"""Try to plan. If it fails, suggest config extensions with structured info."""
if isinstance(query, dict):
query = SemanticQuery(**query)
try:
plan = self.planner.plan(query)
# Also validate that SQL generation succeeds
try:
self.generator.generate(plan, self.sources)
except Exception as gen_err:
return {
"success": False,
"error": f"SQL generation failed: {gen_err}",
"plan": plan,
"referenced_sources": sorted(set(plan.sources_used)),
"missing_sources": [],
"graph_errors": [],
"suggestions": [
{
"description": f"SQL generation error: {gen_err}",
"required_sources": [],
"required_joins": [],
"notes": [
"The query plan was valid but the SQL generator encountered an error.",
"This may indicate a limitation in the aggregate locality system.",
],
}
],
}
return {
"success": True,
"plan": plan,
"suggestions": [],
}
except Exception as e:
from semantic_layer.parser import ExpressionParser
parser = ExpressionParser()
# Collect all source references from the query
referenced_sources: set[str] = set()
all_exprs: list[str] = []
for m in query.measures:
if isinstance(m, str):
all_exprs.append(m)
elif isinstance(m, dict):
all_exprs.append(m.get("expr", ""))
for d in query.dimensions:
if isinstance(d, str):
all_exprs.append(d)
elif isinstance(d, dict):
all_exprs.append(d.get("field", ""))
all_exprs.extend(query.filters)
for expr in all_exprs:
referenced_sources.update(parser.extract_source_refs(expr))
# Identify missing sources
known_sources = set(self.sources.keys())
missing_sources = sorted(referenced_sources - known_sources)
graph_errors = _format_component_errors(self.graph.find_components())
suggestions = []
if missing_sources:
# Suggest source definitions for missing sources
required_joins = []
for ms in missing_sources:
# Infer potential join targets from column naming (e.g. orders → orders.id)
for known_name, known_src in self.sources.items():
candidate_fk = f"{known_name}_id"
# Check if the missing source might join to this known source
if any(c.name == candidate_fk for c in known_src.columns):
required_joins.append(
{
"source": known_name,
"to": ms,
"on": f"{candidate_fk} = {ms}.id",
"relationship": "many_to_one",
}
)
suggestions.append(
{
"description": f"Define missing source(s): {', '.join(missing_sources)}",
"required_sources": missing_sources,
"required_joins": required_joins,
"notes": [
f"Create YAML definition(s) for: {', '.join(missing_sources)}",
"Each source needs at minimum: name, table (or sql), grain, and columns",
],
}
)
if not missing_sources and len(referenced_sources) > 1:
# Identify which specific pairs are disconnected
present_sources = sorted(referenced_sources & known_sources)
disconnected_pairs = []
for i, src_a in enumerate(present_sources):
for src_b in present_sources[i + 1 :]:
path = self.graph.find_path(src_a, src_b)
if path is None:
disconnected_pairs.append((src_a, src_b))
required_joins = []
for src_a, src_b in disconnected_pairs:
required_joins.append(
{
"source": src_a,
"to": src_b,
"on": f"{src_b}_id = {src_b}.id",
"relationship": "many_to_one",
}
)
suggestions.append(
{
"description": f"Add join path(s) connecting: {', '.join(present_sources)}",
"required_sources": [],
"required_joins": required_joins,
"notes": [
f"Disconnected pairs: {[f'{a}{b}' for a, b in disconnected_pairs]}"
if disconnected_pairs
else "Sources are connected but query failed for another reason",
]
if disconnected_pairs
else [
"All sources are connected; check the error message for details",
],
}
)
return {
"success": False,
"error": str(e),
"referenced_sources": sorted(referenced_sources),
"missing_sources": missing_sources,
"graph_errors": graph_errors,
"suggestions": suggestions,
}
def _format_component_errors(components: list[set[str]]) -> list[str]:
"""Render multi-component topology as graph_error strings for `suggest()` / CLI."""
if len(components) <= 1:
return []
sorted_components = sorted(
components, key=lambda c: (-len(c), sorted(c)[0] if c else "")
)
lines = []
for i, component in enumerate(sorted_components, start=1):
names = sorted(component)
if len(names) > 3:
sample = ", ".join(names[:2])
lines.append(
f"Component {i} ({len(names)} sources): {sample}, ... (+{len(names) - 2} more)"
)
else:
lines.append(f"Component {i} ({len(names)} sources): {', '.join(names)}")
return [f"Disconnected components: {len(components)}"] + lines

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,285 @@
from __future__ import annotations
import heapq
import logging
from dataclasses import dataclass, field
from semantic_layer.models import SourceDefinition
# DIALECT CONVENTION:
# YAML-authored join `on:` clauses may contain dialect-specific casts
# (e.g. BigQuery `SAFE_CAST(x AS INT64)`). `_parse_on` parses them with
# `read=self.dialect` so the AST reflects the author's intent.
logger = logging.getLogger(__name__)
RELATIONSHIP_INVERSE = {
"many_to_one": "one_to_many",
"one_to_many": "many_to_one",
"one_to_one": "one_to_one",
}
@dataclass
class JoinEdge:
from_source: str
to_source: str
from_column: str
to_column: str
relationship: str
alias: str | None = None
@dataclass
class JoinPath:
edges: list[JoinEdge]
has_one_to_many: bool = False
is_ambiguous: bool = False
@property
def source_names(self) -> list[str]:
if not self.edges:
return []
names = [self.edges[0].from_source]
for e in self.edges:
names.append(e.to_source)
return names
@dataclass
class JoinTree:
edges: list[JoinEdge] = field(default_factory=list)
sources: set[str] = field(default_factory=set)
has_one_to_many: bool = False
class JoinGraph:
def __init__(
self,
sources: dict[str, SourceDefinition],
*,
dialect: str = "postgres",
):
self.sources = sources
self.dialect = dialect
self.adjacency: dict[str, list[JoinEdge]] = {}
def build(self) -> None:
# alias_name → actual source name
self.alias_map: dict[str, str] = {}
for name in self.sources:
self.adjacency.setdefault(name, [])
for source in self.sources.values():
for join in source.joins:
from_col, to_col = self._parse_on(join.on, join.to)
target_name = join.alias if join.alias else join.to
if join.alias:
self.alias_map[join.alias] = join.to
# Forward edge: source → alias (or target)
fwd = JoinEdge(
from_source=source.name,
to_source=target_name,
from_column=from_col,
to_column=to_col,
relationship=join.relationship,
alias=join.alias,
)
self.adjacency.setdefault(target_name, [])
self.adjacency[source.name].append(fwd)
# Reverse edge: alias (or target) → source
rev = JoinEdge(
from_source=target_name,
to_source=source.name,
from_column=to_col,
to_column=from_col,
relationship=RELATIONSHIP_INVERSE[join.relationship],
alias=join.alias,
)
self.adjacency[target_name].append(rev)
def find_path(self, from_source: str, to_source: str) -> JoinPath | None:
"""Dijkstra shortest path between two sources.
Also detects ambiguity: if multiple equal-cost paths exist to the
destination, the returned ``JoinPath`` has ``is_ambiguous=True``.
"""
if from_source == to_source:
return JoinPath(edges=[], has_one_to_many=False)
if from_source not in self.adjacency or to_source not in self.adjacency:
return None
# (cost, counter, current_node, path_edges)
counter = 0
heap: list[tuple[int, int, str, list[JoinEdge]]] = [
(0, counter, from_source, [])
]
visited: set[str] = set()
first_path: JoinPath | None = None
first_cost: int | None = None
while heap:
cost, _, current, path = heapq.heappop(heap)
# All equal-cost alternatives exhausted — stop.
if first_cost is not None and cost > first_cost:
break
if current == to_source:
has_o2m = any(e.relationship == "one_to_many" for e in path)
if first_path is None:
first_path = JoinPath(edges=path, has_one_to_many=has_o2m)
first_cost = cost
continue # don't visit dest — keep looking for alternatives
else:
first_path.is_ambiguous = True
return first_path
if current in visited:
continue
visited.add(current)
for edge in self.adjacency.get(current, []):
if edge.to_source not in visited:
counter += 1
# Prefer safe (many_to_one / one_to_one) paths over one_to_many
edge_cost = (
1 if edge.relationship in ("many_to_one", "one_to_one") else 10
)
heapq.heappush(
heap, (cost + edge_cost, counter, edge.to_source, path + [edge])
)
return first_path
def resolve_join_tree(
self, source_names: set[str], root: str | None = None
) -> JoinTree:
"""
Steiner tree approximation: pick root source,
find shortest path to each other source, merge paths.
"""
if len(source_names) <= 1:
return JoinTree(sources=source_names)
if root is not None and root in source_names:
names = [root] + sorted(source_names - {root})
else:
names = sorted(source_names)
root = names[0]
tree = JoinTree(sources={root})
for target in names[1:]:
if target in tree.sources:
continue
path = self.find_path(root, target)
if path is not None and path.is_ambiguous:
logger.warning(
"Ambiguous join path from '%s' to '%s': multiple equal-cost "
"paths exist. The engine picked one arbitrarily. Use join "
"aliases to disambiguate.",
root,
target,
)
if path is None:
raise ValueError(
f"No join path from '{root}' to '{target}'. "
f"These sources are not connected in the join graph."
)
for edge in path.edges:
if not any(
e.from_source == edge.from_source and e.to_source == edge.to_source
for e in tree.edges
):
tree.edges.append(edge)
if edge.relationship == "one_to_many":
tree.has_one_to_many = True
tree.sources.add(edge.from_source)
tree.sources.add(edge.to_source)
return tree
def find_components(self) -> list[set[str]]:
"""Partition the graph into connected components.
Returns one set per component. For an empty graph, returns []. For a
fully connected graph, returns a single-element list. Used both for
validation (multi-component warning) and for suggest().
Aliases and their base source are treated as belonging to the same
component, since alias-scoped queries resolve back to the base table.
"""
# Bidirectional alias↔base adjacency so BFS treats them as one node
alias_neighbors: dict[str, list[str]] = {}
for alias, base in self.alias_map.items():
alias_neighbors.setdefault(alias, []).append(base)
alias_neighbors.setdefault(base, []).append(alias)
components: list[set[str]] = []
unvisited = set(self.adjacency)
while unvisited:
start = next(iter(unvisited))
component: set[str] = set()
queue = [start]
while queue:
node = queue.pop()
if node in component:
continue
component.add(node)
for edge in self.adjacency.get(node, []):
if edge.to_source not in component:
queue.append(edge.to_source)
for neighbor in alias_neighbors.get(node, []):
if neighbor not in component:
queue.append(neighbor)
components.append(component)
unvisited -= component
return components
def _parse_on(self, on_clause: str, target_source: str) -> tuple[str, str]:
"""
Parse join conditions into (from_columns, to_columns) using sqlglot AST.
Single key: "customer_id = customers.id" ("customer_id", "id")
Composite: "a = t.x AND b = t.y" ("a,b", "x,y")
Composite keys are stored as comma-separated strings.
"""
import sqlglot
from sqlglot import exp as _exp
from semantic_layer.parser import quote_reserved_identifiers
quoted = quote_reserved_identifiers(on_clause)
tree = sqlglot.parse_one(
f"SELECT 1 FROM _a JOIN _b ON {quoted}", read=self.dialect
)
from_cols: list[str] = []
to_cols: list[str] = []
for eq_node in tree.find_all(_exp.EQ):
left = eq_node.left
right = eq_node.right
# Reject nested equality (e.g., "a = b = c")
if isinstance(left, _exp.EQ) or isinstance(right, _exp.EQ):
raise ValueError(f"Invalid join condition: '{on_clause}'")
# Extract column name, stripping any source qualifier
def _col_name(node: _exp.Expression) -> str:
if isinstance(node, _exp.Column):
return node.name
return node.sql(dialect="postgres")
from_cols.append(_col_name(left))
to_cols.append(_col_name(right))
if not from_cols:
raise ValueError(f"Invalid join condition: '{on_clause}'")
return ",".join(from_cols), ",".join(to_cols)

View file

@ -0,0 +1,210 @@
from __future__ import annotations
import logging
import re
from copy import deepcopy
from pathlib import Path
import yaml
from semantic_layer.manifest import (
Manifest,
_description_sources,
_resolve_description,
project_manifest_entry,
validate_overlay,
)
from semantic_layer.models import (
JoinDeclaration,
MeasureDefinition,
Segment,
SourceColumn,
SourceDefinition,
)
logger = logging.getLogger(__name__)
_SCHEMA_DIR = "_schema"
def _normalize_ws(s: str) -> str:
"""Collapse whitespace for join deduplication."""
return re.sub(r"\s+", " ", s.strip())
class SourceLoader:
def __init__(self, sources_dir: str | Path):
self.sources_dir = Path(sources_dir)
def load_all(self) -> dict[str, SourceDefinition]:
"""Load all sources using two-tier architecture.
1. Load _schema/*.yaml manifest shards project to SourceDefinitions
2. Load *.yaml files outside _schema/
- Has `sql` or `table` standalone source (load directly)
- Otherwise overlay (compose with matching manifest entry)
3. Validate cross-references
"""
sources: dict[str, SourceDefinition] = {}
description_sources: dict[str, dict[str, str] | None] = {}
# 1. Load manifest shards
schema_dir = self.sources_dir / _SCHEMA_DIR
if schema_dir.is_dir():
for path in sorted(schema_dir.glob("*.yaml")):
manifest = self._load_manifest_shard(path)
for name, entry in manifest.tables.items():
if name in sources:
raise ValueError(
f"Duplicate source name '{name}' in manifest shard {path}"
)
sources[name] = project_manifest_entry(name, entry)
description_sources[name] = _description_sources(
entry.descriptions, entry.description, entry.db_description
)
# 2. Load files outside _schema/
for path in sorted(self.sources_dir.rglob("*.yaml")):
# Skip manifest shards
if _is_in_schema_dir(path, self.sources_dir):
continue
with open(path) as f:
data = yaml.safe_load(f)
if not isinstance(data, dict):
continue
name = data.get("name")
if not name:
continue
if data.get("sql") or data.get("table"):
# Standalone source — load directly
if name in sources:
raise ValueError(
f"Duplicate source name '{name}': standalone file {path} "
f"conflicts with manifest entry"
)
sources[name] = SourceDefinition(**data)
else:
# Overlay — validate and compose with matching manifest entry
errors = validate_overlay(data)
if errors:
raise ValueError(
f"Invalid overlay '{name}' in {path}: {'; '.join(errors)}"
)
base = sources.get(name)
if base:
(
sources[name],
description_sources[name],
) = self._compose(base, data, description_sources.get(name))
else:
logger.warning(
"Orphan overlay '%s' in %s: no matching manifest entry, skipping",
name,
path,
)
self._validate_cross_references(sources)
return sources
def load_file(self, path: str | Path) -> SourceDefinition:
"""Load and validate a single standalone YAML source definition."""
path = Path(path)
with open(path) as f:
data = yaml.safe_load(f)
source = SourceDefinition(**data)
if not source.table and not source.sql:
raise ValueError(
f"Standalone source '{source.name}' in {path} must have 'table' or 'sql'"
)
return source
def _load_manifest_shard(self, path: Path) -> Manifest:
"""Load a single manifest shard file."""
with open(path) as f:
data = yaml.safe_load(f)
return Manifest(**data)
def _compose(
self,
base: SourceDefinition,
overlay: dict,
base_description_sources: dict[str, str] | None = None,
) -> tuple[SourceDefinition, dict[str, str] | None]:
"""Compose a manifest-projected SourceDefinition with an overlay."""
source = deepcopy(base)
description_sources = dict(base_description_sources or {})
# Overlay description semantics match the server: `description` writes the
# `user` source key, and `descriptions` merges keyed sources before a single
# visible description is resolved from the full map.
if overlay.get("description"):
description_sources["user"] = overlay["description"]
if overlay.get("descriptions"):
description_sources.update(
{
source_name: text
for source_name, text in overlay["descriptions"].items()
if text
}
)
if overlay.get("description") or overlay.get("descriptions"):
source.description = _resolve_description(
description_sources or None,
)
# Filter columns
excluded = set(overlay.get("exclude_columns", []))
source.columns = [c for c in source.columns if c.name not in excluded]
# Append computed columns (overlay columns with expr)
for col in overlay.get("columns", []):
source.columns.append(SourceColumn(**col))
# Set measures
source.measures = [MeasureDefinition(**m) for m in overlay.get("measures", [])]
# Set segments
source.segments = [Segment(**s) for s in overlay.get("segments", [])]
# Override grain
if overlay.get("grain"):
source.grain = overlay["grain"]
# Union + dedupe joins, apply suppressions
disabled = {_normalize_ws(j) for j in overlay.get("disable_joins", [])}
manifest_joins = [
j for j in source.joins if _normalize_ws(j.on) not in disabled
]
overlay_joins = [JoinDeclaration(**j) for j in overlay.get("joins", [])]
existing_keys = {f"{j.to}::{_normalize_ws(j.on)}" for j in manifest_joins}
new_joins = [
j
for j in overlay_joins
if f"{j.to}::{_normalize_ws(j.on)}" not in existing_keys
]
source.joins = manifest_joins + new_joins
return source, (description_sources or None)
def _validate_cross_references(self, sources: dict[str, SourceDefinition]) -> None:
"""Validate that all join targets reference existing sources."""
for source in sources.values():
for join in source.joins:
if join.to not in sources:
raise ValueError(
f"Source '{source.name}' joins to '{join.to}', "
f"but '{join.to}' is not defined"
)
def _is_in_schema_dir(path: Path, sources_dir: Path) -> bool:
"""Check if a path is inside the _schema/ directory."""
try:
path.relative_to(sources_dir / _SCHEMA_DIR)
return True
except ValueError:
return False

View file

@ -0,0 +1,233 @@
"""Manifest models and projection for the two-tier schema architecture.
The manifest (`_schema/*.yaml`) stores physical table catalog data with DB-native
types, PK flags, and join provenance. This module handles:
- Manifest-specific data models (ManifestColumn, ManifestJoin, ManifestEntry)
- DB-native semantic type mapping
- Projection from ManifestEntry SourceDefinition
"""
from __future__ import annotations
from typing import Literal
from pydantic import BaseModel
from semantic_layer.models import (
ColumnRole,
DefaultTimeDimensionDbt,
FreshnessDbt,
JoinDeclaration,
SourceColumn,
SourceColumnTests,
SourceDefinition,
)
# ── Type mapping (DB-native → semantic) ─────────────────────────────
_TYPE_MAP: dict[str, str] = {
# number family
"integer": "number",
"bigint": "number",
"smallint": "number",
"numeric": "number",
"decimal": "number",
"float": "number",
"double": "number",
"real": "number",
"int": "number",
"int2": "number",
"int4": "number",
"int8": "number",
"float4": "number",
"float8": "number",
"double precision": "number",
"number": "number",
"tinyint": "number",
"mediumint": "number",
# time family
"timestamp": "time",
"timestamptz": "time",
"timestamp with time zone": "time",
"timestamp without time zone": "time",
"timestamp_ntz": "time",
"timestamp_ltz": "time",
"timestamp_tz": "time",
"datetime": "time",
"date": "time",
"time": "time",
"timetz": "time",
# boolean family
"boolean": "boolean",
"bool": "boolean",
# fallback → 'string'
}
def map_column_type(db_type: str) -> str:
"""Map a DB-native column type to a semantic type (string/number/time/boolean)."""
normalized = db_type.lower().split("(")[0].strip()
return _TYPE_MAP.get(normalized, "string")
# ── Manifest data models ────────────────────────────────────────────
_DEFAULT_PRIORITY = ["user", "ai", "dbt", "db"]
def _description_sources(
descriptions: dict[str, str] | None,
description: str | None = None,
db_description: str | None = None,
) -> dict[str, str] | None:
"""Normalize multi-source descriptions to a keyed map."""
if descriptions:
result = {source: text for source, text in descriptions.items() if text}
if result:
return result
result: dict[str, str] = {}
if description:
result["ai"] = description
if db_description:
result["db"] = db_description
return result or None
def _resolve_description(
descriptions: dict[str, str] | None,
description: str | None = None,
db_description: str | None = None,
) -> str | None:
"""Resolve a single description from a multi-source map or legacy flat fields."""
if descriptions:
for source in _DEFAULT_PRIORITY:
if text := descriptions.get(source):
return text
# Fallback: first available
for text in descriptions.values():
if text:
return text
# Legacy flat fields
if description:
return description
if db_description:
return db_description
return None
class ManifestColumn(BaseModel):
name: str
type: str # DB-native type (e.g., "integer", "varchar", "timestamp")
pk: bool = False
nullable: bool = True
descriptions: dict[str, str] | None = None
# Legacy flat fields (backwards-compatible YAML parsing)
description: str | None = None
db_description: str | None = None
constraints: dict | None = None
enum_values: dict[str, list[str]] | None = None
tests: SourceColumnTests | None = None
@property
def resolved_description(self) -> str | None:
return _resolve_description(
self.descriptions, self.description, self.db_description
)
class ManifestJoin(BaseModel):
to: str
on: str
relationship: Literal["many_to_one", "one_to_many", "one_to_one"]
source: Literal["formal", "inferred", "manual"] = "formal"
class ManifestEntry(BaseModel):
table: str
descriptions: dict[str, str] | None = None
# Legacy flat fields (backwards-compatible YAML parsing)
description: str | None = None
db_description: str | None = None
columns: list[ManifestColumn]
joins: list[ManifestJoin] = []
default_time_dimension: DefaultTimeDimensionDbt | None = None
tags: dict[str, list[str]] | None = None
freshness: dict[str, FreshnessDbt] | None = None
@property
def resolved_description(self) -> str | None:
return _resolve_description(
self.descriptions, self.description, self.db_description
)
class Manifest(BaseModel):
"""A single manifest shard file (`_schema/{schema}.yaml`)."""
tables: dict[str, ManifestEntry]
# ── Projection ──────────────────────────────────────────────────────
def validate_overlay(data: dict) -> list[str]:
"""Validate that overlay data doesn't contain structural fields.
Returns a list of error messages (empty if valid).
"""
errors: list[str] = []
if "table" in data:
errors.append("Overlay must not contain 'table' (owned by manifest)")
if "sql" in data:
errors.append(
"Overlay must not contain 'sql' (that makes it a standalone source)"
)
for col in data.get("columns", []):
if "type" in col and "expr" not in col:
errors.append(
f"Overlay column '{col.get('name', '?')}' specifies 'type' without 'expr' "
f"(structural types are inherited from manifest — only computed columns may specify a type)"
)
return errors
def project_manifest_entry(name: str, entry: ManifestEntry) -> SourceDefinition:
"""Convert a raw manifest entry into a valid SourceDefinition.
- Maps DB-native column types to semantic types
- Auto-derives grain from PK columns (or all columns if no PKs)
- Strips join provenance (source field)
"""
columns = [
SourceColumn(
name=c.name,
type=map_column_type(c.type),
role=ColumnRole.TIME
if map_column_type(c.type) == "time"
else ColumnRole.DEFAULT,
description=c.resolved_description,
constraints=c.constraints,
enum_values=c.enum_values,
tests=c.tests,
)
for c in entry.columns
]
pk_columns = [c.name for c in entry.columns if c.pk]
grain = pk_columns if pk_columns else [c.name for c in entry.columns]
return SourceDefinition(
name=name,
table=entry.table,
description=entry.resolved_description,
grain=grain,
columns=columns,
joins=[
JoinDeclaration(to=j.to, on=j.on, relationship=j.relationship)
for j in entry.joins
],
default_time_dimension=entry.default_time_dimension,
tags=entry.tags,
freshness=entry.freshness,
)

View file

@ -0,0 +1,235 @@
from __future__ import annotations
from enum import Enum
from typing import Any, Literal
from pydantic import BaseModel, Field, model_validator
# ── Source Definition Models ──────────────────────────────────────────
class ColumnVisibility(str, Enum):
PUBLIC = "public"
INTERNAL = "internal"
HIDDEN = "hidden"
class ColumnRole(str, Enum):
TIME = "time"
DEFAULT = "default"
class ColumnDbtConstraints(BaseModel):
not_null: bool | None = None
unique: bool | None = None
class DbtDataTestRef(BaseModel):
name: str
package: str
kwargs: dict[str, Any] | None = None
class SourceColumnTests(BaseModel):
dbt: list[DbtDataTestRef] | None = None
dbt_by_package: dict[str, list[str]] | None = None
class FreshnessDbt(BaseModel):
raw: Any | None = None
loaded_at_field: str | None = None
class SourceColumn(BaseModel):
name: str
type: Literal["string", "number", "time", "boolean"]
visibility: ColumnVisibility = ColumnVisibility.PUBLIC
role: ColumnRole = ColumnRole.DEFAULT
description: str | None = None
expr: str | None = None
natural_granularity: str | None = None
constraints: dict[str, ColumnDbtConstraints] | None = None
enum_values: dict[str, list[str]] | None = None
tests: SourceColumnTests | None = None
class JoinDeclaration(BaseModel):
to: str
on: str # e.g. "customer_id = customers.id"
relationship: Literal["many_to_one", "one_to_many", "one_to_one"]
alias: str | None = None
class MeasureDefinition(BaseModel):
name: str
expr: str # e.g. "sum(amount)"
filter: str | None = None # e.g. "status != 'refunded'"
segments: list[str] = [] # bare segment names defined on the measure's own source
description: str | None = None
class Segment(BaseModel):
"""A named, reusable boolean predicate scoped to a single source."""
name: str
expr: str # e.g. "is_paid = true and is_refunded = '0'"
description: str | None = None
class DefaultTimeDimensionDbt(BaseModel):
dbt: str | None = None
class SourceDefinition(BaseModel):
name: str
description: str | None = None
table: str | None = None
sql: str | None = None
grain: list[str]
columns: list[SourceColumn]
joins: list[JoinDeclaration] = []
measures: list[MeasureDefinition] = []
segments: list[Segment] = []
default_time_dimension: DefaultTimeDimensionDbt | None = None
tags: dict[str, list[str]] | None = None
freshness: dict[str, FreshnessDbt] | None = None
@model_validator(mode="after")
def validate_source(self) -> SourceDefinition:
if self.table and self.sql:
raise ValueError("'table' and 'sql' are mutually exclusive")
if not self.grain:
raise ValueError("grain must be non-empty")
return self
@property
def is_sql_source(self) -> bool:
return self.sql is not None
@property
def is_table_source(self) -> bool:
return self.table is not None
# ── Query Models ──────────────────────────────────────────────────────
class QueryMeasure(BaseModel):
"""Either a pre-defined name ('orders.revenue') or runtime expr."""
ref: str | None = None
expr: str | None = None
name: str | None = None
class QueryDimension(BaseModel):
"""Either a column ref or a time granularity."""
field: str
granularity: str | None = None
class SemanticQuery(BaseModel):
measures: list[str | dict[str, Any]]
dimensions: list[str | dict[str, Any]] = []
filters: list[str] = []
# dotted "source.segment" names; AND-ed into matching measures
segments: list[str] = []
order_by: list[str | dict[str, Any]] = []
limit: int = 1000
include_empty: bool = True
@model_validator(mode="after")
def _validate_limit(self) -> SemanticQuery:
if self.limit is not None and self.limit < 0:
raise ValueError(f"limit must be non-negative, got {self.limit}")
return self
# ── Plan & Result Models ──────────────────────────────────────────────
class Provenance(str, Enum):
VERIFIED = "verified"
COMPOSED = "composed"
DIMENSION = "dimension"
class ResolvedColumn(BaseModel):
name: str
provenance: Provenance
expr: str | None = None
description: str | None = None
granularity: str | None = None
class ResolvedMeasure(BaseModel):
name: str
expr: str # the aggregate expression, e.g. "sum(amount)"
source_name: str
original_name: str | None = None
qualified_ref: str | None = None
filter: str | None = None
provenance: Provenance = Provenance.COMPOSED
is_derived: bool = False
depends_on: list[str] = [] # names of other measures this depends on
description: str | None = None
class MeasureGroup(BaseModel):
"""A group of measures from the same source, for aggregate locality."""
source_name: str
measures: list[ResolvedMeasure]
join_path_to_dims: list[str] = []
class ResolvedJoin(BaseModel):
from_source: str
to_source: str
from_column: str
to_column: str
relationship: str
class OrderByClause(BaseModel):
field: str
direction: str = "asc"
class ResolvedPlan(BaseModel):
sources_used: list[str]
join_paths: list[str] # human-readable descriptions
joins: list[ResolvedJoin] = [] # structured join info for generator
anchor_source: str | None = None # the primary FROM source
anchor_grain: list[str]
fan_out_description: str
has_fan_out: bool = False
measure_groups: list[MeasureGroup] = []
aggregate_locality: list[str] # human-readable CTE descriptions
where_filters: list[str]
having_filters: list[str]
columns: list[ResolvedColumn]
measures: list[ResolvedMeasure] = []
dimensions: list[QueryDimension] = []
order_by: list[OrderByClause] = []
limit: int | None = None
include_empty: bool = True
class QueryResult(BaseModel):
resolved_plan: ResolvedPlan
sql: str
dialect: str
columns: list[ResolvedColumn]
class ValidationReport(BaseModel):
errors: list[str] = Field(default_factory=list)
warnings: list[str] = Field(default_factory=list)
per_source_warnings: dict[str, list[str]] = Field(default_factory=dict)
@property
def valid(self) -> bool:
return len(self.errors) == 0

View file

@ -0,0 +1,303 @@
from __future__ import annotations
import functools
import re
from dataclasses import dataclass, field
import sqlglot
from sqlglot import exp
# DIALECT CONVENTION:
# `ExpressionParser` wraps read-only AST walks over user-authored
# expressions. Callers must construct it with the connection's native
# dialect (per sl_capture). The parse cache is keyed on (sql, dialect)
# so engines with different dialects do not share AST collisions.
AGGREGATE_FUNCTIONS = frozenset(
{
"sum",
"avg",
"count",
"count_distinct",
"min",
"max",
"median",
"percentile",
}
)
# Maps sqlglot AggFunc subclasses to our canonical names
_AGG_NODE_MAP: dict[type, str] = {
exp.Sum: "sum",
exp.Avg: "avg",
exp.Count: "count",
exp.Min: "min",
exp.Max: "max",
}
# Custom aggregates that sqlglot parses as Anonymous (not standard SQL)
_CUSTOM_AGG_NAMES = frozenset({"count_distinct", "percentile", "median"})
# SQL reserved words that cause parse failures when used as identifiers
_SQL_RESERVED = frozenset(
{
"select",
"from",
"where",
"group",
"order",
"by",
"having",
"limit",
"join",
"on",
"as",
"and",
"or",
"not",
"in",
"is",
"null",
"true",
"false",
"between",
"like",
"case",
"when",
"then",
"else",
"end",
"insert",
"update",
"delete",
"create",
"drop",
"alter",
"table",
"index",
"view",
"union",
"all",
"distinct",
"into",
"values",
"set",
"with",
"exists",
"any",
"some",
"offset",
"fetch",
"for",
"grant",
"revoke",
"primary",
"key",
"foreign",
"references",
"check",
"constraint",
"default",
"column",
"cross",
"full",
"inner",
"left",
"right",
"outer",
"natural",
"using",
"except",
"intersect",
# Snowflake / cross-dialect reserved words
"glob",
"ilike",
"lateral",
"match_recognize",
"notnull",
"out",
"qualify",
"regexp",
"returning",
"rlike",
"rollback",
"sample",
"tablesample",
"top",
"uncache",
"xor",
}
)
# Regex pattern for source.column references (word.word)
_DOTTED_IDENT_RE = re.compile(r"\b(\w+)\.(\w+)\b")
# Matches single-quoted SQL string literals (including escaped quotes '')
_STRING_LITERAL_RE = re.compile(r"'(?:[^']|'')*'")
@dataclass
class ParsedExpression:
original: str
source_refs: set[str] = field(default_factory=set)
column_refs: set[str] = field(default_factory=set) # "source.column" format
is_aggregate: bool = False
aggregate_function: str | None = None
has_window_function: bool = False
depends_on_measures: set[str] = field(default_factory=set)
def _strip_quotes(name: str) -> str:
"""Strip surrounding double quotes from an identifier."""
if name.startswith('"') and name.endswith('"'):
return name[1:-1]
return name
def quote_reserved_identifiers(expr: str) -> str:
"""Quote source.column references where either part is a SQL reserved word.
String literals are masked before processing to prevent matching
dotted identifiers inside quoted strings like 'group.value'.
"""
# Mask string literals to avoid matching inside them
literals: list[str] = []
def _mask_literal(m: re.Match) -> str:
literals.append(m.group(0))
return f"__SL_LIT_{len(literals) - 1}__"
masked = _STRING_LITERAL_RE.sub(_mask_literal, expr)
def _quote_match(m: re.Match) -> str:
source, col = m.group(1), m.group(2)
start = m.start()
if start > 0 and masked[start - 1] == '"':
return m.group(0)
needs_quote = False
source_q = source
col_q = col
if source.lower() in _SQL_RESERVED:
source_q = f'"{source}"'
needs_quote = True
if col.lower() in _SQL_RESERVED:
col_q = f'"{col}"'
needs_quote = True
if needs_quote:
return f"{source_q}.{col_q}"
return m.group(0)
result = _DOTTED_IDENT_RE.sub(_quote_match, masked)
# Restore string literals
for i, lit in enumerate(literals):
result = result.replace(f"__SL_LIT_{i}__", lit)
return result
@functools.lru_cache(maxsize=256)
def _cached_parse_select(sql: str, dialect: str) -> exp.Expression:
"""Cache parsed SELECT wrapper trees keyed by (sql, dialect).
Each (sql, dialect) pair gets its own entry, so engines using different
dialects don't share AST cache collisions.
"""
return sqlglot.parse_one(sql, read=dialect)
class ExpressionParser:
"""Parses user-authored SQL expressions for AST walks.
Must be constructed with the connection's native dialect. User-authored
`expr:`, `filter:`, and segment predicates from YAML are written in that
dialect (per the sl_capture skill contract) and parsing them as postgres
silently drops dialect-specific tokens (e.g. BigQuery `INTERVAL 30 DAY`).
"""
def __init__(self, dialect: str = "postgres") -> None:
self.dialect = dialect
def _quote_reserved_identifiers(self, expr: str) -> str:
return quote_reserved_identifiers(expr)
def _parse_as_select(self, quoted_expr: str) -> exp.Expression:
"""Parse expression wrapped in SELECT, using cache for repeated expressions."""
return _cached_parse_select(f"SELECT {quoted_expr}", self.dialect)
def parse(
self,
expr: str,
known_measure_names: set[str] | None = None,
) -> ParsedExpression:
known_measure_names = known_measure_names or set()
result = ParsedExpression(original=expr)
if not expr or not expr.strip():
return result
quoted_expr = self._quote_reserved_identifiers(expr)
tree = self._parse_as_select(quoted_expr)
# Extract source.column references
for col in tree.find_all(exp.Column):
if col.table:
source_name = _strip_quotes(col.table)
col_name = _strip_quotes(col.name)
result.source_refs.add(source_name)
result.column_refs.add(f"{source_name}.{col_name}")
# Detect aggregate functions (built-in AggFunc subclasses).
# Aggregates nested inside scalar/correlated subqueries do NOT make the
# outer expression aggregate — e.g. `col = (SELECT MAX(col) FROM t)` is a
# plain column predicate, not a HAVING candidate.
def _inside_subquery(node: exp.Expression) -> bool:
parent = node.parent
while parent is not None:
if isinstance(parent, exp.Subquery):
return True
parent = parent.parent
return False
agg_names: list[str] = []
for node in tree.find_all(exp.AggFunc):
if _inside_subquery(node):
continue
name = _AGG_NODE_MAP.get(type(node))
if name:
agg_names.append(name)
else:
agg_names.append(node.key.lower())
# Detect custom aggregates parsed as Anonymous (count_distinct, percentile, median)
for node in tree.find_all(exp.Anonymous):
if _inside_subquery(node):
continue
if node.name.lower() in _CUSTOM_AGG_NAMES:
agg_names.append(node.name.lower())
if agg_names:
result.is_aggregate = True
result.aggregate_function = agg_names[0]
# Detect window functions (OVER clause)
if tree.find(exp.Window):
result.has_window_function = True
# Detect dependencies on named measures (bare identifiers without table qualifier)
if known_measure_names:
for col in tree.find_all(exp.Column):
if not col.table and col.name in known_measure_names:
result.depends_on_measures.add(col.name)
return result
def extract_source_refs(self, expr: str) -> set[str]:
"""Quick extraction of source names from an expression."""
if not expr or not expr.strip():
return set()
quoted_expr = self._quote_reserved_identifiers(expr)
tree = self._parse_as_select(quoted_expr)
return {
_strip_quotes(col.table) for col in tree.find_all(exp.Column) if col.table
}

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,72 @@
from __future__ import annotations
import logging
import sqlglot
from sqlglot import exp
logger = logging.getLogger(__name__)
def extract_table_refs(sql: str, dialect: str = "postgres") -> list[tuple[str, ...]]:
"""Return a deduped list of warehouse-table refs found in `sql` as
tuples of normalized (lowercase, unquoted) name parts.
Skips CTE self-references. Returns refs in the order they first appear
so callers can present consistent error messages. Each tuple is the
fully-qualified name as written in the SQL: `("staging", "shipments")`,
`("analytics", "marts", "listings")`, or `("listings",)`.
On parse failure returns []; coverage check is best-effort and must
not break source writes when the SQL has unusual syntax.
"""
try:
tree = sqlglot.parse_one(sql, dialect=dialect)
except Exception as e:
logger.debug("sql_table_extractor: parse failed (%s); skipping coverage", e)
return []
cte_names = {cte.alias_or_name.lower() for cte in tree.find_all(exp.CTE)}
seen: set[tuple[str, ...]] = set()
out: list[tuple[str, ...]] = []
for t in tree.find_all(exp.Table):
name = (t.name or "").lower()
if not name or name in cte_names:
continue
parts: list[str] = []
catalog = t.args.get("catalog")
db = t.args.get("db")
if catalog and getattr(catalog, "name", None):
parts.append(catalog.name.lower())
if db and getattr(db, "name", None):
parts.append(db.name.lower())
parts.append(name)
ref = tuple(parts)
if ref not in seen:
seen.add(ref)
out.append(ref)
return out
def normalize_table(value: str) -> tuple[str, ...]:
"""Split a `table:` field value into normalized, lowercased parts."""
return tuple(p.strip('"').strip("`").lower() for p in value.split(".") if p)
def ref_matches_source_table(ref: tuple[str, ...], source_table: str) -> bool:
"""True iff `ref` is a suffix of `source_table` (or vice versa for the
1-part bare-name case).
Examples:
ref=(marts, listings) table=ANALYTICS.MARTS.LISTINGS True
ref=(analytics, marts, x) table=ANALYTICS.MARTS.X True
ref=(listings,) table=ANALYTICS.MARTS.LISTINGS True (bare matches last)
ref=(staging, shipments) table=ANALYTICS.MARTS.SHIPMENTS False (db differs)
"""
src = normalize_table(source_table)
if not src or not ref:
return False
if len(ref) > len(src):
return False
return src[-len(ref) :] == ref

View file

@ -0,0 +1,111 @@
from __future__ import annotations
import logging
from dataclasses import dataclass
from typing import Literal
import sqlglot
from sqlglot import exp
logger = logging.getLogger(__name__)
SUPPORTED_TABLE_IDENTIFIER_DIALECTS = {
"bigquery",
"snowflake",
"postgres",
"redshift",
"mysql",
"sqlite",
"tsql",
"clickhouse",
}
ParseTableIdentifierReason = Literal[
"looker_template_unresolved",
"derived_table_not_supported",
"no_physical_table",
"multiple_table_references",
"unsupported_dialect",
"parse_error",
]
@dataclass(frozen=True)
class ParseTableIdentifierItem:
key: str
sql_table_name: str
dialect: str
@dataclass(frozen=True)
class ParsedIdentifier:
ok: bool
catalog: str | None = None
schema_: str | None = None
name: str | None = None
canonical_table: str | None = None
reason: ParseTableIdentifierReason | None = None
detail: str | None = None
def parse_table_identifier_batch(
items: list[ParseTableIdentifierItem],
) -> dict[str, ParsedIdentifier]:
return {
item.key: parse_table_identifier_one(item.sql_table_name, item.dialect)
for item in items
}
def parse_table_identifier_one(sql_table_name: str, dialect: str) -> ParsedIdentifier:
normalized_dialect = dialect.lower()
if normalized_dialect not in SUPPORTED_TABLE_IDENTIFIER_DIALECTS:
return ParsedIdentifier(
ok=False,
reason="unsupported_dialect",
detail=f"Unsupported sqlglot dialect for table identifier parsing: {dialect}",
)
if "${" in sql_table_name or "@{" in sql_table_name:
return ParsedIdentifier(ok=False, reason="looker_template_unresolved")
try:
parsed = sqlglot.parse_one(
f"SELECT * FROM {sql_table_name}",
read=normalized_dialect,
)
from_clause = parsed.args.get("from_")
if from_clause is None or from_clause.this is None:
return ParsedIdentifier(ok=False, reason="no_physical_table")
from_expr = from_clause.this
if isinstance(from_expr, (exp.Subquery, exp.Values, exp.Lateral)):
return ParsedIdentifier(ok=False, reason="derived_table_not_supported")
if not isinstance(from_expr, exp.Table):
return ParsedIdentifier(ok=False, reason="derived_table_not_supported")
tables = list(parsed.find_all(exp.Table))
if not tables:
return ParsedIdentifier(ok=False, reason="no_physical_table")
if len(tables) > 1:
return ParsedIdentifier(ok=False, reason="multiple_table_references")
table = tables[0]
canonical_table = exp.Table(
this=exp.to_identifier(table.name),
db=exp.to_identifier(table.db) if table.db else None,
catalog=exp.to_identifier(table.catalog) if table.catalog else None,
).sql(dialect=normalized_dialect)
return ParsedIdentifier(
ok=True,
catalog=table.catalog or None,
schema_=table.db or None,
name=table.name,
canonical_table=canonical_table,
)
except sqlglot.errors.ParseError as exc:
return ParsedIdentifier(ok=False, reason="parse_error", detail=str(exc))
except Exception as exc:
logger.exception("Unexpected failure while parsing Looker sql_table_name")
return ParsedIdentifier(ok=False, reason="parse_error", detail=str(exc))