Harden semantic layer source validation

This commit is contained in:
Luca Martial 2026-05-11 20:39:07 -07:00
parent 66109caa1d
commit 6536c5da26
14 changed files with 235 additions and 17 deletions

View file

@ -12,6 +12,7 @@ from semantic_layer.models import (
)
from semantic_layer.planner import QueryPlanner
from semantic_layer.sql_table_extractor import (
extract_projected_columns,
extract_table_refs,
ref_matches_source_table,
)
@ -83,15 +84,48 @@ class SemanticEngine:
report.errors.extend(self._collect_orphan_join_target_errors())
def _check_invalid_grain(self, report: ValidationReport) -> None:
dialect = getattr(self.generator, "dialect", "postgres")
for source in self.sources.values():
qualified_grain: set[str] = set()
for grain_col in source.grain:
if "." in grain_col:
qualified_grain.add(grain_col)
report.errors.append(
f"Source '{source.name}' grain entry '{grain_col}' is a "
f"qualified name. Grain must use unqualified output column "
f"names (e.g. 'account_id', not 'activity.account_id')."
)
for col in source.columns:
if "." in col.name:
report.errors.append(
f"Source '{source.name}' column name '{col.name}' contains "
f"'.'. Column names must be unqualified."
)
column_names = {c.name for c in source.columns}
for grain_col in source.grain:
if grain_col in qualified_grain:
continue
if grain_col not in column_names:
report.errors.append(
f"Source '{source.name}' has grain column '{grain_col}' "
f"that is not in its columns list"
)
if source.is_sql_source and source.sql:
projected = extract_projected_columns(source.sql, dialect=dialect)
if projected is not None:
for grain_col in source.grain:
if grain_col in qualified_grain:
continue
if grain_col not in projected:
report.errors.append(
f"Source '{source.name}' grain column '{grain_col}' "
f"is not in the SQL SELECT projection. Add it to the "
f"SELECT list (or remove it from grain)."
)
def _check_join_columns(self, report: ValidationReport) -> None:
for source in self.sources.values():
source_columns = {c.name for c in source.columns}
@ -108,7 +142,9 @@ class SemanticEngine:
)
continue
local_cols = [col.strip() for col in local_raw.split(",") if col.strip()]
local_cols = [
col.strip() for col in local_raw.split(",") if col.strip()
]
target_cols = [
col.strip() for col in target_raw.split(",") if col.strip()
]