Initial open-source release

This commit is contained in:
Andrey Avtomonov 2026-05-10 23:12:26 +02:00
commit 1a42152e6f
1199 changed files with 257054 additions and 0 deletions

View file

View file

@ -0,0 +1,90 @@
from __future__ import annotations
import tempfile
from pathlib import Path
import pytest
import sqlglot
import yaml
from semantic_layer.engine import SemanticEngine
from semantic_layer.loader import SourceLoader
from semantic_layer.models import SourceDefinition
SOURCES_DIR = Path(__file__).parent.parent / "sources" / "ecommerce"
TPCH_DIR = Path(__file__).parent.parent / "sources" / "tpch"
@pytest.fixture
def ecommerce_sources() -> dict[str, SourceDefinition]:
loader = SourceLoader(SOURCES_DIR)
return loader.load_all()
@pytest.fixture
def tpch_sources() -> dict[str, SourceDefinition]:
loader = SourceLoader(TPCH_DIR)
return loader.load_all()
# ── Shared test helpers ──────────────────────────────────────────────
def make_engine(
sources_dict: dict[str, dict], dialect: str = "postgres"
) -> SemanticEngine:
"""Build a SemanticEngine from inline source dicts (writes temp YAML files)."""
tmpdir = tempfile.mkdtemp()
for name, data in sources_dict.items():
with open(Path(tmpdir) / f"{name}.yaml", "w") as f:
yaml.dump(data, f)
return SemanticEngine(tmpdir, dialect=dialect)
def assert_valid_sql(sql: str):
try:
sqlglot.parse(sql)
except Exception as e:
pytest.fail(f"Generated SQL is not valid: {e}\n\nSQL:\n{sql}")
@pytest.fixture
def make_bq_fct_orders_engine() -> SemanticEngine:
"""BigQuery-dialect engine with fct_orders source mirroring the production YAML."""
source = {
"name": "fct_orders",
"table": "analytics.fct_orders",
"grain": ["order_id"],
"columns": [
{"name": "order_id", "type": "number"},
{"name": "status", "type": "string"},
{"name": "transaction_date", "type": "time"},
],
"segments": [
{"name": "non_cancelled", "expr": "status != 'cancelled'"},
{
"name": "last_30_days",
"expr": "transaction_date >= timestamp(date_sub(current_date(), interval 30 day))",
},
],
"measures": [
{
"name": "daily_active_orders",
"expr": "count(distinct order_id)",
"segments": ["non_cancelled", "last_30_days"],
},
],
}
return make_engine({"fct_orders": source}, dialect="bigquery")
@pytest.fixture
def make_engine_factory():
"""Factory fixture: pass a sources-dict + dialect, get a SemanticEngine."""
def _make(
sources_dict: dict[str, dict], dialect: str = "postgres"
) -> SemanticEngine:
return make_engine(sources_dict, dialect=dialect)
return _make

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,447 @@
"""Tests for the CLI interface (semantic_layer.cli)."""
from __future__ import annotations
import json
from io import StringIO
from pathlib import Path
from unittest.mock import patch
import pytest
from semantic_layer.cli import main, print_plan
from semantic_layer.graph import JoinGraph
from semantic_layer.models import (
JoinDeclaration,
SemanticQuery,
SourceColumn,
SourceDefinition,
)
from semantic_layer.planner import QueryPlanner
SOURCES_DIR = str(Path(__file__).parent.parent / "sources" / "ecommerce")
# ── From test_edge_cases.py: TestCliParserArgs ───────────────────────
class TestCliParserArgs:
def test_no_args_errors(self):
with pytest.raises(SystemExit):
main([])
def test_sources_only_no_query(self, capsys):
with pytest.raises(SystemExit):
main(["--sources", SOURCES_DIR])
def test_list_sources_no_measures_needed(self, capsys):
main(["--sources", SOURCES_DIR, "--list-sources"])
output = capsys.readouterr().out
assert "orders" in output
assert "customers" in output
def test_plan_only_mode(self, capsys):
main(
[
"--sources",
SOURCES_DIR,
"-q",
json.dumps(
{
"measures": ["sum(orders.amount)"],
"dimensions": ["orders.status"],
}
),
"--plan-only",
]
)
output = capsys.readouterr().out
assert "Resolved Plan" in output
assert "Anchor" in output
def test_plan_and_sql(self, capsys):
main(
[
"--sources",
SOURCES_DIR,
"-q",
json.dumps(
{
"measures": ["sum(orders.amount)"],
"dimensions": ["orders.status"],
}
),
"--plan",
]
)
output = capsys.readouterr().out
assert "Resolved Plan" in output
assert "SELECT" in output
def test_compact_mode(self, capsys):
main(
[
"--sources",
SOURCES_DIR,
"-q",
json.dumps(
{
"measures": ["sum(orders.amount)"],
"dimensions": ["orders.status"],
}
),
"--compact",
]
)
output = capsys.readouterr().out
assert "SELECT" in output
assert "-- dialect:" not in output
def test_json_input(self, capsys):
query_json = json.dumps(
{
"measures": ["sum(orders.amount)"],
"dimensions": ["orders.status"],
}
)
with patch("sys.stdin", StringIO(query_json)):
main(["--sources", SOURCES_DIR, "--json"])
output = capsys.readouterr().out
assert "SELECT" in output
def test_json_input_with_filters(self, capsys):
query_json = json.dumps(
{
"measures": ["sum(orders.amount)"],
"dimensions": ["orders.status"],
"filters": ["orders.status = 'completed'"],
}
)
with patch("sys.stdin", StringIO(query_json)):
main(["--sources", SOURCES_DIR, "--json"])
output = capsys.readouterr().out
assert "completed" in output
def test_json_input_with_order_by(self, capsys):
query_json = json.dumps(
{
"measures": ["sum(orders.amount)"],
"dimensions": ["orders.status"],
"order_by": [{"field": "orders.status", "direction": "desc"}],
}
)
with patch("sys.stdin", StringIO(query_json)):
main(["--sources", SOURCES_DIR, "--json"])
output = capsys.readouterr().out
assert "SELECT" in output
def test_measures_with_alias(self, capsys):
main(
[
"--sources",
SOURCES_DIR,
"-q",
json.dumps(
{
"measures": [
{"expr": "sum(orders.amount)", "name": "total_rev"}
],
"dimensions": ["orders.status"],
}
),
]
)
output = capsys.readouterr().out
assert "total_rev" in output
def test_dimension_with_granularity_cli(self, capsys):
main(
[
"--sources",
SOURCES_DIR,
"-q",
json.dumps(
{
"measures": ["sum(orders.amount)"],
"dimensions": [
{"field": "orders.created_at", "granularity": "month"}
],
}
),
]
)
output = capsys.readouterr().out
assert "DATE_TRUNC" in output
def test_multiple_filters_cli(self, capsys):
main(
[
"--sources",
SOURCES_DIR,
"-q",
json.dumps(
{
"measures": ["sum(orders.amount)"],
"dimensions": ["orders.status"],
"filters": [
"orders.status = 'completed'",
"orders.amount > 100",
],
}
),
]
)
output = capsys.readouterr().out
assert "WHERE" in output
def test_limit_cli(self, capsys):
main(
[
"--sources",
SOURCES_DIR,
"-q",
json.dumps(
{
"measures": ["sum(orders.amount)"],
"dimensions": ["orders.status"],
"limit": 50,
}
),
]
)
output = capsys.readouterr().out
assert "LIMIT 50" in output
def test_dialect_cli(self, capsys):
main(
[
"--sources",
SOURCES_DIR,
"-q",
json.dumps(
{
"measures": ["sum(orders.amount)"],
"dimensions": ["orders.status"],
}
),
"--dialect",
"bigquery",
]
)
output = capsys.readouterr().out
assert "bigquery" in output
# ── From test_edge_cases.py: TestCLISuggest ──────────────────────────
class TestCliSuggest:
def test_suggest_valid_query(self, capsys):
main(
[
"--sources",
SOURCES_DIR,
"-q",
json.dumps(
{
"measures": ["sum(orders.amount)"],
"dimensions": ["orders.status"],
}
),
"--suggest",
]
)
output = capsys.readouterr().out
assert "valid" in output.lower()
def test_suggest_invalid_query(self, capsys):
main(
[
"--sources",
SOURCES_DIR,
"-q",
json.dumps(
{
"measures": ["sum(nonexistent.amount)"],
"dimensions": ["orders.status"],
}
),
"--suggest",
]
)
output = capsys.readouterr().out
assert "failed" in output.lower() or "Suggestion" in output
# ── From test_edge_cases.py: TestCLIOrderBy ──────────────────────────
class TestCliOrderBy:
def test_order_by_desc(self, capsys):
main(
[
"--sources",
SOURCES_DIR,
"-q",
json.dumps(
{
"measures": ["sum(orders.amount)"],
"dimensions": ["orders.status"],
"order_by": [
{"field": "sum(orders.amount)", "direction": "desc"}
],
}
),
]
)
output = capsys.readouterr().out
assert "DESC" in output
def test_order_by_asc(self, capsys):
main(
[
"--sources",
SOURCES_DIR,
"-q",
json.dumps(
{
"measures": ["sum(orders.amount)"],
"dimensions": ["orders.status"],
"order_by": [{"field": "orders.status", "direction": "asc"}],
}
),
]
)
output = capsys.readouterr().out
assert "ORDER BY" in output
# ── From test_brainstorm_cases.py: TestBrainstormCliOutput ───────────
def _build_chasm_sources() -> dict[str, SourceDefinition]:
customers = SourceDefinition(
name="customers",
table="public.customers",
grain=["id"],
columns=[
SourceColumn(name="id", type="number"),
SourceColumn(name="segment", type="string"),
],
)
orders = SourceDefinition(
name="orders",
table="public.orders",
grain=["id"],
columns=[
SourceColumn(name="id", type="number"),
SourceColumn(name="customer_id", type="number"),
SourceColumn(name="amount", type="number"),
],
joins=[
JoinDeclaration(
to="customers",
on="customer_id = customers.id",
relationship="many_to_one",
)
],
)
tickets = SourceDefinition(
name="tickets",
table="public.tickets",
grain=["id"],
columns=[
SourceColumn(name="id", type="number"),
SourceColumn(name="customer_id", type="number"),
SourceColumn(name="cost", type="number"),
],
joins=[
JoinDeclaration(
to="customers",
on="customer_id = customers.id",
relationship="many_to_one",
)
],
)
return {"customers": customers, "orders": orders, "tickets": tickets}
def _write_sources(sources_dict: dict[str, dict]) -> str:
import tempfile
import yaml
tmpdir = tempfile.mkdtemp()
for name, data in sources_dict.items():
with open(Path(tmpdir) / f"{name}.yaml", "w") as f:
yaml.dump(data, f)
return tmpdir
class TestCliPlanOutput:
def test_print_plan_includes_join_locality_where_and_having(self, capsys):
sources = _build_chasm_sources()
graph = JoinGraph(sources)
graph.build()
planner = QueryPlanner(sources, graph)
query = SemanticQuery(
measures=["sum(orders.amount)", "count(tickets.id)"],
dimensions=["customers.segment"],
filters=["customers.segment = 'SMB'", "sum(orders.amount) > 10000"],
)
plan = planner.plan(query)
print_plan(plan)
output = capsys.readouterr().out
assert "Resolved Plan" in output
assert "Joins:" in output
assert "Locality:" in output
assert "WHERE:" in output
assert "HAVING:" in output
assert "customers.segment" in output
def test_suggest_cli_surfaces_graph_errors(self, capsys):
tmpdir = _write_sources(
{
"a": {
"name": "a",
"table": "t",
"grain": ["id"],
"columns": [{"name": "id", "type": "number"}],
},
"b": {
"name": "b",
"table": "t2",
"grain": ["id"],
"columns": [
{"name": "id", "type": "number"},
{"name": "val", "type": "number"},
],
},
}
)
main(
[
"--sources",
tmpdir,
"-q",
json.dumps({"measures": ["sum(a.id)"], "dimensions": ["b.val"]}),
"--suggest",
]
)
output = capsys.readouterr().out
assert "Query failed:" in output
assert "Graph error:" in output
assert "Disconnected components" in output
assert "Suggestion:" in output
def test_list_sources_includes_join_and_filtered_measure_details(self, capsys):
main(["--sources", SOURCES_DIR, "--list-sources"])
output = capsys.readouterr().out
assert "joins:" in output
assert "→ customers (many_to_one) on customer_id = customers.id" in output
assert "revenue: sum(amount) (filter: status != 'refunded')" in output

View file

@ -0,0 +1,313 @@
"""Tests for computed column (expr) support on table sources."""
from __future__ import annotations
from semantic_layer.models import SourceColumn
from .conftest import assert_valid_sql, make_engine
def _lineitem_source(**overrides):
base = {
"name": "lineitem",
"table": "public.lineitem",
"grain": ["l_orderkey", "l_linenumber"],
"columns": [
{"name": "l_orderkey", "type": "number"},
{"name": "l_linenumber", "type": "number"},
{"name": "l_extendedprice", "type": "number"},
{"name": "l_discount", "type": "number"},
{"name": "l_quantity", "type": "number"},
{"name": "l_returnflag", "type": "string"},
{
"name": "net_price",
"type": "number",
"expr": "l_extendedprice * (1 - l_discount)",
},
],
}
base.update(overrides)
return base
class TestComputedColumnDimension:
def test_computed_column_in_select_and_group_by(self):
engine = make_engine({"lineitem": _lineitem_source()})
result = engine.query(
{
"measures": ["sum(lineitem.l_quantity)"],
"dimensions": ["lineitem.net_price"],
}
)
assert_valid_sql(result.sql)
assert "l_extendedprice" in result.sql
assert "l_discount" in result.sql
assert "AS net_price" in result.sql
def test_date_trunc_on_computed_column(self):
engine = make_engine(
{
"events": {
"name": "events",
"table": "public.events",
"grain": ["id"],
"columns": [
{"name": "id", "type": "number"},
{"name": "created_at", "type": "time", "role": "time"},
{"name": "offset_hours", "type": "number"},
{
"name": "local_time",
"type": "time",
"role": "time",
"expr": "created_at + offset_hours * INTERVAL '1 hour'",
},
{"name": "value", "type": "number"},
],
}
}
)
result = engine.query(
{
"measures": ["sum(events.value)"],
"dimensions": [{"field": "events.local_time", "granularity": "month"}],
}
)
assert_valid_sql(result.sql)
assert "DATE_TRUNC" in result.sql
assert "created_at" in result.sql
assert "offset_hours" in result.sql
class TestComputedColumnInMeasure:
def test_runtime_aggregate_on_computed_column(self):
engine = make_engine({"lineitem": _lineitem_source()})
result = engine.query(
{
"measures": ["sum(lineitem.net_price)"],
"dimensions": [],
}
)
assert_valid_sql(result.sql)
assert "SUM" in result.sql.upper()
assert "l_extendedprice" in result.sql
assert "l_discount" in result.sql
def test_predefined_measure_referencing_computed_column(self):
source = _lineitem_source(
measures=[
{"name": "total_net", "expr": "sum(net_price)"},
]
)
engine = make_engine({"lineitem": source})
result = engine.query(
{
"measures": ["lineitem.total_net"],
"dimensions": [],
}
)
assert_valid_sql(result.sql)
assert "l_extendedprice" in result.sql
assert "l_discount" in result.sql
class TestComputedColumnInFilter:
def test_computed_column_in_where_filter(self):
engine = make_engine({"lineitem": _lineitem_source()})
result = engine.query(
{
"measures": ["sum(lineitem.l_extendedprice)"],
"dimensions": [],
"filters": ["lineitem.net_price > 100"],
}
)
assert_valid_sql(result.sql)
assert "WHERE" in result.sql
assert "l_extendedprice" in result.sql
assert "l_discount" in result.sql
class TestComputedColumnWithJoins:
def test_join_on_uses_physical_columns(self):
engine = make_engine(
{
"orders": {
"name": "orders",
"table": "public.orders",
"grain": ["id"],
"columns": [
{"name": "id", "type": "number"},
{"name": "customer_id", "type": "number"},
{"name": "amount", "type": "number"},
{"name": "discount", "type": "number"},
{
"name": "net_amount",
"type": "number",
"expr": "amount * (1 - discount)",
},
],
"joins": [
{
"to": "customers",
"on": "customer_id = customers.id",
"relationship": "many_to_one",
}
],
},
"customers": {
"name": "customers",
"table": "public.customers",
"grain": ["id"],
"columns": [
{"name": "id", "type": "number"},
{"name": "segment", "type": "string"},
],
},
}
)
result = engine.query(
{
"measures": ["sum(orders.net_amount)"],
"dimensions": ["customers.segment"],
}
)
assert_valid_sql(result.sql)
# JOIN ON should use physical columns
assert "orders.customer_id" in result.sql
assert "customers.id" in result.sql
# Measure should be expanded
assert "orders.amount" in result.sql
assert "orders.discount" in result.sql
class TestComputedColumnLocality:
def test_computed_column_in_aggregate_locality(self):
engine = make_engine(
{
"hub": {
"name": "hub",
"table": "public.hub",
"grain": ["id"],
"columns": [
{"name": "id", "type": "number"},
{"name": "segment", "type": "string"},
],
},
"fact_a": {
"name": "fact_a",
"table": "public.fact_a",
"grain": ["id"],
"columns": [
{"name": "id", "type": "number"},
{"name": "hub_id", "type": "number"},
{"name": "price", "type": "number"},
{"name": "qty", "type": "number"},
{
"name": "total",
"type": "number",
"expr": "price * qty",
},
],
"joins": [
{
"to": "hub",
"on": "hub_id = hub.id",
"relationship": "many_to_one",
}
],
},
"fact_b": {
"name": "fact_b",
"table": "public.fact_b",
"grain": ["id"],
"columns": [
{"name": "id", "type": "number"},
{"name": "hub_id", "type": "number"},
{"name": "val", "type": "number"},
],
"joins": [
{
"to": "hub",
"on": "hub_id = hub.id",
"relationship": "many_to_one",
}
],
},
}
)
result = engine.query(
{
"measures": ["sum(fact_a.total)", "sum(fact_b.val)"],
"dimensions": ["hub.segment"],
}
)
assert_valid_sql(result.sql)
assert "_agg" in result.sql
assert "fact_a.price" in result.sql
assert "fact_a.qty" in result.sql
class TestComputedColumnModel:
def test_source_column_with_expr(self):
col = SourceColumn(
name="net_price", type="number", expr="price * (1 - discount)"
)
assert col.expr == "price * (1 - discount)"
def test_source_column_without_expr(self):
col = SourceColumn(name="price", type="number")
assert col.expr is None
def test_source_column_expr_in_yaml_roundtrip(self):
engine = make_engine(
{
"t": {
"name": "t",
"table": "public.t",
"grain": ["id"],
"columns": [
{"name": "id", "type": "number"},
{"name": "a", "type": "number"},
{"name": "b", "type": "number"},
{
"name": "c",
"type": "number",
"expr": "a + b",
},
],
}
}
)
src = engine.sources["t"]
c_col = next(c for c in src.columns if c.name == "c")
assert c_col.expr == "a + b"
def test_bigquery_computed_column_with_timestamp_add(make_engine_factory):
"""Computed column authored with BigQuery-native TIMESTAMP_ADD must survive."""
source = {
"name": "events",
"table": "events",
"grain": ["id"],
"columns": [
{"name": "id", "type": "number"},
{"name": "event_at", "type": "time"},
{"name": "tz_offset", "type": "number"},
{
"name": "local_hour",
"type": "time",
"expr": "TIMESTAMP_ADD(event_at, INTERVAL tz_offset HOUR)",
},
],
"measures": [{"name": "cnt", "expr": "count(*)"}],
}
engine = make_engine_factory({"events": source}, dialect="bigquery")
result = engine.query(
{
"measures": ["events.cnt"],
"dimensions": ["events.local_hour"],
"filters": [],
}
)
assert "TIMESTAMP_ADD" in result.sql.upper()
assert "HOUR" in result.sql.upper()

View file

