mirror of
https://github.com/Kaelio/ktx.git
synced 2026-06-07 07:55:13 +02:00
361 lines
12 KiB
Python
361 lines
12 KiB
Python
|
|
"""TPC-H schema tests: loading, graph, planning, and SQL execution against DuckDB."""
|
||
|
|
|
||
|
|
from __future__ import annotations
|
||
|
|
|
||
|
|
from pathlib import Path
|
||
|
|
|
||
|
|
import pytest
|
||
|
|
|
||
|
|
from semantic_layer.engine import SemanticEngine
|
||
|
|
from semantic_layer.graph import JoinGraph
|
||
|
|
from semantic_layer.loader import SourceLoader
|
||
|
|
from semantic_layer.models import SourceDefinition
|
||
|
|
|
||
|
|
TPCH_DIR = Path(__file__).parent.parent / "sources" / "tpch"
|
||
|
|
TPCH_TABLES = [
|
||
|
|
"region",
|
||
|
|
"nation",
|
||
|
|
"supplier",
|
||
|
|
"customer",
|
||
|
|
"part",
|
||
|
|
"partsupp",
|
||
|
|
"orders",
|
||
|
|
"lineitem",
|
||
|
|
]
|
||
|
|
|
||
|
|
try:
|
||
|
|
import duckdb
|
||
|
|
|
||
|
|
HAS_DUCKDB = True
|
||
|
|
except ImportError:
|
||
|
|
HAS_DUCKDB = False
|
||
|
|
|
||
|
|
|
||
|
|
# ── Fixtures ─────────────────────────────────────────────────────────
|
||
|
|
|
||
|
|
|
||
|
|
@pytest.fixture(scope="module")
|
||
|
|
def sources() -> dict[str, SourceDefinition]:
|
||
|
|
return SourceLoader(TPCH_DIR).load_all()
|
||
|
|
|
||
|
|
|
||
|
|
@pytest.fixture(scope="module")
|
||
|
|
def graph(sources: dict[str, SourceDefinition]) -> JoinGraph:
|
||
|
|
g = JoinGraph(sources)
|
||
|
|
g.build()
|
||
|
|
return g
|
||
|
|
|
||
|
|
|
||
|
|
@pytest.fixture(scope="module")
|
||
|
|
def engine() -> SemanticEngine:
|
||
|
|
return SemanticEngine(str(TPCH_DIR), dialect="duckdb")
|
||
|
|
|
||
|
|
|
||
|
|
@pytest.fixture(scope="module")
|
||
|
|
def tpch_conn():
|
||
|
|
if not HAS_DUCKDB:
|
||
|
|
pytest.skip("duckdb not installed")
|
||
|
|
conn = duckdb.connect()
|
||
|
|
conn.execute("INSTALL tpch; LOAD tpch")
|
||
|
|
conn.execute("CALL dbgen(sf=0.01)")
|
||
|
|
conn.execute("CREATE SCHEMA IF NOT EXISTS public")
|
||
|
|
for t in TPCH_TABLES:
|
||
|
|
conn.execute(f"CREATE VIEW public.{t} AS SELECT * FROM main.{t}")
|
||
|
|
return conn
|
||
|
|
|
||
|
|
|
||
|
|
# ── Loader Tests ─────────────────────────────────────────────────────
|
||
|
|
|
||
|
|
|
||
|
|
class TestTpchLoader:
|
||
|
|
def test_all_sources_loaded(self, sources):
|
||
|
|
assert set(sources.keys()) == set(TPCH_TABLES)
|
||
|
|
|
||
|
|
def test_lineitem_columns(self, sources):
|
||
|
|
li = sources["lineitem"]
|
||
|
|
col_names = {c.name for c in li.columns}
|
||
|
|
assert "l_orderkey" in col_names
|
||
|
|
assert "l_extendedprice" in col_names
|
||
|
|
assert "l_shipdate" in col_names
|
||
|
|
assert len(li.columns) == 16
|
||
|
|
|
||
|
|
def test_lineitem_composite_grain(self, sources):
|
||
|
|
assert sources["lineitem"].grain == ["l_orderkey", "l_linenumber"]
|
||
|
|
|
||
|
|
def test_partsupp_composite_grain(self, sources):
|
||
|
|
assert sources["partsupp"].grain == ["ps_partkey", "ps_suppkey"]
|
||
|
|
|
||
|
|
def test_lineitem_measures(self, sources):
|
||
|
|
measure_names = {m.name for m in sources["lineitem"].measures}
|
||
|
|
assert "revenue" in measure_names
|
||
|
|
assert "returned_revenue" in measure_names
|
||
|
|
assert "charge" in measure_names
|
||
|
|
assert len(sources["lineitem"].measures) == 8
|
||
|
|
|
||
|
|
def test_returned_revenue_has_filter(self, sources):
|
||
|
|
m = next(
|
||
|
|
m for m in sources["lineitem"].measures if m.name == "returned_revenue"
|
||
|
|
)
|
||
|
|
assert m.filter == "l_returnflag = 'R'"
|
||
|
|
|
||
|
|
def test_lineitem_joins(self, sources):
|
||
|
|
join_targets = {j.to for j in sources["lineitem"].joins}
|
||
|
|
assert join_targets == {"orders", "part", "supplier"}
|
||
|
|
|
||
|
|
def test_region_is_leaf(self, sources):
|
||
|
|
assert sources["region"].joins == []
|
||
|
|
assert sources["region"].measures == []
|
||
|
|
|
||
|
|
def test_orders_measures(self, sources):
|
||
|
|
measure_names = {m.name for m in sources["orders"].measures}
|
||
|
|
assert measure_names == {"order_count", "total_price", "avg_order_value"}
|
||
|
|
|
||
|
|
|
||
|
|
# ── Graph Tests ──────────────────────────────────────────────────────
|
||
|
|
|
||
|
|
|
||
|
|
class TestTpchGraph:
|
||
|
|
def test_all_sources_in_graph(self, graph):
|
||
|
|
assert set(graph.adjacency.keys()) >= set(TPCH_TABLES)
|
||
|
|
|
||
|
|
def test_lineitem_to_region_path(self, graph):
|
||
|
|
"""Shortest path: lineitem → supplier → nation → region (3 hops)."""
|
||
|
|
path = graph.find_path("lineitem", "region")
|
||
|
|
assert path is not None
|
||
|
|
source_chain = [path.edges[0].from_source] + [e.to_source for e in path.edges]
|
||
|
|
assert "lineitem" in source_chain
|
||
|
|
assert "region" in source_chain
|
||
|
|
assert len(path.edges) == 3
|
||
|
|
|
||
|
|
def test_lineitem_to_part_direct(self, graph):
|
||
|
|
path = graph.find_path("lineitem", "part")
|
||
|
|
assert path is not None
|
||
|
|
assert len(path.edges) == 1
|
||
|
|
|
||
|
|
def test_part_to_supplier_via_lineitem(self, graph):
|
||
|
|
"""Shortest path: part → lineitem → supplier (2 hops, shorter than via partsupp)."""
|
||
|
|
path = graph.find_path("part", "supplier")
|
||
|
|
assert path is not None
|
||
|
|
assert len(path.edges) == 2
|
||
|
|
|
||
|
|
def test_partsupp_bridges_part_and_supplier(self, graph):
|
||
|
|
"""partsupp has direct edges to both part and supplier."""
|
||
|
|
path_to_part = graph.find_path("partsupp", "part")
|
||
|
|
path_to_supplier = graph.find_path("partsupp", "supplier")
|
||
|
|
assert path_to_part is not None and len(path_to_part.edges) == 1
|
||
|
|
assert path_to_supplier is not None and len(path_to_supplier.edges) == 1
|
||
|
|
|
||
|
|
def test_graph_is_single_component(self, graph):
|
||
|
|
components = graph.find_components()
|
||
|
|
assert len(components) == 1
|
||
|
|
|
||
|
|
|
||
|
|
# ── Plan-only Tests (no DuckDB needed) ───────────────────────────────
|
||
|
|
|
||
|
|
|
||
|
|
class TestTpchPlanning:
|
||
|
|
def test_q1_plan(self, engine):
|
||
|
|
plan = engine.plan_only(
|
||
|
|
{
|
||
|
|
"measures": ["lineitem.revenue"],
|
||
|
|
"dimensions": ["lineitem.l_returnflag", "lineitem.l_linestatus"],
|
||
|
|
}
|
||
|
|
)
|
||
|
|
assert plan.anchor_source == "lineitem"
|
||
|
|
assert len(plan.sources_used) == 1
|
||
|
|
|
||
|
|
def test_q5_plan_multi_hop(self, engine):
|
||
|
|
plan = engine.plan_only(
|
||
|
|
{
|
||
|
|
"measures": ["lineitem.revenue"],
|
||
|
|
"dimensions": ["nation.n_name"],
|
||
|
|
"filters": ["region.r_name = 'ASIA'"],
|
||
|
|
}
|
||
|
|
)
|
||
|
|
assert "lineitem" in plan.sources_used
|
||
|
|
assert "nation" in plan.sources_used
|
||
|
|
assert "region" in plan.sources_used
|
||
|
|
|
||
|
|
def test_filtered_measure_plan(self, engine):
|
||
|
|
plan = engine.plan_only(
|
||
|
|
{
|
||
|
|
"measures": ["lineitem.returned_revenue"],
|
||
|
|
"dimensions": ["customer.c_name"],
|
||
|
|
}
|
||
|
|
)
|
||
|
|
assert any(m.filter for m in plan.measures)
|
||
|
|
|
||
|
|
def test_time_granularity_plan(self, engine):
|
||
|
|
plan = engine.plan_only(
|
||
|
|
{
|
||
|
|
"measures": ["lineitem.revenue"],
|
||
|
|
"dimensions": [{"field": "orders.o_orderdate", "granularity": "month"}],
|
||
|
|
}
|
||
|
|
)
|
||
|
|
col_names = [c.name for c in plan.columns]
|
||
|
|
# Column may be named "o_orderdate" with granularity metadata
|
||
|
|
assert "o_orderdate" in col_names
|
||
|
|
dim_col = next(c for c in plan.columns if c.name == "o_orderdate")
|
||
|
|
assert dim_col.granularity == "month"
|
||
|
|
|
||
|
|
def test_suggest_valid_query(self, engine):
|
||
|
|
result = engine.suggest(
|
||
|
|
{
|
||
|
|
"measures": ["lineitem.revenue"],
|
||
|
|
"dimensions": ["lineitem.l_returnflag"],
|
||
|
|
}
|
||
|
|
)
|
||
|
|
assert result["success"] is True
|
||
|
|
|
||
|
|
def test_suggest_missing_source(self, engine):
|
||
|
|
result = engine.suggest(
|
||
|
|
{
|
||
|
|
"measures": ["sum(lineitem.l_quantity)"],
|
||
|
|
"dimensions": ["nonexistent.col"],
|
||
|
|
}
|
||
|
|
)
|
||
|
|
assert result["success"] is False
|
||
|
|
|
||
|
|
|
||
|
|
# ── Execution Tests (require DuckDB) ────────────────────────────────
|
||
|
|
|
||
|
|
|
||
|
|
@pytest.mark.skipif(not HAS_DUCKDB, reason="duckdb not installed")
|
||
|
|
class TestTpchExecution:
|
||
|
|
def test_q1_pricing_summary(self, tpch_conn, engine):
|
||
|
|
result = engine.query(
|
||
|
|
{
|
||
|
|
"measures": [
|
||
|
|
"lineitem.revenue",
|
||
|
|
"lineitem.total_quantity",
|
||
|
|
"lineitem.avg_discount",
|
||
|
|
"lineitem.line_count",
|
||
|
|
],
|
||
|
|
"dimensions": ["lineitem.l_returnflag", "lineitem.l_linestatus"],
|
||
|
|
}
|
||
|
|
)
|
||
|
|
rows = tpch_conn.execute(result.sql).fetchall()
|
||
|
|
assert len(rows) > 0
|
||
|
|
# TPC-H has exactly 4 combinations: A/F, N/F, N/O, R/F
|
||
|
|
assert len(rows) <= 4
|
||
|
|
|
||
|
|
def test_q5_revenue_by_nation_asia(self, tpch_conn, engine):
|
||
|
|
"""4-hop join with filter: lineitem→supplier→nation→region."""
|
||
|
|
result = engine.query(
|
||
|
|
{
|
||
|
|
"measures": ["lineitem.revenue"],
|
||
|
|
"dimensions": ["nation.n_name"],
|
||
|
|
"filters": ["region.r_name = 'ASIA'"],
|
||
|
|
}
|
||
|
|
)
|
||
|
|
rows = tpch_conn.execute(result.sql).fetchall()
|
||
|
|
assert len(rows) > 0
|
||
|
|
# ASIA has 5 nations
|
||
|
|
assert len(rows) <= 5
|
||
|
|
|
||
|
|
def test_q3_revenue_by_month(self, tpch_conn, engine):
|
||
|
|
"""DATE_TRUNC + multi-table filter."""
|
||
|
|
result = engine.query(
|
||
|
|
{
|
||
|
|
"measures": ["lineitem.revenue"],
|
||
|
|
"dimensions": [{"field": "orders.o_orderdate", "granularity": "month"}],
|
||
|
|
"filters": ["customer.c_mktsegment = 'BUILDING'"],
|
||
|
|
"limit": 12,
|
||
|
|
}
|
||
|
|
)
|
||
|
|
rows = tpch_conn.execute(result.sql).fetchall()
|
||
|
|
assert len(rows) > 0
|
||
|
|
assert len(rows) <= 12
|
||
|
|
|
||
|
|
def test_q10_returned_revenue(self, tpch_conn, engine):
|
||
|
|
"""Filtered measure with CASE WHEN."""
|
||
|
|
result = engine.query(
|
||
|
|
{
|
||
|
|
"measures": ["lineitem.returned_revenue"],
|
||
|
|
"dimensions": ["customer.c_name"],
|
||
|
|
"limit": 10,
|
||
|
|
}
|
||
|
|
)
|
||
|
|
rows = tpch_conn.execute(result.sql).fetchall()
|
||
|
|
assert len(rows) > 0
|
||
|
|
assert len(rows) <= 10
|
||
|
|
|
||
|
|
def test_order_count(self, tpch_conn, engine):
|
||
|
|
result = engine.query(
|
||
|
|
{
|
||
|
|
"measures": ["orders.order_count"],
|
||
|
|
"dimensions": ["orders.o_orderstatus"],
|
||
|
|
}
|
||
|
|
)
|
||
|
|
rows = tpch_conn.execute(result.sql).fetchall()
|
||
|
|
assert len(rows) > 0
|
||
|
|
# Sum of counts should equal total orders at SF=0.01
|
||
|
|
total = sum(r[1] for r in rows)
|
||
|
|
assert total == 15000 # SF=0.01 → 15000 orders
|
||
|
|
|
||
|
|
def test_supply_cost_by_nation(self, tpch_conn, engine):
|
||
|
|
"""Bridge table path: partsupp → supplier → nation."""
|
||
|
|
result = engine.query(
|
||
|
|
{
|
||
|
|
"measures": ["partsupp.total_supply_cost"],
|
||
|
|
"dimensions": ["nation.n_name"],
|
||
|
|
}
|
||
|
|
)
|
||
|
|
rows = tpch_conn.execute(result.sql).fetchall()
|
||
|
|
assert len(rows) == 25 # 25 nations
|
||
|
|
|
||
|
|
def test_avg_order_value(self, tpch_conn, engine):
|
||
|
|
result = engine.query(
|
||
|
|
{
|
||
|
|
"measures": ["orders.avg_order_value"],
|
||
|
|
"dimensions": ["customer.c_mktsegment"],
|
||
|
|
}
|
||
|
|
)
|
||
|
|
rows = tpch_conn.execute(result.sql).fetchall()
|
||
|
|
assert len(rows) == 5 # 5 market segments
|
||
|
|
# avg values should be positive
|
||
|
|
for row in rows:
|
||
|
|
assert row[1] > 0
|
||
|
|
|
||
|
|
def test_lineitem_charge(self, tpch_conn, engine):
|
||
|
|
"""Complex expression: sum(price * (1 - discount) * (1 + tax))."""
|
||
|
|
result = engine.query(
|
||
|
|
{
|
||
|
|
"measures": ["lineitem.charge"],
|
||
|
|
"dimensions": ["lineitem.l_returnflag"],
|
||
|
|
}
|
||
|
|
)
|
||
|
|
rows = tpch_conn.execute(result.sql).fetchall()
|
||
|
|
assert len(rows) > 0
|
||
|
|
for row in rows:
|
||
|
|
assert row[1] > 0
|
||
|
|
|
||
|
|
def test_order_by_desc(self, tpch_conn, engine):
|
||
|
|
result = engine.query(
|
||
|
|
{
|
||
|
|
"measures": ["lineitem.revenue"],
|
||
|
|
"dimensions": ["nation.n_name"],
|
||
|
|
"order_by": [{"field": "lineitem.revenue", "direction": "desc"}],
|
||
|
|
"limit": 5,
|
||
|
|
}
|
||
|
|
)
|
||
|
|
rows = tpch_conn.execute(result.sql).fetchall()
|
||
|
|
assert len(rows) == 5
|
||
|
|
# Revenue should be descending
|
||
|
|
revenues = [r[1] for r in rows]
|
||
|
|
assert revenues == sorted(revenues, reverse=True)
|
||
|
|
|
||
|
|
def test_multiple_filters(self, tpch_conn, engine):
|
||
|
|
result = engine.query(
|
||
|
|
{
|
||
|
|
"measures": ["lineitem.revenue"],
|
||
|
|
"dimensions": ["orders.o_orderpriority"],
|
||
|
|
"filters": [
|
||
|
|
"customer.c_mktsegment = 'BUILDING'",
|
||
|
|
"nation.n_name = 'FRANCE'",
|
||
|
|
],
|
||
|
|
}
|
||
|
|
)
|
||
|
|
rows = tpch_conn.execute(result.sql).fetchall()
|
||
|
|
assert len(rows) > 0
|