mirror of
https://github.com/Kaelio/ktx.git
synced 2026-06-07 07:55:13 +02:00
741 lines
25 KiB
Python
741 lines
25 KiB
Python
|
|
"""Tests targeting specific coverage gaps in planner.py, generator.py, models.py, engine.py."""
|
||
|
|
|
||
|
|
from __future__ import annotations
|
||
|
|
|
||
|
|
import pytest
|
||
|
|
from pydantic import ValidationError
|
||
|
|
|
||
|
|
from semantic_layer.generator import SqlGenerator
|
||
|
|
from semantic_layer.graph import JoinGraph
|
||
|
|
from semantic_layer.models import (
|
||
|
|
JoinDeclaration,
|
||
|
|
MeasureDefinition,
|
||
|
|
SemanticQuery,
|
||
|
|
SourceColumn,
|
||
|
|
SourceDefinition,
|
||
|
|
)
|
||
|
|
from semantic_layer.planner import QueryPlanner
|
||
|
|
|
||
|
|
from conftest import assert_valid_sql, make_engine
|
||
|
|
|
||
|
|
|
||
|
|
# ── Helpers ──────────────────────────────────────────────────────────
|
||
|
|
|
||
|
|
|
||
|
|
def _make_planner(sources: dict[str, SourceDefinition]) -> QueryPlanner:
|
||
|
|
graph = JoinGraph(sources)
|
||
|
|
graph.build()
|
||
|
|
return QueryPlanner(sources, graph)
|
||
|
|
|
||
|
|
|
||
|
|
def _plan_and_generate(sources: dict[str, SourceDefinition], query_dict: dict) -> str:
|
||
|
|
planner = _make_planner(sources)
|
||
|
|
generator = SqlGenerator(dialect="postgres")
|
||
|
|
query = SemanticQuery(**query_dict)
|
||
|
|
plan = planner.plan(query)
|
||
|
|
sql = generator.generate(plan, sources)
|
||
|
|
assert_valid_sql(sql)
|
||
|
|
return sql
|
||
|
|
|
||
|
|
|
||
|
|
# ── Source fixtures ──────────────────────────────────────────────────
|
||
|
|
|
||
|
|
|
||
|
|
def _simple_sources() -> dict[str, SourceDefinition]:
|
||
|
|
"""orders -> customers (m2o)."""
|
||
|
|
customers = SourceDefinition(
|
||
|
|
name="customers",
|
||
|
|
table="public.customers",
|
||
|
|
grain=["id"],
|
||
|
|
columns=[
|
||
|
|
SourceColumn(name="id", type="number"),
|
||
|
|
SourceColumn(name="segment", type="string"),
|
||
|
|
],
|
||
|
|
)
|
||
|
|
orders = SourceDefinition(
|
||
|
|
name="orders",
|
||
|
|
table="public.orders",
|
||
|
|
grain=["id"],
|
||
|
|
columns=[
|
||
|
|
SourceColumn(name="id", type="number"),
|
||
|
|
SourceColumn(name="customer_id", type="number"),
|
||
|
|
SourceColumn(name="amount", type="number"),
|
||
|
|
SourceColumn(name="status", type="string"),
|
||
|
|
],
|
||
|
|
joins=[
|
||
|
|
JoinDeclaration(
|
||
|
|
to="customers",
|
||
|
|
on="customer_id = customers.id",
|
||
|
|
relationship="many_to_one",
|
||
|
|
)
|
||
|
|
],
|
||
|
|
measures=[
|
||
|
|
MeasureDefinition(
|
||
|
|
name="revenue", expr="sum(amount)", filter="status != 'refunded'"
|
||
|
|
),
|
||
|
|
MeasureDefinition(name="order_count", expr="count(id)"),
|
||
|
|
],
|
||
|
|
)
|
||
|
|
return {"customers": customers, "orders": orders}
|
||
|
|
|
||
|
|
|
||
|
|
def _chasm_sources() -> dict[str, SourceDefinition]:
|
||
|
|
"""Two fact tables (orders, tickets) -> hub (customers). Classic chasm trap."""
|
||
|
|
customers = SourceDefinition(
|
||
|
|
name="customers",
|
||
|
|
table="public.customers",
|
||
|
|
grain=["id"],
|
||
|
|
columns=[
|
||
|
|
SourceColumn(name="id", type="number"),
|
||
|
|
SourceColumn(name="segment", type="string"),
|
||
|
|
],
|
||
|
|
)
|
||
|
|
orders = SourceDefinition(
|
||
|
|
name="orders",
|
||
|
|
table="public.orders",
|
||
|
|
grain=["id"],
|
||
|
|
columns=[
|
||
|
|
SourceColumn(name="id", type="number"),
|
||
|
|
SourceColumn(name="customer_id", type="number"),
|
||
|
|
SourceColumn(name="amount", type="number"),
|
||
|
|
],
|
||
|
|
joins=[
|
||
|
|
JoinDeclaration(
|
||
|
|
to="customers",
|
||
|
|
on="customer_id = customers.id",
|
||
|
|
relationship="many_to_one",
|
||
|
|
)
|
||
|
|
],
|
||
|
|
measures=[MeasureDefinition(name="revenue", expr="sum(amount)")],
|
||
|
|
)
|
||
|
|
tickets = SourceDefinition(
|
||
|
|
name="tickets",
|
||
|
|
table="public.tickets",
|
||
|
|
grain=["id"],
|
||
|
|
columns=[
|
||
|
|
SourceColumn(name="id", type="number"),
|
||
|
|
SourceColumn(name="customer_id", type="number"),
|
||
|
|
SourceColumn(name="priority", type="string"),
|
||
|
|
],
|
||
|
|
joins=[
|
||
|
|
JoinDeclaration(
|
||
|
|
to="customers",
|
||
|
|
on="customer_id = customers.id",
|
||
|
|
relationship="many_to_one",
|
||
|
|
)
|
||
|
|
],
|
||
|
|
measures=[MeasureDefinition(name="ticket_count", expr="count(id)")],
|
||
|
|
)
|
||
|
|
return {"customers": customers, "orders": orders, "tickets": tickets}
|
||
|
|
|
||
|
|
|
||
|
|
def _chain_sources_with_derived() -> dict[str, SourceDefinition]:
|
||
|
|
"""orders -> customers -> tiers (m2o chain) with derived measures."""
|
||
|
|
tiers = SourceDefinition(
|
||
|
|
name="tiers",
|
||
|
|
table="public.tiers",
|
||
|
|
grain=["id"],
|
||
|
|
columns=[
|
||
|
|
SourceColumn(name="id", type="number"),
|
||
|
|
SourceColumn(name="level", type="string"),
|
||
|
|
],
|
||
|
|
)
|
||
|
|
customers = SourceDefinition(
|
||
|
|
name="customers",
|
||
|
|
table="public.customers",
|
||
|
|
grain=["id"],
|
||
|
|
columns=[
|
||
|
|
SourceColumn(name="id", type="number"),
|
||
|
|
SourceColumn(name="tier_id", type="number"),
|
||
|
|
SourceColumn(name="segment", type="string"),
|
||
|
|
],
|
||
|
|
joins=[
|
||
|
|
JoinDeclaration(
|
||
|
|
to="tiers", on="tier_id = tiers.id", relationship="many_to_one"
|
||
|
|
)
|
||
|
|
],
|
||
|
|
)
|
||
|
|
orders = SourceDefinition(
|
||
|
|
name="orders",
|
||
|
|
table="public.orders",
|
||
|
|
grain=["id"],
|
||
|
|
columns=[
|
||
|
|
SourceColumn(name="id", type="number"),
|
||
|
|
SourceColumn(name="customer_id", type="number"),
|
||
|
|
SourceColumn(name="amount", type="number"),
|
||
|
|
SourceColumn(name="status", type="string"),
|
||
|
|
],
|
||
|
|
joins=[
|
||
|
|
JoinDeclaration(
|
||
|
|
to="customers",
|
||
|
|
on="customer_id = customers.id",
|
||
|
|
relationship="many_to_one",
|
||
|
|
)
|
||
|
|
],
|
||
|
|
measures=[
|
||
|
|
MeasureDefinition(
|
||
|
|
name="revenue", expr="sum(amount)", filter="status != 'refunded'"
|
||
|
|
),
|
||
|
|
MeasureDefinition(name="order_count", expr="count(id)"),
|
||
|
|
MeasureDefinition(name="avg_order", expr="revenue / order_count"),
|
||
|
|
],
|
||
|
|
)
|
||
|
|
return {"tiers": tiers, "customers": customers, "orders": orders}
|
||
|
|
|
||
|
|
|
||
|
|
# ── Planner: nested aggregation (lines 432-440) ─────────────────────
|
||
|
|
|
||
|
|
|
||
|
|
class TestNestedAggregation:
|
||
|
|
def test_nested_aggregation_raises(self):
|
||
|
|
"""avg(sum(orders.amount)) should be rejected."""
|
||
|
|
sources = _simple_sources()
|
||
|
|
planner = _make_planner(sources)
|
||
|
|
with pytest.raises(ValueError, match="Nested aggregation is not supported"):
|
||
|
|
planner.plan(
|
||
|
|
SemanticQuery(
|
||
|
|
measures=["avg(sum(orders.amount))"],
|
||
|
|
dimensions=["orders.status"],
|
||
|
|
)
|
||
|
|
)
|
||
|
|
|
||
|
|
def test_nested_max_count_raises(self):
|
||
|
|
"""max(count(orders.id)) should be rejected."""
|
||
|
|
sources = _simple_sources()
|
||
|
|
planner = _make_planner(sources)
|
||
|
|
with pytest.raises(ValueError, match="Nested aggregation is not supported"):
|
||
|
|
planner.plan(
|
||
|
|
SemanticQuery(
|
||
|
|
measures=["max(count(orders.id))"],
|
||
|
|
dimensions=["orders.status"],
|
||
|
|
)
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
# ── Planner: OR filter mixing (lines 810-833) ───────────────────────
|
||
|
|
|
||
|
|
|
||
|
|
class TestOrFilterMixing:
|
||
|
|
def test_or_mixing_agg_and_nonagg_raises(self):
|
||
|
|
"""OR that mixes aggregate and non-aggregate conditions should raise."""
|
||
|
|
sources = _simple_sources()
|
||
|
|
planner = _make_planner(sources)
|
||
|
|
with pytest.raises(ValueError, match="mixes aggregate and non-aggregate"):
|
||
|
|
planner.plan(
|
||
|
|
SemanticQuery(
|
||
|
|
measures=["sum(orders.amount)"],
|
||
|
|
dimensions=["orders.status"],
|
||
|
|
filters=["orders.amount > 100 OR sum(orders.amount) > 5000"],
|
||
|
|
)
|
||
|
|
)
|
||
|
|
|
||
|
|
def test_or_pure_where_ok(self):
|
||
|
|
"""OR with all non-aggregate conditions should be fine."""
|
||
|
|
sources = _simple_sources()
|
||
|
|
sql = _plan_and_generate(
|
||
|
|
sources,
|
||
|
|
{
|
||
|
|
"measures": ["sum(orders.amount)"],
|
||
|
|
"dimensions": ["orders.status"],
|
||
|
|
"filters": ["orders.amount > 100 OR orders.amount < 10"],
|
||
|
|
},
|
||
|
|
)
|
||
|
|
assert "OR" in sql.upper()
|
||
|
|
|
||
|
|
def test_or_pure_having_ok(self):
|
||
|
|
"""OR with all aggregate conditions should be fine."""
|
||
|
|
sources = _simple_sources()
|
||
|
|
sql = _plan_and_generate(
|
||
|
|
sources,
|
||
|
|
{
|
||
|
|
"measures": ["sum(orders.amount)"],
|
||
|
|
"dimensions": ["orders.status"],
|
||
|
|
"filters": ["sum(orders.amount) > 1000 OR count(orders.id) > 5"],
|
||
|
|
},
|
||
|
|
)
|
||
|
|
assert "HAVING" in sql.upper()
|
||
|
|
|
||
|
|
|
||
|
|
# ── Planner: empty source refs (line 62) ─────────────────────────────
|
||
|
|
|
||
|
|
|
||
|
|
class TestEmptySourceRef:
|
||
|
|
def test_no_source_refs_raises(self):
|
||
|
|
"""Query that references no sources should raise."""
|
||
|
|
sources = _simple_sources()
|
||
|
|
planner = _make_planner(sources)
|
||
|
|
with pytest.raises(ValueError, match="does not reference any source"):
|
||
|
|
planner.plan(
|
||
|
|
SemanticQuery(
|
||
|
|
measures=["sum(1)"],
|
||
|
|
dimensions=[],
|
||
|
|
)
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
# ── Planner: predefined measure dependency chains (lines 189-194, 237, 281-282) ──
|
||
|
|
|
||
|
|
|
||
|
|
class TestPredefinedMeasureDeps:
|
||
|
|
def test_derived_measure_resolves_dependencies(self):
|
||
|
|
"""avg_order depends on revenue and order_count — both should appear in plan."""
|
||
|
|
sources = _chain_sources_with_derived()
|
||
|
|
planner = _make_planner(sources)
|
||
|
|
plan = planner.plan(
|
||
|
|
SemanticQuery(
|
||
|
|
measures=["orders.avg_order"],
|
||
|
|
dimensions=["orders.status"],
|
||
|
|
)
|
||
|
|
)
|
||
|
|
measure_names = {m.name for m in plan.measures}
|
||
|
|
assert "avg_order" in measure_names
|
||
|
|
assert "revenue" in measure_names
|
||
|
|
assert "order_count" in measure_names
|
||
|
|
|
||
|
|
def test_derived_measure_generates_valid_sql(self):
|
||
|
|
"""Derived measures should produce valid SQL."""
|
||
|
|
sources = _chain_sources_with_derived()
|
||
|
|
sql = _plan_and_generate(
|
||
|
|
sources,
|
||
|
|
{
|
||
|
|
"measures": ["orders.avg_order"],
|
||
|
|
"dimensions": ["customers.segment"],
|
||
|
|
},
|
||
|
|
)
|
||
|
|
assert "GROUP BY" in sql.upper()
|
||
|
|
|
||
|
|
|
||
|
|
# ── Planner: fan-out with one_to_many to dimension sources (lines 595-643) ──
|
||
|
|
|
||
|
|
|
||
|
|
class TestFanOutEdgeCases:
|
||
|
|
def test_single_source_fan_out_to_dimension(self):
|
||
|
|
"""Measure source with one_to_many to dimension should trigger fan-out."""
|
||
|
|
hub = SourceDefinition(
|
||
|
|
name="hub",
|
||
|
|
table="public.hub",
|
||
|
|
grain=["id"],
|
||
|
|
columns=[
|
||
|
|
SourceColumn(name="id", type="number"),
|
||
|
|
SourceColumn(name="name", type="string"),
|
||
|
|
],
|
||
|
|
joins=[
|
||
|
|
JoinDeclaration(
|
||
|
|
to="detail", on="id = detail.hub_id", relationship="one_to_many"
|
||
|
|
)
|
||
|
|
],
|
||
|
|
)
|
||
|
|
detail = SourceDefinition(
|
||
|
|
name="detail",
|
||
|
|
table="public.detail",
|
||
|
|
grain=["id"],
|
||
|
|
columns=[
|
||
|
|
SourceColumn(name="id", type="number"),
|
||
|
|
SourceColumn(name="hub_id", type="number"),
|
||
|
|
SourceColumn(name="category", type="string"),
|
||
|
|
],
|
||
|
|
)
|
||
|
|
sources = {"hub": hub, "detail": detail}
|
||
|
|
planner = _make_planner(sources)
|
||
|
|
plan = planner.plan(
|
||
|
|
SemanticQuery(
|
||
|
|
measures=["sum(hub.id)"],
|
||
|
|
dimensions=["detail.category"],
|
||
|
|
)
|
||
|
|
)
|
||
|
|
assert plan.has_fan_out
|
||
|
|
|
||
|
|
def test_merged_groups_fan_out_to_dimension(self):
|
||
|
|
"""Two measure sources on the same m2o chain, but with o2m to dimension source."""
|
||
|
|
dim = SourceDefinition(
|
||
|
|
name="dim",
|
||
|
|
table="public.dim",
|
||
|
|
grain=["id"],
|
||
|
|
columns=[
|
||
|
|
SourceColumn(name="id", type="number"),
|
||
|
|
SourceColumn(name="label", type="string"),
|
||
|
|
],
|
||
|
|
)
|
||
|
|
parent = SourceDefinition(
|
||
|
|
name="parent",
|
||
|
|
table="public.parent",
|
||
|
|
grain=["id"],
|
||
|
|
columns=[
|
||
|
|
SourceColumn(name="id", type="number"),
|
||
|
|
SourceColumn(name="val", type="number"),
|
||
|
|
],
|
||
|
|
joins=[
|
||
|
|
JoinDeclaration(to="dim", on="id = dim.id", relationship="one_to_many")
|
||
|
|
],
|
||
|
|
)
|
||
|
|
child = SourceDefinition(
|
||
|
|
name="child",
|
||
|
|
table="public.child",
|
||
|
|
grain=["id"],
|
||
|
|
columns=[
|
||
|
|
SourceColumn(name="id", type="number"),
|
||
|
|
SourceColumn(name="parent_id", type="number"),
|
||
|
|
SourceColumn(name="amount", type="number"),
|
||
|
|
],
|
||
|
|
joins=[
|
||
|
|
JoinDeclaration(
|
||
|
|
to="parent", on="parent_id = parent.id", relationship="many_to_one"
|
||
|
|
)
|
||
|
|
],
|
||
|
|
)
|
||
|
|
sources = {"dim": dim, "parent": parent, "child": child}
|
||
|
|
planner = _make_planner(sources)
|
||
|
|
plan = planner.plan(
|
||
|
|
SemanticQuery(
|
||
|
|
measures=["sum(child.amount)"],
|
||
|
|
dimensions=["dim.label"],
|
||
|
|
)
|
||
|
|
)
|
||
|
|
assert plan.has_fan_out
|
||
|
|
|
||
|
|
def test_filter_fan_out_one_to_many_raises(self):
|
||
|
|
"""Filter on source reachable only via one_to_many from measure source should raise."""
|
||
|
|
parent = SourceDefinition(
|
||
|
|
name="parent",
|
||
|
|
table="public.parent",
|
||
|
|
grain=["id"],
|
||
|
|
columns=[
|
||
|
|
SourceColumn(name="id", type="number"),
|
||
|
|
SourceColumn(name="val", type="number"),
|
||
|
|
],
|
||
|
|
joins=[
|
||
|
|
JoinDeclaration(
|
||
|
|
to="child", on="id = child.parent_id", relationship="one_to_many"
|
||
|
|
)
|
||
|
|
],
|
||
|
|
)
|
||
|
|
child = SourceDefinition(
|
||
|
|
name="child",
|
||
|
|
table="public.child",
|
||
|
|
grain=["id"],
|
||
|
|
columns=[
|
||
|
|
SourceColumn(name="id", type="number"),
|
||
|
|
SourceColumn(name="parent_id", type="number"),
|
||
|
|
SourceColumn(name="category", type="string"),
|
||
|
|
],
|
||
|
|
)
|
||
|
|
sources = {"parent": parent, "child": child}
|
||
|
|
planner = _make_planner(sources)
|
||
|
|
with pytest.raises(ValueError, match="one_to_many join"):
|
||
|
|
planner.plan(
|
||
|
|
SemanticQuery(
|
||
|
|
measures=["sum(parent.val)"],
|
||
|
|
dimensions=[],
|
||
|
|
filters=["child.category = 'A'"],
|
||
|
|
)
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
# ── Generator: NULL dimension in multi-CTE (lines 385-388) ──────────
|
||
|
|
|
||
|
|
|
||
|
|
class TestNullDimensionInCTE:
|
||
|
|
def test_dimension_not_in_any_cte_gets_null(self):
|
||
|
|
"""When a dimension is from a source not reachable by any CTE, generate NULL."""
|
||
|
|
# Use a 3-fact chasm topology where one dimension is only reachable by one fact
|
||
|
|
hub = SourceDefinition(
|
||
|
|
name="hub",
|
||
|
|
table="public.hub",
|
||
|
|
grain=["id"],
|
||
|
|
columns=[
|
||
|
|
SourceColumn(name="id", type="number"),
|
||
|
|
SourceColumn(name="name", type="string"),
|
||
|
|
],
|
||
|
|
)
|
||
|
|
fact_a = SourceDefinition(
|
||
|
|
name="fact_a",
|
||
|
|
table="public.fact_a",
|
||
|
|
grain=["id"],
|
||
|
|
columns=[
|
||
|
|
SourceColumn(name="id", type="number"),
|
||
|
|
SourceColumn(name="hub_id", type="number"),
|
||
|
|
SourceColumn(name="val", type="number"),
|
||
|
|
SourceColumn(name="extra", type="string"),
|
||
|
|
],
|
||
|
|
joins=[
|
||
|
|
JoinDeclaration(
|
||
|
|
to="hub", on="hub_id = hub.id", relationship="many_to_one"
|
||
|
|
)
|
||
|
|
],
|
||
|
|
)
|
||
|
|
fact_b = SourceDefinition(
|
||
|
|
name="fact_b",
|
||
|
|
table="public.fact_b",
|
||
|
|
grain=["id"],
|
||
|
|
columns=[
|
||
|
|
SourceColumn(name="id", type="number"),
|
||
|
|
SourceColumn(name="hub_id", type="number"),
|
||
|
|
SourceColumn(name="val", type="number"),
|
||
|
|
],
|
||
|
|
joins=[
|
||
|
|
JoinDeclaration(
|
||
|
|
to="hub", on="hub_id = hub.id", relationship="many_to_one"
|
||
|
|
)
|
||
|
|
],
|
||
|
|
)
|
||
|
|
sources = {"hub": hub, "fact_a": fact_a, "fact_b": fact_b}
|
||
|
|
sql = _plan_and_generate(
|
||
|
|
sources,
|
||
|
|
{
|
||
|
|
"measures": ["sum(fact_a.val)", "sum(fact_b.val)"],
|
||
|
|
"dimensions": ["hub.name"],
|
||
|
|
},
|
||
|
|
)
|
||
|
|
# Should produce aggregate locality CTEs with FULL JOIN
|
||
|
|
assert "FULL" in sql.upper() or "WITH" in sql.upper()
|
||
|
|
|
||
|
|
|
||
|
|
# ── Generator: CTE alias collision (lines 202-206) ──────────────────
|
||
|
|
|
||
|
|
|
||
|
|
class TestCTEAliasCollision:
|
||
|
|
def test_alias_collision_resolved(self):
|
||
|
|
"""When a source name matches a potential CTE alias, suffix should be used."""
|
||
|
|
# Create a source named "orders_agg" to collide with the CTE alias
|
||
|
|
orders_agg = SourceDefinition(
|
||
|
|
name="orders_agg",
|
||
|
|
table="public.orders_agg",
|
||
|
|
grain=["id"],
|
||
|
|
columns=[
|
||
|
|
SourceColumn(name="id", type="number"),
|
||
|
|
SourceColumn(name="segment", type="string"),
|
||
|
|
],
|
||
|
|
)
|
||
|
|
orders = SourceDefinition(
|
||
|
|
name="orders",
|
||
|
|
table="public.orders",
|
||
|
|
grain=["id"],
|
||
|
|
columns=[
|
||
|
|
SourceColumn(name="id", type="number"),
|
||
|
|
SourceColumn(name="customer_id", type="number"),
|
||
|
|
SourceColumn(name="amount", type="number"),
|
||
|
|
],
|
||
|
|
joins=[
|
||
|
|
JoinDeclaration(
|
||
|
|
to="orders_agg",
|
||
|
|
on="customer_id = orders_agg.id",
|
||
|
|
relationship="many_to_one",
|
||
|
|
)
|
||
|
|
],
|
||
|
|
)
|
||
|
|
tickets = SourceDefinition(
|
||
|
|
name="tickets",
|
||
|
|
table="public.tickets",
|
||
|
|
grain=["id"],
|
||
|
|
columns=[
|
||
|
|
SourceColumn(name="id", type="number"),
|
||
|
|
SourceColumn(name="customer_id", type="number"),
|
||
|
|
SourceColumn(name="priority", type="string"),
|
||
|
|
],
|
||
|
|
joins=[
|
||
|
|
JoinDeclaration(
|
||
|
|
to="orders_agg",
|
||
|
|
on="customer_id = orders_agg.id",
|
||
|
|
relationship="many_to_one",
|
||
|
|
)
|
||
|
|
],
|
||
|
|
)
|
||
|
|
sources = {"orders_agg": orders_agg, "orders": orders, "tickets": tickets}
|
||
|
|
sql = _plan_and_generate(
|
||
|
|
sources,
|
||
|
|
{
|
||
|
|
"measures": ["sum(orders.amount)", "count(tickets.id)"],
|
||
|
|
"dimensions": ["orders_agg.segment"],
|
||
|
|
},
|
||
|
|
)
|
||
|
|
# Should still produce valid SQL even with the collision
|
||
|
|
assert_valid_sql(sql)
|
||
|
|
|
||
|
|
|
||
|
|
# ── Models: negative limit (line 95) ────────────────────────────────
|
||
|
|
|
||
|
|
|
||
|
|
class TestNegativeLimit:
|
||
|
|
def test_negative_limit_raises(self):
|
||
|
|
with pytest.raises(ValidationError, match="limit"):
|
||
|
|
SemanticQuery(
|
||
|
|
measures=["sum(orders.amount)"],
|
||
|
|
limit=-1,
|
||
|
|
)
|
||
|
|
|
||
|
|
def test_zero_limit_allowed(self):
|
||
|
|
q = SemanticQuery(measures=["sum(orders.amount)"], limit=0)
|
||
|
|
assert q.limit == 0
|
||
|
|
|
||
|
|
|
||
|
|
# ── Engine: suggest with missing sources (lines 100-106, 127) ────────
|
||
|
|
|
||
|
|
|
||
|
|
class TestEngineSuggest:
|
||
|
|
def test_suggest_with_missing_source(self):
|
||
|
|
"""Suggest should return suggestions for missing sources."""
|
||
|
|
engine = make_engine(
|
||
|
|
{
|
||
|
|
"orders": {
|
||
|
|
"name": "orders",
|
||
|
|
"table": "public.orders",
|
||
|
|
"grain": ["id"],
|
||
|
|
"columns": [
|
||
|
|
{"name": "id", "type": "number"},
|
||
|
|
{"name": "amount", "type": "number"},
|
||
|
|
],
|
||
|
|
},
|
||
|
|
}
|
||
|
|
)
|
||
|
|
result = engine.suggest(
|
||
|
|
{
|
||
|
|
"measures": ["sum(unknown_source.val)"],
|
||
|
|
"dimensions": ["orders.id"],
|
||
|
|
}
|
||
|
|
)
|
||
|
|
assert not result["success"]
|
||
|
|
assert any(
|
||
|
|
"missing" in s["description"].lower()
|
||
|
|
or "unknown_source" in s["description"]
|
||
|
|
for s in result.get("suggestions", [])
|
||
|
|
)
|
||
|
|
|
||
|
|
def test_suggest_with_dict_measure_and_dimension(self):
|
||
|
|
"""Suggest handles dict-format measures and dimensions in failure path."""
|
||
|
|
engine = make_engine(
|
||
|
|
{
|
||
|
|
"orders": {
|
||
|
|
"name": "orders",
|
||
|
|
"table": "public.orders",
|
||
|
|
"grain": ["id"],
|
||
|
|
"columns": [
|
||
|
|
{"name": "id", "type": "number"},
|
||
|
|
{"name": "amount", "type": "number"},
|
||
|
|
],
|
||
|
|
},
|
||
|
|
}
|
||
|
|
)
|
||
|
|
# Use a nested aggregate to trigger a planner error that hits the dict-handling code
|
||
|
|
result = engine.suggest(
|
||
|
|
{
|
||
|
|
"measures": [{"expr": "avg(sum(missing.val))", "name": "total"}],
|
||
|
|
"dimensions": [{"field": "missing.category"}],
|
||
|
|
}
|
||
|
|
)
|
||
|
|
assert not result["success"]
|
||
|
|
|
||
|
|
|
||
|
|
# ── Planner: order_by resolution formats (lines 113-116) ────────────
|
||
|
|
|
||
|
|
|
||
|
|
class TestOrderByResolution:
|
||
|
|
def test_order_by_as_dict(self):
|
||
|
|
sources = _simple_sources()
|
||
|
|
planner = _make_planner(sources)
|
||
|
|
plan = planner.plan(
|
||
|
|
SemanticQuery(
|
||
|
|
measures=["sum(orders.amount)"],
|
||
|
|
dimensions=["orders.status"],
|
||
|
|
order_by=[{"field": "orders.status", "direction": "desc"}],
|
||
|
|
)
|
||
|
|
)
|
||
|
|
assert len(plan.order_by) == 1
|
||
|
|
assert plan.order_by[0].direction == "desc"
|
||
|
|
|
||
|
|
def test_order_by_as_string(self):
|
||
|
|
sources = _simple_sources()
|
||
|
|
planner = _make_planner(sources)
|
||
|
|
plan = planner.plan(
|
||
|
|
SemanticQuery(
|
||
|
|
measures=["sum(orders.amount)"],
|
||
|
|
dimensions=["orders.status"],
|
||
|
|
order_by=["orders.status"],
|
||
|
|
)
|
||
|
|
)
|
||
|
|
assert len(plan.order_by) == 1
|
||
|
|
|
||
|
|
|
||
|
|
# ── Planner: measure with no source refs (line 343) ─────────────────
|
||
|
|
|
||
|
|
|
||
|
|
class TestMeasureNoSourceRef:
|
||
|
|
def test_bare_column_no_aggregate_raises(self):
|
||
|
|
"""A measure like 'orders.nonexistent' that isn't predefined should raise."""
|
||
|
|
sources = _simple_sources()
|
||
|
|
planner = _make_planner(sources)
|
||
|
|
with pytest.raises(
|
||
|
|
ValueError, match="does not reference any source|not a pre-defined measure"
|
||
|
|
):
|
||
|
|
planner.plan(
|
||
|
|
SemanticQuery(
|
||
|
|
measures=["sum(1)"],
|
||
|
|
dimensions=["orders.status"],
|
||
|
|
)
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
# ── Generator: custom aggregate parsing (lines 614-617) ─────────────
|
||
|
|
|
||
|
|
|
||
|
|
class TestCustomAggregates:
|
||
|
|
def test_count_distinct_generates_valid_sql(self):
|
||
|
|
sources = _simple_sources()
|
||
|
|
sql = _plan_and_generate(
|
||
|
|
sources,
|
||
|
|
{
|
||
|
|
"measures": ["count(distinct orders.id)"],
|
||
|
|
"dimensions": ["orders.status"],
|
||
|
|
},
|
||
|
|
)
|
||
|
|
upper = sql.upper()
|
||
|
|
assert "COUNT(DISTINCT" in upper or "COUNT (DISTINCT" in upper
|
||
|
|
|
||
|
|
|
||
|
|
# ── Generator: qualified predefined expressions via multi-hop joins (lines 925-931) ──
|
||
|
|
|
||
|
|
|
||
|
|
class TestQualifiedPredefinedExpr:
|
||
|
|
def test_predefined_filter_with_joined_column(self):
|
||
|
|
"""Predefined measure with a filter referencing a column from a joined table."""
|
||
|
|
sources = _chain_sources_with_derived()
|
||
|
|
sql = _plan_and_generate(
|
||
|
|
sources,
|
||
|
|
{
|
||
|
|
"measures": ["orders.revenue"],
|
||
|
|
"dimensions": ["tiers.level"],
|
||
|
|
},
|
||
|
|
)
|
||
|
|
assert_valid_sql(sql)
|
||
|
|
assert "CASE WHEN" in sql.upper()
|
||
|
|
|
||
|
|
|
||
|
|
# ── End-to-end: chasm trap with aggregate locality ───────────────────
|
||
|
|
|
||
|
|
|
||
|
|
class TestChasmTrapEndToEnd:
|
||
|
|
def test_two_fact_tables_produce_valid_sql(self):
|
||
|
|
sources = _chasm_sources()
|
||
|
|
sql = _plan_and_generate(
|
||
|
|
sources,
|
||
|
|
{
|
||
|
|
"measures": ["sum(orders.amount)", "count(tickets.id)"],
|
||
|
|
"dimensions": ["customers.segment"],
|
||
|
|
},
|
||
|
|
)
|
||
|
|
upper = sql.upper()
|
||
|
|
assert "WITH" in upper
|
||
|
|
assert "FULL" in upper or "JOIN" in upper
|
||
|
|
|
||
|
|
def test_chasm_with_filter_on_hub(self):
|
||
|
|
sources = _chasm_sources()
|
||
|
|
sql = _plan_and_generate(
|
||
|
|
sources,
|
||
|
|
{
|
||
|
|
"measures": ["sum(orders.amount)", "count(tickets.id)"],
|
||
|
|
"dimensions": ["customers.segment"],
|
||
|
|
"filters": ["customers.segment = 'enterprise'"],
|
||
|
|
},
|
||
|
|
)
|
||
|
|
assert "enterprise" in sql
|
||
|
|
assert_valid_sql(sql)
|