@ -0,0 +1,288 @@
from __future__ import annotations
from conftest import assert_valid_sql, make_engine
def _duplicate_predefined_sources() -> dict[str, dict]:
return {
"customers": {
"name": "customers",
"table": "public.customers",
"grain": ["id"],
"columns": [
{"name": "id", "type": "number"},
{"name": "segment", "type": "string"},
],
},
"orders": {
"name": "orders",
"table": "public.orders",
"grain": ["id"],
"columns": [
{"name": "id", "type": "number"},
{"name": "customer_id", "type": "number"},
{"name": "amount", "type": "number"},
],
"joins": [
{
"to": "customers",
"on": "customer_id = customers.id",
"relationship": "many_to_one",
}
],
"measures": [{"name": "revenue", "expr": "sum(amount)"}],
},
"refunds": {
"name": "refunds",
"table": "public.refunds",
"grain": ["id"],
"columns": [
{"name": "id", "type": "number"},
{"name": "customer_id", "type": "number"},
{"name": "amount", "type": "number"},
],
"joins": [
{
"to": "customers",
"on": "customer_id = customers.id",
"relationship": "many_to_one",
}
],
"measures": [{"name": "revenue", "expr": "sum(amount)"}],
},
}
def _include_empty_sources() -> dict[str, dict]:
return {
"customers": {
"name": "customers",
"table": "public.customers",
"grain": ["id"],
"columns": [
{"name": "id", "type": "number"},
{"name": "segment", "type": "string"},
],
},
"orders": {
"name": "orders",
"table": "public.orders",
"grain": ["id"],
"columns": [
{"name": "id", "type": "number"},
{"name": "customer_id", "type": "number"},
{"name": "amount", "type": "number"},
],
"joins": [
{
"to": "customers",
"on": "customer_id = customers.id",
"relationship": "many_to_one",
}
],
},
}
def _alias_measure_sources() -> dict[str, dict]:
return {
"customers": {
"name": "customers",
"table": "public.customers",
"grain": ["id"],
"columns": [
{"name": "id", "type": "number"},
{"name": "lifetime_value", "type": "number"},
],
"measures": [{"name": "total_ltv", "expr": "sum(lifetime_value)"}],
},
"orders": {
"name": "orders",
"table": "public.orders",
"grain": ["id"],
"columns": [
{"name": "id", "type": "number"},
{"name": "billing_customer_id", "type": "number"},
{"name": "status", "type": "string"},
],
"joins": [
{
"to": "customers",
"on": "billing_customer_id = customers.id",
"relationship": "many_to_one",
"alias": "billing_customer",
}
],
},
}
def test_duplicate_predefined_names_stay_distinct_in_derived_measure():
engine = make_engine(_duplicate_predefined_sources())
result = engine.query(
{
"measures": [
"orders.revenue",
"refunds.revenue",
{"expr": "orders.revenue - refunds.revenue", "name": "net"},
],
"dimensions": ["customers.segment"],
}
)
assert result.resolved_plan.has_fan_out
assert "orders_agg.orders_revenue" in result.sql
assert "refunds_agg.refunds_revenue" in result.sql
assert "revenue - revenue" not in result.sql
assert_valid_sql(result.sql)
def test_duplicate_predefined_names_expand_having_filters_in_locality_mode():
engine = make_engine(_duplicate_predefined_sources())
result = engine.query(
{
"measures": ["orders.revenue", "refunds.revenue"],
"dimensions": ["customers.segment"],
"filters": ["orders.revenue > 100"],
}
)
# In multi-CTE mode, HAVING refs are wrapped in COALESCE for FULL JOIN NULL safety
assert "WHERE COALESCE(orders_agg.orders_revenue, 0) > 100" in result.sql
assert_valid_sql(result.sql)
def test_include_empty_anchors_the_dimension_side():
engine = make_engine(_include_empty_sources())
result = engine.query(
{
"measures": ["sum(orders.amount)"],
"dimensions": ["customers.segment"],
"include_empty": True,
}
)
assert result.resolved_plan.anchor_source == "customers"
assert "FROM public.customers AS customers" in result.sql
assert "LEFT JOIN public.orders AS orders" in result.sql
assert_valid_sql(result.sql)
def test_cross_grain_measures_on_same_chain_use_aggregate_locality():
engine = make_engine(
{
"customers": {
"name": "customers",
"table": "public.customers",
"grain": ["id"],
"columns": [
{"name": "id", "type": "number"},
{"name": "segment", "type": "string"},
{"name": "credit_limit", "type": "number"},
],
},
"orders": {
"name": "orders",
"table": "public.orders",
"grain": ["id"],
"columns": [
{"name": "id", "type": "number"},
{"name": "customer_id", "type": "number"},
{"name": "amount", "type": "number"},
],
"joins": [
{
"to": "customers",
"on": "customer_id = customers.id",
"relationship": "many_to_one",
}
],
},
}
)
result = engine.query(
{
"measures": ["sum(orders.amount)", "sum(customers.credit_limit)"],
"dimensions": ["customers.segment"],
}
)
assert result.resolved_plan.has_fan_out
assert "orders_agg" in result.sql
assert "customers_agg" in result.sql
assert_valid_sql(result.sql)
def test_filtered_count_distinct_keeps_distinct_inside_count():
engine = make_engine(
{
"orders": {
"name": "orders",
"table": "public.orders",
"grain": ["id"],
"columns": [
{"name": "id", "type": "number"},
{"name": "customer_id", "type": "number"},
{"name": "status", "type": "string"},
],
"measures": [
{
"name": "paid_customers",
"expr": "count_distinct(customer_id)",
"filter": "status = 'paid'",
}
],
}
}
)
result = engine.query(
{"measures": ["orders.paid_customers"], "dimensions": ["orders.status"]}
)
assert "COUNT(DISTINCT CASE WHEN orders.status = 'paid'" in result.sql
assert_valid_sql(result.sql)
def test_predefined_measure_via_alias_uses_real_table_and_alias_qualification():
engine = make_engine(_alias_measure_sources())
result = engine.query(
{
"measures": ["billing_customer.total_ltv"],
"dimensions": ["billing_customer.id"],
}
)
assert "FROM public.customers AS billing_customer" in result.sql
assert "SUM(billing_customer.lifetime_value)" in result.sql
assert_valid_sql(result.sql)
def test_runtime_case_measure_gets_a_safe_auto_alias():
engine = make_engine(
{
"orders": {
"name": "orders",
"table": "public.orders",
"grain": ["id"],
"columns": [
{"name": "id", "type": "number"},
{"name": "amount", "type": "number"},
{"name": "status", "type": "string"},
],
}
}
)
result = engine.query(
{
"measures": [
"sum(CASE WHEN orders.status = 'paid' THEN orders.amount ELSE 0 END)"
],
"dimensions": ["orders.status"],
}
)
assert (
"sum_case_when_orders_status_paid_then_orders_amount_else_0_end" in result.sql
)
assert "=" not in result.resolved_plan.measures[0].name
assert_valid_sql(result.sql)

View file

@ -0,0 +1,740 @@
"""Tests targeting specific coverage gaps in planner.py, generator.py, models.py, engine.py."""
from __future__ import annotations
import pytest
from pydantic import ValidationError
from semantic_layer.generator import SqlGenerator
from semantic_layer.graph import JoinGraph
from semantic_layer.models import (
JoinDeclaration,
MeasureDefinition,
SemanticQuery,
SourceColumn,
SourceDefinition,
)
from semantic_layer.planner import QueryPlanner
from conftest import assert_valid_sql, make_engine
# ── Helpers ──────────────────────────────────────────────────────────
def _make_planner(sources: dict[str, SourceDefinition]) -> QueryPlanner:
graph = JoinGraph(sources)
graph.build()
return QueryPlanner(sources, graph)
def _plan_and_generate(sources: dict[str, SourceDefinition], query_dict: dict) -> str:
planner = _make_planner(sources)
generator = SqlGenerator(dialect="postgres")
query = SemanticQuery(**query_dict)
plan = planner.plan(query)
sql = generator.generate(plan, sources)
assert_valid_sql(sql)
return sql
# ── Source fixtures ──────────────────────────────────────────────────
def _simple_sources() -> dict[str, SourceDefinition]:
"""orders -> customers (m2o)."""
customers = SourceDefinition(
name="customers",
table="public.customers",
grain=["id"],
columns=[
SourceColumn(name="id", type="number"),
SourceColumn(name="segment", type="string"),
],
)
orders = SourceDefinition(
name="orders",
table="public.orders",
grain=["id"],
columns=[
SourceColumn(name="id", type="number"),
SourceColumn(name="customer_id", type="number"),
SourceColumn(name="amount", type="number"),
SourceColumn(name="status", type="string"),
],
joins=[
JoinDeclaration(
to="customers",
on="customer_id = customers.id",
relationship="many_to_one",
)
],
measures=[
MeasureDefinition(
name="revenue", expr="sum(amount)", filter="status != 'refunded'"
),
MeasureDefinition(name="order_count", expr="count(id)"),
],
)
return {"customers": customers, "orders": orders}
def _chasm_sources() -> dict[str, SourceDefinition]:
"""Two fact tables (orders, tickets) -> hub (customers). Classic chasm trap."""
customers = SourceDefinition(
name="customers",
table="public.customers",
grain=["id"],
columns=[
SourceColumn(name="id", type="number"),
SourceColumn(name="segment", type="string"),
],
)
orders = SourceDefinition(
name="orders",
table="public.orders",
grain=["id"],
columns=[
SourceColumn(name="id", type="number"),
SourceColumn(name="customer_id", type="number"),
SourceColumn(name="amount", type="number"),
],
joins=[
JoinDeclaration(
to="customers",
on="customer_id = customers.id",
relationship="many_to_one",
)
],
measures=[MeasureDefinition(name="revenue", expr="sum(amount)")],
)
tickets = SourceDefinition(
name="tickets",
table="public.tickets",
grain=["id"],
columns=[
SourceColumn(name="id", type="number"),
SourceColumn(name="customer_id", type="number"),
SourceColumn(name="priority", type="string"),
],
joins=[
JoinDeclaration(
to="customers",
on="customer_id = customers.id",
relationship="many_to_one",
)
],
measures=[MeasureDefinition(name="ticket_count", expr="count(id)")],
)
return {"customers": customers, "orders": orders, "tickets": tickets}
def _chain_sources_with_derived() -> dict[str, SourceDefinition]:
"""orders -> customers -> tiers (m2o chain) with derived measures."""
tiers = SourceDefinition(
name="tiers",
table="public.tiers",
grain=["id"],
columns=[
SourceColumn(name="id", type="number"),
SourceColumn(name="level", type="string"),
],
)
customers = SourceDefinition(
name="customers",
table="public.customers",
grain=["id"],
columns=[
SourceColumn(name="id", type="number"),
SourceColumn(name="tier_id", type="number"),
SourceColumn(name="segment", type="string"),
],
joins=[
JoinDeclaration(
to="tiers", on="tier_id = tiers.id", relationship="many_to_one"
)
],
)
orders = SourceDefinition(
name="orders",
table="public.orders",
grain=["id"],
columns=[
SourceColumn(name="id", type="number"),
SourceColumn(name="customer_id", type="number"),
SourceColumn(name="amount", type="number"),
SourceColumn(name="status", type="string"),
],
joins=[
JoinDeclaration(
to="customers",
on="customer_id = customers.id",
relationship="many_to_one",
)
],
measures=[
MeasureDefinition(
name="revenue", expr="sum(amount)", filter="status != 'refunded'"
),
MeasureDefinition(name="order_count", expr="count(id)"),
MeasureDefinition(name="avg_order", expr="revenue / order_count"),
],
)
return {"tiers": tiers, "customers": customers, "orders": orders}
# ── Planner: nested aggregation (lines 432-440) ─────────────────────
class TestNestedAggregation:
def test_nested_aggregation_raises(self):
"""avg(sum(orders.amount)) should be rejected."""
sources = _simple_sources()
planner = _make_planner(sources)
with pytest.raises(ValueError, match="Nested aggregation is not supported"):
planner.plan(
SemanticQuery(
measures=["avg(sum(orders.amount))"],
dimensions=["orders.status"],
)
)
def test_nested_max_count_raises(self):
"""max(count(orders.id)) should be rejected."""
sources = _simple_sources()
planner = _make_planner(sources)
with pytest.raises(ValueError, match="Nested aggregation is not supported"):
planner.plan(
SemanticQuery(
measures=["max(count(orders.id))"],
dimensions=["orders.status"],
)
)
# ── Planner: OR filter mixing (lines 810-833) ───────────────────────
class TestOrFilterMixing:
def test_or_mixing_agg_and_nonagg_raises(self):
"""OR that mixes aggregate and non-aggregate conditions should raise."""
sources = _simple_sources()
planner = _make_planner(sources)
with pytest.raises(ValueError, match="mixes aggregate and non-aggregate"):
planner.plan(
SemanticQuery(
measures=["sum(orders.amount)"],
dimensions=["orders.status"],
filters=["orders.amount > 100 OR sum(orders.amount) > 5000"],
)
)
def test_or_pure_where_ok(self):
"""OR with all non-aggregate conditions should be fine."""
sources = _simple_sources()
sql = _plan_and_generate(
sources,
{
"measures": ["sum(orders.amount)"],
"dimensions": ["orders.status"],
"filters": ["orders.amount > 100 OR orders.amount < 10"],
},
)
assert "OR" in sql.upper()
def test_or_pure_having_ok(self):
"""OR with all aggregate conditions should be fine."""
sources = _simple_sources()
sql = _plan_and_generate(
sources,
{
"measures": ["sum(orders.amount)"],
"dimensions": ["orders.status"],
"filters": ["sum(orders.amount) > 1000 OR count(orders.id) > 5"],
},
)
assert "HAVING" in sql.upper()
# ── Planner: empty source refs (line 62) ─────────────────────────────
class TestEmptySourceRef:
def test_no_source_refs_raises(self):
"""Query that references no sources should raise."""
sources = _simple_sources()
planner = _make_planner(sources)
with pytest.raises(ValueError, match="does not reference any source"):
planner.plan(
SemanticQuery(
measures=["sum(1)"],
dimensions=[],
)
)
# ── Planner: predefined measure dependency chains (lines 189-194, 237, 281-282) ──
class TestPredefinedMeasureDeps:
def test_derived_measure_resolves_dependencies(self):
"""avg_order depends on revenue and order_count — both should appear in plan."""
sources = _chain_sources_with_derived()
planner = _make_planner(sources)
plan = planner.plan(
SemanticQuery(
measures=["orders.avg_order"],
dimensions=["orders.status"],
)
)
measure_names = {m.name for m in plan.measures}
assert "avg_order" in measure_names
assert "revenue" in measure_names
assert "order_count" in measure_names
def test_derived_measure_generates_valid_sql(self):
"""Derived measures should produce valid SQL."""
sources = _chain_sources_with_derived()
sql = _plan_and_generate(
sources,
{
"measures": ["orders.avg_order"],
"dimensions": ["customers.segment"],
},
)
assert "GROUP BY" in sql.upper()
# ── Planner: fan-out with one_to_many to dimension sources (lines 595-643) ──
class TestFanOutEdgeCases:
def test_single_source_fan_out_to_dimension(self):
"""Measure source with one_to_many to dimension should trigger fan-out."""
hub = SourceDefinition(
name="hub",
table="public.hub",
grain=["id"],
columns=[
SourceColumn(name="id", type="number"),
SourceColumn(name="name", type="string"),
],
joins=[
JoinDeclaration(
to="detail", on="id = detail.hub_id", relationship="one_to_many"
)
],
)
detail = SourceDefinition(
name="detail",
table="public.detail",
grain=["id"],
columns=[
SourceColumn(name="id", type="number"),
SourceColumn(name="hub_id", type="number"),
SourceColumn(name="category", type="string"),
],
)
sources = {"hub": hub, "detail": detail}
planner = _make_planner(sources)
plan = planner.plan(
SemanticQuery(
measures=["sum(hub.id)"],
dimensions=["detail.category"],
)
)
assert plan.has_fan_out
def test_merged_groups_fan_out_to_dimension(self):
"""Two measure sources on the same m2o chain, but with o2m to dimension source."""
dim = SourceDefinition(
name="dim",
table="public.dim",
grain=["id"],
columns=[
SourceColumn(name="id", type="number"),
SourceColumn(name="label", type="string"),
],
)
parent = SourceDefinition(
name="parent",
table="public.parent",
grain=["id"],
columns=[
SourceColumn(name="id", type="number"),
SourceColumn(name="val", type="number"),
],
joins=[
JoinDeclaration(to="dim", on="id = dim.id", relationship="one_to_many")
],
)
child = SourceDefinition(
name="child",
table="public.child",
grain=["id"],
columns=[
SourceColumn(name="id", type="number"),
SourceColumn(name="parent_id", type="number"),
SourceColumn(name="amount", type="number"),
],
joins=[
JoinDeclaration(
to="parent", on="parent_id = parent.id", relationship="many_to_one"
)
],
)
sources = {"dim": dim, "parent": parent, "child": child}
planner = _make_planner(sources)
plan = planner.plan(
SemanticQuery(
measures=["sum(child.amount)"],
dimensions=["dim.label"],
)
)
assert plan.has_fan_out
def test_filter_fan_out_one_to_many_raises(self):
"""Filter on source reachable only via one_to_many from measure source should raise."""
parent = SourceDefinition(
name="parent",
table="public.parent",
grain=["id"],
columns=[
SourceColumn(name="id", type="number"),
SourceColumn(name="val", type="number"),
],
joins=[
JoinDeclaration(
to="child", on="id = child.parent_id", relationship="one_to_many"
)
],
)
child = SourceDefinition(
name="child",
table="public.child",
grain=["id"],
columns=[
SourceColumn(name="id", type="number"),
SourceColumn(name="parent_id", type="number"),
SourceColumn(name="category", type="string"),
],
)
sources = {"parent": parent, "child": child}
planner = _make_planner(sources)
with pytest.raises(ValueError, match="one_to_many join"):
planner.plan(
SemanticQuery(
measures=["sum(parent.val)"],
dimensions=[],
filters=["child.category = 'A'"],
)
)
# ── Generator: NULL dimension in multi-CTE (lines 385-388) ──────────
class TestNullDimensionInCTE:
def test_dimension_not_in_any_cte_gets_null(self):
"""When a dimension is from a source not reachable by any CTE, generate NULL."""
# Use a 3-fact chasm topology where one dimension is only reachable by one fact
hub = SourceDefinition(
name="hub",
table="public.hub",
grain=["id"],
columns=[
SourceColumn(name="id", type="number"),
SourceColumn(name="name", type="string"),
],
)
fact_a = SourceDefinition(
name="fact_a",
table="public.fact_a",
grain=["id"],
columns=[
SourceColumn(name="id", type="number"),
SourceColumn(name="hub_id", type="number"),
SourceColumn(name="val", type="number"),
SourceColumn(name="extra", type="string"),
],
joins=[
JoinDeclaration(
to="hub", on="hub_id = hub.id", relationship="many_to_one"
)
],
)
fact_b = SourceDefinition(
name="fact_b",
table="public.fact_b",
grain=["id"],
columns=[
SourceColumn(name="id", type="number"),
SourceColumn(name="hub_id", type="number"),
SourceColumn(name="val", type="number"),
],
joins=[
JoinDeclaration(
to="hub", on="hub_id = hub.id", relationship="many_to_one"
)
],
)
sources = {"hub": hub, "fact_a": fact_a, "fact_b": fact_b}
sql = _plan_and_generate(
sources,
{
"measures": ["sum(fact_a.val)", "sum(fact_b.val)"],
"dimensions": ["hub.name"],
},
)
# Should produce aggregate locality CTEs with FULL JOIN
assert "FULL" in sql.upper() or "WITH" in sql.upper()
# ── Generator: CTE alias collision (lines 202-206) ──────────────────
class TestCTEAliasCollision:
def test_alias_collision_resolved(self):
"""When a source name matches a potential CTE alias, suffix should be used."""
# Create a source named "orders_agg" to collide with the CTE alias
orders_agg = SourceDefinition(
name="orders_agg",
table="public.orders_agg",
grain=["id"],
columns=[
SourceColumn(name="id", type="number"),
SourceColumn(name="segment", type="string"),
],
)
orders = SourceDefinition(
name="orders",
table="public.orders",
grain=["id"],
columns=[
SourceColumn(name="id", type="number"),
SourceColumn(name="customer_id", type="number"),
SourceColumn(name="amount", type="number"),
],
joins=[
JoinDeclaration(
to="orders_agg",
on="customer_id = orders_agg.id",
relationship="many_to_one",
)
],
)
tickets = SourceDefinition(
name="tickets",
table="public.tickets",
grain=["id"],
columns=[
SourceColumn(name="id", type="number"),
SourceColumn(name="customer_id", type="number"),
SourceColumn(name="priority", type="string"),
],
joins=[
JoinDeclaration(
to="orders_agg",
on="customer_id = orders_agg.id",
relationship="many_to_one",
)
],
)
sources = {"orders_agg": orders_agg, "orders": orders, "tickets": tickets}
sql = _plan_and_generate(
sources,
{
"measures": ["sum(orders.amount)", "count(tickets.id)"],
"dimensions": ["orders_agg.segment"],
},
)
# Should still produce valid SQL even with the collision
assert_valid_sql(sql)
# ── Models: negative limit (line 95) ────────────────────────────────
class TestNegativeLimit:
def test_negative_limit_raises(self):
with pytest.raises(ValidationError, match="limit"):
SemanticQuery(
measures=["sum(orders.amount)"],
limit=-1,
)
def test_zero_limit_allowed(self):
q = SemanticQuery(measures=["sum(orders.amount)"], limit=0)
assert q.limit == 0
# ── Engine: suggest with missing sources (lines 100-106, 127) ────────
class TestEngineSuggest:
def test_suggest_with_missing_source(self):
"""Suggest should return suggestions for missing sources."""
engine = make_engine(
{
"orders": {
"name": "orders",
"table": "public.orders",
"grain": ["id"],
"columns": [
{"name": "id", "type": "number"},
{"name": "amount", "type": "number"},
],
},
}
)
result = engine.suggest(
{
"measures": ["sum(unknown_source.val)"],
"dimensions": ["orders.id"],
}
)
assert not result["success"]
assert any(
"missing" in s["description"].lower()
or "unknown_source" in s["description"]
for s in result.get("suggestions", [])
)
def test_suggest_with_dict_measure_and_dimension(self):
"""Suggest handles dict-format measures and dimensions in failure path."""
engine = make_engine(
{
"orders": {
"name": "orders",
"table": "public.orders",
"grain": ["id"],
"columns": [
{"name": "id", "type": "number"},
{"name": "amount", "type": "number"},
],
},
}
)
# Use a nested aggregate to trigger a planner error that hits the dict-handling code
result = engine.suggest(
{
"measures": [{"expr": "avg(sum(missing.val))", "name": "total"}],
"dimensions": [{"field": "missing.category"}],
}
)
assert not result["success"]
# ── Planner: order_by resolution formats (lines 113-116) ────────────
class TestOrderByResolution:
def test_order_by_as_dict(self):
sources = _simple_sources()
planner = _make_planner(sources)
plan = planner.plan(
SemanticQuery(
measures=["sum(orders.amount)"],
dimensions=["orders.status"],
order_by=[{"field": "orders.status", "direction": "desc"}],
)
)
assert len(plan.order_by) == 1
assert plan.order_by[0].direction == "desc"
def test_order_by_as_string(self):
sources = _simple_sources()
planner = _make_planner(sources)
plan = planner.plan(
SemanticQuery(
measures=["sum(orders.amount)"],
dimensions=["orders.status"],
order_by=["orders.status"],
)
)
assert len(plan.order_by) == 1
# ── Planner: measure with no source refs (line 343) ─────────────────
class TestMeasureNoSourceRef:
def test_bare_column_no_aggregate_raises(self):
"""A measure like 'orders.nonexistent' that isn't predefined should raise."""
sources = _simple_sources()
planner = _make_planner(sources)
with pytest.raises(
ValueError, match="does not reference any source|not a pre-defined measure"
):
planner.plan(
SemanticQuery(
measures=["sum(1)"],
dimensions=["orders.status"],
)
)
# ── Generator: custom aggregate parsing (lines 614-617) ─────────────
class TestCustomAggregates:
def test_count_distinct_generates_valid_sql(self):
sources = _simple_sources()
sql = _plan_and_generate(
sources,
{
"measures": ["count(distinct orders.id)"],
"dimensions": ["orders.status"],
},
)
upper = sql.upper()
assert "COUNT(DISTINCT" in upper or "COUNT (DISTINCT" in upper
# ── Generator: qualified predefined expressions via multi-hop joins (lines 925-931) ──
class TestQualifiedPredefinedExpr:
def test_predefined_filter_with_joined_column(self):
"""Predefined measure with a filter referencing a column from a joined table."""
sources = _chain_sources_with_derived()
sql = _plan_and_generate(
sources,
{
"measures": ["orders.revenue"],
"dimensions": ["tiers.level"],
},
)
assert_valid_sql(sql)
assert "CASE WHEN" in sql.upper()
# ── End-to-end: chasm trap with aggregate locality ───────────────────
class TestChasmTrapEndToEnd:
def test_two_fact_tables_produce_valid_sql(self):
sources = _chasm_sources()
sql = _plan_and_generate(
sources,
{
"measures": ["sum(orders.amount)", "count(tickets.id)"],
"dimensions": ["customers.segment"],
},
)
upper = sql.upper()
assert "WITH" in upper
assert "FULL" in upper or "JOIN" in upper
def test_chasm_with_filter_on_hub(self):
sources = _chasm_sources()
sql = _plan_and_generate(
sources,
{
"measures": ["sum(orders.amount)", "count(tickets.id)"],
"dimensions": ["customers.segment"],
"filters": ["customers.segment = 'enterprise'"],
},
)
assert "enterprise" in sql
assert_valid_sql(sql)

View file

@ -0,0 +1,220 @@
"""Tests for semantic_layer.duplicate_check.validate_measure_duplicates."""
from __future__ import annotations
from semantic_layer.duplicate_check import validate_measure_duplicates
from semantic_layer.models import (
MeasureDefinition,
SourceColumn,
SourceDefinition,
)
def _make_source(name: str, measures: list[MeasureDefinition]) -> SourceDefinition:
return SourceDefinition(
name=name,
table=f"public.{name}",
grain=["id"],
columns=[SourceColumn(name="id", type="number")],
measures=measures,
)
def test_same_expr_different_filter_is_flagged() -> None:
"""The replay-trimmed case: count(*) twice, one with is_active filter."""
source = _make_source(
"fct_subscriptions",
[
MeasureDefinition(
name="active_subscription_count",
expr="count(*)",
filter="is_active = true",
),
MeasureDefinition(
name="new_subscription_count",
expr="count(*)",
),
],
)
errors = validate_measure_duplicates({"fct_subscriptions": source})
assert len(errors) == 1
assert "new_subscription_count" in errors[0]
assert "active_subscription_count" in errors[0]
assert "differs only by `filter`" in errors[0]
def test_same_expr_same_filter_is_flagged() -> None:
"""Two measures with identical expr and filter — flagged as duplicate pair."""
source = _make_source(
"fct_orders",
[
MeasureDefinition(
name="order_count_a", expr="count(*)", filter="is_paid = true"
),
MeasureDefinition(
name="order_count_b", expr="count(*)", filter="is_paid = true"
),
],
)
errors = validate_measure_duplicates({"fct_orders": source})
assert len(errors) == 1
assert "same expression and filter" in errors[0]
def test_different_expr_is_not_flagged() -> None:
"""count(*) vs sum(amount) on same source — legitimately distinct measures."""
source = _make_source(
"fct_orders",
[
MeasureDefinition(name="order_count", expr="count(*)"),
MeasureDefinition(name="total_revenue", expr="sum(amount)"),
MeasureDefinition(name="avg_revenue", expr="avg(amount)"),
],
)
errors = validate_measure_duplicates({"fct_orders": source})
assert errors == []
def test_measures_on_different_sources_not_compared() -> None:
"""Same expr on two different sources is not a duplicate."""
a = _make_source("fct_a", [MeasureDefinition(name="total", expr="count(*)")])
b = _make_source("fct_b", [MeasureDefinition(name="total", expr="count(*)")])
errors = validate_measure_duplicates({"fct_a": a, "fct_b": b})
assert errors == []
def test_whitespace_and_case_are_normalized() -> None:
"""COUNT(*) and count(*) and count( * ) all compare equal."""
source = _make_source(
"fct_orders",
[
MeasureDefinition(name="a", expr="count(*)"),
MeasureDefinition(name="b", expr="COUNT(*)"),
MeasureDefinition(name="c", expr=" count( * ) "),
],
)
errors = validate_measure_duplicates({"fct_orders": source})
# Three measures pairwise — should yield 3 errors (a vs b, a vs c, b vs c)
assert len(errors) == 3
def test_unparseable_expr_is_skipped_not_errored() -> None:
"""A measure whose expr can't be parsed is ignored — don't block commit."""
source = _make_source(
"fct_orders",
[
MeasureDefinition(name="bad", expr="!!! not SQL !!!"),
MeasureDefinition(name="good", expr="count(*)"),
],
)
# Should not raise, should not flag — the parser validator will catch the bad one elsewhere
errors = validate_measure_duplicates({"fct_orders": source})
assert errors == []
def test_non_commutative_args_not_treated_as_equivalent() -> None:
"""safe_divide(a, b) is NOT equivalent to safe_divide(b, a)."""
source = _make_source(
"fct_orders",
[
MeasureDefinition(
name="ratio_ab", expr="safe_divide(count(*), sum(amount))"
),
MeasureDefinition(
name="ratio_ba", expr="safe_divide(sum(amount), count(*))"
),
],
)
errors = validate_measure_duplicates({"fct_orders": source})
assert errors == []
def test_single_measure_source_no_comparison() -> None:
source = _make_source(
"fct_orders", [MeasureDefinition(name="total", expr="count(*)")]
)
errors = validate_measure_duplicates({"fct_orders": source})
assert errors == []
def test_same_expr_different_segments_is_not_flagged() -> None:
"""Two measures with same expr but different named segments are by-design distinct."""
source = _make_source(
"fct_subscriptions",
[
MeasureDefinition(
name="active_count", expr="count(*)", segments=["active"]
),
MeasureDefinition(
name="inactive_count", expr="count(*)", segments=["inactive"]
),
],
)
errors = validate_measure_duplicates({"fct_subscriptions": source})
assert errors == []
def test_same_expr_same_segments_is_flagged() -> None:
"""Same expr + same segment set = a true duplicate."""
source = _make_source(
"fct_subscriptions",
[
MeasureDefinition(name="a_count", expr="count(*)", segments=["active"]),
MeasureDefinition(name="b_count", expr="count(*)", segments=["active"]),
],
)
errors = validate_measure_duplicates({"fct_subscriptions": source})
assert len(errors) == 1
assert "same expression and filter" in errors[0]
def test_segment_difference_with_filter_difference_not_flagged() -> None:
"""Segments differ → distinct measures even if filter also differs."""
source = _make_source(
"fct_subscriptions",
[
MeasureDefinition(
name="m1",
expr="count(*)",
segments=["active"],
filter="protocol = 'TRT'",
),
MeasureDefinition(name="m2", expr="count(*)", segments=["inactive"]),
],
)
errors = validate_measure_duplicates({"fct_subscriptions": source})
assert errors == []
def test_bigquery_native_exprs_compared_correctly():
"""Two measures with identical BigQuery-native exprs must be flagged as duplicates."""
from semantic_layer.duplicate_check import validate_measure_duplicates
from semantic_layer.models import (
MeasureDefinition,
SourceColumn,
SourceDefinition,
)
source = SourceDefinition(
name="fct_orders",
table="fct_orders",
grain=["id"],
columns=[
SourceColumn(name="id", type="number"),
SourceColumn(name="amount", type="number"),
],
measures=[
MeasureDefinition(
name="safe_ratio_a",
expr="SAFE_DIVIDE(sum(amount), count(*))",
),
MeasureDefinition(
name="safe_ratio_b",
expr="SAFE_DIVIDE(sum(amount), count(*))",
),
],
)
errors = validate_measure_duplicates({"fct_orders": source}, dialect="bigquery")
assert any("safe_ratio_a" in e and "safe_ratio_b" in e for e in errors), (
f"Duplicate detection missed identical BigQuery-native exprs: {errors}"
)

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,731 @@
import pytest
from semantic_layer.graph import JoinGraph
from semantic_layer.models import SourceDefinition, SourceColumn, JoinDeclaration
@pytest.fixture
def graph(ecommerce_sources):
g = JoinGraph(ecommerce_sources)
g.build()
return g
class TestJoinGraphBuild:
def test_all_sources_in_adjacency(self, graph, ecommerce_sources):
assert set(graph.adjacency.keys()) == set(ecommerce_sources.keys())
def test_bidirectional_edges(self, graph):
# orders declares join to customers → both directions exist
orders_edges = graph.adjacency["orders"]
assert any(e.to_source == "customers" for e in orders_edges)
customers_edges = graph.adjacency["customers"]
assert any(e.to_source == "orders" for e in customers_edges)
def test_relationship_inversion(self, graph):
# orders → customers is many_to_one
fwd = next(e for e in graph.adjacency["orders"] if e.to_source == "customers")
assert fwd.relationship == "many_to_one"
# customers → orders is one_to_many (reverse)
rev = next(e for e in graph.adjacency["customers"] if e.to_source == "orders")
assert rev.relationship == "one_to_many"
def test_on_parsing(self, graph):
fwd = next(e for e in graph.adjacency["orders"] if e.to_source == "customers")
assert fwd.from_column == "customer_id"
assert fwd.to_column == "id"
class TestFindPath:
def test_direct_join(self, graph):
path = graph.find_path("orders", "customers")
assert path is not None
assert len(path.edges) == 1
assert path.edges[0].from_source == "orders"
assert path.edges[0].to_source == "customers"
assert not path.has_one_to_many
def test_two_hop_m2o(self, graph):
# orders → customers → regions (all m2o)
path = graph.find_path("orders", "regions")
assert path is not None
assert len(path.edges) == 2
assert path.source_names == ["orders", "customers", "regions"]
assert not path.has_one_to_many
def test_reverse_path_flagged(self, graph):
# regions → customers (o2m) → orders (o2m)
path = graph.find_path("regions", "orders")
assert path is not None
assert len(path.edges) == 2
assert path.has_one_to_many
def test_through_bridge(self, graph):
# orders → order_items is reverse (o2m), order_items → products is m2o
# But shortest may be: orders ← order_items → products
path = graph.find_path("orders", "products")
assert path is not None
assert "order_items" in path.source_names
def test_churn_risk_to_regions(self, graph):
path = graph.find_path("churn_risk", "regions")
assert path is not None
assert "customers" in path.source_names
def test_same_source(self, graph):
path = graph.find_path("orders", "orders")
assert path is not None
assert len(path.edges) == 0
assert not path.has_one_to_many
def test_source_names_property(self, graph):
path = graph.find_path("orders", "regions")
assert path.source_names == ["orders", "customers", "regions"]
def test_empty_path_source_names(self, graph):
path = graph.find_path("orders", "orders")
assert path.source_names == []
class TestResolveJoinTree:
def test_single_source(self, graph):
tree = graph.resolve_join_tree({"orders"})
assert tree.sources == {"orders"}
assert tree.edges == []
def test_two_sources(self, graph):
tree = graph.resolve_join_tree({"orders", "customers"})
assert "orders" in tree.sources
assert "customers" in tree.sources
assert len(tree.edges) >= 1
def test_three_sources_via_customers(self, graph):
tree = graph.resolve_join_tree({"churn_risk", "regions", "orders"})
assert "customers" in tree.sources # intermediate node added
assert len(tree.sources) >= 4
def test_disconnected_raises(self):
from semantic_layer.models import SourceDefinition, SourceColumn
src_a = SourceDefinition(
name="a",
table="t",
grain=["id"],
columns=[SourceColumn(name="id", type="number")],
)
src_b = SourceDefinition(
name="b",
table="t2",
grain=["id"],
columns=[SourceColumn(name="id", type="number")],
)
g = JoinGraph({"a": src_a, "b": src_b})
g.build()
with pytest.raises(ValueError, match="No join path"):
g.resolve_join_tree({"a", "b"})
class TestOneToOneRelationship:
def test_one_to_one_no_fan_out(self):
"""one_to_one joins should not flag has_one_to_many."""
from semantic_layer.models import (
SourceDefinition,
SourceColumn,
JoinDeclaration,
)
users = SourceDefinition(
name="users",
table="t",
grain=["id"],
columns=[SourceColumn(name="id", type="number")],
)
profiles = SourceDefinition(
name="profiles",
table="t2",
grain=["user_id"],
columns=[SourceColumn(name="user_id", type="number")],
joins=[
JoinDeclaration(
to="users", on="user_id = users.id", relationship="one_to_one"
)
],
)
g = JoinGraph({"users": users, "profiles": profiles})
g.build()
path = g.find_path("profiles", "users")
assert path is not None
assert not path.has_one_to_many
# Reverse should also be one_to_one
rev_path = g.find_path("users", "profiles")
assert rev_path is not None
assert not rev_path.has_one_to_many
def test_one_to_one_inverse(self):
"""one_to_one inverted should stay one_to_one."""
from semantic_layer.models import (
SourceDefinition,
SourceColumn,
JoinDeclaration,
)
a = SourceDefinition(
name="a",
table="t",
grain=["id"],
columns=[SourceColumn(name="id", type="number")],
)
b = SourceDefinition(
name="b",
table="t2",
grain=["a_id"],
columns=[SourceColumn(name="a_id", type="number")],
joins=[
JoinDeclaration(to="a", on="a_id = a.id", relationship="one_to_one")
],
)
g = JoinGraph({"a": a, "b": b})
g.build()
fwd = next(e for e in g.adjacency["b"] if e.to_source == "a")
assert fwd.relationship == "one_to_one"
rev = next(e for e in g.adjacency["a"] if e.to_source == "b")
assert rev.relationship == "one_to_one"
class TestMultipleJoinsFromSource:
def test_order_items_two_joins(self, graph):
"""order_items has joins to both orders and products."""
oi_edges = graph.adjacency["order_items"]
targets = {e.to_source for e in oi_edges}
assert "orders" in targets
assert "products" in targets
def test_path_through_bridge(self, graph):
"""Can find path from orders to products through order_items."""
path = graph.find_path("orders", "products")
assert path is not None
assert "order_items" in path.source_names
class TestResolveJoinTreeRoot:
def test_root_is_respected(self, graph):
"""When root is specified, it should be the anchor of the tree."""
tree = graph.resolve_join_tree({"orders", "regions"}, root="orders")
assert "orders" in tree.sources
assert "regions" in tree.sources
assert "customers" in tree.sources # intermediate
def test_root_not_in_sources_uses_default(self, graph):
"""When root is not in source_names, falls back to sorted order."""
tree = graph.resolve_join_tree({"orders", "customers"}, root="nonexistent")
assert "orders" in tree.sources
assert "customers" in tree.sources
class TestFindComponents:
def test_connected_graph(self, graph):
components = graph.find_components()
assert len(components) == 1
assert components[0] == set(graph.adjacency.keys())
def test_disconnected_graph(self):
from semantic_layer.models import SourceDefinition, SourceColumn
src_a = SourceDefinition(
name="a",
table="t",
grain=["id"],
columns=[SourceColumn(name="id", type="number")],
)
src_b = SourceDefinition(
name="b",
table="t2",
grain=["id"],
columns=[SourceColumn(name="id", type="number")],
)
g = JoinGraph({"a": src_a, "b": src_b})
g.build()
components = g.find_components()
assert len(components) == 2
assert {frozenset(c) for c in components} == {
frozenset({"a"}),
frozenset({"b"}),
}
# ── From test_edge_cases.py ──────────────────────────────────────────
class TestGraphEdgeCases:
def test_self_referencing_join(self):
emp_with_join = SourceDefinition(
name="employees",
table="t",
grain=["id"],
columns=[
SourceColumn(name="id", type="number"),
SourceColumn(name="manager_id", type="number"),
SourceColumn(name="salary", type="number"),
],
joins=[
JoinDeclaration(
to="employees",
on="manager_id = employees.id",
relationship="many_to_one",
)
],
)
sources = {"employees": emp_with_join}
graph = JoinGraph(sources)
graph.build()
path = graph.find_path("employees", "employees")
assert path is not None
assert len(path.edges) == 0
def test_no_sources(self):
graph = JoinGraph({})
graph.build()
components = graph.find_components()
assert components == []
def test_single_source_no_joins(self):
src = SourceDefinition(
name="a",
table="t",
grain=["id"],
columns=[SourceColumn(name="id", type="number")],
)
graph = JoinGraph({"a": src})
graph.build()
assert graph.find_path("a", "a") is not None
assert graph.find_path("a", "nonexistent") is None
def test_two_disconnected_sources(self):
a = SourceDefinition(
name="a",
table="t",
grain=["id"],
columns=[SourceColumn(name="id", type="number")],
)
b = SourceDefinition(
name="b",
table="t2",
grain=["id"],
columns=[SourceColumn(name="id", type="number")],
)
graph = JoinGraph({"a": a, "b": b})
graph.build()
assert graph.find_path("a", "b") is None
def test_on_clause_with_spaces(self):
g = JoinGraph({})
result = g._parse_on(" customer_id = customers.id ", "customers")
assert result == ("customer_id", "id")
def test_on_clause_without_prefix(self):
g = JoinGraph({})
result = g._parse_on("customer_id = id", "customers")
assert result == ("customer_id", "id")
def test_on_clause_invalid(self):
g = JoinGraph({})
with pytest.raises(ValueError, match="Invalid join condition"):
g._parse_on("customer_id", "customers")
def test_on_clause_three_parts(self):
g = JoinGraph({})
with pytest.raises(ValueError, match="Invalid join condition"):
g._parse_on("a = b = c", "target")
def test_composite_join_key(self):
"""Composite join: 'a = t.x AND b = t.y' → comma-separated columns."""
g = JoinGraph({})
from_col, to_col = g._parse_on(
"product_id = inventory.product_id AND warehouse_id = inventory.warehouse_id",
"inventory",
)
assert from_col == "product_id,warehouse_id"
assert to_col == "product_id,warehouse_id"
def test_composite_join_key_with_source_prefix(self):
"""Composite join with source prefix on left side."""
g = JoinGraph({})
from_col, to_col = g._parse_on(
"items.product_id = inventory.product_id AND items.warehouse_id = inventory.warehouse_id",
"inventory",
)
assert from_col == "product_id,warehouse_id"
assert to_col == "product_id,warehouse_id"
def test_composite_join_generates_correct_sql(self):
"""End-to-end: composite join keys produce multi-condition ON clause."""
items = SourceDefinition(
name="items",
table="public.items",
grain=["order_id", "product_id"],
columns=[
SourceColumn(name="order_id", type="number"),
SourceColumn(name="product_id", type="number"),
SourceColumn(name="warehouse_id", type="number"),
SourceColumn(name="qty", type="number"),
],
joins=[
JoinDeclaration(
to="inventory",
on="product_id = inventory.product_id AND warehouse_id = inventory.warehouse_id",
relationship="many_to_one",
)
],
)
inv = SourceDefinition(
name="inventory",
table="public.inventory",
grain=["product_id", "warehouse_id"],
columns=[
SourceColumn(name="product_id", type="number"),
SourceColumn(name="warehouse_id", type="number"),
SourceColumn(name="stock", type="number"),
],
)
graph = JoinGraph({"items": items, "inventory": inv})
graph.build()
path = graph.find_path("items", "inventory")
assert path is not None
assert len(path.edges) == 1
assert path.edges[0].from_column == "product_id,warehouse_id"
assert path.edges[0].to_column == "product_id,warehouse_id"
def test_resolve_join_tree_empty_set(self):
graph = JoinGraph({})
graph.build()
tree = graph.resolve_join_tree(set())
assert tree.sources == set()
assert tree.edges == []
# ── From test_brainstorm_cases.py ────────────────────────────────────
class TestJoinTreeReusesIntermediates:
def test_resolve_join_tree_reuses_intermediate_sources(self):
a = SourceDefinition(
name="a",
table="public.a",
grain=["id"],
columns=[SourceColumn(name="id", type="number")],
joins=[
JoinDeclaration(to="z", on="z_id = z.id", relationship="many_to_one")
],
)
z = SourceDefinition(
name="z",
table="public.z",
grain=["id"],
columns=[
SourceColumn(name="id", type="number"),
SourceColumn(name="m_id", type="number"),
],
joins=[
JoinDeclaration(to="m", on="m_id = m.id", relationship="many_to_one")
],
)
m = SourceDefinition(
name="m",
table="public.m",
grain=["id"],
columns=[SourceColumn(name="id", type="number")],
)
graph = JoinGraph({"a": a, "z": z, "m": m})
graph.build()
tree = graph.resolve_join_tree({"a", "m", "z"}, root="a")
assert tree.sources == {"a", "z", "m"}
assert len(tree.edges) == 2
assert {(edge.from_source, edge.to_source) for edge in tree.edges} == {
("a", "z"),
("z", "m"),
}
class TestDijkstraEdgeWeightPreference:
"""LIMIT 2: Dijkstra prefers safe (m2o) paths over one_to_many paths."""
def test_dijkstra_prefers_safe_path(self):
"""1-hop o2m path vs 2-hop all-m2o path: Dijkstra should pick the 2-hop m2o path."""
# A --o2m--> C (direct, 1-hop, but unsafe)
# A --m2o--> B --m2o--> C (2-hop, all safe)
a = SourceDefinition(
name="a",
table="t",
grain=["id"],
columns=[SourceColumn(name="id", type="number")],
joins=[
JoinDeclaration(to="c", on="c_id = c.id", relationship="one_to_many"),
JoinDeclaration(to="b", on="b_id = b.id", relationship="many_to_one"),
],
)
b = SourceDefinition(
name="b",
table="t2",
grain=["id"],
columns=[
SourceColumn(name="id", type="number"),
SourceColumn(name="c_id", type="number"),
],
joins=[
JoinDeclaration(to="c", on="c_id = c.id", relationship="many_to_one"),
],
)
c = SourceDefinition(
name="c",
table="t3",
grain=["id"],
columns=[
SourceColumn(name="id", type="number"),
SourceColumn(name="c_id", type="number"),
],
)
g = JoinGraph({"a": a, "b": b, "c": c})
g.build()
path = g.find_path("a", "c")
assert path is not None
# Should pick the 2-hop safe path (a -> b -> c) over the 1-hop o2m (a -> c)
assert len(path.edges) == 2
assert path.source_names == ["a", "b", "c"]
assert not path.has_one_to_many
def test_dijkstra_uses_unsafe_when_only_option(self):
"""When only an o2m path exists, it should still be returned."""
a = SourceDefinition(
name="a",
table="t",
grain=["id"],
columns=[SourceColumn(name="id", type="number")],
joins=[
JoinDeclaration(to="b", on="b_id = b.id", relationship="one_to_many"),
],
)
b = SourceDefinition(
name="b",
table="t2",
grain=["id"],
columns=[SourceColumn(name="id", type="number")],
)
g = JoinGraph({"a": a, "b": b})
g.build()
path = g.find_path("a", "b")
assert path is not None
assert len(path.edges) == 1
assert path.has_one_to_many
class TestAmbiguousPathDetection:
"""Tests for 12.1 fix: diamond graph ambiguity detection."""
@staticmethod
def _diamond_sources():
"""Diamond: A →(m2o) B →(m2o) D, A →(m2o) C →(m2o) D. Two equal-cost paths."""
return {
"a": SourceDefinition(
name="a",
table="t_a",
grain=["id"],
columns=[SourceColumn(name="id", type="number")],
joins=[
JoinDeclaration(
to="b", on="b_id = b.id", relationship="many_to_one"
),
JoinDeclaration(
to="c", on="c_id = c.id", relationship="many_to_one"
),
],
),
"b": SourceDefinition(
name="b",
table="t_b",
grain=["id"],
columns=[SourceColumn(name="id", type="number")],
joins=[
JoinDeclaration(
to="d", on="d_id = d.id", relationship="many_to_one"
)
],
),
"c": SourceDefinition(
name="c",
table="t_c",
grain=["id"],
columns=[SourceColumn(name="id", type="number")],
joins=[
JoinDeclaration(
to="d", on="d_id = d.id", relationship="many_to_one"
)
],
),
"d": SourceDefinition(
name="d",
table="t_d",
grain=["id"],
columns=[SourceColumn(name="id", type="number")],
),
}
def test_diamond_graph_is_ambiguous(self):
g = JoinGraph(self._diamond_sources())
g.build()
path = g.find_path("a", "d")
assert path is not None
assert path.is_ambiguous is True
def test_linear_graph_not_ambiguous(self):
"""A → B → C: single path, no ambiguity."""
sources = {
"a": SourceDefinition(
name="a",
table="t_a",
grain=["id"],
columns=[SourceColumn(name="id", type="number")],
joins=[
JoinDeclaration(
to="b", on="b_id = b.id", relationship="many_to_one"
)
],
),
"b": SourceDefinition(
name="b",
table="t_b",
grain=["id"],
columns=[SourceColumn(name="id", type="number")],
joins=[
JoinDeclaration(
to="c", on="c_id = c.id", relationship="many_to_one"
)
],
),
"c": SourceDefinition(
name="c",
table="t_c",
grain=["id"],
columns=[SourceColumn(name="id", type="number")],
),
}
g = JoinGraph(sources)
g.build()
path = g.find_path("a", "c")
assert path is not None
assert path.is_ambiguous is False
def test_different_cost_paths_not_ambiguous(self):
"""A →(m2o) B →(m2o) D and A →(o2m) C →(m2o) D: costs differ."""
sources = {
"a": SourceDefinition(
name="a",
table="t_a",
grain=["id"],
columns=[SourceColumn(name="id", type="number")],
joins=[
JoinDeclaration(
to="b", on="b_id = b.id", relationship="many_to_one"
),
JoinDeclaration(
to="c", on="id = c.a_id", relationship="one_to_many"
),
],
),
"b": SourceDefinition(
name="b",
table="t_b",
grain=["id"],
columns=[SourceColumn(name="id", type="number")],
joins=[
JoinDeclaration(
to="d", on="d_id = d.id", relationship="many_to_one"
)
],
),
"c": SourceDefinition(
name="c",
table="t_c",
grain=["id"],
columns=[
SourceColumn(name="id", type="number"),
SourceColumn(name="a_id", type="number"),
],
joins=[
JoinDeclaration(
to="d", on="d_id = d.id", relationship="many_to_one"
)
],
),
"d": SourceDefinition(
name="d",
table="t_d",
grain=["id"],
columns=[SourceColumn(name="id", type="number")],
),
}
g = JoinGraph(sources)
g.build()
path = g.find_path("a", "d")
assert path is not None
# Safe path (cost 2) vs unsafe path (cost 11) — not ambiguous
assert path.is_ambiguous is False
assert path.has_one_to_many is False
def test_ambiguous_path_warning_in_resolve_join_tree(self, caplog):
"""resolve_join_tree logs a warning for ambiguous paths."""
import logging
g = JoinGraph(self._diamond_sources())
g.build()
with caplog.at_level(logging.WARNING, logger="semantic_layer.graph"):
g.resolve_join_tree({"a", "d"}, root="a")
assert any("Ambiguous join path" in r.message for r in caplog.records)
def test_bigquery_native_on_clause_extracts_column_pair():
"""Join on: with BigQuery-specific casts must parse and yield column pairs."""
orders = SourceDefinition(
name="orders",
table="orders",
grain=["id"],
columns=[
SourceColumn(name="id", type="number"),
SourceColumn(name="user_id", type="number"),
],
joins=[
JoinDeclaration(
to="users",
on="user_id = SAFE_CAST(users.id AS INT64)",
relationship="many_to_one",
)
],
)
users = SourceDefinition(
name="users",
table="users",
grain=["id"],
columns=[SourceColumn(name="id", type="number")],
)
graph = JoinGraph({"orders": orders, "users": users}, dialect="bigquery")
graph.build()
# The graph must have recorded the compatibility edge
orders_edges = graph.adjacency.get("orders", [])
assert any(e.to_source == "users" for e in orders_edges), (
f"orders → users edge missing after BigQuery-native on: parse:\n{orders_edges}"
)
def test_joingraph_dialect_defaults_to_postgres():
"""Default keeps existing test ergonomics unchanged."""
g = JoinGraph({})
assert g.dialect == "postgres"

View file

@ -0,0 +1,171 @@
import pytest
from pathlib import Path
import tempfile
import yaml
from semantic_layer.loader import SourceLoader
from semantic_layer.models import SourceDefinition
SOURCES_DIR = Path(__file__).parent.parent / "sources" / "ecommerce"
class TestSourceLoader:
def test_load_all_ecommerce(self, ecommerce_sources):
assert len(ecommerce_sources) == 6
assert set(ecommerce_sources.keys()) == {
"customers",
"orders",
"regions",
"products",
"order_items",
"churn_risk",
}
def test_orders_source(self, ecommerce_sources):
orders = ecommerce_sources["orders"]
assert orders.is_table_source
assert orders.table == "public.orders"
assert orders.grain == ["id"]
assert len(orders.columns) == 6
assert len(orders.measures) == 5
assert len(orders.joins) == 1
assert orders.joins[0].to == "customers"
assert orders.joins[0].relationship == "many_to_one"
def test_churn_risk_sql_source(self, ecommerce_sources):
churn = ecommerce_sources["churn_risk"]
assert churn.is_sql_source
assert churn.sql is not None
assert "calculate_churn_score" in churn.sql
assert churn.grain == ["customer_id"]
assert len(churn.measures) == 1
assert churn.measures[0].name == "avg_risk"
def test_regions_no_joins(self, ecommerce_sources):
regions = ecommerce_sources["regions"]
assert regions.joins == []
assert regions.measures == []
def test_order_items_bridge(self, ecommerce_sources):
oi = ecommerce_sources["order_items"]
assert len(oi.joins) == 2
targets = {j.to for j in oi.joins}
assert targets == {"orders", "products"}
def test_revenue_measure_has_filter(self, ecommerce_sources):
orders = ecommerce_sources["orders"]
revenue = next(m for m in orders.measures if m.name == "revenue")
assert revenue.filter == "status != 'refunded'"
assert revenue.expr == "sum(amount)"
def test_load_single_file(self):
loader = SourceLoader(SOURCES_DIR)
src = loader.load_file(SOURCES_DIR / "regions.yaml")
assert src.name == "regions"
assert isinstance(src, SourceDefinition)
def test_invalid_join_target(self):
with tempfile.TemporaryDirectory() as tmpdir:
data = {
"name": "bad_source",
"table": "t",
"grain": ["id"],
"columns": [{"name": "id", "type": "number"}],
"joins": [
{
"to": "nonexistent",
"on": "id = nonexistent.id",
"relationship": "many_to_one",
}
],
}
path = Path(tmpdir) / "bad.yaml"
with open(path, "w") as f:
yaml.dump(data, f)
loader = SourceLoader(tmpdir)
with pytest.raises(ValueError, match="nonexistent"):
loader.load_all()
def test_duplicate_source_name(self):
with tempfile.TemporaryDirectory() as tmpdir:
data = {
"name": "dupe",
"table": "t",
"grain": ["id"],
"columns": [{"name": "id", "type": "number"}],
}
for fname in ["a.yaml", "b.yaml"]:
with open(Path(tmpdir) / fname, "w") as f:
yaml.dump(data, f)
loader = SourceLoader(tmpdir)
with pytest.raises(ValueError, match="Duplicate source name"):
loader.load_all()
def test_source_description_loads(self, ecommerce_sources):
churn = ecommerce_sources["churn_risk"]
assert churn.description is not None
assert "churn" in churn.description.lower()
def test_column_role_loads(self, ecommerce_sources):
orders = ecommerce_sources["orders"]
time_col = next(c for c in orders.columns if c.name == "created_at")
assert time_col.role == "time"
def test_source_without_description(self, ecommerce_sources):
regions = ecommerce_sources["regions"]
assert regions.description is None
# ── From test_edge_cases.py ──────────────────────────────────────────
class TestLoaderEdgeCases:
def test_empty_directory(self):
with tempfile.TemporaryDirectory() as tmpdir:
loader = SourceLoader(tmpdir)
sources = loader.load_all()
assert sources == {}
def test_non_yaml_files_ignored(self):
with tempfile.TemporaryDirectory() as tmpdir:
(Path(tmpdir) / "readme.txt").write_text("not a yaml file")
loader = SourceLoader(tmpdir)
sources = loader.load_all()
assert sources == {}
def test_yaml_with_extra_fields(self):
with tempfile.TemporaryDirectory() as tmpdir:
data = {
"name": "test",
"table": "t",
"grain": ["id"],
"columns": [{"name": "id", "type": "number"}],
"unknown_field": "should be rejected",
}
with open(Path(tmpdir) / "test.yaml", "w") as f:
yaml.dump(data, f)
loader = SourceLoader(tmpdir)
try:
sources = loader.load_all()
assert "test" in sources
except Exception:
pass
def test_subdirectory_sources(self):
with tempfile.TemporaryDirectory() as tmpdir:
subdir = Path(tmpdir) / "sub"
subdir.mkdir()
data = {
"name": "nested",
"table": "t",
"grain": ["id"],
"columns": [{"name": "id", "type": "number"}],
}
with open(subdir / "nested.yaml", "w") as f:
yaml.dump(data, f)
loader = SourceLoader(tmpdir)
sources = loader.load_all()
assert "nested" in sources

View file

