Harden semantic layer source validation

This commit is contained in:
Luca Martial 2026-05-11 20:39:07 -07:00
parent 66109caa1d
commit 6536c5da26
14 changed files with 235 additions and 17 deletions

View file

@ -12,6 +12,7 @@ from semantic_layer.models import (
)
from semantic_layer.planner import QueryPlanner
from semantic_layer.sql_table_extractor import (
extract_projected_columns,
extract_table_refs,
ref_matches_source_table,
)
@ -83,15 +84,48 @@ class SemanticEngine:
report.errors.extend(self._collect_orphan_join_target_errors())
def _check_invalid_grain(self, report: ValidationReport) -> None:
dialect = getattr(self.generator, "dialect", "postgres")
for source in self.sources.values():
qualified_grain: set[str] = set()
for grain_col in source.grain:
if "." in grain_col:
qualified_grain.add(grain_col)
report.errors.append(
f"Source '{source.name}' grain entry '{grain_col}' is a "
f"qualified name. Grain must use unqualified output column "
f"names (e.g. 'account_id', not 'activity.account_id')."
)
for col in source.columns:
if "." in col.name:
report.errors.append(
f"Source '{source.name}' column name '{col.name}' contains "
f"'.'. Column names must be unqualified."
)
column_names = {c.name for c in source.columns}
for grain_col in source.grain:
if grain_col in qualified_grain:
continue
if grain_col not in column_names:
report.errors.append(
f"Source '{source.name}' has grain column '{grain_col}' "
f"that is not in its columns list"
)
if source.is_sql_source and source.sql:
projected = extract_projected_columns(source.sql, dialect=dialect)
if projected is not None:
for grain_col in source.grain:
if grain_col in qualified_grain:
continue
if grain_col not in projected:
report.errors.append(
f"Source '{source.name}' grain column '{grain_col}' "
f"is not in the SQL SELECT projection. Add it to the "
f"SELECT list (or remove it from grain)."
)
def _check_join_columns(self, report: ValidationReport) -> None:
for source in self.sources.values():
source_columns = {c.name for c in source.columns}
@ -108,7 +142,9 @@ class SemanticEngine:
)
continue
local_cols = [col.strip() for col in local_raw.split(",") if col.strip()]
local_cols = [
col.strip() for col in local_raw.split(",") if col.strip()
]
target_cols = [
col.strip() for col in target_raw.split(",") if col.strip()
]

View file

@ -70,3 +70,30 @@ def ref_matches_source_table(ref: tuple[str, ...], source_table: str) -> bool:
if len(ref) > len(src):
return False
return src[-len(ref) :] == ref
def extract_projected_columns(sql: str, dialect: str = "postgres") -> set[str] | None:
"""Return the set of output column names projected by `sql`.
Returns None if the projection cannot be statically determined when
SELECT * (or qualified `t.*`) is present, or when parsing fails. Callers
should treat None as "unknown projection" and skip projection-dependent
checks rather than reporting a false-positive error.
"""
try:
tree = sqlglot.parse_one(sql, read=dialect)
except Exception as e:
logger.debug("extract_projected_columns: parse failed (%s); skipping", e)
return None
if not isinstance(tree, exp.Select):
return None
for projection in tree.expressions:
# Bare `*` or `t.*` — projection list is opaque.
if isinstance(projection, exp.Star):
return None
if isinstance(projection, exp.Column) and isinstance(projection.this, exp.Star):
return None
return {name for name in tree.named_selects if name}

View file

@ -119,6 +119,116 @@ class TestInvalidGrain:
assert not report.valid
assert any("bad" in e and "nonexistent_col" in e for e in report.errors)
def test_qualified_grain_name_is_rejected(self):
bad = _src(
"activity",
columns=["account_id"],
grain=["activity.account_id"],
)
engine = SemanticEngine.from_sources({"activity": bad})
report = engine.validate()
assert not report.valid
assert any(
"activity" in e and "activity.account_id" in e and "qualified" in e
for e in report.errors
)
def test_qualified_column_name_is_rejected(self):
bad = SourceDefinition(
name="activity",
table="public.activity",
grain=["account_id"],
columns=[
SourceColumn(name="account_id", type="number"),
SourceColumn(name="activity.user_id", type="number"),
],
)
engine = SemanticEngine.from_sources({"activity": bad})
report = engine.validate()
assert not report.valid
assert any(
"activity" in e and "activity.user_id" in e and "unqualified" in e
for e in report.errors
)
def test_sql_source_grain_missing_from_projection(self):
bad = SourceDefinition(
name="large_contract_requesters",
sql=(
"select account.account_name, requester.email as requester_email "
"from orbit_raw.actions activity "
"join orbit_raw.accounts account "
" on account.account_id = activity.account_id "
"join orbit_raw.users requester "
" on requester.user_id = activity.user_id"
),
grain=["account_id", "user_id"],
columns=[
SourceColumn(name="account_id", type="number"),
SourceColumn(name="user_id", type="number"),
SourceColumn(name="account_name", type="string"),
SourceColumn(name="requester_email", type="string"),
],
)
engine = SemanticEngine.from_sources({"large_contract_requesters": bad})
report = engine.validate()
assert not report.valid
assert any(
"large_contract_requesters" in e
and "account_id" in e
and "SELECT projection" in e
for e in report.errors
)
def test_sql_source_grain_in_projection_passes(self):
good = SourceDefinition(
name="contract_requesters",
sql=(
"select activity.account_id, activity.user_id, "
"account.account_name, requester.email as requester_email "
"from orbit_raw.actions activity "
"join orbit_raw.accounts account "
" on account.account_id = activity.account_id "
"join orbit_raw.users requester "
" on requester.user_id = activity.user_id"
),
grain=["account_id", "user_id"],
columns=[
SourceColumn(name="account_id", type="number"),
SourceColumn(name="user_id", type="number"),
SourceColumn(name="account_name", type="string"),
SourceColumn(name="requester_email", type="string"),
],
)
engine = SemanticEngine.from_sources({"contract_requesters": good})
report = engine.validate()
# No grain-related errors. (Other validators may emit unrelated
# warnings — we just assert the grain check is clean.)
assert not any("grain" in e or "SELECT projection" in e for e in report.errors)
def test_sql_source_with_select_star_skips_projection_check(self):
# SELECT * means we can't statically know projected columns;
# the projection check must skip rather than false-fail.
src = SourceDefinition(
name="opaque",
sql="select * from public.events",
grain=["event_id"],
columns=[SourceColumn(name="event_id", type="number")],
)
engine = SemanticEngine.from_sources({"opaque": src})
report = engine.validate()
assert not any("SELECT projection" in e for e in report.errors)
class TestJoinValidation:
def test_join_local_column_must_exist(self):
@ -246,7 +356,9 @@ class TestJoinValidation:
report = engine.validate(recently_touched={"large_contract_requesters"})
assert not report.valid
assert any("mart_account_segments" in e and "joins[]" in e for e in report.errors)
assert any(
"mart_account_segments" in e and "joins[]" in e for e in report.errors
)
class TestDisconnectedComponents: