ktx/python/ktx-sl/tests/test_generator.py

2303 lines
75 KiB
Python
Raw Permalink Normal View History

2026-05-10 23:12:26 +02:00
from pathlib import Path
import pytest
import sqlglot
from semantic_layer.engine import SemanticEngine
from semantic_layer.generator import SqlGenerator
from semantic_layer.graph import JoinGraph
from semantic_layer.models import (
JoinDeclaration,
MeasureDefinition,
SemanticQuery,
SourceColumn,
SourceDefinition,
)
from semantic_layer.planner import QueryPlanner
@pytest.fixture
def planner(ecommerce_sources):
graph = JoinGraph(ecommerce_sources)
graph.build()
return QueryPlanner(ecommerce_sources, graph)
@pytest.fixture
def generator():
return SqlGenerator(dialect="postgres")
def generate_sql(planner, generator, query_dict, sources):
query = SemanticQuery(**query_dict)
plan = planner.plan(query)
return generator.generate(plan, sources)
def assert_valid_sql(sql: str):
"""Assert that the SQL is syntactically valid."""
try:
sqlglot.parse(sql)
except Exception as e:
pytest.fail(f"Generated SQL is not valid: {e}\n\nSQL:\n{sql}")
class TestSimpleSingleSource:
"""Test 1: Simple single source."""
def test_basic_aggregation(self, planner, generator, ecommerce_sources):
sql = generate_sql(
planner,
generator,
{
"measures": ["sum(orders.amount)"],
"dimensions": ["orders.status"],
},
ecommerce_sources,
)
assert_valid_sql(sql)
assert "SUM(orders.amount)" in sql.upper() or "sum(orders.amount)" in sql
assert "status" in sql.lower()
assert "GROUP BY" in sql.upper()
assert "public.orders" in sql
class TestCrossSourceM2O:
"""Test 2: Cross-source, all m2o (the LATAM query)."""
def test_churn_risk_by_region(self, planner, generator, ecommerce_sources):
sql = generate_sql(
planner,
generator,
{
"measures": ["churn_risk.avg_risk"],
"dimensions": ["churn_risk.customer_type", "regions.name"],
"filters": ["regions.name = 'LATAM'", "churn_risk.score > 0.7"],
},
ecommerce_sources,
)
assert_valid_sql(sql)
# Should have CTE for churn_risk (SQL source)
assert "churn_risk AS" in sql
assert "calculate_churn_score" in sql # SQL source content
assert "AVG" in sql.upper() or "avg" in sql
assert "WHERE" in sql.upper()
assert "LATAM" in sql
assert "GROUP BY" in sql.upper()
class TestFanOut:
"""Test 3: Fanout (aggregate locality)."""
2026-05-10 23:12:26 +02:00
def test_orders_by_region_no_fanout(self, planner, generator, ecommerce_sources):
"""orders → customers → regions is all m2o. No fanout needed."""
2026-05-10 23:12:26 +02:00
sql = generate_sql(
planner,
generator,
{
"measures": ["sum(orders.amount)"],
"dimensions": ["regions.name"],
},
ecommerce_sources,
)
assert_valid_sql(sql)
assert "SUM" in sql.upper() or "sum" in sql
assert "JOIN" in sql.upper()
# Should NOT have aggregate locality CTEs
assert "_agg" not in sql
class TestChasmTrap:
"""Test 4: Chasm trap (two o2m from same dimension source)."""
def test_chasm_trap_generates_locality(self):
customers = SourceDefinition(
name="customers",
table="public.customers",
grain=["id"],
columns=[
SourceColumn(name="id", type="number"),
SourceColumn(name="segment", type="string"),
],
)
orders = SourceDefinition(
name="orders",
table="public.orders",
grain=["id"],
columns=[
SourceColumn(name="id", type="number"),
SourceColumn(name="customer_id", type="number"),
SourceColumn(name="amount", type="number"),
],
joins=[
JoinDeclaration(
to="customers",
on="customer_id = customers.id",
relationship="many_to_one",
)
],
)
tickets = SourceDefinition(
name="tickets",
table="public.tickets",
grain=["id"],
columns=[
SourceColumn(name="id", type="number"),
SourceColumn(name="customer_id", type="number"),
],
joins=[
JoinDeclaration(
to="customers",
on="customer_id = customers.id",
relationship="many_to_one",
)
],
)
sources = {"customers": customers, "orders": orders, "tickets": tickets}
graph = JoinGraph(sources)
graph.build()
planner = QueryPlanner(sources, graph)
generator = SqlGenerator(dialect="postgres")
query = SemanticQuery(
measures=["sum(orders.amount)", "count(tickets.id)"],
dimensions=["customers.segment"],
)
plan = planner.plan(query)
sql = generator.generate(plan, sources)
assert_valid_sql(sql)
# Should have pre-aggregation CTEs
assert "orders_agg" in sql
assert "tickets_agg" in sql
assert "FULL JOIN" in sql.upper() or "full join" in sql.lower()
assert "COALESCE" in sql.upper() or "coalesce" in sql.lower()
class TestDerivedExpression:
"""Test 5: Derived expression."""
def test_profit_calculation(self, planner, generator, ecommerce_sources):
sql = generate_sql(
planner,
generator,
{
"measures": [
{"expr": "sum(orders.amount)", "name": "total_rev"},
{"expr": "sum(orders.cost)", "name": "total_cost"},
{"expr": "total_rev - total_cost", "name": "profit"},
],
"dimensions": ["orders.status"],
},
ecommerce_sources,
)
assert_valid_sql(sql)
assert "total_rev" in sql
assert "total_cost" in sql
assert "profit" in sql
# The derived expression should inline the aggregate expressions
assert "GROUP BY" in sql.upper()
class TestAutoHaving:
"""Test 6: Auto-HAVING."""
def test_having_filter(self, planner, generator, ecommerce_sources):
sql = generate_sql(
planner,
generator,
{
"measures": ["sum(orders.amount)"],
"dimensions": ["orders.status"],
"filters": ["sum(orders.amount) > 10000"],
},
ecommerce_sources,
)
assert_valid_sql(sql)
assert "HAVING" in sql.upper()
assert "10000" in sql
class TestTimeGranularity:
"""Test 7: Time granularity."""
def test_month_truncation(self, planner, generator, ecommerce_sources):
sql = generate_sql(
planner,
generator,
{
"measures": ["sum(orders.amount)"],
"dimensions": [{"field": "orders.created_at", "granularity": "month"}],
},
ecommerce_sources,
)
assert_valid_sql(sql)
assert "DATE_TRUNC" in sql.upper()
assert "month" in sql.lower()
class TestPreDefinedMeasureWithFilter:
"""Test 8: Pre-defined measure with filter."""
def test_revenue_filter(self, planner, generator, ecommerce_sources):
sql = generate_sql(
planner,
generator,
{
"measures": ["orders.revenue"],
"dimensions": ["orders.status"],
},
ecommerce_sources,
)
assert_valid_sql(sql)
# Revenue has filter: status != 'refunded'
# Should generate: SUM(CASE WHEN status != 'refunded' THEN amount END)
upper = sql.upper()
assert "CASE WHEN" in upper or "case when" in sql
assert "REFUNDED" in upper or "refunded" in sql
assert "SUM" in upper
class TestDialectValidation:
def test_invalid_dialect_raises(self):
with pytest.raises(ValueError, match="Unknown SQL dialect"):
SqlGenerator(dialect="not_real")
def test_valid_dialect_postgres(self):
gen = SqlGenerator(dialect="postgres")
assert gen.dialect == "postgres"
def test_valid_dialect_bigquery(self):
gen = SqlGenerator(dialect="bigquery")
assert gen.dialect == "bigquery"
def test_valid_dialect_snowflake(self):
gen = SqlGenerator(dialect="snowflake")
assert gen.dialect == "snowflake"
class TestDialectTranspilation:
"""Test 9: Dialect transpilation."""
def test_bigquery(self, planner, ecommerce_sources):
gen = SqlGenerator(dialect="bigquery")
sql = generate_sql(
planner,
gen,
{
"measures": ["sum(orders.amount)"],
"dimensions": [{"field": "orders.created_at", "granularity": "month"}],
},
ecommerce_sources,
)
assert sql
# BigQuery: col is a timestamp, so sqlglot emits TIMESTAMP_TRUNC(col, MONTH).
# Either form is valid BQ; both must have MONTH as an unquoted part.
assert "DATE_TRUNC(" in sql or "TIMESTAMP_TRUNC(" in sql
assert ", MONTH)" in sql
assert "DATE_TRUNC('month'" not in sql
assert "TIMESTAMP_TRUNC('month'" not in sql
def test_snowflake(self, planner, ecommerce_sources):
gen = SqlGenerator(dialect="snowflake")
sql = generate_sql(
planner,
gen,
{
"measures": ["sum(orders.amount)"],
"dimensions": ["orders.status"],
},
ecommerce_sources,
)
assert sql
class TestSqlSourceAsCte:
"""Test 10: SQL source as CTE."""
def test_churn_risk_cte(self, planner, generator, ecommerce_sources):
sql = generate_sql(
planner,
generator,
{
"measures": ["churn_risk.avg_risk"],
"dimensions": ["churn_risk.customer_type"],
},
ecommerce_sources,
)
assert_valid_sql(sql)
# churn_risk is a SQL source, should appear as CTE
assert "WITH" in sql.upper()
assert "churn_risk AS" in sql
assert "customer_type" in sql
assert "AVG" in sql.upper() or "avg" in sql
class TestLimitClause:
"""Test: LIMIT appears in generated SQL."""
def test_limit_in_sql(self, planner, generator, ecommerce_sources):
sql = generate_sql(
planner,
generator,
{
"measures": ["sum(orders.amount)"],
"dimensions": ["orders.status"],
"limit": 50,
},
ecommerce_sources,
)
assert_valid_sql(sql)
assert "LIMIT 50" in sql.upper()
def test_default_limit(self, planner, generator, ecommerce_sources):
sql = generate_sql(
planner,
generator,
{
"measures": ["sum(orders.amount)"],
"dimensions": ["orders.status"],
},
ecommerce_sources,
)
assert_valid_sql(sql)
assert "LIMIT 1000" in sql.upper()
class TestOrderByClause:
"""Test: ORDER BY appears in generated SQL."""
def test_order_by_dimensions(self, planner, generator, ecommerce_sources):
sql = generate_sql(
planner,
generator,
{
"measures": ["sum(orders.amount)"],
"dimensions": ["orders.status"],
},
ecommerce_sources,
)
assert_valid_sql(sql)
assert "ORDER BY" in sql.upper()
class TestMultipleWhereFilters:
"""Test: Multiple WHERE filters combined with AND."""
def test_two_where_filters(self, planner, generator, ecommerce_sources):
sql = generate_sql(
planner,
generator,
{
"measures": ["sum(orders.amount)"],
"dimensions": ["orders.status"],
"filters": ["orders.status = 'completed'", "orders.amount > 100"],
},
ecommerce_sources,
)
assert_valid_sql(sql)
assert "WHERE" in sql.upper()
assert "completed" in sql
assert "100" in sql
assert "AND" in sql.upper()
class TestCombinedWhereHaving:
"""Test: Both WHERE and HAVING in same query."""
def test_where_and_having(self, planner, generator, ecommerce_sources):
sql = generate_sql(
planner,
generator,
{
"measures": ["sum(orders.amount)"],
"dimensions": ["orders.status"],
"filters": [
"orders.status = 'completed'",
"sum(orders.amount) > 10000",
],
},
ecommerce_sources,
)
assert_valid_sql(sql)
assert "WHERE" in sql.upper()
assert "HAVING" in sql.upper()
assert "completed" in sql
assert "10000" in sql
class TestMultiplePreDefinedMeasures:
"""Test: Multiple pre-defined measures from same source."""
def test_revenue_and_order_count(self, planner, generator, ecommerce_sources):
sql = generate_sql(
planner,
generator,
{
"measures": ["orders.revenue", "orders.order_count"],
"dimensions": ["orders.status"],
},
ecommerce_sources,
)
assert_valid_sql(sql)
assert "revenue" in sql.lower()
assert "order_count" in sql.lower()
# Revenue should have CASE WHEN (filtered measure)
assert "CASE WHEN" in sql.upper()
class TestRuntimeAggregationCrossSource:
"""Test: Runtime aggregation across joined sources."""
def test_runtime_agg_by_region(self, planner, generator, ecommerce_sources):
sql = generate_sql(
planner,
generator,
{
"measures": [{"expr": "count(orders.id)", "name": "order_count"}],
"dimensions": ["regions.name"],
},
ecommerce_sources,
)
assert_valid_sql(sql)
assert "COUNT" in sql.upper()
assert "regions" in sql.lower()
assert "JOIN" in sql.upper()
class TestChasmTrapWithDerived:
"""Test: Chasm trap with derived measures referencing different CTEs."""
def test_derived_across_ctes(self):
customers = SourceDefinition(
name="customers",
table="public.customers",
grain=["id"],
columns=[
SourceColumn(name="id", type="number"),
SourceColumn(name="segment", type="string"),
],
)
orders = SourceDefinition(
name="orders",
table="public.orders",
grain=["id"],
columns=[
SourceColumn(name="id", type="number"),
SourceColumn(name="customer_id", type="number"),
SourceColumn(name="amount", type="number"),
],
joins=[
JoinDeclaration(
to="customers",
on="customer_id = customers.id",
relationship="many_to_one",
)
],
)
tickets = SourceDefinition(
name="tickets",
table="public.tickets",
grain=["id"],
columns=[
SourceColumn(name="id", type="number"),
SourceColumn(name="customer_id", type="number"),
SourceColumn(name="cost", type="number"),
],
joins=[
JoinDeclaration(
to="customers",
on="customer_id = customers.id",
relationship="many_to_one",
)
],
)
sources = {"customers": customers, "orders": orders, "tickets": tickets}
from semantic_layer.graph import JoinGraph
graph = JoinGraph(sources)
graph.build()
planner = QueryPlanner(sources, graph)
gen = SqlGenerator(dialect="postgres")
query = SemanticQuery(
measures=[
{"expr": "sum(orders.amount)", "name": "total_rev"},
{"expr": "sum(tickets.cost)", "name": "total_cost"},
{"expr": "total_rev - total_cost", "name": "profit"},
],
dimensions=["customers.segment"],
)
plan = planner.plan(query)
sql = gen.generate(plan, sources)
assert_valid_sql(sql)
assert "orders_agg" in sql
assert "tickets_agg" in sql
assert "profit" in sql
SOURCES_DIR = str(Path(__file__).parent.parent / "sources" / "ecommerce")
# ── From test_edge_cases.py: generator edge cases ───────────────────
class TestGeneratorEdgeCases:
def test_no_dimensions_no_group_by(self):
engine = SemanticEngine(SOURCES_DIR, dialect="postgres")
result = engine.query({"measures": ["sum(orders.amount)"]})
assert "GROUP BY" not in result.sql
assert_valid_sql(result.sql)
def test_multiple_time_dimensions(self):
engine = SemanticEngine(SOURCES_DIR, dialect="postgres")
result = engine.query(
{
"measures": ["sum(orders.amount)"],
"dimensions": [
{"field": "orders.created_at", "granularity": "month"},
{"field": "customers.created_at", "granularity": "year"},
],
}
)
sql = result.sql
assert "DATE_TRUNC('month'" in sql
assert "DATE_TRUNC('year'" in sql
assert_valid_sql(sql)
def test_limit_zero(self):
engine = SemanticEngine(SOURCES_DIR, dialect="postgres")
result = engine.query(
{
"measures": ["sum(orders.amount)"],
"dimensions": ["orders.status"],
"limit": 0,
}
)
assert_valid_sql(result.sql)
def test_very_large_limit(self):
engine = SemanticEngine(SOURCES_DIR, dialect="postgres")
result = engine.query(
{
"measures": ["sum(orders.amount)"],
"dimensions": ["orders.status"],
"limit": 999999,
}
)
assert "LIMIT 999999" in result.sql
assert_valid_sql(result.sql)
def test_chasm_trap_no_dimensions(self):
from conftest import make_engine
customers = {
"name": "customers",
"table": "t",
"grain": ["id"],
"columns": [{"name": "id", "type": "number"}],
}
orders = {
"name": "orders",
"table": "t2",
"grain": ["id"],
"columns": [
{"name": "id", "type": "number"},
{"name": "customer_id", "type": "number"},
{"name": "amount", "type": "number"},
],
"joins": [
{
"to": "customers",
"on": "customer_id = customers.id",
"relationship": "many_to_one",
}
],
}
tickets = {
"name": "tickets",
"table": "t3",
"grain": ["id"],
"columns": [
{"name": "id", "type": "number"},
{"name": "customer_id", "type": "number"},
],
"joins": [
{
"to": "customers",
"on": "customer_id = customers.id",
"relationship": "many_to_one",
}
],
}
engine = make_engine(
{"customers": customers, "orders": orders, "tickets": tickets}
)
result = engine.query({"measures": ["sum(orders.amount)", "count(tickets.id)"]})
assert_valid_sql(result.sql)
def test_sql_source_with_chasm_trap(self):
from conftest import make_engine
customers = {
"name": "customers",
"table": "public.customers",
"grain": ["id"],
"columns": [
{"name": "id", "type": "number"},
{"name": "segment", "type": "string"},
],
}
orders = {
"name": "orders",
"sql": "SELECT id, customer_id, amount FROM raw_orders WHERE amount > 0",
"grain": ["id"],
"columns": [
{"name": "id", "type": "number"},
{"name": "customer_id", "type": "number"},
{"name": "amount", "type": "number"},
],
"joins": [
{
"to": "customers",
"on": "customer_id = customers.id",
"relationship": "many_to_one",
}
],
}
tickets = {
"name": "tickets",
"table": "public.tickets",
"grain": ["id"],
"columns": [
{"name": "id", "type": "number"},
{"name": "customer_id", "type": "number"},
],
"joins": [
{
"to": "customers",
"on": "customer_id = customers.id",
"relationship": "many_to_one",
}
],
}
engine = make_engine(
{"customers": customers, "orders": orders, "tickets": tickets}
)
result = engine.query(
{
"measures": ["sum(orders.amount)", "count(tickets.id)"],
"dimensions": ["customers.segment"],
}
)
assert "orders AS" in result.sql
assert "orders_agg" in result.sql
assert_valid_sql(result.sql)
def test_dialect_duckdb(self):
engine = SemanticEngine(SOURCES_DIR, dialect="duckdb")
result = engine.query(
{
"measures": ["sum(orders.amount)"],
"dimensions": ["orders.status"],
}
)
assert result.dialect == "duckdb"
assert result.sql
def test_dialect_mysql(self):
engine = SemanticEngine(SOURCES_DIR, dialect="mysql")
result = engine.query(
{
"measures": ["sum(orders.amount)"],
"dimensions": ["orders.status"],
}
)
assert result.dialect == "mysql"
assert result.sql
def test_pre_defined_measure_cross_source_join(self):
engine = SemanticEngine(SOURCES_DIR, dialect="postgres")
result = engine.query(
{
"measures": ["orders.revenue"],
"dimensions": ["regions.name"],
}
)
assert "CASE WHEN" in result.sql
assert "regions" in result.sql.lower()
assert_valid_sql(result.sql)
# ── From test_edge_cases.py: duplicate aliases, granularity, ORDER BY
class TestDuplicateColumnAliases:
def test_same_column_name_different_sources(self):
engine = SemanticEngine(SOURCES_DIR, dialect="postgres")
result = engine.query(
{
"measures": ["sum(orders.amount)"],
"dimensions": ["orders.created_at", "customers.created_at"],
}
)
sql = result.sql
assert_valid_sql(sql)
assert "orders_created_at" in sql
assert "customers_created_at" in sql
def test_same_column_name_one_with_granularity(self):
engine = SemanticEngine(SOURCES_DIR, dialect="postgres")
result = engine.query(
{
"measures": ["sum(orders.amount)"],
"dimensions": [
{"field": "orders.created_at", "granularity": "month"},
"customers.created_at",
],
}
)
sql = result.sql
assert_valid_sql(sql)
assert "orders_created_at_month" in sql
assert "customers_created_at" in sql
class TestEmptyGranularity:
def test_empty_granularity_treated_as_no_granularity(self):
engine = SemanticEngine(SOURCES_DIR, dialect="postgres")
result = engine.query(
{
"measures": ["sum(orders.amount)"],
"dimensions": [{"field": "orders.created_at", "granularity": ""}],
}
)
sql = result.sql
assert "DATE_TRUNC" not in sql
assert "orders.created_at" in sql
assert_valid_sql(sql)
class TestOrderBySupported:
def test_order_by_desc(self):
engine = SemanticEngine(SOURCES_DIR, dialect="postgres")
result = engine.query(
{
"measures": ["sum(orders.amount)"],
"dimensions": ["orders.status"],
"order_by": [{"field": "sum(orders.amount)", "direction": "desc"}],
}
)
sql = result.sql
assert "ORDER BY" in sql
assert "DESC" in sql.upper()
assert_valid_sql(sql)
def test_order_by_multiple_fields(self):
engine = SemanticEngine(SOURCES_DIR, dialect="postgres")
result = engine.query(
{
"measures": ["sum(orders.amount)"],
"dimensions": ["orders.status"],
"order_by": [
{"field": "orders.status", "direction": "asc"},
{"field": "sum(orders.amount)", "direction": "desc"},
],
}
)
sql = result.sql
assert "ORDER BY" in sql
assert "DESC" in sql.upper()
assert_valid_sql(sql)
def test_default_order_by_when_not_specified(self):
engine = SemanticEngine(SOURCES_DIR, dialect="postgres")
result = engine.query(
{
"measures": ["sum(orders.amount)"],
"dimensions": ["orders.status"],
}
)
assert "ORDER BY 1" in result.sql
class TestMeasureNameCollision:
def test_two_measures_same_auto_name(self):
engine = SemanticEngine(SOURCES_DIR, dialect="postgres")
result = engine.query(
{
"measures": ["sum(orders.amount)", "sum(orders.amount)"],
"dimensions": ["orders.status"],
}
)
sql = result.sql
assert_valid_sql(sql)
assert "sum_orders_amount" in sql.lower() or "sum(orders.amount)" in sql.lower()
def test_runtime_name_matches_predefined(self):
engine = SemanticEngine(SOURCES_DIR, dialect="postgres")
result = engine.query(
{
"measures": [
"orders.revenue",
{"expr": "sum(orders.cost)", "name": "revenue"},
],
"dimensions": ["orders.status"],
}
)
sql = result.sql
assert_valid_sql(sql)
class TestChainedJoins:
def test_four_hop_join(self):
engine = SemanticEngine(SOURCES_DIR, dialect="postgres")
result = engine.query(
{
"measures": ["sum(order_items.quantity)"],
"dimensions": ["regions.name"],
}
)
sql = result.sql
assert_valid_sql(sql)
assert "order_items" in sql.lower()
assert "orders" in sql.lower()
assert "customers" in sql.lower()
assert "regions" in sql.lower()
def test_measure_from_leaf_dim_from_root(self):
engine = SemanticEngine(SOURCES_DIR, dialect="postgres")
result = engine.query(
{
"measures": ["sum(order_items.price)"],
"dimensions": ["products.category"],
}
)
assert_valid_sql(result.sql)
assert "products" in result.sql.lower()
# ── From test_edge_cases.py: locality CTE filters, join types ────────
class TestWhereFilterInLocalityCTE:
def test_where_filter_in_both_ctes(self):
from conftest import make_engine
customers = {
"name": "customers",
"table": "public.customers",
"grain": ["id"],
"columns": [
{"name": "id", "type": "number"},
{"name": "segment", "type": "string"},
],
}
orders = {
"name": "orders",
"table": "public.orders",
"grain": ["id"],
"columns": [
{"name": "id", "type": "number"},
{"name": "customer_id", "type": "number"},
{"name": "amount", "type": "number"},
],
"joins": [
{
"to": "customers",
"on": "customer_id = customers.id",
"relationship": "many_to_one",
}
],
}
tickets = {
"name": "tickets",
"table": "public.tickets",
"grain": ["id"],
"columns": [
{"name": "id", "type": "number"},
{"name": "customer_id", "type": "number"},
],
"joins": [
{
"to": "customers",
"on": "customer_id = customers.id",
"relationship": "many_to_one",
}
],
}
engine = make_engine(
{"customers": customers, "orders": orders, "tickets": tickets}
)
result = engine.query(
{
"measures": ["sum(orders.amount)", "count(tickets.id)"],
"dimensions": ["customers.segment"],
"filters": ["customers.segment = 'Enterprise'"],
}
)
sql = result.sql
assert_valid_sql(sql)
assert "Enterprise" in sql
assert sql.count("Enterprise") >= 2
class TestThreeCteFullJoinCoalesce:
def test_three_cte_join_uses_coalesce(self):
sources = {
"hub": SourceDefinition(
name="hub",
table="public.hub",
grain=["id"],
columns=[
SourceColumn(name="id", type="number"),
SourceColumn(name="segment", type="string"),
],
),
"fact_a": SourceDefinition(
name="fact_a",
table="public.fact_a",
grain=["id"],
columns=[
SourceColumn(name="id", type="number"),
SourceColumn(name="hub_id", type="number"),
SourceColumn(name="val", type="number"),
],
joins=[
JoinDeclaration(
to="hub", on="hub_id = hub.id", relationship="many_to_one"
)
],
),
"fact_b": SourceDefinition(
name="fact_b",
table="public.fact_b",
grain=["id"],
columns=[
SourceColumn(name="id", type="number"),
SourceColumn(name="hub_id", type="number"),
SourceColumn(name="val", type="number"),
],
joins=[
JoinDeclaration(
to="hub", on="hub_id = hub.id", relationship="many_to_one"
)
],
),
"fact_c": SourceDefinition(
name="fact_c",
table="public.fact_c",
grain=["id"],
columns=[
SourceColumn(name="id", type="number"),
SourceColumn(name="hub_id", type="number"),
SourceColumn(name="val", type="number"),
],
joins=[
JoinDeclaration(
to="hub", on="hub_id = hub.id", relationship="many_to_one"
)
],
),
}
engine = SemanticEngine.from_sources(sources)
result = engine.query(
{
"measures": ["sum(fact_a.val)", "sum(fact_b.val)", "sum(fact_c.val)"],
"dimensions": ["hub.segment"],
}
)
sql_upper = result.sql.upper()
assert "COALESCE(" in result.sql
assert sql_upper.count("FULL JOIN") == 2
def test_two_cte_join_no_coalesce_needed(self):
engine = SemanticEngine(SOURCES_DIR, dialect="postgres")
result = engine.query(
{
"measures": ["sum(orders.amount)", "avg(churn_risk.score)"],
"dimensions": ["customers.segment"],
}
)
sql = result.sql
lines = [
line.strip() for line in sql.split("\n") if "FULL JOIN" in line.upper()
]
for line in lines:
assert "COALESCE" not in line
# ── From test_bug_fixes.py ───────────────────────────────────────────
BUG_FIX_SOURCES = {
"orders": {
"name": "orders",
"table": "public.orders",
"grain": ["id"],
"columns": [
{"name": "id", "type": "number"},
{"name": "amount", "type": "number"},
{"name": "cost", "type": "number"},
{"name": "status", "type": "string"},
{"name": "customer_id", "type": "number"},
{"name": "created_at", "type": "time", "role": "time"},
],
"joins": [
{
"to": "customers",
"on": "customer_id = customers.id",
"relationship": "many_to_one",
},
],
"measures": [
{
"name": "revenue",
"expr": "sum(amount)",
"filter": "status != 'refunded'",
},
{"name": "order_count", "expr": "count(id)"},
],
},
"customers": {
"name": "customers",
"table": "public.customers",
"grain": ["id"],
"columns": [
{"name": "id", "type": "number"},
{"name": "name", "type": "string"},
{"name": "segment", "type": "string"},
],
},
}
class TestPercentileAlias:
def test_percentile_alias_has_no_comma(self):
from conftest import make_engine
engine = make_engine(BUG_FIX_SOURCES)
result = engine.query(
{
"measures": ["percentile(orders.amount, 0.9)"],
"dimensions": ["orders.status"],
}
)
assert_valid_sql(result.sql)
assert "," not in result.resolved_plan.measures[0].name
assert "percentile_orders_amount_0_9" == result.resolved_plan.measures[0].name
def test_median_alias_clean(self):
from conftest import make_engine
engine = make_engine(BUG_FIX_SOURCES)
result = engine.query(
{
"measures": ["median(orders.amount)"],
"dimensions": ["orders.status"],
}
)
assert_valid_sql(result.sql)
assert "," not in result.resolved_plan.measures[0].name
class TestCountDistinct:
def test_count_distinct_translated(self):
from conftest import make_engine
engine = make_engine(BUG_FIX_SOURCES)
result = engine.query(
{
"measures": ["count_distinct(orders.customer_id)"],
"dimensions": ["orders.status"],
}
)
assert_valid_sql(result.sql)
assert "COUNT(DISTINCT" in result.sql.upper()
assert "count_distinct(" not in result.sql.lower()
def test_count_distinct_in_chasm_cte(self):
from conftest import make_engine
sources = {
"hub": {
"name": "hub",
"table": "public.hub",
"grain": ["id"],
"columns": [
{"name": "id", "type": "number"},
{"name": "segment", "type": "string"},
],
},
"fact_a": {
"name": "fact_a",
"table": "public.fact_a",
"grain": ["id"],
"columns": [
{"name": "id", "type": "number"},
{"name": "hub_id", "type": "number"},
{"name": "val", "type": "number"},
],
"joins": [
{
"to": "hub",
"on": "hub_id = hub.id",
"relationship": "many_to_one",
}
],
},
"fact_b": {
"name": "fact_b",
"table": "public.fact_b",
"grain": ["id"],
"columns": [
{"name": "id", "type": "number"},
{"name": "hub_id", "type": "number"},
{"name": "user_id", "type": "number"},
],
"joins": [
{
"to": "hub",
"on": "hub_id = hub.id",
"relationship": "many_to_one",
}
],
},
}
engine = make_engine(sources)
result = engine.query(
{
"measures": ["sum(fact_a.val)", "count_distinct(fact_b.user_id)"],
"dimensions": ["hub.segment"],
}
)
assert_valid_sql(result.sql)
assert "COUNT(DISTINCT" in result.sql.upper()
class TestColumnValidation:
def test_nonexistent_column_in_measure(self):
from conftest import make_engine
engine = make_engine(BUG_FIX_SOURCES)
with pytest.raises(ValueError, match="does not exist in source"):
engine.query(
{
"measures": ["sum(orders.nonexistent_column)"],
"dimensions": ["orders.status"],
}
)
def test_nonexistent_column_in_dimension(self):
from conftest import make_engine
engine = make_engine(BUG_FIX_SOURCES)
with pytest.raises(ValueError, match="does not exist in source"):
engine.query(
{
"measures": ["sum(orders.amount)"],
"dimensions": ["orders.nonexistent_dim"],
}
)
def test_nonexistent_column_in_filter(self):
from conftest import make_engine
engine = make_engine(BUG_FIX_SOURCES)
with pytest.raises(ValueError, match="does not exist in source"):
engine.query(
{
"measures": ["sum(orders.amount)"],
"dimensions": ["orders.status"],
"filters": ["orders.nonexistent_col = 'x'"],
}
)
def test_valid_columns_pass(self):
from conftest import make_engine
engine = make_engine(BUG_FIX_SOURCES)
result = engine.query(
{
"measures": ["sum(orders.amount)"],
"dimensions": ["orders.status"],
"filters": ["orders.status = 'completed'"],
}
)
assert_valid_sql(result.sql)
def test_error_lists_available_columns(self):
from conftest import make_engine
engine = make_engine(BUG_FIX_SOURCES)
with pytest.raises(ValueError, match="Available:.*amount"):
engine.query(
{
"measures": ["sum(orders.bogus)"],
"dimensions": ["orders.status"],
}
)
class TestCrossSourceMeasureFilter:
def test_measure_filter_adds_join(self):
from conftest import make_engine
cross_sources = {
"orders": {
"name": "orders",
"table": "public.orders",
"grain": ["id"],
"columns": [
{"name": "id", "type": "number"},
{"name": "amount", "type": "number"},
{"name": "status", "type": "string"},
{"name": "customer_id", "type": "number"},
],
"joins": [
{
"to": "customers",
"on": "customer_id = customers.id",
"relationship": "many_to_one",
},
],
"measures": [
{
"name": "vip_revenue",
"expr": "sum(amount)",
"filter": "customers.segment = 'VIP'",
},
],
},
"customers": {
"name": "customers",
"table": "public.customers",
"grain": ["id"],
"columns": [
{"name": "id", "type": "number"},
{"name": "segment", "type": "string"},
],
},
}
engine = make_engine(cross_sources)
result = engine.query(
{
"measures": ["orders.vip_revenue"],
"dimensions": ["orders.status"],
}
)
assert_valid_sql(result.sql)
assert "customers" in result.resolved_plan.sources_used
assert "JOIN" in result.sql.upper()
assert "customers" in result.sql.lower()
def test_measure_filter_produces_case_when(self):
from conftest import make_engine
cross_sources = {
"orders": {
"name": "orders",
"table": "public.orders",
"grain": ["id"],
"columns": [
{"name": "id", "type": "number"},
{"name": "amount", "type": "number"},
{"name": "status", "type": "string"},
{"name": "customer_id", "type": "number"},
],
"joins": [
{
"to": "customers",
"on": "customer_id = customers.id",
"relationship": "many_to_one",
},
],
"measures": [
{
"name": "vip_revenue",
"expr": "sum(amount)",
"filter": "customers.segment = 'VIP'",
},
],
},
"customers": {
"name": "customers",
"table": "public.customers",
"grain": ["id"],
"columns": [
{"name": "id", "type": "number"},
{"name": "segment", "type": "string"},
],
},
}
engine = make_engine(cross_sources)
result = engine.query(
{
"measures": ["orders.vip_revenue"],
"dimensions": ["orders.status"],
}
)
sql_upper = result.sql.upper()
assert "CASE WHEN" in sql_upper
assert "VIP" in result.sql
# ── From test_brainstorm_cases.py ────────────────────────────────────
class TestPredefinedMeasureWithFilterWrapping:
def test_non_aggregate_predefined_formula_with_filter_wraps_entire_expr(self):
orders = SourceDefinition(
name="orders",
table="public.orders",
grain=["id"],
columns=[
SourceColumn(name="id", type="number"),
SourceColumn(name="amount", type="number"),
SourceColumn(name="status", type="string"),
],
measures=[
MeasureDefinition(
name="completed_amount_twice",
expr="amount * 2",
filter="status = 'completed'",
)
],
)
sources = {"orders": orders}
graph = JoinGraph(sources)
graph.build()
planner_local = QueryPlanner(sources, graph)
gen = SqlGenerator(dialect="postgres")
plan = planner_local.plan(
SemanticQuery(measures=["orders.completed_amount_twice"])
)
sql = gen.generate(plan, sources)
assert "CASE WHEN orders.status = 'completed' THEN orders.amount * 2 END" in sql
sqlglot.parse(sql)
# ── From test_spec_gaps.py ───────────────────────────────────────────
class TestIncludeEmpty:
def test_include_empty_true_uses_left_join(self):
from conftest import make_engine
sources = {
"customers": {
"name": "customers",
"table": "public.customers",
"grain": ["id"],
"columns": [
{"name": "id", "type": "number"},
{"name": "segment", "type": "string"},
],
},
"orders": {
"name": "orders",
"table": "public.orders",
"grain": ["id"],
"columns": [
{"name": "id", "type": "number"},
{"name": "customer_id", "type": "number"},
{"name": "amount", "type": "number"},
],
"joins": [
{
"to": "customers",
"on": "customer_id = customers.id",
"relationship": "many_to_one",
}
],
},
}
engine = make_engine(sources)
result = engine.query(
{
"measures": ["sum(orders.amount)"],
"dimensions": ["customers.segment"],
"include_empty": True,
}
)
assert "LEFT JOIN" in result.sql.upper()
assert_valid_sql(result.sql)
def test_include_empty_false_uses_inner_join(self):
from conftest import make_engine
sources = {
"customers": {
"name": "customers",
"table": "public.customers",
"grain": ["id"],
"columns": [
{"name": "id", "type": "number"},
{"name": "segment", "type": "string"},
],
},
"orders": {
"name": "orders",
"table": "public.orders",
"grain": ["id"],
"columns": [
{"name": "id", "type": "number"},
{"name": "customer_id", "type": "number"},
{"name": "amount", "type": "number"},
],
"joins": [
{
"to": "customers",
"on": "customer_id = customers.id",
"relationship": "many_to_one",
}
],
},
}
engine = make_engine(sources)
result = engine.query(
{
"measures": ["sum(orders.amount)"],
"dimensions": ["customers.segment"],
"include_empty": False,
}
)
sql_upper = result.sql.upper()
assert "LEFT JOIN" not in sql_upper
assert "JOIN" in sql_upper
assert_valid_sql(result.sql)
def test_include_empty_true_different_from_false(self):
from conftest import make_engine
sources = {
"customers": {
"name": "customers",
"table": "public.customers",
"grain": ["id"],
"columns": [
{"name": "id", "type": "number"},
{"name": "segment", "type": "string"},
],
},
"orders": {
"name": "orders",
"table": "public.orders",
"grain": ["id"],
"columns": [
{"name": "id", "type": "number"},
{"name": "customer_id", "type": "number"},
{"name": "amount", "type": "number"},
],
"joins": [
{
"to": "customers",
"on": "customer_id = customers.id",
"relationship": "many_to_one",
}
],
},
}
engine = make_engine(sources)
result_true = engine.query(
{
"measures": ["sum(orders.amount)"],
"dimensions": ["customers.segment"],
"include_empty": True,
}
)
result_false = engine.query(
{
"measures": ["sum(orders.amount)"],
"dimensions": ["customers.segment"],
"include_empty": False,
}
)
assert result_true.sql != result_false.sql
def test_include_empty_in_resolved_plan(self):
from conftest import make_engine
sources = {
"customers": {
"name": "customers",
"table": "public.customers",
"grain": ["id"],
"columns": [
{"name": "id", "type": "number"},
{"name": "segment", "type": "string"},
],
},
"orders": {
"name": "orders",
"table": "public.orders",
"grain": ["id"],
"columns": [
{"name": "id", "type": "number"},
{"name": "customer_id", "type": "number"},
{"name": "amount", "type": "number"},
],
"joins": [
{
"to": "customers",
"on": "customer_id = customers.id",
"relationship": "many_to_one",
}
],
},
}
engine = make_engine(sources)
plan = engine.plan_only(
{
"measures": ["sum(orders.amount)"],
"dimensions": ["customers.segment"],
"include_empty": True,
}
)
assert plan.include_empty is True
def test_include_empty_locality_uses_full_join(self):
from conftest import make_engine
chasm = {
"customers": {
"name": "customers",
"table": "public.customers",
"grain": ["id"],
"columns": [
{"name": "id", "type": "number"},
{"name": "segment", "type": "string"},
],
},
"orders": {
"name": "orders",
"table": "public.orders",
"grain": ["id"],
"columns": [
{"name": "id", "type": "number"},
{"name": "customer_id", "type": "number"},
{"name": "amount", "type": "number"},
],
"joins": [
{
"to": "customers",
"on": "customer_id = customers.id",
"relationship": "many_to_one",
}
],
},
"tickets": {
"name": "tickets",
"table": "public.tickets",
"grain": ["id"],
"columns": [
{"name": "id", "type": "number"},
{"name": "customer_id", "type": "number"},
],
"joins": [
{
"to": "customers",
"on": "customer_id = customers.id",
"relationship": "many_to_one",
}
],
},
}
engine = make_engine(chasm)
result = engine.query(
{
"measures": ["sum(orders.amount)", "count(tickets.id)"],
"dimensions": ["customers.segment"],
"include_empty": True,
}
)
assert "FULL JOIN" in result.sql.upper()
assert_valid_sql(result.sql)
def test_include_empty_false_locality_uses_inner_join_between_ctes(self):
from conftest import make_engine
chasm = {
"customers": {
"name": "customers",
"table": "public.customers",
"grain": ["id"],
"columns": [
{"name": "id", "type": "number"},
{"name": "segment", "type": "string"},
],
},
"orders": {
"name": "orders",
"table": "public.orders",
"grain": ["id"],
"columns": [
{"name": "id", "type": "number"},
{"name": "customer_id", "type": "number"},
{"name": "amount", "type": "number"},
],
"joins": [
{
"to": "customers",
"on": "customer_id = customers.id",
"relationship": "many_to_one",
}
],
},
"tickets": {
"name": "tickets",
"table": "public.tickets",
"grain": ["id"],
"columns": [
{"name": "id", "type": "number"},
{"name": "customer_id", "type": "number"},
],
"joins": [
{
"to": "customers",
"on": "customer_id = customers.id",
"relationship": "many_to_one",
}
],
},
}
engine = make_engine(chasm)
result = engine.query(
{
"measures": ["sum(orders.amount)", "count(tickets.id)"],
"dimensions": ["customers.segment"],
"include_empty": False,
}
)
sql_upper = result.sql.upper()
assert "FULL JOIN" not in sql_upper
assert_valid_sql(result.sql)
class TestFilterPushDown:
def test_where_filter_only_pushed_to_relevant_cte(self):
from conftest import make_engine
chasm = {
"customers": {
"name": "customers",
"table": "public.customers",
"grain": ["id"],
"columns": [
{"name": "id", "type": "number"},
{"name": "segment", "type": "string"},
],
},
"orders": {
"name": "orders",
"table": "public.orders",
"grain": ["id"],
"columns": [
{"name": "id", "type": "number"},
{"name": "customer_id", "type": "number"},
{"name": "amount", "type": "number"},
],
"joins": [
{
"to": "customers",
"on": "customer_id = customers.id",
"relationship": "many_to_one",
}
],
},
"tickets": {
"name": "tickets",
"table": "public.tickets",
"grain": ["id"],
"columns": [
{"name": "id", "type": "number"},
{"name": "customer_id", "type": "number"},
],
"joins": [
{
"to": "customers",
"on": "customer_id = customers.id",
"relationship": "many_to_one",
}
],
},
}
engine = make_engine(chasm)
result = engine.query(
{
"measures": ["sum(orders.amount)", "count(tickets.id)"],
"dimensions": ["customers.segment"],
"filters": ["orders.amount > 100"],
}
)
sql = result.sql
assert_valid_sql(sql)
tickets_start = sql.find("tickets_agg AS (")
assert tickets_start >= 0
depth = 0
tickets_body_start = tickets_start + len("tickets_agg AS (")
tickets_end = tickets_body_start
for i, c in enumerate(sql[tickets_body_start:], tickets_body_start):
if c == "(":
depth += 1
elif c == ")":
if depth == 0:
tickets_end = i
break
depth -= 1
tickets_body = sql[tickets_body_start:tickets_end]
assert "orders.amount" not in tickets_body
orders_start = sql.find("orders_agg AS (")
assert orders_start >= 0
depth = 0
orders_body_start = orders_start + len("orders_agg AS (")
orders_end = orders_body_start
for i, c in enumerate(sql[orders_body_start:], orders_body_start):
if c == "(":
depth += 1
elif c == ")":
if depth == 0:
orders_end = i
break
depth -= 1
orders_body = sql[orders_body_start:orders_end]
assert "orders.amount > 100" in orders_body
def test_filter_on_shared_dimension_pushed_to_all_ctes(self):
from conftest import make_engine
chasm = {
"customers": {
"name": "customers",
"table": "public.customers",
"grain": ["id"],
"columns": [
{"name": "id", "type": "number"},
{"name": "segment", "type": "string"},
],
},
"orders": {
"name": "orders",
"table": "public.orders",
"grain": ["id"],
"columns": [
{"name": "id", "type": "number"},
{"name": "customer_id", "type": "number"},
{"name": "amount", "type": "number"},
],
"joins": [
{
"to": "customers",
"on": "customer_id = customers.id",
"relationship": "many_to_one",
}
],
},
"tickets": {
"name": "tickets",
"table": "public.tickets",
"grain": ["id"],
"columns": [
{"name": "id", "type": "number"},
{"name": "customer_id", "type": "number"},
],
"joins": [
{
"to": "customers",
"on": "customer_id = customers.id",
"relationship": "many_to_one",
}
],
},
}
engine = make_engine(chasm)
result = engine.query(
{
"measures": ["sum(orders.amount)", "count(tickets.id)"],
"dimensions": ["customers.segment"],
"filters": ["customers.segment = 'Enterprise'"],
}
)
sql = result.sql
assert_valid_sql(sql)
count = sql.count("customers.segment = 'Enterprise'")
assert count >= 2
class TestJoinAliases:
def test_alias_used_in_sql_generation(self):
from conftest import make_engine
sources = {
"orders": {
"name": "orders",
"table": "public.orders",
"grain": ["id"],
"columns": [
{"name": "id", "type": "number"},
{"name": "billing_customer_id", "type": "number"},
{"name": "shipping_customer_id", "type": "number"},
{"name": "amount", "type": "number"},
],
"joins": [
{
"to": "customers",
"on": "billing_customer_id = customers.id",
"relationship": "many_to_one",
"alias": "billing_customer",
},
{
"to": "customers",
"on": "shipping_customer_id = customers.id",
"relationship": "many_to_one",
"alias": "shipping_customer",
},
],
},
"customers": {
"name": "customers",
"table": "public.customers",
"grain": ["id"],
"columns": [
{"name": "id", "type": "number"},
{"name": "name", "type": "string"},
],
},
}
engine = make_engine(sources)
result = engine.query(
{
"measures": ["sum(orders.amount)"],
"dimensions": ["billing_customer.name"],
}
)
sql = result.sql
assert_valid_sql(sql)
assert "billing_customer" in sql
def test_self_referencing_alias(self):
from conftest import make_engine
sources = {
"employees": {
"name": "employees",
"table": "public.employees",
"grain": ["id"],
"columns": [
{"name": "id", "type": "number"},
{"name": "name", "type": "string"},
{"name": "manager_id", "type": "number"},
],
"joins": [
{
"to": "employees",
"on": "manager_id = employees.id",
"relationship": "many_to_one",
"alias": "manager",
},
],
},
}
engine = make_engine(sources)
result = engine.query(
{
"measures": ["count(employees.id)"],
"dimensions": ["manager.name"],
}
)
sql = result.sql
assert_valid_sql(sql)
assert "manager" in sql
class TestMedianPercentile:
def test_median_generates_percentile_cont(self):
from conftest import make_engine
sources = {
"scores": {
"name": "scores",
"table": "public.scores",
"grain": ["id"],
"columns": [
{"name": "id", "type": "number"},
{"name": "value", "type": "number"},
{"name": "category", "type": "string"},
],
},
}
engine = make_engine(sources)
result = engine.query(
{
"measures": [{"expr": "median(scores.value)", "name": "med_val"}],
"dimensions": ["scores.category"],
}
)
sql_upper = result.sql.upper()
assert "PERCENTILE_CONT" in sql_upper
assert "0.5" in result.sql
assert_valid_sql(result.sql)
def test_percentile_generates_percentile_cont(self):
from conftest import make_engine
sources = {
"scores": {
"name": "scores",
"table": "public.scores",
"grain": ["id"],
"columns": [
{"name": "id", "type": "number"},
{"name": "value", "type": "number"},
{"name": "category", "type": "string"},
],
},
}
engine = make_engine(sources)
result = engine.query(
{
"measures": [{"expr": "percentile(scores.value, 0.9)", "name": "p90"}],
"dimensions": ["scores.category"],
}
)
sql_upper = result.sql.upper()
assert "PERCENTILE_CONT" in sql_upper
assert "0.9" in result.sql
assert_valid_sql(result.sql)
class TestSqlSourceWithClause:
"""BUG 3: SQL sources with inner WITH clauses should not produce nested WITH."""
def test_sql_source_with_inner_with_clause(self):
"""SQL source containing WITH base AS (...) SELECT ... should not produce nested WITH."""
from conftest import make_engine
sources = {
"enriched_orders": {
"name": "enriched_orders",
"sql": "WITH base AS (SELECT id, amount FROM raw_orders WHERE amount > 0) SELECT id, amount FROM base",
"grain": ["id"],
"columns": [
{"name": "id", "type": "number"},
{"name": "amount", "type": "number"},
],
},
}
engine = make_engine(sources)
result = engine.query(
{
"measures": ["sum(enriched_orders.amount)"],
"dimensions": [],
}
)
sql = result.sql
assert_valid_sql(sql)
# Should NOT have nested WITH (WITH ... WITH ...)
# The inner WITH should be promoted/flattened
upper_sql = sql.upper()
# Only one WITH keyword should appear at the top level
with_count = upper_sql.count("WITH ")
assert with_count == 1, f"Expected 1 WITH, got {with_count}. SQL:\n{sql}"
# The inner CTE name should be promoted with prefix
assert "enriched_orders__base" in sql
def test_sql_source_without_with_unchanged(self):
"""Regular SQL source (no inner WITH) should work as before."""
from conftest import make_engine
sources = {
"simple_view": {
"name": "simple_view",
"sql": "SELECT id, amount FROM raw_orders WHERE amount > 0",
"grain": ["id"],
"columns": [
{"name": "id", "type": "number"},
{"name": "amount", "type": "number"},
],
},
}
engine = make_engine(sources)
result = engine.query(
{
"measures": ["sum(simple_view.amount)"],
"dimensions": [],
}
)
sql = result.sql
assert_valid_sql(sql)
assert "simple_view AS" in sql
assert "raw_orders" in sql
def test_sql_source_inner_cte_unaliased_references(self):
"""Inner CTEs referenced without explicit aliases should get AS old_name after promotion."""
from conftest import make_engine
sources = {
"analysis": {
"name": "analysis",
"sql": (
"WITH q AS (SELECT id, amount, status FROM raw_data), "
"filtered AS (SELECT q.id, q.amount FROM q WHERE q.status = 'active') "
"SELECT filtered.id, filtered.amount FROM filtered"
),
"grain": ["id"],
"columns": [
{"name": "id", "type": "number"},
{"name": "amount", "type": "number"},
],
},
}
engine = make_engine(sources)
result = engine.query(
{
"measures": ["sum(analysis.amount)"],
"dimensions": [],
}
)
sql = result.sql
assert_valid_sql(sql)
assert "analysis__q" in sql
assert "analysis__filtered" in sql
# Unaliased CTE refs should get AS old_name so column qualifiers work
assert "analysis__q AS q" in sql or 'analysis__q AS "q"' in sql
assert (
"analysis__filtered AS filtered" in sql
or 'analysis__filtered AS "filtered"' in sql
)
class TestSqliteTimeGranularity:
"""SQLite uses strftime/date arithmetic instead of DATE_TRUNC."""
@pytest.fixture
def sqlite_generator(self):
return SqlGenerator(dialect="sqlite")
@pytest.mark.parametrize(
"granularity,expected_fragment",
[
("year", "STRFTIME('%Y-01-01'"),
("month", "STRFTIME('%Y-%m-01'"),
("day", "STRFTIME('%Y-%m-%d'"),
("hour", "STRFTIME('%Y-%m-%d %H:00:00'"),
("quarter", "PRINTF('%02d'"),
("week", "weekday 1"),
],
)
def test_granularity(
self,
ecommerce_sources,
sqlite_generator,
granularity,
expected_fragment,
):
graph = JoinGraph(ecommerce_sources)
graph.build()
planner = QueryPlanner(ecommerce_sources, graph)
plan = planner.plan(
SemanticQuery(
measures=["count(orders.id)"],
dimensions=[{"field": "orders.created_at", "granularity": granularity}],
)
)
sql = sqlite_generator.generate(plan, ecommerce_sources)
assert expected_fragment in sql, f"Expected '{expected_fragment}' in:\n{sql}"
assert "DATE_TRUNC" not in sql
def test_postgres_unchanged(self, ecommerce_sources):
"""Postgres still generates DATE_TRUNC as before."""
gen = SqlGenerator(dialect="postgres")
graph = JoinGraph(ecommerce_sources)
graph.build()
planner = QueryPlanner(ecommerce_sources, graph)
plan = planner.plan(
SemanticQuery(
measures=["count(orders.id)"],
dimensions=[{"field": "orders.created_at", "granularity": "month"}],
)
)
sql = gen.generate(plan, ecommerce_sources)
assert "DATE_TRUNC" in sql
class TestTranspileWithNativeCtes:
def test_bigquery_hyphenated_project_ref_survives_in_sql_source(
self, make_engine_factory
):
"""sql: source body with BigQuery-specific hyphenated project ref must survive verbatim."""
# A project ref like `my-project.dataset.table` is not valid postgres,
# so feeding it to a postgres parser at transpile time would fail.
source = {
"name": "raw_events",
"sql": "SELECT id, user_id FROM `my-project.analytics.events`",
"grain": ["id"],
"columns": [
{"name": "id", "type": "number"},
{"name": "user_id", "type": "number"},
],
"measures": [{"name": "event_count", "expr": "count(*)"}],
}
engine = make_engine_factory({"raw_events": source}, dialect="bigquery")
result = engine.query(
{"measures": ["raw_events.event_count"], "dimensions": [], "filters": []}
)
assert "my-project.analytics.events" in result.sql, (
f"Hyphenated BigQuery project ref was rewritten:\n{result.sql}"
)
def test_postgres_only_idiom_in_outer_gets_translated(self, make_engine_factory):
"""Postgres-only idioms in the engine scaffold are translated to target."""
source = {
"name": "events",
"table": "events",
"grain": ["id"],
"columns": [
{"name": "id", "type": "number"},
{"name": "amount", "type": "number"},
{"name": "created_at", "type": "time"},
],
"measures": [{"name": "total", "expr": "sum(amount)"}],
}
engine = make_engine_factory({"events": source}, dialect="bigquery")
result = engine.query(
{
"measures": ["events.total"],
"dimensions": [{"field": "events.created_at", "granularity": "month"}],
"filters": [],
}
)
# BigQuery's DATE_TRUNC (or TIMESTAMP_TRUNC) uses (col, MONTH) order.
sql = result.sql.upper()
assert "DATE_TRUNC(" in sql or "TIMESTAMP_TRUNC(" in sql
assert ", MONTH)" in sql
class TestNativeDialectExprPreservation:
"""User-authored expr: in native dialect must survive composition intact.
Regression coverage for the fct_orders.daily_active_orders 400 bug:
BigQuery segments authored with `INTERVAL 30 DAY` were being parsed as
postgres, silently dropping the `DAY` unit.
"""
def test_bigquery_segment_with_interval_day_preserves_unit(
self, make_bq_fct_orders_engine
):
"""Production repro: segment with `interval 30 day` must emit `INTERVAL 30 DAY`."""
engine = make_bq_fct_orders_engine
result = engine.query(
{
"measures": ["fct_orders.daily_active_orders"],
"dimensions": [],
"filters": [],
}
)
sql = result.sql
assert "INTERVAL '30'" not in sql or "INTERVAL '30' DAY" in sql, (
f"BigQuery INTERVAL unit was dropped.\nSQL:\n{sql}"
)
# More specific: the DAY unit must be present somewhere
assert "DAY" in sql.upper(), f"DAY unit missing from emitted SQL:\n{sql}"
def test_bigquery_measure_filter_with_timestamp_sub(self, make_engine_factory):
"""Measure filter using TIMESTAMP_SUB(INTERVAL 7 DAY) must preserve unit."""
source = {
"name": "events",
"table": "events",
"grain": ["id"],
"columns": [
{"name": "id", "type": "number"},
{"name": "user_id", "type": "number"},
{"name": "ts", "type": "time"},
],
"measures": [
{
"name": "recent_users",
"expr": "count(distinct user_id)",
"filter": "ts >= TIMESTAMP_SUB(CURRENT_TIMESTAMP(), INTERVAL 7 DAY)",
}
],
}
engine = make_engine_factory({"events": source}, dialect="bigquery")
result = engine.query(
{"measures": ["events.recent_users"], "dimensions": [], "filters": []}
)
sql = result.sql.upper()
assert "TIMESTAMP_SUB" in sql
assert "INTERVAL 7 DAY" in sql or "INTERVAL '7' DAY" in sql
def test_snowflake_segment_with_dateadd(self, make_engine_factory):
"""Snowflake DATEADD(day, -30, CURRENT_TIMESTAMP()) must survive."""
source = {
"name": "orders",
"table": "orders",
"grain": ["id"],
"columns": [
{"name": "id", "type": "number"},
{"name": "placed_at", "type": "time"},
],
"segments": [
{
"name": "recent",
"expr": "placed_at >= DATEADD(day, -30, CURRENT_TIMESTAMP())",
}
],
"measures": [{"name": "cnt", "expr": "count(*)", "segments": ["recent"]}],
}
engine = make_engine_factory({"orders": source}, dialect="snowflake")
result = engine.query(
{"measures": ["orders.cnt"], "dimensions": [], "filters": []}
)
assert "DATEADD" in result.sql.upper()
def test_postgres_interval_baseline_still_works(self, make_engine_factory):
"""Baseline: postgres INTERVAL '30 days' round-trips correctly."""
source = {
"name": "orders",
"table": "orders",
"grain": ["id"],
"columns": [
{"name": "id", "type": "number"},
{"name": "placed_at", "type": "time"},
],
"segments": [
{
"name": "recent",
"expr": "placed_at >= current_date - interval '30 days'",
}
],
"measures": [{"name": "cnt", "expr": "count(*)", "segments": ["recent"]}],
}
engine = make_engine_factory({"orders": source}, dialect="postgres")
result = engine.query(
{"measures": ["orders.cnt"], "dimensions": [], "filters": []}
)
sql_upper = result.sql.upper()
assert "INTERVAL" in sql_upper
assert "30 DAYS" in sql_upper or "30' DAY" in sql_upper