@ -0,0 +1,619 @@
"""Tests for manifest models, projection, overlay validation, and two-tier loading."""
from __future__ import annotations
from pathlib import Path
import pytest
import yaml
from semantic_layer.loader import SourceLoader
from semantic_layer.manifest import (
ManifestColumn,
ManifestEntry,
ManifestJoin,
map_column_type,
project_manifest_entry,
validate_overlay,
)
from semantic_layer.models import ColumnRole
# ── Type Mapping Tests ──────────────────────────────────────────────
class TestMapColumnType:
def test_map_column_type_numbers(self):
number_types = [
"integer",
"bigint",
"smallint",
"numeric",
"decimal",
"float",
"double",
"real",
"int",
"int2",
"int4",
"int8",
"float4",
"float8",
"double precision",
"number",
"tinyint",
"mediumint",
]
for db_type in number_types:
assert map_column_type(db_type) == "number", (
f"{db_type} should map to 'number'"
)
def test_map_column_type_time(self):
time_types = [
"timestamp",
"timestamptz",
"timestamp with time zone",
"timestamp without time zone",
"TIMESTAMP_NTZ",
"TIMESTAMP_LTZ",
"TIMESTAMP_TZ",
"datetime",
"date",
"time",
"timetz",
]
for db_type in time_types:
assert map_column_type(db_type) == "time", f"{db_type} should map to 'time'"
def test_map_column_type_boolean(self):
for db_type in ["boolean", "bool"]:
assert map_column_type(db_type) == "boolean", (
f"{db_type} should map to 'boolean'"
)
def test_map_column_type_string_fallback(self):
string_types = ["varchar", "text", "char", "unknown", "jsonb", "xml"]
for db_type in string_types:
assert map_column_type(db_type) == "string", (
f"{db_type} should map to 'string'"
)
def test_map_column_type_strips_precision(self):
assert map_column_type("numeric(10,2)") == "number"
assert map_column_type("varchar(255)") == "string"
assert map_column_type("decimal(18,4)") == "number"
assert map_column_type("timestamp(6)") == "time"
assert map_column_type("char(1)") == "string"
# ── Manifest Projection Tests ──────────────────────────────────────
class TestProjectManifestEntry:
@pytest.fixture()
def orders_entry(self) -> ManifestEntry:
return ManifestEntry(
table="public.orders",
description="Customer orders",
columns=[
ManifestColumn(name="id", type="integer", pk=True),
ManifestColumn(name="customer_id", type="integer"),
ManifestColumn(name="total", type="numeric"),
ManifestColumn(name="status", type="varchar"),
ManifestColumn(name="created_at", type="timestamp"),
],
joins=[
ManifestJoin(
to="customers",
on="orders.customer_id = customers.id",
relationship="many_to_one",
source="formal",
),
],
)
def test_project_manifest_entry_basic(self, orders_entry: ManifestEntry):
src = project_manifest_entry("orders", orders_entry)
assert src.name == "orders"
assert src.table == "public.orders"
assert src.description == "Customer orders"
assert len(src.columns) == 5
assert src.measures == []
col_names = [c.name for c in src.columns]
assert col_names == ["id", "customer_id", "total", "status", "created_at"]
def test_project_manifest_entry_type_mapping(self, orders_entry: ManifestEntry):
src = project_manifest_entry("orders", orders_entry)
col_types = {c.name: c.type for c in src.columns}
assert col_types["id"] == "number"
assert col_types["customer_id"] == "number"
assert col_types["total"] == "number"
assert col_types["status"] == "string"
assert col_types["created_at"] == "time"
def test_project_manifest_entry_grain_from_pk(self, orders_entry: ManifestEntry):
src = project_manifest_entry("orders", orders_entry)
assert src.grain == ["id"]
def test_project_manifest_entry_grain_all_columns_no_pk(self):
entry = ManifestEntry(
table="public.events",
columns=[
ManifestColumn(name="user_id", type="integer"),
ManifestColumn(name="event_type", type="varchar"),
ManifestColumn(name="ts", type="timestamp"),
],
)
src = project_manifest_entry("events", entry)
assert src.grain == ["user_id", "event_type", "ts"]
def test_project_manifest_entry_joins_stripped(self, orders_entry: ManifestEntry):
src = project_manifest_entry("orders", orders_entry)
assert len(src.joins) == 1
join = src.joins[0]
assert join.to == "customers"
assert join.on == "orders.customer_id = customers.id"
assert join.relationship == "many_to_one"
assert not hasattr(join, "source") or getattr(join, "source", None) is None
def test_project_manifest_entry_time_role(self, orders_entry: ManifestEntry):
src = project_manifest_entry("orders", orders_entry)
time_cols = [c for c in src.columns if c.role == ColumnRole.TIME]
assert len(time_cols) == 1
assert time_cols[0].name == "created_at"
non_time = [c for c in src.columns if c.role == ColumnRole.DEFAULT]
assert len(non_time) == 4
def test_project_manifest_entry_preserves_dbt_metadata(self):
entry = ManifestEntry(
table="public.orders",
columns=[
ManifestColumn(
name="status",
type="varchar",
constraints={"dbt": {"not_null": True}},
enum_values={"dbt": ["placed", "shipped"]},
tests={"dbt": [{"name": "accepted_values", "package": "dbt"}]},
)
],
tags={"dbt": ["mart"]},
freshness={"dbt": {"loaded_at_field": "updated_at"}},
)
src = project_manifest_entry("orders", entry)
assert src.columns[0].constraints is not None
assert src.columns[0].constraints["dbt"].not_null is True
assert src.columns[0].enum_values == {"dbt": ["placed", "shipped"]}
assert src.columns[0].tests is not None
assert src.columns[0].tests.model_dump(mode="python", exclude_none=True) == {
"dbt": [{"name": "accepted_values", "package": "dbt"}]
}
assert src.tags == {"dbt": ["mart"]}
assert src.freshness is not None
assert src.freshness["dbt"].loaded_at_field == "updated_at"
# ── Overlay Validation Tests ───────────────────────────────────────
class TestValidateOverlay:
def test_validate_overlay_valid(self):
data = {
"name": "orders",
"description": "Revenue-bearing orders",
"grain": ["id"],
"measures": [{"name": "revenue", "expr": "sum(total)"}],
"columns": [
{"name": "is_high_value", "expr": "total > 1000", "type": "boolean"}
],
"exclude_columns": ["status"],
}
errors = validate_overlay(data)
assert errors == []
def test_validate_overlay_rejects_table(self):
data = {"name": "orders", "table": "public.orders"}
errors = validate_overlay(data)
assert len(errors) == 1
assert "table" in errors[0].lower()
def test_validate_overlay_rejects_sql(self):
data = {"name": "orders", "sql": "SELECT * FROM orders"}
errors = validate_overlay(data)
assert len(errors) == 1
assert "sql" in errors[0].lower()
def test_validate_overlay_rejects_type_without_expr(self):
data = {
"name": "orders",
"columns": [{"name": "status", "type": "string"}],
}
errors = validate_overlay(data)
assert len(errors) == 1
assert "type" in errors[0].lower()
assert "expr" in errors[0].lower()
def test_validate_overlay_allows_type_with_expr(self):
data = {
"name": "orders",
"columns": [{"name": "is_big", "type": "boolean", "expr": "total > 1000"}],
}
errors = validate_overlay(data)
assert errors == []
# ── Two-Tier Loading Tests ─────────────────────────────────────────
def _write_yaml(path: Path, data: dict | list) -> None:
path.parent.mkdir(parents=True, exist_ok=True)
with open(path, "w") as f:
yaml.dump(data, f, default_flow_style=False)
def _manifest_tables() -> dict:
"""Manifest shard with orders + customers tables."""
return {
"tables": {
"orders": {
"table": "public.orders",
"description": "Customer orders",
"columns": [
{"name": "id", "type": "integer", "pk": True},
{"name": "customer_id", "type": "integer"},
{"name": "total", "type": "numeric"},
{"name": "status", "type": "varchar"},
{"name": "created_at", "type": "timestamp"},
],
"joins": [
{
"to": "customers",
"on": "orders.customer_id = customers.id",
"relationship": "many_to_one",
"source": "formal",
},
],
},
"customers": {
"table": "public.customers",
"description": "Customer accounts",
"columns": [
{"name": "id", "type": "integer", "pk": True},
{"name": "name", "type": "varchar"},
],
"joins": [
{
"to": "orders",
"on": "customers.id = orders.customer_id",
"relationship": "one_to_many",
"source": "formal",
},
],
},
},
}
class TestTwoTierLoading:
def test_load_manifest_shard(self, tmp_path: Path):
schema_dir = tmp_path / "_schema"
_write_yaml(schema_dir / "public.yaml", _manifest_tables())
loader = SourceLoader(tmp_path)
sources = loader.load_all()
assert "orders" in sources
assert "customers" in sources
assert sources["orders"].table == "public.orders"
assert sources["orders"].grain == ["id"]
assert sources["customers"].table == "public.customers"
def test_load_standalone_source(self, tmp_path: Path):
standalone = {
"name": "regions",
"table": "public.regions",
"grain": ["id"],
"columns": [
{"name": "id", "type": "number"},
{"name": "name", "type": "string"},
],
}
_write_yaml(tmp_path / "regions.yaml", standalone)
loader = SourceLoader(tmp_path)
sources = loader.load_all()
assert "regions" in sources
assert sources["regions"].table == "public.regions"
assert sources["regions"].is_table_source
def test_overlay_descriptions_do_not_promote_base_description_to_user_source(
self, tmp_path: Path
):
standalone = {
"name": "regions",
"description": "Standalone description",
"table": "public.regions",
"grain": ["id"],
"columns": [
{"name": "id", "type": "number"},
],
}
_write_yaml(tmp_path / "a_regions.yaml", standalone)
overlay = {"name": "regions", "descriptions": {"dbt": "dbt description"}}
_write_yaml(tmp_path / "z_regions_overlay.yaml", overlay)
loader = SourceLoader(tmp_path)
sources = loader.load_all()
assert sources["regions"].description == "dbt description"
def test_load_sql_source(self, tmp_path: Path):
sql_source = {
"name": "active_users",
"sql": "SELECT id, email FROM users WHERE active = true",
"grain": ["id"],
"columns": [
{"name": "id", "type": "number"},
{"name": "email", "type": "string"},
],
}
_write_yaml(tmp_path / "active_users.yaml", sql_source)
loader = SourceLoader(tmp_path)
sources = loader.load_all()
assert "active_users" in sources
assert sources["active_users"].is_sql_source
assert "SELECT" in sources["active_users"].sql
def test_load_overlay_composition(self, tmp_path: Path):
schema_dir = tmp_path / "_schema"
_write_yaml(schema_dir / "public.yaml", _manifest_tables())
overlay = {
"name": "orders",
"description": "Revenue-bearing orders",
"grain": ["id"],
"measures": [{"name": "revenue", "expr": "sum(total)"}],
}
_write_yaml(tmp_path / "orders.yaml", overlay)
# Customers overlay (empty, just name match) to avoid cross-ref error
_write_yaml(tmp_path / "customers.yaml", {"name": "customers"})
loader = SourceLoader(tmp_path)
sources = loader.load_all()
orders = sources["orders"]
assert orders.table == "public.orders"
assert orders.description == "Revenue-bearing orders"
assert len(orders.measures) == 1
assert orders.measures[0].name == "revenue"
def test_overlay_description_override(self, tmp_path: Path):
schema_dir = tmp_path / "_schema"
_write_yaml(schema_dir / "public.yaml", _manifest_tables())
overlay = {"name": "orders", "description": "Overridden description"}
_write_yaml(tmp_path / "orders.yaml", overlay)
_write_yaml(tmp_path / "customers.yaml", {"name": "customers"})
loader = SourceLoader(tmp_path)
sources = loader.load_all()
assert sources["orders"].description == "Overridden description"
def test_overlay_descriptions_map_preserves_higher_priority_manifest_description(
self, tmp_path: Path
):
schema_dir = tmp_path / "_schema"
_write_yaml(schema_dir / "public.yaml", _manifest_tables())
overlay = {
"name": "orders",
"descriptions": {
"db": "DB description",
"dbt": "dbt description",
},
}
_write_yaml(tmp_path / "orders.yaml", overlay)
_write_yaml(tmp_path / "customers.yaml", {"name": "customers"})
loader = SourceLoader(tmp_path)
sources = loader.load_all()
assert sources["orders"].description == "Customer orders"
def test_overlay_descriptions_map_overrides_lower_priority_db_description(
self, tmp_path: Path
):
schema_dir = tmp_path / "_schema"
_write_yaml(
schema_dir / "public.yaml",
{
"tables": {
"orders": {
"table": "public.orders",
"descriptions": {"db": "DB description"},
"columns": [{"name": "id", "type": "integer", "pk": True}],
},
"customers": {
"table": "public.customers",
"columns": [{"name": "id", "type": "integer", "pk": True}],
},
}
},
)
overlay = {
"name": "orders",
"descriptions": {
"dbt": "dbt description",
},
}
_write_yaml(tmp_path / "orders.yaml", overlay)
_write_yaml(tmp_path / "customers.yaml", {"name": "customers"})
loader = SourceLoader(tmp_path)
sources = loader.load_all()
assert sources["orders"].description == "dbt description"
def test_overlay_exclude_columns(self, tmp_path: Path):
schema_dir = tmp_path / "_schema"
_write_yaml(schema_dir / "public.yaml", _manifest_tables())
overlay = {"name": "orders", "exclude_columns": ["status"]}
_write_yaml(tmp_path / "orders.yaml", overlay)
_write_yaml(tmp_path / "customers.yaml", {"name": "customers"})
loader = SourceLoader(tmp_path)
sources = loader.load_all()
col_names = [c.name for c in sources["orders"].columns]
assert "status" not in col_names
assert "id" in col_names
assert "total" in col_names
def test_overlay_computed_columns_appended(self, tmp_path: Path):
schema_dir = tmp_path / "_schema"
_write_yaml(schema_dir / "public.yaml", _manifest_tables())
overlay = {
"name": "orders",
"columns": [
{"name": "is_high_value", "expr": "total > 1000", "type": "boolean"},
],
}
_write_yaml(tmp_path / "orders.yaml", overlay)
_write_yaml(tmp_path / "customers.yaml", {"name": "customers"})
loader = SourceLoader(tmp_path)
sources = loader.load_all()
col_names = [c.name for c in sources["orders"].columns]
assert "is_high_value" in col_names
# Original columns still present
assert "id" in col_names
assert "total" in col_names
# Computed column is at end
hv = next(c for c in sources["orders"].columns if c.name == "is_high_value")
assert hv.expr == "total > 1000"
assert hv.type == "boolean"
def test_overlay_measures_set(self, tmp_path: Path):
schema_dir = tmp_path / "_schema"
_write_yaml(schema_dir / "public.yaml", _manifest_tables())
overlay = {
"name": "orders",
"measures": [
{"name": "revenue", "expr": "sum(total)"},
{"name": "order_count", "expr": "count(id)"},
],
}
_write_yaml(tmp_path / "orders.yaml", overlay)
_write_yaml(tmp_path / "customers.yaml", {"name": "customers"})
loader = SourceLoader(tmp_path)
sources = loader.load_all()
assert len(sources["orders"].measures) == 2
measure_names = {m.name for m in sources["orders"].measures}
assert measure_names == {"revenue", "order_count"}
def test_overlay_grain_override(self, tmp_path: Path):
schema_dir = tmp_path / "_schema"
_write_yaml(schema_dir / "public.yaml", _manifest_tables())
overlay = {"name": "orders", "grain": ["id", "customer_id"]}
_write_yaml(tmp_path / "orders.yaml", overlay)
_write_yaml(tmp_path / "customers.yaml", {"name": "customers"})
loader = SourceLoader(tmp_path)
sources = loader.load_all()
assert sources["orders"].grain == ["id", "customer_id"]
def test_overlay_join_union_and_dedupe(self, tmp_path: Path):
schema_dir = tmp_path / "_schema"
_write_yaml(schema_dir / "public.yaml", _manifest_tables())
# Add a "regions" standalone so the join target exists
_write_yaml(
tmp_path / "regions.yaml",
{
"name": "regions",
"table": "public.regions",
"grain": ["id"],
"columns": [
{"name": "id", "type": "number"},
{"name": "name", "type": "string"},
],
},
)
overlay = {
"name": "orders",
"joins": [
# Duplicate of manifest join (should be deduped)
{
"to": "customers",
"on": "orders.customer_id = customers.id",
"relationship": "many_to_one",
},
# New join
{
"to": "regions",
"on": "orders.region_id = regions.id",
"relationship": "many_to_one",
},
],
}
_write_yaml(tmp_path / "orders.yaml", overlay)
_write_yaml(tmp_path / "customers.yaml", {"name": "customers"})
loader = SourceLoader(tmp_path)
sources = loader.load_all()
joins = sources["orders"].joins
# Manifest had 1 join to customers, overlay adds 1 new (regions), duplicate deduped
assert len(joins) == 2
join_targets = [j.to for j in joins]
assert "customers" in join_targets
assert "regions" in join_targets
def test_overlay_disable_joins(self, tmp_path: Path):
schema_dir = tmp_path / "_schema"
_write_yaml(schema_dir / "public.yaml", _manifest_tables())
overlay = {
"name": "orders",
"disable_joins": ["orders.customer_id = customers.id"],
}
_write_yaml(tmp_path / "orders.yaml", overlay)
# Customers still needs to exist since the customers manifest entry has
# a join back to orders that is NOT disabled
_write_yaml(tmp_path / "customers.yaml", {"name": "customers"})
loader = SourceLoader(tmp_path)
sources = loader.load_all()
assert len(sources["orders"].joins) == 0
def test_overlay_rejects_invalid(self, tmp_path: Path):
schema_dir = tmp_path / "_schema"
_write_yaml(schema_dir / "public.yaml", _manifest_tables())
# An overlay with a column that has type but no expr is invalid
overlay = {
"name": "orders",
"columns": [{"name": "status", "type": "string"}],
}
_write_yaml(tmp_path / "orders.yaml", overlay)
_write_yaml(tmp_path / "customers.yaml", {"name": "customers"})
loader = SourceLoader(tmp_path)
with pytest.raises(ValueError, match="Invalid overlay"):
loader.load_all()

View file

@ -0,0 +1,373 @@
import pytest
from pydantic import ValidationError
from semantic_layer.models import (
ColumnRole,
ColumnVisibility,
ColumnDbtConstraints,
DefaultTimeDimensionDbt,
FreshnessDbt,
MeasureGroup,
Provenance,
QueryResult,
ResolvedColumn,
ResolvedMeasure,
ResolvedPlan,
SemanticQuery,
SourceColumn,
SourceDefinition,
)
class TestSourceColumn:
def test_defaults(self):
col = SourceColumn(name="id", type="number")
assert col.visibility == ColumnVisibility.PUBLIC
assert col.role == ColumnRole.DEFAULT
assert col.description is None
def test_all_fields(self):
col = SourceColumn(
name="id", type="number", visibility="hidden", role="time", description="PK"
)
assert col.visibility == ColumnVisibility.HIDDEN
assert col.role == ColumnRole.TIME
def test_invalid_type(self):
with pytest.raises(ValidationError):
SourceColumn(name="id", type="integer")
class TestSourceDefinition:
def test_table_source(self):
src = SourceDefinition(
name="orders",
table="public.orders",
grain=["id"],
columns=[SourceColumn(name="id", type="number")],
)
assert src.table == "public.orders"
assert src.sql is None
assert src.is_table_source
assert not src.is_sql_source
def test_sql_source(self):
src = SourceDefinition(
name="churn",
sql="SELECT * FROM x",
grain=["customer_id"],
columns=[SourceColumn(name="customer_id", type="number")],
)
assert src.sql == "SELECT * FROM x"
assert src.table is None
assert src.is_sql_source
assert not src.is_table_source
def test_table_and_sql_mutually_exclusive(self):
with pytest.raises(ValidationError, match="mutually exclusive"):
SourceDefinition(
name="bad",
table="t",
sql="SELECT 1",
grain=["id"],
columns=[SourceColumn(name="id", type="number")],
)
def test_empty_grain_rejected(self):
with pytest.raises(ValidationError, match="grain must be non-empty"):
SourceDefinition(
name="bad",
table="t",
grain=[],
columns=[SourceColumn(name="id", type="number")],
)
def test_measures_and_joins(self):
src = SourceDefinition(
name="orders",
table="public.orders",
grain=["id"],
columns=[SourceColumn(name="id", type="number")],
joins=[
{
"to": "customers",
"on": "cid = customers.id",
"relationship": "many_to_one",
}
],
measures=[{"name": "revenue", "expr": "sum(amount)"}],
)
assert len(src.joins) == 1
assert src.joins[0].to == "customers"
assert len(src.measures) == 1
assert src.measures[0].name == "revenue"
def test_default_time_dimension_optional_and_dump(self):
minimal = SourceDefinition(
name="orders",
table="t",
grain=["id"],
columns=[SourceColumn(name="id", type="number")],
)
assert minimal.default_time_dimension is None
src = SourceDefinition(
name="orders",
table="t",
grain=["id"],
columns=[SourceColumn(name="id", type="number")],
default_time_dimension=DefaultTimeDimensionDbt(dbt="order_date"),
)
dumped = src.model_dump(mode="python", exclude_none=True)
assert dumped["default_time_dimension"] == {"dbt": "order_date"}
round_tripped = SourceDefinition.model_validate(dumped)
assert round_tripped.default_time_dimension == DefaultTimeDimensionDbt(
dbt="order_date"
)
def test_dbt_structural_metadata_round_trips(self):
src = SourceDefinition(
name="orders",
table="public.orders",
grain=["id"],
columns=[
SourceColumn(
name="status",
type="string",
constraints={"dbt": {"not_null": True, "unique": True}},
enum_values={"dbt": ["placed", "shipped"]},
tests={
"dbt": [{"name": "accepted_values", "package": "dbt"}],
"dbt_by_package": {"dbt": ["accepted_values"]},
},
)
],
tags={"dbt": ["mart", "finance"]},
freshness={
"dbt": {
"loaded_at_field": "updated_at",
"raw": {"warn_after": {"count": 12, "period": "hour"}},
}
},
default_time_dimension=DefaultTimeDimensionDbt(dbt="updated_at"),
)
assert src.columns[0].constraints == {
"dbt": ColumnDbtConstraints(not_null=True, unique=True)
}
assert src.columns[0].enum_values == {"dbt": ["placed", "shipped"]}
assert src.columns[0].tests is not None
assert src.columns[0].tests.model_dump(mode="python", exclude_none=True) == {
"dbt": [{"name": "accepted_values", "package": "dbt"}],
"dbt_by_package": {"dbt": ["accepted_values"]},
}
assert src.tags == {"dbt": ["mart", "finance"]}
assert src.freshness == {
"dbt": FreshnessDbt(
loaded_at_field="updated_at",
raw={"warn_after": {"count": 12, "period": "hour"}},
)
}
dumped = src.model_dump(mode="python", exclude_none=True)
round_tripped = SourceDefinition.model_validate(dumped)
assert round_tripped.columns[0].constraints == src.columns[0].constraints
assert round_tripped.columns[0].enum_values == src.columns[0].enum_values
assert round_tripped.columns[0].tests == src.columns[0].tests
assert round_tripped.tags == src.tags
assert round_tripped.freshness == src.freshness
class TestSemanticQuery:
def test_minimal(self):
q = SemanticQuery(measures=["sum(orders.amount)"])
assert q.dimensions == []
assert q.filters == []
assert q.limit == 1000
def test_mixed_measures(self):
q = SemanticQuery(
measures=[
"orders.revenue",
{"expr": "sum(orders.amount)", "name": "total"},
]
)
assert isinstance(q.measures[0], str)
assert isinstance(q.measures[1], dict)
def test_with_dimensions(self):
q = SemanticQuery(
measures=["sum(orders.amount)"],
dimensions=[
"orders.status",
{"field": "orders.created_at", "granularity": "month"},
],
)
assert len(q.dimensions) == 2
class TestResolvedModels:
def test_resolved_column(self):
col = ResolvedColumn(
name="revenue", provenance=Provenance.VERIFIED, expr="sum(amount)"
)
assert col.provenance == Provenance.VERIFIED
def test_resolved_measure(self):
m = ResolvedMeasure(name="revenue", expr="sum(amount)", source_name="orders")
assert m.provenance == Provenance.COMPOSED
assert not m.is_derived
def test_measure_group(self):
m = ResolvedMeasure(name="rev", expr="sum(amount)", source_name="orders")
g = MeasureGroup(source_name="orders", measures=[m])
assert g.source_name == "orders"
def test_resolved_plan(self):
plan = ResolvedPlan(
sources_used=["orders"],
join_paths=[],
anchor_grain=["id"],
fan_out_description="none",
aggregate_locality=[],
where_filters=[],
having_filters=[],
columns=[ResolvedColumn(name="revenue", provenance=Provenance.COMPOSED)],
)
assert plan.has_fan_out is False
assert plan.measure_groups == []
def test_query_result(self):
plan = ResolvedPlan(
sources_used=["orders"],
join_paths=[],
anchor_grain=["id"],
fan_out_description="none",
aggregate_locality=[],
where_filters=[],
having_filters=[],
columns=[],
)
result = QueryResult(
resolved_plan=plan, sql="SELECT 1", dialect="postgres", columns=[]
)
assert result.dialect == "postgres"
class TestJoinDeclaration:
def test_with_alias(self):
from semantic_layer.models import JoinDeclaration
j = JoinDeclaration(
to="customers",
on="billing_customer_id = customers.id",
relationship="many_to_one",
alias="billing_customer",
)
assert j.alias == "billing_customer"
assert j.to == "customers"
def test_without_alias(self):
from semantic_layer.models import JoinDeclaration
j = JoinDeclaration(
to="customers",
on="customer_id = customers.id",
relationship="many_to_one",
)
assert j.alias is None
class TestMeasureDefinition:
def test_with_filter_and_description(self):
from semantic_layer.models import MeasureDefinition
m = MeasureDefinition(
name="revenue",
expr="sum(amount)",
filter="status != 'refunded'",
description="Net revenue excluding refunds",
)
assert m.filter == "status != 'refunded'"
assert m.description == "Net revenue excluding refunds"
def test_minimal(self):
from semantic_layer.models import MeasureDefinition
m = MeasureDefinition(name="total", expr="count(id)")
assert m.filter is None
assert m.description is None
class TestSemanticQueryExtended:
def test_include_empty_default(self):
q = SemanticQuery(measures=["sum(orders.amount)"])
assert q.include_empty is True
def test_include_empty_false(self):
q = SemanticQuery(measures=["sum(orders.amount)"], include_empty=False)
assert q.include_empty is False
def test_with_order_by(self):
q = SemanticQuery(
measures=["sum(orders.amount)"],
order_by=[{"field": "orders.amount", "direction": "desc"}],
)
assert len(q.order_by) == 1
assert q.order_by[0]["direction"] == "desc"
def test_custom_limit(self):
q = SemanticQuery(measures=["sum(orders.amount)"], limit=50)
assert q.limit == 50
# ── From test_edge_cases.py ──────────────────────────────────────────
class TestModelEdgeCases:
def test_semantic_query_empty_measures(self):
q = SemanticQuery(measures=[])
assert q.measures == []
def test_semantic_query_defaults(self):
q = SemanticQuery(measures=["sum(x.y)"])
assert q.dimensions == []
assert q.filters == []
assert q.order_by == []
assert q.limit == 1000
assert q.include_empty is True
def test_semantic_query_with_order_by(self):
q = SemanticQuery(
measures=["sum(orders.amount)"],
order_by=[{"field": "orders.status", "direction": "desc"}],
)
assert len(q.order_by) == 1
def test_table_and_sql_mutually_exclusive(self):
with pytest.raises(ValidationError, match="mutually exclusive"):
SourceDefinition(
name="bad",
table="t",
sql="SELECT 1",
grain=["id"],
columns=[SourceColumn(name="id", type="number")],
)
def test_empty_grain_rejected(self):
with pytest.raises(ValidationError, match="grain must be non-empty"):
SourceDefinition(
name="bad",
table="t",
grain=[],
columns=[SourceColumn(name="id", type="number")],
)
def test_measure_definition_with_filter(self):
from semantic_layer.models import MeasureDefinition
m = MeasureDefinition(
name="rev", expr="sum(amount)", filter="status != 'refunded'"
)
assert m.filter == "status != 'refunded'"

View file

@ -0,0 +1,279 @@
from semantic_layer.parser import ExpressionParser
parser = ExpressionParser()
class TestAggregateDetection:
def test_sum(self):
r = parser.parse("sum(orders.amount)")
assert r.is_aggregate
assert r.aggregate_function == "sum"
def test_avg(self):
r = parser.parse("avg(score)")
assert r.is_aggregate
assert r.aggregate_function == "avg"
def test_count(self):
r = parser.parse("count(orders.id)")
assert r.is_aggregate
assert r.aggregate_function == "count"
def test_count_distinct(self):
r = parser.parse("count_distinct(orders.customer_id)")
assert r.is_aggregate
assert r.aggregate_function == "count_distinct"
def test_non_aggregate(self):
r = parser.parse("orders.revenue")
assert not r.is_aggregate
assert r.aggregate_function is None
def test_multiple_aggregates(self):
r = parser.parse("sum(orders.amount) / count(orders.id)")
assert r.is_aggregate
# first aggregate found
assert r.aggregate_function == "sum"
def test_aggregate_in_scalar_subquery_not_aggregate(self):
# `col = (SELECT MAX(col) FROM t)` is a plain column predicate, not HAVING-bound
r = parser.parse("orders.created_at = (SELECT MAX(created_at) FROM orders)")
assert not r.is_aggregate
assert r.aggregate_function is None
def test_aggregate_in_in_subquery_not_aggregate(self):
r = parser.parse("orders.id IN (SELECT COUNT(id) FROM orders)")
assert not r.is_aggregate
def test_custom_agg_in_subquery_not_aggregate(self):
r = parser.parse(
"orders.customer_id = (SELECT count_distinct(customer_id) FROM orders)"
)
assert not r.is_aggregate
def test_outer_aggregate_with_inner_subquery_still_aggregate(self):
# Outer SUM on a plain column, even if subquery appears elsewhere
r = parser.parse("sum(orders.amount) > (SELECT AVG(amount) FROM orders)")
assert r.is_aggregate
assert r.aggregate_function == "sum"
class TestSourceRefs:
def test_single_ref(self):
r = parser.parse("sum(orders.amount)")
assert r.source_refs == {"orders"}
assert r.column_refs == {"orders.amount"}
def test_multiple_refs(self):
r = parser.parse("sum(orders.revenue) / count(customers.id)")
assert r.source_refs == {"orders", "customers"}
assert r.column_refs == {"orders.revenue", "customers.id"}
def test_pre_defined_ref(self):
r = parser.parse("orders.revenue")
assert r.source_refs == {"orders"}
assert r.column_refs == {"orders.revenue"}
def test_no_refs(self):
r = parser.parse("total_rev - total_cost")
assert r.source_refs == set()
assert r.column_refs == set()
def test_mixed_refs(self):
r = parser.parse("sum(orders.amount) + churn_risk.score")
assert r.source_refs == {"orders", "churn_risk"}
class TestDerivedMeasures:
def test_depends_on_known_measures(self):
r = parser.parse(
"total_rev - total_cost",
known_measure_names={"total_rev", "total_cost"},
)
assert r.depends_on_measures == {"total_rev", "total_cost"}
assert not r.is_aggregate
def test_no_false_positives(self):
# "sum" should not be detected as a measure dependency
r = parser.parse(
"sum(orders.amount)",
known_measure_names={"sum"},
)
assert r.depends_on_measures == set()
def test_mixed_ref_and_derived(self):
r = parser.parse(
"total_rev / count(orders.id)",
known_measure_names={"total_rev"},
)
assert r.depends_on_measures == {"total_rev"}
assert r.is_aggregate
def test_empty_known_measures(self):
r = parser.parse("total_rev - total_cost")
assert r.depends_on_measures == set()
class TestExtractSourceRefs:
def test_basic(self):
refs = parser.extract_source_refs("sum(orders.amount)")
assert refs == {"orders"}
def test_multiple(self):
refs = parser.extract_source_refs("orders.amount + customers.score")
assert refs == {"orders", "customers"}
def test_no_refs(self):
refs = parser.extract_source_refs("count(*)")
assert refs == set()
class TestEdgeCases:
def test_percentile(self):
r = parser.parse("percentile(churn_risk.score, 0.9)")
assert r.is_aggregate
assert r.aggregate_function == "percentile"
assert r.source_refs == {"churn_risk"}
def test_string_literal_not_detected(self):
# "status != 'refunded'" — 'refunded' should not be a source ref
r = parser.parse("status != 'refunded'")
assert r.source_refs == set()
def test_complex_expression(self):
r = parser.parse("sum(orders.amount) / count(orders.id) * 100")
assert r.is_aggregate
assert r.source_refs == {"orders"}
assert r.column_refs == {"orders.amount", "orders.id"}
class TestAdditionalAggregates:
def test_min(self):
r = parser.parse("min(orders.amount)")
assert r.is_aggregate
assert r.aggregate_function == "min"
def test_max(self):
r = parser.parse("max(orders.amount)")
assert r.is_aggregate
assert r.aggregate_function == "max"
def test_median(self):
r = parser.parse("median(orders.amount)")
assert r.is_aggregate
assert r.aggregate_function == "median"
def test_nested_function_not_aggregate(self):
"""abs() is not an aggregate function, but sum() wrapping it is."""
r = parser.parse("sum(orders.amount)")
assert r.is_aggregate
assert r.source_refs == {"orders"}
def test_comparison_operators(self):
"""Filter-like expression with comparison."""
r = parser.parse("orders.status = 'completed'")
assert not r.is_aggregate
assert r.source_refs == {"orders"}
def test_multiple_source_column_refs(self):
"""Expression referencing columns from 3 different sources."""
r = parser.parse(
"sum(orders.amount) + count(customers.id) - avg(tickets.score)"
)
assert r.is_aggregate
assert r.source_refs == {"orders", "customers", "tickets"}
# ── From test_edge_cases.py: TestExpressionParserEdgeCases ───────────
class TestExpressionParserEdgeCases:
def test_empty_string(self):
result = parser.parse("")
assert result.source_refs == set()
assert result.column_refs == set()
assert not result.is_aggregate
def test_count_star(self):
result = parser.parse("count(*)")
assert result.is_aggregate
assert result.aggregate_function == "count"
assert result.source_refs == set()
def test_multiple_aggregate_functions(self):
result = parser.parse("sum(orders.amount) + avg(orders.cost)")
assert result.is_aggregate
assert result.aggregate_function == "sum"
assert result.source_refs == {"orders"}
def test_nested_function_not_aggregate(self):
result = parser.parse("lower(orders.status)")
assert not result.is_aggregate
assert result.source_refs == {"orders"}
def test_source_ref_in_string_literal(self):
result = parser.parse("'orders.amount'")
assert "orders" not in result.source_refs
assert len(result.column_refs) == 0
def test_underscore_names(self):
result = parser.parse("sum(order_items.unit_price)")
assert "order_items" in result.source_refs
assert "order_items.unit_price" in result.column_refs
def test_extract_source_refs_multi(self):
refs = parser.extract_source_refs("orders.amount + customers.score")
assert refs == {"orders", "customers"}
class TestReservedWordHandling:
"""LIMIT 4: Reserved SQL keywords as source or column names."""
def test_reserved_word_source_name(self):
"""Parse 'sum(where.value)' where 'where' is a source name."""
r = parser.parse("sum(where.value)")
assert r.source_refs == {"where"}
assert r.column_refs == {"where.value"}
assert r.is_aggregate
def test_reserved_word_column_name(self):
"""Parse 'select.from' where both are reserved words."""
r = parser.parse("select.from")
assert r.source_refs == {"select"}
assert r.column_refs == {"select.from"}
def test_reserved_word_in_extract_source_refs(self):
"""extract_source_refs should handle reserved words in expressions."""
refs = parser.extract_source_refs("where.value > 10")
assert refs == {"where"}
def test_extract_source_refs_bigquery_native():
"""BigQuery-native filter must not drop source refs due to mis-parse."""
from semantic_layer.parser import ExpressionParser
parser = ExpressionParser(dialect="bigquery")
refs = parser.extract_source_refs(
"SAFE_DIVIDE(orders.revenue, customers.count) > 0"
)
assert refs == {"orders", "customers"}
def test_expression_parser_dialect_defaults_to_postgres():
"""Constructor default is postgres — keeps existing tests working."""
from semantic_layer.parser import ExpressionParser
parser = ExpressionParser()
assert parser.dialect == "postgres"
def test_extract_source_refs_postgres_baseline():
"""Postgres-dialect parser continues to work on postgres syntax."""
from semantic_layer.parser import ExpressionParser
parser = ExpressionParser(dialect="postgres")
refs = parser.extract_source_refs(
"orders.created_at >= current_date - interval '30 days'"
)
assert refs == {"orders"}

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,293 @@
"""Tests for named segments — reusable boolean predicates on a source.
Segments are AND-ed into the measure's effective filter via the same CASE WHEN
pathway used by `measure.filter`. They never become a global WHERE clause.
"""
from __future__ import annotations
import pytest
from .conftest import assert_valid_sql, make_engine
def _orders_source(**overrides):
base = {
"name": "orders",
"table": "public.orders",
"grain": ["id"],
"columns": [
{"name": "id", "type": "number"},
{"name": "amount", "type": "number"},
{"name": "is_paid", "type": "boolean"},
{"name": "is_refunded", "type": "string"},
{"name": "customer_id", "type": "number"},
],
"segments": [
{
"name": "paid_non_refunded",
"expr": "is_paid = true and is_refunded = '0'",
"description": "Settled, not reversed.",
},
],
"measures": [
{
"name": "total_revenue",
"expr": "sum(amount)",
"segments": ["paid_non_refunded"],
},
],
}
base.update(overrides)
return base
def _customers_source(**overrides):
base = {
"name": "customers",
"table": "public.customers",
"grain": ["id"],
"columns": [
{"name": "id", "type": "number"},
{"name": "is_vip", "type": "boolean"},
],
"measures": [
{"name": "customer_count", "expr": "count(distinct id)"},
],
}
base.update(overrides)
return base
# ── Composition + golden SQL shape ───────────────────────────────────
class TestSegmentComposition:
def test_measure_segment_lands_in_case_when_wrap(self):
engine = make_engine({"orders": _orders_source()})
result = engine.query({"measures": ["orders.total_revenue"]})
assert_valid_sql(result.sql)
sql_upper = result.sql.upper()
# Filter must be inside CASE WHEN (the measure-filter pathway)
assert "CASE WHEN" in sql_upper
assert "is_paid" in result.sql.lower()
assert "is_refunded" in result.sql.lower()
# Should NOT show up as a global WHERE
# (a WHERE clause may exist for other reasons — assert no segment expr in it)
# Easiest: assert WHERE doesn't contain the segment's exact predicate.
# Split before/after first WHERE keyword if any.
assert "WHERE IS_PAID" not in sql_upper.replace(" = ", " = ")
def test_measure_filter_and_segment_both_applied(self):
src = _orders_source()
src["measures"][0]["filter"] = "amount > 0"
engine = make_engine({"orders": src})
result = engine.query({"measures": ["orders.total_revenue"]})
assert_valid_sql(result.sql)
sql_lower = result.sql.lower()
# Both predicates appear inside the measure's CASE WHEN wrap
assert "amount > 0" in sql_lower
assert "is_paid" in sql_lower
assert "is_refunded" in sql_lower
# AND composition: ensure both halves are joined
assert " and " in sql_lower
def test_query_time_segment_applies_to_measure(self):
# Measure has no measure-bound segment; segment is applied at query time.
src = _orders_source()
src["measures"] = [{"name": "raw_revenue", "expr": "sum(amount)"}]
engine = make_engine({"orders": src})
result = engine.query(
{
"measures": ["orders.raw_revenue"],
"segments": ["orders.paid_non_refunded"],
}
)
assert_valid_sql(result.sql)
sql_lower = result.sql.lower()
assert "case when" in sql_lower
assert "is_paid" in sql_lower
assert "is_refunded" in sql_lower
def test_measure_and_query_segments_compose(self):
# Measure has paid_non_refunded; query adds 'high_value'.
src = _orders_source()
src["segments"].append(
{"name": "high_value", "expr": "amount >= 100"},
)
engine = make_engine({"orders": src})
result = engine.query(
{
"measures": ["orders.total_revenue"],
"segments": ["orders.high_value"],
}
)
assert_valid_sql(result.sql)
sql_lower = result.sql.lower()
# All three predicates present
assert "is_paid" in sql_lower
assert "is_refunded" in sql_lower
assert "amount >= 100" in sql_lower
# ── Multi-source query: scope is per-measure, not global ─────────────
class TestSegmentMultiSourceScope:
def test_segment_does_not_apply_to_other_source_measures(self):
# Query touches both orders and customers; segment is on orders only.
# Assert that the segment predicate does NOT show up in the
# customers CTE / WHERE on customers.
engine = make_engine(
{
"orders": _orders_source(
joins=[
{
"to": "customers",
"on": "customer_id = customers.id",
"relationship": "many_to_one",
}
],
measures=[
{"name": "raw_revenue", "expr": "sum(amount)"},
],
),
"customers": _customers_source(),
}
)
result = engine.query(
{
"measures": [
"orders.raw_revenue",
"customers.customer_count",
],
"segments": ["orders.paid_non_refunded"],
}
)
assert_valid_sql(result.sql)
sql_lower = result.sql.lower()
# Segment predicate appears (it landed on orders)
assert "is_paid" in sql_lower
# The customers measure's pre-aggregation CTE / clause must not be filtered by the segment.
# Heuristic: find each line that references count(distinct ... id) and assert no
# "is_paid" or "is_refunded" in the same CASE WHEN block. The simpler assertion
# is that there's no global WHERE applying the segment.
# We assert the segment doesn't appear inside an aggregate against the customers source.
# Concretely: count(...customers...) should not contain is_paid/is_refunded.
# Walk the SQL and find COUNT(DISTINCT ... ID) — that aggregate must be unfiltered.
import re
count_aggs = re.findall(
r"COUNT\s*\(\s*DISTINCT[^()]*\)", result.sql, flags=re.IGNORECASE
)
assert count_aggs, "expected at least one COUNT(DISTINCT ...) aggregate"
for agg in count_aggs:
assert "is_paid" not in agg.lower(), (
f"customer_count aggregate must not be filtered by segment: {agg}"
)
# ── Error cases ──────────────────────────────────────────────────────
class TestSegmentErrors:
def test_unknown_bare_name_in_measure_segments(self):
src = _orders_source()
src["measures"][0]["segments"] = ["does_not_exist"]
engine = make_engine({"orders": src})
with pytest.raises(ValueError, match="unknown segment 'does_not_exist'"):
engine.query({"measures": ["orders.total_revenue"]})
def test_unknown_query_time_segment_name(self):
engine = make_engine({"orders": _orders_source()})
with pytest.raises(ValueError, match="unknown segment 'does_not_exist'"):
engine.query(
{
"measures": ["orders.total_revenue"],
"segments": ["orders.does_not_exist"],
}
)
def test_unknown_query_time_segment_source(self):
engine = make_engine({"orders": _orders_source()})
with pytest.raises(ValueError, match="unknown source 'no_such_source'"):
engine.query(
{
"measures": ["orders.total_revenue"],
"segments": ["no_such_source.foo"],
}
)
def test_query_time_segment_must_be_dotted(self):
engine = make_engine({"orders": _orders_source()})
with pytest.raises(ValueError, match="dotted"):
engine.query(
{
"measures": ["orders.total_revenue"],
"segments": ["paid_non_refunded"], # missing source prefix
}
)
def test_no_op_query_time_segment_errors(self):
# Segment on customers, but no customers measure in the query.
engine = make_engine(
{
"orders": _orders_source(
joins=[
{
"to": "customers",
"on": "customer_id = customers.id",
"relationship": "many_to_one",
}
],
measures=[{"name": "raw_revenue", "expr": "sum(amount)"}],
),
"customers": _customers_source(
segments=[{"name": "vips", "expr": "is_vip = true"}]
),
}
)
with pytest.raises(ValueError, match="no matching"):
engine.query(
{
"measures": ["orders.raw_revenue"],
"segments": ["customers.vips"],
}
)
def test_bigquery_native_segment_referenced_by_measure(make_engine_factory):
"""Segment authored in BigQuery dialect, referenced by a measure,
must not degrade the segment's native syntax when composed."""
source = {
"name": "fct_orders",
"table": "fct_orders",
"grain": ["id"],
"columns": [
{"name": "id", "type": "number"},
{"name": "status", "type": "string"},
{"name": "ts", "type": "time"},
],
"segments": [
{"name": "non_cancelled", "expr": "status != 'cancelled'"},
{
"name": "last_30",
"expr": "ts >= timestamp(date_sub(current_date(), interval 30 day))",
},
],
"measures": [
{
"name": "dau",
"expr": "count(distinct id)",
"segments": ["non_cancelled", "last_30"],
}
],
}
engine = make_engine_factory({"fct_orders": source}, dialect="bigquery")
result = engine.query(
{"measures": ["fct_orders.dau"], "dimensions": [], "filters": []}
)
sql = result.sql
assert "INTERVAL '30'" not in sql or "DAY" in sql.upper(), (
f"INTERVAL unit lost in segment reference:\n{sql}"
)

View file

@ -0,0 +1,470 @@
"""Comprehensive Snowflake dialect tests covering all major SQL generation code paths."""
from __future__ import annotations
from pathlib import Path
import pytest
import sqlglot
from semantic_layer.engine import SemanticEngine
from semantic_layer.models import SourceColumn, SourceDefinition
SOURCES_DIR = str(Path(__file__).parent.parent / "sources" / "ecommerce")
def assert_valid_snowflake_sql(sql: str):
"""Assert SQL parses as valid Snowflake SQL."""
try:
result = sqlglot.parse(sql, read="snowflake")
assert result and all(r is not None for r in result)
except Exception as e:
pytest.fail(f"SQL is not valid Snowflake: {e}\n\nSQL:\n{sql}")
@pytest.fixture
def sf_engine():
return SemanticEngine(SOURCES_DIR, dialect="snowflake")
@pytest.fixture
def chasm_engine():
"""Engine with hub + two fact tables for chasm trap / aggregate locality tests."""
sources = {
"hub": SourceDefinition(
name="hub",
table="public.hub",
grain=["id"],
columns=[
SourceColumn(name="id", type="number"),
SourceColumn(name="segment", type="string"),
],
),
"fact_a": SourceDefinition(
name="fact_a",
table="public.fact_a",
grain=["id"],
columns=[
SourceColumn(name="id", type="number"),
SourceColumn(name="hub_id", type="number"),
SourceColumn(name="val", type="number"),
SourceColumn(name="created_at", type="time"),
],
joins=[
{"to": "hub", "on": "hub_id = hub.id", "relationship": "many_to_one"}
],
),
"fact_b": SourceDefinition(
name="fact_b",
table="public.fact_b",
grain=["id"],
columns=[
SourceColumn(name="id", type="number"),
SourceColumn(name="hub_id", type="number"),
SourceColumn(name="val", type="number"),
],
joins=[
{"to": "hub", "on": "hub_id = hub.id", "relationship": "many_to_one"}
],
measures=[{"name": "total_val", "expr": "sum(val)", "filter": "val > 0"}],
),
}
return SemanticEngine.from_sources(sources, dialect="snowflake")
# ── Basic query patterns ─────────────────────────────────────────────
class TestSnowflakeBasic:
def test_simple_single_source(self, sf_engine):
result = sf_engine.query(
{
"measures": ["sum(orders.amount)"],
"dimensions": ["orders.status"],
}
)
sql = result.sql
assert result.dialect == "snowflake"
assert_valid_snowflake_sql(sql)
assert "GROUP BY" in sql
def test_cross_source_m2o(self, sf_engine):
result = sf_engine.query(
{
"measures": ["sum(orders.amount)"],
"dimensions": ["customers.segment", "regions.name"],
}
)
sql = result.sql
assert_valid_snowflake_sql(sql)
assert "JOIN" in sql
def test_predefined_measure_with_filter(self, sf_engine):
result = sf_engine.query(
{
"measures": ["orders.revenue"],
"dimensions": ["orders.status"],
}
)
sql = result.sql
assert_valid_snowflake_sql(sql)
assert "CASE WHEN" in sql
assert "<>" in sql # sqlglot transpiles != to <>
assert "'refunded'" in sql
def test_derived_measures(self, sf_engine):
result = sf_engine.query(
{
"measures": [
{"expr": "sum(orders.amount)", "name": "total_rev"},
{"expr": "sum(orders.cost)", "name": "total_cost"},
{"expr": "total_rev - total_cost", "name": "profit"},
],
"dimensions": ["customers.segment"],
}
)
sql = result.sql
assert_valid_snowflake_sql(sql)
assert "profit" in sql.lower()
assert "total_rev" in sql
assert "total_cost" in sql
def test_include_empty_false(self, sf_engine):
result_left = sf_engine.query(
{
"measures": ["sum(orders.amount)"],
"dimensions": ["customers.segment"],
"include_empty": True,
}
)
result_inner = sf_engine.query(
{
"measures": ["sum(orders.amount)"],
"dimensions": ["customers.segment"],
"include_empty": False,
}
)
assert_valid_snowflake_sql(result_left.sql)
assert_valid_snowflake_sql(result_inner.sql)
assert "LEFT JOIN" in result_left.sql.upper()
assert "LEFT JOIN" not in result_inner.sql.upper()
# ── Time granularity ─────────────────────────────────────────────────
class TestSnowflakeTimeGranularity:
@pytest.mark.parametrize("granularity", ["day", "week", "month", "quarter", "year"])
def test_date_trunc_uppercase(self, sf_engine, granularity):
result = sf_engine.query(
{
"measures": ["sum(orders.amount)"],
"dimensions": [
{"field": "orders.created_at", "granularity": granularity}
],
}
)
sql = result.sql
assert_valid_snowflake_sql(sql)
# Snowflake DATE_TRUNC uses uppercase granularity
assert f"DATE_TRUNC('{granularity.upper()}'" in sql
# ── Filters ──────────────────────────────────────────────────────────
class TestSnowflakeFilters:
def test_having_filter(self, sf_engine):
result = sf_engine.query(
{
"measures": ["sum(orders.amount)"],
"dimensions": ["orders.status"],
"filters": ["sum(orders.amount) > 10000"],
}
)
sql = result.sql
assert_valid_snowflake_sql(sql)
assert "HAVING" in sql
assert "10000" in sql
def test_where_and_having(self, sf_engine):
result = sf_engine.query(
{
"measures": ["sum(orders.amount)"],
"dimensions": ["orders.status"],
"filters": [
"orders.status != 'cancelled'",
"sum(orders.amount) > 1000",
],
}
)
sql = result.sql
assert_valid_snowflake_sql(sql)
assert "WHERE" in sql
assert "HAVING" in sql
# ── SQL sources / CTEs ───────────────────────────────────────────────
class TestSnowflakeCTE:
def test_sql_source_as_cte(self, sf_engine):
result = sf_engine.query(
{
"measures": ["avg(churn_risk.score)"],
"dimensions": ["churn_risk.customer_type"],
}
)
sql = result.sql
assert_valid_snowflake_sql(sql)
assert "WITH" in sql
assert "churn_risk" in sql
def test_cross_source_with_sql_source(self, sf_engine):
result = sf_engine.query(
{
"measures": ["avg(churn_risk.score)"],
"dimensions": ["regions.name"],
}
)
sql = result.sql
assert_valid_snowflake_sql(sql)
assert "WITH" in sql
assert "JOIN" in sql
def test_sql_source_with_datediff(self):
"""DATEDIFF in SQL source must survive transpilation (not become AGE)."""
sources = {
"cohorts": SourceDefinition(
name="cohorts",
sql="SELECT id, DATEDIFF(WEEK, start_date, end_date) AS n FROM t",
grain=["id"],
columns=[
SourceColumn(name="id", type="number"),
SourceColumn(name="n", type="number"),
],
),
}
engine = SemanticEngine.from_sources(sources, dialect="snowflake")
result = engine.query({"measures": ["sum(cohorts.n)"], "dimensions": []})
assert_valid_snowflake_sql(result.sql)
assert "DATEDIFF" in result.sql.upper()
assert "AGE" not in result.sql.upper()
def test_sql_source_with_datediff_in_ctes(self):
"""DATEDIFF inside inner CTEs must survive CTE promotion."""
sources = {
"retention": SourceDefinition(
name="retention",
sql=(
"WITH spine AS ("
" SELECT DISTINCT cohort_week,"
" DATEDIFF(WEEK, cohort_week, period_week) AS n"
" FROM adopters"
") SELECT cohort_week, n, COUNT(*) AS cnt FROM spine GROUP BY 1, 2"
),
grain=["cohort_week", "n"],
columns=[
SourceColumn(name="cohort_week", type="time"),
SourceColumn(name="n", type="number"),
SourceColumn(name="cnt", type="number"),
],
),
}
engine = SemanticEngine.from_sources(sources, dialect="snowflake")
result = engine.query(
{"measures": ["sum(retention.cnt)"], "dimensions": ["retention.n"]}
)
assert_valid_snowflake_sql(result.sql)
assert "DATEDIFF" in result.sql.upper()
# Inner CTE should be promoted with prefix
assert "retention__spine" in result.sql
# ── Aggregate functions ──────────────────────────────────────────────
class TestSnowflakeAggregateFunctions:
def test_median_percentile_cont(self, sf_engine):
result = sf_engine.query(
{
"measures": [{"expr": "median(orders.amount)", "name": "median_order"}],
"dimensions": ["orders.status"],
}
)
sql = result.sql
assert_valid_snowflake_sql(sql)
assert "PERCENTILE_CONT" in sql
assert "WITHIN GROUP" in sql
def test_percentile(self, sf_engine):
result = sf_engine.query(
{
"measures": [{"expr": "percentile(orders.amount, 0.9)", "name": "p90"}],
"dimensions": ["orders.status"],
}
)
sql = result.sql
assert_valid_snowflake_sql(sql)
assert "PERCENTILE_CONT" in sql
assert "0.9" in sql
def test_count_distinct(self, sf_engine):
result = sf_engine.query(
{
"measures": ["count_distinct(orders.customer_id)"],
"dimensions": ["orders.status"],
}
)
sql = result.sql
assert_valid_snowflake_sql(sql)
assert "COUNT(DISTINCT" in sql
# ── Aggregate locality / chasm traps ─────────────────────────────────
class TestSnowflakeAggregateLocality:
def test_chasm_trap_full_join(self, chasm_engine):
result = chasm_engine.query(
{
"measures": ["sum(fact_a.val)", "sum(fact_b.val)"],
"dimensions": ["hub.segment"],
}
)
sql = result.sql
assert_valid_snowflake_sql(sql)
assert "FULL JOIN" in sql.upper()
assert "COALESCE" in sql.upper()
assert "fact_a_agg" in sql
assert "fact_b_agg" in sql
def test_chasm_trap_predefined_filtered_measure(self, chasm_engine):
result = chasm_engine.query(
{
"measures": ["sum(fact_a.val)", "fact_b.total_val"],
"dimensions": ["hub.segment"],
}
)
sql = result.sql
assert_valid_snowflake_sql(sql)
assert "CASE WHEN" in sql
assert "total_val" in sql
def test_chasm_trap_derived_measure(self, chasm_engine):
result = chasm_engine.query(
{
"measures": [
{"expr": "sum(fact_a.val)", "name": "total_a"},
{"expr": "sum(fact_b.val)", "name": "total_b"},
{"expr": "total_a + total_b", "name": "grand_total"},
],
"dimensions": ["hub.segment"],
}
)
sql = result.sql
assert_valid_snowflake_sql(sql)
assert "grand_total" in sql
assert "COALESCE" in sql.upper()
def test_chasm_trap_derived_ratio_nullif(self, chasm_engine):
result = chasm_engine.query(
{
"measures": [
{"expr": "sum(fact_a.val)", "name": "total_a"},
{"expr": "sum(fact_b.val)", "name": "total_b"},
{"expr": "total_a / total_b", "name": "ratio"},
],
"dimensions": ["hub.segment"],
}
)
sql = result.sql
assert_valid_snowflake_sql(sql)
assert "NULLIF" in sql.upper()
assert "ratio" in sql
def test_chasm_trap_having(self, chasm_engine):
result = chasm_engine.query(
{
"measures": ["sum(fact_a.val)", "sum(fact_b.val)"],
"dimensions": ["hub.segment"],
"filters": ["sum(fact_a.val) > 100"],
}
)
sql = result.sql
assert_valid_snowflake_sql(sql)
assert "100" in sql
# HAVING in locality mode becomes WHERE on outer query
assert "WHERE" in sql
# ── ORDER BY + LIMIT ─────────────────────────────────────────────────
class TestSnowflakeOrderByLimit:
def test_order_by_desc_with_limit(self, sf_engine):
result = sf_engine.query(
{
"measures": ["sum(orders.amount)"],
"dimensions": ["orders.status"],
"order_by": [{"field": "sum(orders.amount)", "direction": "desc"}],
"limit": 10,
}
)
sql = result.sql
assert_valid_snowflake_sql(sql)
assert "DESC" in sql.upper()
assert "LIMIT 10" in sql
# ── Snowflake reserved words as identifiers ──────────────────────────
class TestSnowflakeReservedWords:
"""Snowflake-specific reserved words (sample, qualify) must be quoted."""
@pytest.mark.parametrize("source_name", ["sample", "qualify"])
def test_snowflake_reserved_word_as_source_name(self, source_name):
sources = {
source_name: SourceDefinition(
name=source_name,
table=f"public.{source_name}s",
grain=["id"],
columns=[
SourceColumn(name="id", type="number"),
SourceColumn(name="val", type="number"),
],
),
}
engine = SemanticEngine.from_sources(sources, dialect="snowflake")
result = engine.query(
{
"measures": [f"sum({source_name}.val)"],
"dimensions": [],
}
)
assert_valid_snowflake_sql(result.sql)
assert "SUM" in result.sql.upper()
@pytest.mark.parametrize("col_name", ["sample", "qualify"])
def test_snowflake_reserved_word_as_column_name(self, col_name):
sources = {
"orders": SourceDefinition(
name="orders",
table="public.orders",
grain=["id"],
columns=[
SourceColumn(name="id", type="number"),
SourceColumn(name=col_name, type="number"),
],
),
}
engine = SemanticEngine.from_sources(sources, dialect="snowflake")
result = engine.query(
{
"measures": [f"sum(orders.{col_name})"],
"dimensions": [],
}
)
assert_valid_snowflake_sql(result.sql)
assert "SUM" in result.sql.upper()

View file

@ -0,0 +1,296 @@
from __future__ import annotations
from semantic_layer.engine import SemanticEngine
from semantic_layer.models import (
JoinDeclaration,
SourceColumn,
SourceDefinition,
)
from semantic_layer.sql_table_extractor import (
extract_table_refs,
normalize_table,
ref_matches_source_table,
)
def _table_src(
name: str, table: str, columns: list[str] | None = None
) -> SourceDefinition:
cols = columns or ["id"]
return SourceDefinition(
name=name,
table=table,
grain=["id"],
columns=[SourceColumn(name=c, type="number") for c in cols],
)
def _sql_src(
name: str,
sql: str,
columns: list[str] | None = None,
joins: list[JoinDeclaration] | None = None,
) -> SourceDefinition:
cols = columns or ["id"]
return SourceDefinition(
name=name,
sql=sql,
grain=["id"],
columns=[SourceColumn(name=c, type="number") for c in cols],
joins=joins or [],
)
class TestExtractTableRefs:
def test_simple_select(self):
refs = extract_table_refs("select id from analytics.marts.listings")
assert refs == [("analytics", "marts", "listings")]
def test_join_clause(self):
sql = """
select l.id from analytics.marts.listings l
join analytics.marts.accounts a on l.account_id = a.id
"""
assert extract_table_refs(sql) == [
("analytics", "marts", "listings"),
("analytics", "marts", "accounts"),
]
def test_cte_alias_skipped(self):
sql = """
with d as (select id from staging.shipments)
select * from d join staging.items_shipments i on d.id = i.shipment_id
"""
# `d` is a CTE — must not appear. `staging.shipments` and
# `staging.items_shipments` both should.
refs = extract_table_refs(sql)
assert ("staging", "shipments") in refs
assert ("staging", "items_shipments") in refs
assert all(ref != ("d",) for ref in refs)
def test_dedup(self):
sql = """
select * from analytics.marts.listings l1
join analytics.marts.listings l2 on l1.id = l2.id
"""
assert extract_table_refs(sql) == [("analytics", "marts", "listings")]
def test_unparseable_returns_empty(self):
assert extract_table_refs("not valid sql !!!") == []
class TestRefMatching:
def test_normalize_strips_quotes_and_lowercases(self):
assert normalize_table('"ANALYTICS"."MARTS"."LISTINGS"') == (
"analytics",
"marts",
"listings",
)
def test_full_match(self):
assert ref_matches_source_table(
("analytics", "marts", "listings"), "ANALYTICS.MARTS.LISTINGS"
)
def test_two_part_suffix_matches_three_part_table(self):
assert ref_matches_source_table(
("marts", "listings"), "ANALYTICS.MARTS.LISTINGS"
)
def test_bare_name_matches_three_part_table(self):
assert ref_matches_source_table(("listings",), "ANALYTICS.MARTS.LISTINGS")
def test_db_mismatch_blocks_match(self):
assert not ref_matches_source_table(
("staging", "listings"), "ANALYTICS.MARTS.LISTINGS"
)
def test_longer_ref_does_not_match_shorter_table(self):
assert not ref_matches_source_table(
("analytics", "marts", "listings"), "marts.listings"
)
class TestSqlJoinCoverage:
def _build_engine(
self,
listings_table: str = "ANALYTICS.MARTS.LISTINGS",
accounts_table: str = "ANALYTICS.MARTS.ACCOUNTS",
new_source_sql: str | None = None,
new_source_joins: list[JoinDeclaration] | None = None,
) -> SemanticEngine:
listings = _table_src("LISTINGS", listings_table)
accounts = _table_src("ACCOUNTS", accounts_table)
sources = {"LISTINGS": listings, "ACCOUNTS": accounts}
if new_source_sql is not None:
sources["my_source"] = _sql_src(
"my_source",
sql=new_source_sql,
joins=new_source_joins,
)
return SemanticEngine.from_sources(sources)
def test_coverage_gap_emitted_as_error(self):
sql = """
select l.id, a.name
from ANALYTICS.MARTS.LISTINGS l
join ANALYTICS.MARTS.ACCOUNTS a on l.account_id = a.id
"""
engine = self._build_engine(new_source_sql=sql, new_source_joins=[])
report = engine.validate(recently_touched={"my_source"})
assert not report.valid
coverage_errors = [e for e in report.errors if "my_source" in e]
assert any("LISTINGS" in e and "ACCOUNTS" in e for e in coverage_errors), (
f"Expected coverage error mentioning LISTINGS and ACCOUNTS, got: {report.errors}"
)
def test_declared_join_satisfies_coverage(self):
sql = """
select l.id, a.name
from ANALYTICS.MARTS.LISTINGS l
join ANALYTICS.MARTS.ACCOUNTS a on l.account_id = a.id
"""
joins = [
JoinDeclaration(
to="LISTINGS",
on="my_source.listing_id = LISTINGS.id",
relationship="many_to_one",
),
JoinDeclaration(
to="ACCOUNTS",
on="my_source.account_id = ACCOUNTS.id",
relationship="many_to_one",
),
]
engine = self._build_engine(new_source_sql=sql, new_source_joins=joins)
report = engine.validate(recently_touched={"my_source"})
coverage_errors = [
e for e in report.errors if "my_source" in e and "joins[]" in e
]
assert coverage_errors == []
def test_partial_coverage_lists_only_missing(self):
sql = """
select l.id, a.name
from ANALYTICS.MARTS.LISTINGS l
join ANALYTICS.MARTS.ACCOUNTS a on l.account_id = a.id
"""
joins = [
JoinDeclaration(
to="LISTINGS",
on="my_source.listing_id = LISTINGS.id",
relationship="many_to_one",
),
]
engine = self._build_engine(new_source_sql=sql, new_source_joins=joins)
report = engine.validate(recently_touched={"my_source"})
coverage_errors = [
e for e in report.errors if "my_source" in e and "ACCOUNTS" in e
]
assert coverage_errors, f"Expected ACCOUNTS gap, got: {report.errors}"
assert all("LISTINGS]" not in e for e in coverage_errors), (
f"LISTINGS should be satisfied: {report.errors}"
)
def test_unmapped_table_does_not_trigger_coverage_error(self):
# SQL references staging.foo which has no manifest entry — the
# check is silent. (The agent is still expected to write a wiki
# note, but that's outside the validator's scope.)
sql = "select id from staging.foo"
engine = self._build_engine(new_source_sql=sql)
report = engine.validate(recently_touched={"my_source"})
assert not any("my_source" in e and "joins[]" in e for e in report.errors), (
f"Unmapped table must not be flagged: {report.errors}"
)
def test_quoted_identifiers_match(self):
sql = (
'select * from "ANALYTICS"."MARTS"."LISTINGS" l '
'join "ANALYTICS"."MARTS"."ACCOUNTS" a on l.account_id = a.id'
)
engine = self._build_engine(new_source_sql=sql, new_source_joins=[])
report = engine.validate(recently_touched={"my_source"})
assert any(
"my_source" in e and "LISTINGS" in e and "ACCOUNTS" in e
for e in report.errors
), f"Quoted identifiers should match: {report.errors}"
def test_cte_self_reference_not_flagged(self):
sql = """
with d as (select id from ANALYTICS.MARTS.LISTINGS)
select * from d
"""
# LISTINGS is referenced inside the CTE — that still counts and
# must be flagged (the manifest entry exists). `d` itself must
# NOT be flagged as missing.
engine = self._build_engine(new_source_sql=sql, new_source_joins=[])
report = engine.validate(recently_touched={"my_source"})
coverage_errors = [e for e in report.errors if "my_source" in e]
assert any("LISTINGS" in e for e in coverage_errors)
assert not any("'d'" in e or " d " in e for e in coverage_errors), (
f"CTE alias 'd' must not be flagged: {coverage_errors}"
)
def test_two_part_suffix_match(self):
# Source's SQL references `MARTS.LISTINGS` (2-part) — should match
# the 3-part manifest entry `ANALYTICS.MARTS.LISTINGS`.
sql = "select id from MARTS.LISTINGS"
engine = self._build_engine(new_source_sql=sql, new_source_joins=[])
report = engine.validate(recently_touched={"my_source"})
assert any("my_source" in e and "LISTINGS" in e for e in report.errors), (
f"Two-part suffix should match: {report.errors}"
)
def test_not_recently_touched_means_no_check(self):
# Same buggy SQL as above, but the source isn't in
# `recently_touched` — coverage check skipped.
sql = """
select l.id from ANALYTICS.MARTS.LISTINGS l
join ANALYTICS.MARTS.ACCOUNTS a on l.account_id = a.id
"""
engine = self._build_engine(new_source_sql=sql, new_source_joins=[])
report = engine.validate(recently_touched=None)
coverage_errors = [
e for e in report.errors if "my_source" in e and "joins[]" in e
]
assert coverage_errors == []
def test_table_only_source_skipped(self):
# A source with `table:` (no SQL) cannot be coverage-checked.
listings = _table_src("LISTINGS", "ANALYTICS.MARTS.LISTINGS")
bare = _table_src("bare", "public.bare", columns=["id"])
engine = SemanticEngine.from_sources({"LISTINGS": listings, "bare": bare})
report = engine.validate(recently_touched={"bare"})
assert not any("bare" in e and "joins[]" in e for e in report.errors), (
f"Table-only source must not be flagged: {report.errors}"
)
def test_self_reference_not_flagged(self):
# If `my_source` somehow names its own table in the manifest, we
# shouldn't flag itself.
my_source = _sql_src("my_source", sql="select id from public.my_source")
# Not realistic for SQL sources, but make sure self-refs are
# filtered defensively.
engine = SemanticEngine.from_sources({"my_source": my_source})
report = engine.validate(recently_touched={"my_source"})
assert not any("my_source" in e and "joins[]" in e for e in report.errors)

View file

@ -0,0 +1,77 @@
from semantic_layer.table_identifier_parser import (
ParseTableIdentifierItem,
parse_table_identifier_batch,
parse_table_identifier_one,
)
def test_parse_table_identifier_supported_dialects_and_aliases() -> None:
response = parse_table_identifier_batch(
[
ParseTableIdentifierItem(
key="pg",
sql_table_name="public.orders AS o",
dialect="postgres",
),
ParseTableIdentifierItem(
key="bq",
sql_table_name="analytics.orders",
dialect="bigquery",
),
ParseTableIdentifierItem(
key="sf",
sql_table_name="RAW.PUBLIC.ORDERS",
dialect="snowflake",
),
]
)
assert response["pg"].ok is True
assert response["pg"].schema_ == "public"
assert response["pg"].name == "orders"
assert response["pg"].canonical_table == "public.orders"
assert response["bq"].ok is True
assert response["bq"].schema_ == "analytics"
assert response["bq"].name == "orders"
assert response["sf"].ok is True
assert response["sf"].catalog == "RAW"
assert response["sf"].schema_ == "PUBLIC"
assert response["sf"].name == "ORDERS"
def test_parse_table_identifier_rejects_non_physical_inputs() -> None:
assert (
parse_table_identifier_one("${orders.SQL_TABLE_NAME}", "postgres").reason
== "looker_template_unresolved"
)
assert (
parse_table_identifier_one("(select * from public.orders)", "postgres").reason
== "derived_table_not_supported"
)
assert (
parse_table_identifier_one(
"public.orders join public.users on true", "postgres"
).reason
== "multiple_table_references"
)
assert (
parse_table_identifier_one("public.orders", "not-a-dialect").reason
== "unsupported_dialect"
)
def test_parse_table_identifier_preserves_batch_keys() -> None:
response = parse_table_identifier_batch(
[
ParseTableIdentifierItem(
key="z", sql_table_name="public.z", dialect="postgres"
),
ParseTableIdentifierItem(
key="a", sql_table_name="public.a", dialect="postgres"
),
]
)
assert list(response) == ["z", "a"]
assert response["z"].name == "z"
assert response["a"].name == "a"

View file

@ -0,0 +1,360 @@
"""TPC-H schema tests: loading, graph, planning, and SQL execution against DuckDB."""
from __future__ import annotations
from pathlib import Path
import pytest
from semantic_layer.engine import SemanticEngine
from semantic_layer.graph import JoinGraph
from semantic_layer.loader import SourceLoader
from semantic_layer.models import SourceDefinition
TPCH_DIR = Path(__file__).parent.parent / "sources" / "tpch"
TPCH_TABLES = [
"region",
"nation",
"supplier",
"customer",
"part",
"partsupp",
"orders",
"lineitem",
]
try:
import duckdb
HAS_DUCKDB = True
except ImportError:
HAS_DUCKDB = False
# ── Fixtures ─────────────────────────────────────────────────────────
@pytest.fixture(scope="module")
def sources() -> dict[str, SourceDefinition]:
return SourceLoader(TPCH_DIR).load_all()
@pytest.fixture(scope="module")
def graph(sources: dict[str, SourceDefinition]) -> JoinGraph:
g = JoinGraph(sources)
g.build()
return g
@pytest.fixture(scope="module")
def engine() -> SemanticEngine:
return SemanticEngine(str(TPCH_DIR), dialect="duckdb")
@pytest.fixture(scope="module")
def tpch_conn():
if not HAS_DUCKDB:
pytest.skip("duckdb not installed")
conn = duckdb.connect()
conn.execute("INSTALL tpch; LOAD tpch")
conn.execute("CALL dbgen(sf=0.01)")
conn.execute("CREATE SCHEMA IF NOT EXISTS public")
for t in TPCH_TABLES:
conn.execute(f"CREATE VIEW public.{t} AS SELECT * FROM main.{t}")
return conn
# ── Loader Tests ─────────────────────────────────────────────────────
class TestTpchLoader:
def test_all_sources_loaded(self, sources):
assert set(sources.keys()) == set(TPCH_TABLES)
def test_lineitem_columns(self, sources):
li = sources["lineitem"]
col_names = {c.name for c in li.columns}
assert "l_orderkey" in col_names
assert "l_extendedprice" in col_names
assert "l_shipdate" in col_names
assert len(li.columns) == 16
def test_lineitem_composite_grain(self, sources):
assert sources["lineitem"].grain == ["l_orderkey", "l_linenumber"]
def test_partsupp_composite_grain(self, sources):
assert sources["partsupp"].grain == ["ps_partkey", "ps_suppkey"]
def test_lineitem_measures(self, sources):
measure_names = {m.name for m in sources["lineitem"].measures}
assert "revenue" in measure_names
assert "returned_revenue" in measure_names
assert "charge" in measure_names
assert len(sources["lineitem"].measures) == 8
def test_returned_revenue_has_filter(self, sources):
m = next(
m for m in sources["lineitem"].measures if m.name == "returned_revenue"
)
assert m.filter == "l_returnflag = 'R'"
def test_lineitem_joins(self, sources):
join_targets = {j.to for j in sources["lineitem"].joins}
assert join_targets == {"orders", "part", "supplier"}
def test_region_is_leaf(self, sources):
assert sources["region"].joins == []
assert sources["region"].measures == []
def test_orders_measures(self, sources):
measure_names = {m.name for m in sources["orders"].measures}
assert measure_names == {"order_count", "total_price", "avg_order_value"}
# ── Graph Tests ──────────────────────────────────────────────────────
class TestTpchGraph:
def test_all_sources_in_graph(self, graph):
assert set(graph.adjacency.keys()) >= set(TPCH_TABLES)
def test_lineitem_to_region_path(self, graph):
"""Shortest path: lineitem → supplier → nation → region (3 hops)."""
path = graph.find_path("lineitem", "region")
assert path is not None
source_chain = [path.edges[0].from_source] + [e.to_source for e in path.edges]
assert "lineitem" in source_chain
assert "region" in source_chain
assert len(path.edges) == 3
def test_lineitem_to_part_direct(self, graph):
path = graph.find_path("lineitem", "part")
assert path is not None
assert len(path.edges) == 1
def test_part_to_supplier_via_lineitem(self, graph):
"""Shortest path: part → lineitem → supplier (2 hops, shorter than via partsupp)."""
path = graph.find_path("part", "supplier")
assert path is not None
assert len(path.edges) == 2
def test_partsupp_bridges_part_and_supplier(self, graph):
"""partsupp has direct edges to both part and supplier."""
path_to_part = graph.find_path("partsupp", "part")
path_to_supplier = graph.find_path("partsupp", "supplier")
assert path_to_part is not None and len(path_to_part.edges) == 1
assert path_to_supplier is not None and len(path_to_supplier.edges) == 1
def test_graph_is_single_component(self, graph):
components = graph.find_components()
assert len(components) == 1
# ── Plan-only Tests (no DuckDB needed) ───────────────────────────────
class TestTpchPlanning:
def test_q1_plan(self, engine):
plan = engine.plan_only(
{
"measures": ["lineitem.revenue"],
"dimensions": ["lineitem.l_returnflag", "lineitem.l_linestatus"],
}
)
assert plan.anchor_source == "lineitem"
assert len(plan.sources_used) == 1
def test_q5_plan_multi_hop(self, engine):
plan = engine.plan_only(
{
"measures": ["lineitem.revenue"],
"dimensions": ["nation.n_name"],
"filters": ["region.r_name = 'ASIA'"],
}
)
assert "lineitem" in plan.sources_used
assert "nation" in plan.sources_used
assert "region" in plan.sources_used
def test_filtered_measure_plan(self, engine):
plan = engine.plan_only(
{
"measures": ["lineitem.returned_revenue"],
"dimensions": ["customer.c_name"],
}
)
assert any(m.filter for m in plan.measures)
def test_time_granularity_plan(self, engine):
plan = engine.plan_only(
{
"measures": ["lineitem.revenue"],
"dimensions": [{"field": "orders.o_orderdate", "granularity": "month"}],
}
)
col_names = [c.name for c in plan.columns]
# Column may be named "o_orderdate" with granularity metadata
assert "o_orderdate" in col_names
dim_col = next(c for c in plan.columns if c.name == "o_orderdate")
assert dim_col.granularity == "month"
def test_suggest_valid_query(self, engine):
result = engine.suggest(
{
"measures": ["lineitem.revenue"],
"dimensions": ["lineitem.l_returnflag"],
}
)
assert result["success"] is True
def test_suggest_missing_source(self, engine):
result = engine.suggest(
{
"measures": ["sum(lineitem.l_quantity)"],
"dimensions": ["nonexistent.col"],
}
)
assert result["success"] is False
# ── Execution Tests (require DuckDB) ────────────────────────────────
@pytest.mark.skipif(not HAS_DUCKDB, reason="duckdb not installed")
class TestTpchExecution:
def test_q1_pricing_summary(self, tpch_conn, engine):
result = engine.query(
{
"measures": [
"lineitem.revenue",
"lineitem.total_quantity",
"lineitem.avg_discount",
"lineitem.line_count",
],
"dimensions": ["lineitem.l_returnflag", "lineitem.l_linestatus"],
}
)
rows = tpch_conn.execute(result.sql).fetchall()
assert len(rows) > 0
# TPC-H has exactly 4 combinations: A/F, N/F, N/O, R/F
assert len(rows) <= 4
def test_q5_revenue_by_nation_asia(self, tpch_conn, engine):
"""4-hop join with filter: lineitem→supplier→nation→region."""
result = engine.query(
{
"measures": ["lineitem.revenue"],
"dimensions": ["nation.n_name"],
"filters": ["region.r_name = 'ASIA'"],
}
)
rows = tpch_conn.execute(result.sql).fetchall()
assert len(rows) > 0
# ASIA has 5 nations
assert len(rows) <= 5
def test_q3_revenue_by_month(self, tpch_conn, engine):
"""DATE_TRUNC + multi-table filter."""
result = engine.query(
{
"measures": ["lineitem.revenue"],
"dimensions": [{"field": "orders.o_orderdate", "granularity": "month"}],
"filters": ["customer.c_mktsegment = 'BUILDING'"],
"limit": 12,
}
)
rows = tpch_conn.execute(result.sql).fetchall()
assert len(rows) > 0
assert len(rows) <= 12
def test_q10_returned_revenue(self, tpch_conn, engine):
"""Filtered measure with CASE WHEN."""
result = engine.query(
{
"measures": ["lineitem.returned_revenue"],
"dimensions": ["customer.c_name"],
"limit": 10,
}
)
rows = tpch_conn.execute(result.sql).fetchall()
assert len(rows) > 0
assert len(rows) <= 10
def test_order_count(self, tpch_conn, engine):
result = engine.query(
{
"measures": ["orders.order_count"],
"dimensions": ["orders.o_orderstatus"],
}
)
rows = tpch_conn.execute(result.sql).fetchall()
assert len(rows) > 0
# Sum of counts should equal total orders at SF=0.01
total = sum(r[1] for r in rows)
assert total == 15000 # SF=0.01 → 15000 orders
def test_supply_cost_by_nation(self, tpch_conn, engine):
"""Bridge table path: partsupp → supplier → nation."""
result = engine.query(
{
"measures": ["partsupp.total_supply_cost"],
"dimensions": ["nation.n_name"],
}
)
rows = tpch_conn.execute(result.sql).fetchall()
assert len(rows) == 25 # 25 nations
def test_avg_order_value(self, tpch_conn, engine):
result = engine.query(
{
"measures": ["orders.avg_order_value"],
"dimensions": ["customer.c_mktsegment"],
}
)
rows = tpch_conn.execute(result.sql).fetchall()
assert len(rows) == 5 # 5 market segments
# avg values should be positive
for row in rows:
assert row[1] > 0
def test_lineitem_charge(self, tpch_conn, engine):
"""Complex expression: sum(price * (1 - discount) * (1 + tax))."""
result = engine.query(
{
"measures": ["lineitem.charge"],
"dimensions": ["lineitem.l_returnflag"],
}
)
rows = tpch_conn.execute(result.sql).fetchall()
assert len(rows) > 0
for row in rows:
assert row[1] > 0
def test_order_by_desc(self, tpch_conn, engine):
result = engine.query(
{
"measures": ["lineitem.revenue"],
"dimensions": ["nation.n_name"],
"order_by": [{"field": "lineitem.revenue", "direction": "desc"}],
"limit": 5,
}
)
rows = tpch_conn.execute(result.sql).fetchall()
assert len(rows) == 5
# Revenue should be descending
revenues = [r[1] for r in rows]
assert revenues == sorted(revenues, reverse=True)
def test_multiple_filters(self, tpch_conn, engine):
result = engine.query(
{
"measures": ["lineitem.revenue"],
"dimensions": ["orders.o_orderpriority"],
"filters": [
"customer.c_mktsegment = 'BUILDING'",
"nation.n_name = 'FRANCE'",
],
}
)
rows = tpch_conn.execute(result.sql).fetchall()
assert len(rows) > 0

View file

@ -0,0 +1,299 @@
from __future__ import annotations
import pytest
from semantic_layer.engine import SemanticEngine
from semantic_layer.models import (
JoinDeclaration,
SourceColumn,
SourceDefinition,
)
def _src(
name: str,
columns: list[str] | None = None,
grain: list[str] | None = None,
joins: list[JoinDeclaration] | None = None,
) -> SourceDefinition:
"""Minimal-boilerplate source factory for validator tests."""
columns = columns or ["id"]
grain = grain or ["id"]
return SourceDefinition(
name=name,
table=f"public.{name}",
grain=grain,
columns=[SourceColumn(name=c, type="number") for c in columns],
joins=joins or [],
)
class TestValidatorValid:
def test_valid_connected_model(self):
orders = _src(
"orders",
columns=["id", "customer_id"],
joins=[
JoinDeclaration(
to="customers",
on="customer_id = customers.id",
relationship="many_to_one",
)
],
)
customers = _src("customers")
engine = SemanticEngine.from_sources({"orders": orders, "customers": customers})
report = engine.validate()
assert report.valid
assert report.errors == []
assert report.warnings == []
class TestOrphanJoinTarget:
def test_orphan_join_target_is_error(self):
orders = _src(
"orders",
columns=["id", "customer_id"],
joins=[
JoinDeclaration(
to="customers",
on="customer_id = customers.id",
relationship="many_to_one",
)
],
)
# `customers` deliberately not defined
engine = SemanticEngine.from_sources({"orders": orders})
report = engine.validate()
assert not report.valid
assert any(
"orders" in e and "customers" in e and "not defined" in e
for e in report.errors
)
def test_query_with_orphan_target_raises_before_sql(self):
"""Query path must reject orphan targets, not silently emit SQL
that references the undefined table name (which could read a real
unmodeled table sharing that name)."""
orders = _src(
"orders",
columns=["id", "amount", "customer_id"],
joins=[
JoinDeclaration(
to="customers",
on="customer_id = customers.id",
relationship="many_to_one",
)
],
)
engine = SemanticEngine.from_sources({"orders": orders})
with pytest.raises(ValueError) as exc:
engine.query(
{
"measures": ["sum(orders.amount)"],
"dimensions": ["customers.id"],
}
)
msg = str(exc.value)
assert "orders" in msg
assert "customers" in msg
assert "not defined" in msg
class TestInvalidGrain:
def test_grain_column_missing_from_columns(self):
bad = _src(
"bad",
columns=["id"],
grain=["nonexistent_col"],
)
engine = SemanticEngine.from_sources({"bad": bad})
report = engine.validate()
assert not report.valid
assert any("bad" in e and "nonexistent_col" in e for e in report.errors)
class TestDisconnectedComponents:
def test_two_components_produce_warning_not_error(self):
a = _src("a")
b = _src("b")
engine = SemanticEngine.from_sources({"a": a, "b": b})
report = engine.validate()
assert report.valid
assert report.errors == []
assert len(report.warnings) >= 1
disconnection = next(
(w for w in report.warnings if "disconnected components" in w), None
)
assert disconnection is not None
assert "2 disconnected components" in disconnection
assert "Component 1" in disconnection
assert "Component 2" in disconnection
def test_aliases_do_not_create_false_disconnection(self):
"""Two aliases of the same base source must count as one component
with the base, not as separate islands."""
orders = SourceDefinition(
name="orders",
table="public.orders",
grain=["id"],
columns=[
SourceColumn(name="id", type="number"),
SourceColumn(name="amount", type="number"),
SourceColumn(name="billing_customer_id", type="number"),
SourceColumn(name="shipping_customer_id", type="number"),
],
joins=[
JoinDeclaration(
to="customers",
alias="billing_customer",
on="billing_customer_id = billing_customer.id",
relationship="many_to_one",
),
JoinDeclaration(
to="customers",
alias="shipping_customer",
on="shipping_customer_id = shipping_customer.id",
relationship="many_to_one",
),
],
)
customers = _src("customers", columns=["id", "segment"])
engine = SemanticEngine.from_sources({"orders": orders, "customers": customers})
report = engine.validate()
assert report.valid
assert not any("disconnected components" in w for w in report.warnings)
def test_large_component_is_truncated(self):
many = {f"s{i}": _src(f"s{i}") for i in range(10)}
# Join them sequentially so they form one big component
for i in range(9):
many[f"s{i}"].joins.append(
JoinDeclaration(
to=f"s{i + 1}",
on=f"id = s{i + 1}.id",
relationship="many_to_one",
)
)
many["island"] = _src("island")
engine = SemanticEngine.from_sources(many)
report = engine.validate()
disconnection = next(
w for w in report.warnings if "disconnected components" in w
)
assert "(10 sources)" in disconnection
assert "... (+8 more)" in disconnection
assert "(1 sources): island" in disconnection
def test_singleton_component_warning_names_recently_touched_source(self):
orders = _src(
"orders",
columns=["id", "customer_id"],
joins=[
JoinDeclaration(
to="customers",
on="customer_id = customers.id",
relationship="many_to_one",
)
],
)
customers = _src("customers")
lonely_source = _src("lonely_source")
engine = SemanticEngine.from_sources(
{
"orders": orders,
"customers": customers,
"lonely_source": lonely_source,
}
)
report = engine.validate(recently_touched={"lonely_source"})
assert report.per_source_warnings["lonely_source"]
msg = report.per_source_warnings["lonely_source"][0]
assert "lonely_source" in msg
assert "singleton" in msg.lower() or "no joins" in msg.lower()
def test_no_per_source_warning_for_connected_recently_touched_source(self):
orders = _src(
"orders",
columns=["id", "customer_id"],
joins=[
JoinDeclaration(
to="customers",
on="customer_id = customers.id",
relationship="many_to_one",
)
],
)
customers = _src("customers")
engine = SemanticEngine.from_sources({"orders": orders, "customers": customers})
report = engine.validate(recently_touched={"orders"})
assert report.per_source_warnings.get("orders", []) == []
def test_recently_touched_default_none_preserves_existing_behavior(self):
lonely = _src("lonely")
other = _src("other")
engine = SemanticEngine.from_sources({"lonely": lonely, "other": other})
report = engine.validate()
assert any("disconnected components" in w for w in report.warnings)
assert report.per_source_warnings == {}
class TestEcommerceSmoke:
def test_ecommerce_fixtures_validate_cleanly(self, ecommerce_sources):
engine = SemanticEngine.from_sources(ecommerce_sources)
report = engine.validate()
assert report.valid, f"Expected clean report, got errors: {report.errors}"
assert report.warnings == [], f"Expected no warnings, got: {report.warnings}"
class TestMultipleIssuesCollected:
def test_errors_and_warnings_coexist(self):
bad_grain = _src("bad_grain", columns=["id"], grain=["missing"])
orphan_target = _src(
"with_orphan",
columns=["id", "fk"],
joins=[
JoinDeclaration(
to="doesnt_exist",
on="fk = doesnt_exist.id",
relationship="many_to_one",
)
],
)
isolated = _src("isolated")
engine = SemanticEngine.from_sources(
{
"bad_grain": bad_grain,
"with_orphan": orphan_target,
"isolated": isolated,
}
)
report = engine.validate()
assert not report.valid
assert len(report.errors) >= 2
assert any("missing" in e for e in report.errors)
assert any("doesnt_exist" in e for e in report.errors)
assert len(report.warnings) >= 1