mirror of
https://github.com/Kaelio/ktx.git
synced 2026-06-13 08:15:14 +02:00
Initial open-source release
This commit is contained in:
commit
1a42152e6f
1199 changed files with 257054 additions and 0 deletions
0
python/klo-sl/tests/__init__.py
Normal file
0
python/klo-sl/tests/__init__.py
Normal file
90
python/klo-sl/tests/conftest.py
Normal file
90
python/klo-sl/tests/conftest.py
Normal file
|
|
@ -0,0 +1,90 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import sqlglot
|
||||
import yaml
|
||||
|
||||
from semantic_layer.engine import SemanticEngine
|
||||
from semantic_layer.loader import SourceLoader
|
||||
from semantic_layer.models import SourceDefinition
|
||||
|
||||
SOURCES_DIR = Path(__file__).parent.parent / "sources" / "ecommerce"
|
||||
TPCH_DIR = Path(__file__).parent.parent / "sources" / "tpch"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def ecommerce_sources() -> dict[str, SourceDefinition]:
|
||||
loader = SourceLoader(SOURCES_DIR)
|
||||
return loader.load_all()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tpch_sources() -> dict[str, SourceDefinition]:
|
||||
loader = SourceLoader(TPCH_DIR)
|
||||
return loader.load_all()
|
||||
|
||||
|
||||
# ── Shared test helpers ──────────────────────────────────────────────
|
||||
|
||||
|
||||
def make_engine(
|
||||
sources_dict: dict[str, dict], dialect: str = "postgres"
|
||||
) -> SemanticEngine:
|
||||
"""Build a SemanticEngine from inline source dicts (writes temp YAML files)."""
|
||||
tmpdir = tempfile.mkdtemp()
|
||||
for name, data in sources_dict.items():
|
||||
with open(Path(tmpdir) / f"{name}.yaml", "w") as f:
|
||||
yaml.dump(data, f)
|
||||
return SemanticEngine(tmpdir, dialect=dialect)
|
||||
|
||||
|
||||
def assert_valid_sql(sql: str):
|
||||
try:
|
||||
sqlglot.parse(sql)
|
||||
except Exception as e:
|
||||
pytest.fail(f"Generated SQL is not valid: {e}\n\nSQL:\n{sql}")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def make_bq_fct_orders_engine() -> SemanticEngine:
|
||||
"""BigQuery-dialect engine with fct_orders source mirroring the production YAML."""
|
||||
source = {
|
||||
"name": "fct_orders",
|
||||
"table": "analytics.fct_orders",
|
||||
"grain": ["order_id"],
|
||||
"columns": [
|
||||
{"name": "order_id", "type": "number"},
|
||||
{"name": "status", "type": "string"},
|
||||
{"name": "transaction_date", "type": "time"},
|
||||
],
|
||||
"segments": [
|
||||
{"name": "non_cancelled", "expr": "status != 'cancelled'"},
|
||||
{
|
||||
"name": "last_30_days",
|
||||
"expr": "transaction_date >= timestamp(date_sub(current_date(), interval 30 day))",
|
||||
},
|
||||
],
|
||||
"measures": [
|
||||
{
|
||||
"name": "daily_active_orders",
|
||||
"expr": "count(distinct order_id)",
|
||||
"segments": ["non_cancelled", "last_30_days"],
|
||||
},
|
||||
],
|
||||
}
|
||||
return make_engine({"fct_orders": source}, dialect="bigquery")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def make_engine_factory():
|
||||
"""Factory fixture: pass a sources-dict + dialect, get a SemanticEngine."""
|
||||
|
||||
def _make(
|
||||
sources_dict: dict[str, dict], dialect: str = "postgres"
|
||||
) -> SemanticEngine:
|
||||
return make_engine(sources_dict, dialect=dialect)
|
||||
|
||||
return _make
|
||||
1735
python/klo-sl/tests/test_aggregate_locality.py
Normal file
1735
python/klo-sl/tests/test_aggregate_locality.py
Normal file
File diff suppressed because it is too large
Load diff
447
python/klo-sl/tests/test_cli.py
Normal file
447
python/klo-sl/tests/test_cli.py
Normal file
|
|
@ -0,0 +1,447 @@
|
|||
"""Tests for the CLI interface (semantic_layer.cli)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from io import StringIO
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from semantic_layer.cli import main, print_plan
|
||||
from semantic_layer.graph import JoinGraph
|
||||
from semantic_layer.models import (
|
||||
JoinDeclaration,
|
||||
SemanticQuery,
|
||||
SourceColumn,
|
||||
SourceDefinition,
|
||||
)
|
||||
from semantic_layer.planner import QueryPlanner
|
||||
|
||||
SOURCES_DIR = str(Path(__file__).parent.parent / "sources" / "ecommerce")
|
||||
|
||||
|
||||
# ── From test_edge_cases.py: TestCliParserArgs ───────────────────────
|
||||
|
||||
|
||||
class TestCliParserArgs:
|
||||
def test_no_args_errors(self):
|
||||
with pytest.raises(SystemExit):
|
||||
main([])
|
||||
|
||||
def test_sources_only_no_query(self, capsys):
|
||||
with pytest.raises(SystemExit):
|
||||
main(["--sources", SOURCES_DIR])
|
||||
|
||||
def test_list_sources_no_measures_needed(self, capsys):
|
||||
main(["--sources", SOURCES_DIR, "--list-sources"])
|
||||
output = capsys.readouterr().out
|
||||
assert "orders" in output
|
||||
assert "customers" in output
|
||||
|
||||
def test_plan_only_mode(self, capsys):
|
||||
main(
|
||||
[
|
||||
"--sources",
|
||||
SOURCES_DIR,
|
||||
"-q",
|
||||
json.dumps(
|
||||
{
|
||||
"measures": ["sum(orders.amount)"],
|
||||
"dimensions": ["orders.status"],
|
||||
}
|
||||
),
|
||||
"--plan-only",
|
||||
]
|
||||
)
|
||||
output = capsys.readouterr().out
|
||||
assert "Resolved Plan" in output
|
||||
assert "Anchor" in output
|
||||
|
||||
def test_plan_and_sql(self, capsys):
|
||||
main(
|
||||
[
|
||||
"--sources",
|
||||
SOURCES_DIR,
|
||||
"-q",
|
||||
json.dumps(
|
||||
{
|
||||
"measures": ["sum(orders.amount)"],
|
||||
"dimensions": ["orders.status"],
|
||||
}
|
||||
),
|
||||
"--plan",
|
||||
]
|
||||
)
|
||||
output = capsys.readouterr().out
|
||||
assert "Resolved Plan" in output
|
||||
assert "SELECT" in output
|
||||
|
||||
def test_compact_mode(self, capsys):
|
||||
main(
|
||||
[
|
||||
"--sources",
|
||||
SOURCES_DIR,
|
||||
"-q",
|
||||
json.dumps(
|
||||
{
|
||||
"measures": ["sum(orders.amount)"],
|
||||
"dimensions": ["orders.status"],
|
||||
}
|
||||
),
|
||||
"--compact",
|
||||
]
|
||||
)
|
||||
output = capsys.readouterr().out
|
||||
assert "SELECT" in output
|
||||
assert "-- dialect:" not in output
|
||||
|
||||
def test_json_input(self, capsys):
|
||||
query_json = json.dumps(
|
||||
{
|
||||
"measures": ["sum(orders.amount)"],
|
||||
"dimensions": ["orders.status"],
|
||||
}
|
||||
)
|
||||
with patch("sys.stdin", StringIO(query_json)):
|
||||
main(["--sources", SOURCES_DIR, "--json"])
|
||||
output = capsys.readouterr().out
|
||||
assert "SELECT" in output
|
||||
|
||||
def test_json_input_with_filters(self, capsys):
|
||||
query_json = json.dumps(
|
||||
{
|
||||
"measures": ["sum(orders.amount)"],
|
||||
"dimensions": ["orders.status"],
|
||||
"filters": ["orders.status = 'completed'"],
|
||||
}
|
||||
)
|
||||
with patch("sys.stdin", StringIO(query_json)):
|
||||
main(["--sources", SOURCES_DIR, "--json"])
|
||||
output = capsys.readouterr().out
|
||||
assert "completed" in output
|
||||
|
||||
def test_json_input_with_order_by(self, capsys):
|
||||
query_json = json.dumps(
|
||||
{
|
||||
"measures": ["sum(orders.amount)"],
|
||||
"dimensions": ["orders.status"],
|
||||
"order_by": [{"field": "orders.status", "direction": "desc"}],
|
||||
}
|
||||
)
|
||||
with patch("sys.stdin", StringIO(query_json)):
|
||||
main(["--sources", SOURCES_DIR, "--json"])
|
||||
output = capsys.readouterr().out
|
||||
assert "SELECT" in output
|
||||
|
||||
def test_measures_with_alias(self, capsys):
|
||||
main(
|
||||
[
|
||||
"--sources",
|
||||
SOURCES_DIR,
|
||||
"-q",
|
||||
json.dumps(
|
||||
{
|
||||
"measures": [
|
||||
{"expr": "sum(orders.amount)", "name": "total_rev"}
|
||||
],
|
||||
"dimensions": ["orders.status"],
|
||||
}
|
||||
),
|
||||
]
|
||||
)
|
||||
output = capsys.readouterr().out
|
||||
assert "total_rev" in output
|
||||
|
||||
def test_dimension_with_granularity_cli(self, capsys):
|
||||
main(
|
||||
[
|
||||
"--sources",
|
||||
SOURCES_DIR,
|
||||
"-q",
|
||||
json.dumps(
|
||||
{
|
||||
"measures": ["sum(orders.amount)"],
|
||||
"dimensions": [
|
||||
{"field": "orders.created_at", "granularity": "month"}
|
||||
],
|
||||
}
|
||||
),
|
||||
]
|
||||
)
|
||||
output = capsys.readouterr().out
|
||||
assert "DATE_TRUNC" in output
|
||||
|
||||
def test_multiple_filters_cli(self, capsys):
|
||||
main(
|
||||
[
|
||||
"--sources",
|
||||
SOURCES_DIR,
|
||||
"-q",
|
||||
json.dumps(
|
||||
{
|
||||
"measures": ["sum(orders.amount)"],
|
||||
"dimensions": ["orders.status"],
|
||||
"filters": [
|
||||
"orders.status = 'completed'",
|
||||
"orders.amount > 100",
|
||||
],
|
||||
}
|
||||
),
|
||||
]
|
||||
)
|
||||
output = capsys.readouterr().out
|
||||
assert "WHERE" in output
|
||||
|
||||
def test_limit_cli(self, capsys):
|
||||
main(
|
||||
[
|
||||
"--sources",
|
||||
SOURCES_DIR,
|
||||
"-q",
|
||||
json.dumps(
|
||||
{
|
||||
"measures": ["sum(orders.amount)"],
|
||||
"dimensions": ["orders.status"],
|
||||
"limit": 50,
|
||||
}
|
||||
),
|
||||
]
|
||||
)
|
||||
output = capsys.readouterr().out
|
||||
assert "LIMIT 50" in output
|
||||
|
||||
def test_dialect_cli(self, capsys):
|
||||
main(
|
||||
[
|
||||
"--sources",
|
||||
SOURCES_DIR,
|
||||
"-q",
|
||||
json.dumps(
|
||||
{
|
||||
"measures": ["sum(orders.amount)"],
|
||||
"dimensions": ["orders.status"],
|
||||
}
|
||||
),
|
||||
"--dialect",
|
||||
"bigquery",
|
||||
]
|
||||
)
|
||||
output = capsys.readouterr().out
|
||||
assert "bigquery" in output
|
||||
|
||||
|
||||
# ── From test_edge_cases.py: TestCLISuggest ──────────────────────────
|
||||
|
||||
|
||||
class TestCliSuggest:
|
||||
def test_suggest_valid_query(self, capsys):
|
||||
main(
|
||||
[
|
||||
"--sources",
|
||||
SOURCES_DIR,
|
||||
"-q",
|
||||
json.dumps(
|
||||
{
|
||||
"measures": ["sum(orders.amount)"],
|
||||
"dimensions": ["orders.status"],
|
||||
}
|
||||
),
|
||||
"--suggest",
|
||||
]
|
||||
)
|
||||
output = capsys.readouterr().out
|
||||
assert "valid" in output.lower()
|
||||
|
||||
def test_suggest_invalid_query(self, capsys):
|
||||
main(
|
||||
[
|
||||
"--sources",
|
||||
SOURCES_DIR,
|
||||
"-q",
|
||||
json.dumps(
|
||||
{
|
||||
"measures": ["sum(nonexistent.amount)"],
|
||||
"dimensions": ["orders.status"],
|
||||
}
|
||||
),
|
||||
"--suggest",
|
||||
]
|
||||
)
|
||||
output = capsys.readouterr().out
|
||||
assert "failed" in output.lower() or "Suggestion" in output
|
||||
|
||||
|
||||
# ── From test_edge_cases.py: TestCLIOrderBy ──────────────────────────
|
||||
|
||||
|
||||
class TestCliOrderBy:
|
||||
def test_order_by_desc(self, capsys):
|
||||
main(
|
||||
[
|
||||
"--sources",
|
||||
SOURCES_DIR,
|
||||
"-q",
|
||||
json.dumps(
|
||||
{
|
||||
"measures": ["sum(orders.amount)"],
|
||||
"dimensions": ["orders.status"],
|
||||
"order_by": [
|
||||
{"field": "sum(orders.amount)", "direction": "desc"}
|
||||
],
|
||||
}
|
||||
),
|
||||
]
|
||||
)
|
||||
output = capsys.readouterr().out
|
||||
assert "DESC" in output
|
||||
|
||||
def test_order_by_asc(self, capsys):
|
||||
main(
|
||||
[
|
||||
"--sources",
|
||||
SOURCES_DIR,
|
||||
"-q",
|
||||
json.dumps(
|
||||
{
|
||||
"measures": ["sum(orders.amount)"],
|
||||
"dimensions": ["orders.status"],
|
||||
"order_by": [{"field": "orders.status", "direction": "asc"}],
|
||||
}
|
||||
),
|
||||
]
|
||||
)
|
||||
output = capsys.readouterr().out
|
||||
assert "ORDER BY" in output
|
||||
|
||||
|
||||
# ── From test_brainstorm_cases.py: TestBrainstormCliOutput ───────────
|
||||
|
||||
|
||||
def _build_chasm_sources() -> dict[str, SourceDefinition]:
|
||||
customers = SourceDefinition(
|
||||
name="customers",
|
||||
table="public.customers",
|
||||
grain=["id"],
|
||||
columns=[
|
||||
SourceColumn(name="id", type="number"),
|
||||
SourceColumn(name="segment", type="string"),
|
||||
],
|
||||
)
|
||||
orders = SourceDefinition(
|
||||
name="orders",
|
||||
table="public.orders",
|
||||
grain=["id"],
|
||||
columns=[
|
||||
SourceColumn(name="id", type="number"),
|
||||
SourceColumn(name="customer_id", type="number"),
|
||||
SourceColumn(name="amount", type="number"),
|
||||
],
|
||||
joins=[
|
||||
JoinDeclaration(
|
||||
to="customers",
|
||||
on="customer_id = customers.id",
|
||||
relationship="many_to_one",
|
||||
)
|
||||
],
|
||||
)
|
||||
tickets = SourceDefinition(
|
||||
name="tickets",
|
||||
table="public.tickets",
|
||||
grain=["id"],
|
||||
columns=[
|
||||
SourceColumn(name="id", type="number"),
|
||||
SourceColumn(name="customer_id", type="number"),
|
||||
SourceColumn(name="cost", type="number"),
|
||||
],
|
||||
joins=[
|
||||
JoinDeclaration(
|
||||
to="customers",
|
||||
on="customer_id = customers.id",
|
||||
relationship="many_to_one",
|
||||
)
|
||||
],
|
||||
)
|
||||
return {"customers": customers, "orders": orders, "tickets": tickets}
|
||||
|
||||
|
||||
def _write_sources(sources_dict: dict[str, dict]) -> str:
|
||||
import tempfile
|
||||
import yaml
|
||||
|
||||
tmpdir = tempfile.mkdtemp()
|
||||
for name, data in sources_dict.items():
|
||||
with open(Path(tmpdir) / f"{name}.yaml", "w") as f:
|
||||
yaml.dump(data, f)
|
||||
return tmpdir
|
||||
|
||||
|
||||
class TestCliPlanOutput:
|
||||
def test_print_plan_includes_join_locality_where_and_having(self, capsys):
|
||||
sources = _build_chasm_sources()
|
||||
graph = JoinGraph(sources)
|
||||
graph.build()
|
||||
planner = QueryPlanner(sources, graph)
|
||||
|
||||
query = SemanticQuery(
|
||||
measures=["sum(orders.amount)", "count(tickets.id)"],
|
||||
dimensions=["customers.segment"],
|
||||
filters=["customers.segment = 'SMB'", "sum(orders.amount) > 10000"],
|
||||
)
|
||||
plan = planner.plan(query)
|
||||
|
||||
print_plan(plan)
|
||||
output = capsys.readouterr().out
|
||||
|
||||
assert "Resolved Plan" in output
|
||||
assert "Joins:" in output
|
||||
assert "Locality:" in output
|
||||
assert "WHERE:" in output
|
||||
assert "HAVING:" in output
|
||||
assert "customers.segment" in output
|
||||
|
||||
def test_suggest_cli_surfaces_graph_errors(self, capsys):
|
||||
tmpdir = _write_sources(
|
||||
{
|
||||
"a": {
|
||||
"name": "a",
|
||||
"table": "t",
|
||||
"grain": ["id"],
|
||||
"columns": [{"name": "id", "type": "number"}],
|
||||
},
|
||||
"b": {
|
||||
"name": "b",
|
||||
"table": "t2",
|
||||
"grain": ["id"],
|
||||
"columns": [
|
||||
{"name": "id", "type": "number"},
|
||||
{"name": "val", "type": "number"},
|
||||
],
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
main(
|
||||
[
|
||||
"--sources",
|
||||
tmpdir,
|
||||
"-q",
|
||||
json.dumps({"measures": ["sum(a.id)"], "dimensions": ["b.val"]}),
|
||||
"--suggest",
|
||||
]
|
||||
)
|
||||
output = capsys.readouterr().out
|
||||
|
||||
assert "Query failed:" in output
|
||||
assert "Graph error:" in output
|
||||
assert "Disconnected components" in output
|
||||
assert "Suggestion:" in output
|
||||
|
||||
def test_list_sources_includes_join_and_filtered_measure_details(self, capsys):
|
||||
main(["--sources", SOURCES_DIR, "--list-sources"])
|
||||
output = capsys.readouterr().out
|
||||
|
||||
assert "joins:" in output
|
||||
assert "→ customers (many_to_one) on customer_id = customers.id" in output
|
||||
assert "revenue: sum(amount) (filter: status != 'refunded')" in output
|
||||
313
python/klo-sl/tests/test_computed_columns.py
Normal file
313
python/klo-sl/tests/test_computed_columns.py
Normal file
|
|
@ -0,0 +1,313 @@
|
|||
"""Tests for computed column (expr) support on table sources."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
from semantic_layer.models import SourceColumn
|
||||
|
||||
from .conftest import assert_valid_sql, make_engine
|
||||
|
||||
|
||||
def _lineitem_source(**overrides):
|
||||
base = {
|
||||
"name": "lineitem",
|
||||
"table": "public.lineitem",
|
||||
"grain": ["l_orderkey", "l_linenumber"],
|
||||
"columns": [
|
||||
{"name": "l_orderkey", "type": "number"},
|
||||
{"name": "l_linenumber", "type": "number"},
|
||||
{"name": "l_extendedprice", "type": "number"},
|
||||
{"name": "l_discount", "type": "number"},
|
||||
{"name": "l_quantity", "type": "number"},
|
||||
{"name": "l_returnflag", "type": "string"},
|
||||
{
|
||||
"name": "net_price",
|
||||
"type": "number",
|
||||
"expr": "l_extendedprice * (1 - l_discount)",
|
||||
},
|
||||
],
|
||||
}
|
||||
base.update(overrides)
|
||||
return base
|
||||
|
||||
|
||||
class TestComputedColumnDimension:
|
||||
def test_computed_column_in_select_and_group_by(self):
|
||||
engine = make_engine({"lineitem": _lineitem_source()})
|
||||
result = engine.query(
|
||||
{
|
||||
"measures": ["sum(lineitem.l_quantity)"],
|
||||
"dimensions": ["lineitem.net_price"],
|
||||
}
|
||||
)
|
||||
assert_valid_sql(result.sql)
|
||||
assert "l_extendedprice" in result.sql
|
||||
assert "l_discount" in result.sql
|
||||
assert "AS net_price" in result.sql
|
||||
|
||||
def test_date_trunc_on_computed_column(self):
|
||||
engine = make_engine(
|
||||
{
|
||||
"events": {
|
||||
"name": "events",
|
||||
"table": "public.events",
|
||||
"grain": ["id"],
|
||||
"columns": [
|
||||
{"name": "id", "type": "number"},
|
||||
{"name": "created_at", "type": "time", "role": "time"},
|
||||
{"name": "offset_hours", "type": "number"},
|
||||
{
|
||||
"name": "local_time",
|
||||
"type": "time",
|
||||
"role": "time",
|
||||
"expr": "created_at + offset_hours * INTERVAL '1 hour'",
|
||||
},
|
||||
{"name": "value", "type": "number"},
|
||||
],
|
||||
}
|
||||
}
|
||||
)
|
||||
result = engine.query(
|
||||
{
|
||||
"measures": ["sum(events.value)"],
|
||||
"dimensions": [{"field": "events.local_time", "granularity": "month"}],
|
||||
}
|
||||
)
|
||||
assert_valid_sql(result.sql)
|
||||
assert "DATE_TRUNC" in result.sql
|
||||
assert "created_at" in result.sql
|
||||
assert "offset_hours" in result.sql
|
||||
|
||||
|
||||
class TestComputedColumnInMeasure:
|
||||
def test_runtime_aggregate_on_computed_column(self):
|
||||
engine = make_engine({"lineitem": _lineitem_source()})
|
||||
result = engine.query(
|
||||
{
|
||||
"measures": ["sum(lineitem.net_price)"],
|
||||
"dimensions": [],
|
||||
}
|
||||
)
|
||||
assert_valid_sql(result.sql)
|
||||
assert "SUM" in result.sql.upper()
|
||||
assert "l_extendedprice" in result.sql
|
||||
assert "l_discount" in result.sql
|
||||
|
||||
def test_predefined_measure_referencing_computed_column(self):
|
||||
source = _lineitem_source(
|
||||
measures=[
|
||||
{"name": "total_net", "expr": "sum(net_price)"},
|
||||
]
|
||||
)
|
||||
engine = make_engine({"lineitem": source})
|
||||
result = engine.query(
|
||||
{
|
||||
"measures": ["lineitem.total_net"],
|
||||
"dimensions": [],
|
||||
}
|
||||
)
|
||||
assert_valid_sql(result.sql)
|
||||
assert "l_extendedprice" in result.sql
|
||||
assert "l_discount" in result.sql
|
||||
|
||||
|
||||
class TestComputedColumnInFilter:
|
||||
def test_computed_column_in_where_filter(self):
|
||||
engine = make_engine({"lineitem": _lineitem_source()})
|
||||
result = engine.query(
|
||||
{
|
||||
"measures": ["sum(lineitem.l_extendedprice)"],
|
||||
"dimensions": [],
|
||||
"filters": ["lineitem.net_price > 100"],
|
||||
}
|
||||
)
|
||||
assert_valid_sql(result.sql)
|
||||
assert "WHERE" in result.sql
|
||||
assert "l_extendedprice" in result.sql
|
||||
assert "l_discount" in result.sql
|
||||
|
||||
|
||||
class TestComputedColumnWithJoins:
|
||||
def test_join_on_uses_physical_columns(self):
|
||||
engine = make_engine(
|
||||
{
|
||||
"orders": {
|
||||
"name": "orders",
|
||||
"table": "public.orders",
|
||||
"grain": ["id"],
|
||||
"columns": [
|
||||
{"name": "id", "type": "number"},
|
||||
{"name": "customer_id", "type": "number"},
|
||||
{"name": "amount", "type": "number"},
|
||||
{"name": "discount", "type": "number"},
|
||||
{
|
||||
"name": "net_amount",
|
||||
"type": "number",
|
||||
"expr": "amount * (1 - discount)",
|
||||
},
|
||||
],
|
||||
"joins": [
|
||||
{
|
||||
"to": "customers",
|
||||
"on": "customer_id = customers.id",
|
||||
"relationship": "many_to_one",
|
||||
}
|
||||
],
|
||||
},
|
||||
"customers": {
|
||||
"name": "customers",
|
||||
"table": "public.customers",
|
||||
"grain": ["id"],
|
||||
"columns": [
|
||||
{"name": "id", "type": "number"},
|
||||
{"name": "segment", "type": "string"},
|
||||
],
|
||||
},
|
||||
}
|
||||
)
|
||||
result = engine.query(
|
||||
{
|
||||
"measures": ["sum(orders.net_amount)"],
|
||||
"dimensions": ["customers.segment"],
|
||||
}
|
||||
)
|
||||
assert_valid_sql(result.sql)
|
||||
# JOIN ON should use physical columns
|
||||
assert "orders.customer_id" in result.sql
|
||||
assert "customers.id" in result.sql
|
||||
# Measure should be expanded
|
||||
assert "orders.amount" in result.sql
|
||||
assert "orders.discount" in result.sql
|
||||
|
||||
|
||||
class TestComputedColumnLocality:
|
||||
def test_computed_column_in_aggregate_locality(self):
|
||||
engine = make_engine(
|
||||
{
|
||||
"hub": {
|
||||
"name": "hub",
|
||||
"table": "public.hub",
|
||||
"grain": ["id"],
|
||||
"columns": [
|
||||
{"name": "id", "type": "number"},
|
||||
{"name": "segment", "type": "string"},
|
||||
],
|
||||
},
|
||||
"fact_a": {
|
||||
"name": "fact_a",
|
||||
"table": "public.fact_a",
|
||||
"grain": ["id"],
|
||||
"columns": [
|
||||
{"name": "id", "type": "number"},
|
||||
{"name": "hub_id", "type": "number"},
|
||||
{"name": "price", "type": "number"},
|
||||
{"name": "qty", "type": "number"},
|
||||
{
|
||||
"name": "total",
|
||||
"type": "number",
|
||||
"expr": "price * qty",
|
||||
},
|
||||
],
|
||||
"joins": [
|
||||
{
|
||||
"to": "hub",
|
||||
"on": "hub_id = hub.id",
|
||||
"relationship": "many_to_one",
|
||||
}
|
||||
],
|
||||
},
|
||||
"fact_b": {
|
||||
"name": "fact_b",
|
||||
"table": "public.fact_b",
|
||||
"grain": ["id"],
|
||||
"columns": [
|
||||
{"name": "id", "type": "number"},
|
||||
{"name": "hub_id", "type": "number"},
|
||||
{"name": "val", "type": "number"},
|
||||
],
|
||||
"joins": [
|
||||
{
|
||||
"to": "hub",
|
||||
"on": "hub_id = hub.id",
|
||||
"relationship": "many_to_one",
|
||||
}
|
||||
],
|
||||
},
|
||||
}
|
||||
)
|
||||
result = engine.query(
|
||||
{
|
||||
"measures": ["sum(fact_a.total)", "sum(fact_b.val)"],
|
||||
"dimensions": ["hub.segment"],
|
||||
}
|
||||
)
|
||||
assert_valid_sql(result.sql)
|
||||
assert "_agg" in result.sql
|
||||
assert "fact_a.price" in result.sql
|
||||
assert "fact_a.qty" in result.sql
|
||||
|
||||
|
||||
class TestComputedColumnModel:
|
||||
def test_source_column_with_expr(self):
|
||||
col = SourceColumn(
|
||||
name="net_price", type="number", expr="price * (1 - discount)"
|
||||
)
|
||||
assert col.expr == "price * (1 - discount)"
|
||||
|
||||
def test_source_column_without_expr(self):
|
||||
col = SourceColumn(name="price", type="number")
|
||||
assert col.expr is None
|
||||
|
||||
def test_source_column_expr_in_yaml_roundtrip(self):
|
||||
engine = make_engine(
|
||||
{
|
||||
"t": {
|
||||
"name": "t",
|
||||
"table": "public.t",
|
||||
"grain": ["id"],
|
||||
"columns": [
|
||||
{"name": "id", "type": "number"},
|
||||
{"name": "a", "type": "number"},
|
||||
{"name": "b", "type": "number"},
|
||||
{
|
||||
"name": "c",
|
||||
"type": "number",
|
||||
"expr": "a + b",
|
||||
},
|
||||
],
|
||||
}
|
||||
}
|
||||
)
|
||||
src = engine.sources["t"]
|
||||
c_col = next(c for c in src.columns if c.name == "c")
|
||||
assert c_col.expr == "a + b"
|
||||
|
||||
|
||||
def test_bigquery_computed_column_with_timestamp_add(make_engine_factory):
|
||||
"""Computed column authored with BigQuery-native TIMESTAMP_ADD must survive."""
|
||||
source = {
|
||||
"name": "events",
|
||||
"table": "events",
|
||||
"grain": ["id"],
|
||||
"columns": [
|
||||
{"name": "id", "type": "number"},
|
||||
{"name": "event_at", "type": "time"},
|
||||
{"name": "tz_offset", "type": "number"},
|
||||
{
|
||||
"name": "local_hour",
|
||||
"type": "time",
|
||||
"expr": "TIMESTAMP_ADD(event_at, INTERVAL tz_offset HOUR)",
|
||||
},
|
||||
],
|
||||
"measures": [{"name": "cnt", "expr": "count(*)"}],
|
||||
}
|
||||
engine = make_engine_factory({"events": source}, dialect="bigquery")
|
||||
result = engine.query(
|
||||
{
|
||||
"measures": ["events.cnt"],
|
||||
"dimensions": ["events.local_hour"],
|
||||
"filters": [],
|
||||
}
|
||||
)
|
||||
assert "TIMESTAMP_ADD" in result.sql.upper()
|
||||
assert "HOUR" in result.sql.upper()
|
||||
288
python/klo-sl/tests/test_corner_case_regressions.py
Normal file
288
python/klo-sl/tests/test_corner_case_regressions.py
Normal file
|
|
@ -0,0 +1,288 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from conftest import assert_valid_sql, make_engine
|
||||
|
||||
|
||||
def _duplicate_predefined_sources() -> dict[str, dict]:
|
||||
return {
|
||||
"customers": {
|
||||
"name": "customers",
|
||||
"table": "public.customers",
|
||||
"grain": ["id"],
|
||||
"columns": [
|
||||
{"name": "id", "type": "number"},
|
||||
{"name": "segment", "type": "string"},
|
||||
],
|
||||
},
|
||||
"orders": {
|
||||
"name": "orders",
|
||||
"table": "public.orders",
|
||||
"grain": ["id"],
|
||||
"columns": [
|
||||
{"name": "id", "type": "number"},
|
||||
{"name": "customer_id", "type": "number"},
|
||||
{"name": "amount", "type": "number"},
|
||||
],
|
||||
"joins": [
|
||||
{
|
||||
"to": "customers",
|
||||
"on": "customer_id = customers.id",
|
||||
"relationship": "many_to_one",
|
||||
}
|
||||
],
|
||||
"measures": [{"name": "revenue", "expr": "sum(amount)"}],
|
||||
},
|
||||
"refunds": {
|
||||
"name": "refunds",
|
||||
"table": "public.refunds",
|
||||
"grain": ["id"],
|
||||
"columns": [
|
||||
{"name": "id", "type": "number"},
|
||||
{"name": "customer_id", "type": "number"},
|
||||
{"name": "amount", "type": "number"},
|
||||
],
|
||||
"joins": [
|
||||
{
|
||||
"to": "customers",
|
||||
"on": "customer_id = customers.id",
|
||||
"relationship": "many_to_one",
|
||||
}
|
||||
],
|
||||
"measures": [{"name": "revenue", "expr": "sum(amount)"}],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _include_empty_sources() -> dict[str, dict]:
|
||||
return {
|
||||
"customers": {
|
||||
"name": "customers",
|
||||
"table": "public.customers",
|
||||
"grain": ["id"],
|
||||
"columns": [
|
||||
{"name": "id", "type": "number"},
|
||||
{"name": "segment", "type": "string"},
|
||||
],
|
||||
},
|
||||
"orders": {
|
||||
"name": "orders",
|
||||
"table": "public.orders",
|
||||
"grain": ["id"],
|
||||
"columns": [
|
||||
{"name": "id", "type": "number"},
|
||||
{"name": "customer_id", "type": "number"},
|
||||
{"name": "amount", "type": "number"},
|
||||
],
|
||||
"joins": [
|
||||
{
|
||||
"to": "customers",
|
||||
"on": "customer_id = customers.id",
|
||||
"relationship": "many_to_one",
|
||||
}
|
||||
],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _alias_measure_sources() -> dict[str, dict]:
|
||||
return {
|
||||
"customers": {
|
||||
"name": "customers",
|
||||
"table": "public.customers",
|
||||
"grain": ["id"],
|
||||
"columns": [
|
||||
{"name": "id", "type": "number"},
|
||||
{"name": "lifetime_value", "type": "number"},
|
||||
],
|
||||
"measures": [{"name": "total_ltv", "expr": "sum(lifetime_value)"}],
|
||||
},
|
||||
"orders": {
|
||||
"name": "orders",
|
||||
"table": "public.orders",
|
||||
"grain": ["id"],
|
||||
"columns": [
|
||||
{"name": "id", "type": "number"},
|
||||
{"name": "billing_customer_id", "type": "number"},
|
||||
{"name": "status", "type": "string"},
|
||||
],
|
||||
"joins": [
|
||||
{
|
||||
"to": "customers",
|
||||
"on": "billing_customer_id = customers.id",
|
||||
"relationship": "many_to_one",
|
||||
"alias": "billing_customer",
|
||||
}
|
||||
],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def test_duplicate_predefined_names_stay_distinct_in_derived_measure():
|
||||
engine = make_engine(_duplicate_predefined_sources())
|
||||
result = engine.query(
|
||||
{
|
||||
"measures": [
|
||||
"orders.revenue",
|
||||
"refunds.revenue",
|
||||
{"expr": "orders.revenue - refunds.revenue", "name": "net"},
|
||||
],
|
||||
"dimensions": ["customers.segment"],
|
||||
}
|
||||
)
|
||||
|
||||
assert result.resolved_plan.has_fan_out
|
||||
assert "orders_agg.orders_revenue" in result.sql
|
||||
assert "refunds_agg.refunds_revenue" in result.sql
|
||||
assert "revenue - revenue" not in result.sql
|
||||
assert_valid_sql(result.sql)
|
||||
|
||||
|
||||
def test_duplicate_predefined_names_expand_having_filters_in_locality_mode():
|
||||
engine = make_engine(_duplicate_predefined_sources())
|
||||
result = engine.query(
|
||||
{
|
||||
"measures": ["orders.revenue", "refunds.revenue"],
|
||||
"dimensions": ["customers.segment"],
|
||||
"filters": ["orders.revenue > 100"],
|
||||
}
|
||||
)
|
||||
|
||||
# In multi-CTE mode, HAVING refs are wrapped in COALESCE for FULL JOIN NULL safety
|
||||
assert "WHERE COALESCE(orders_agg.orders_revenue, 0) > 100" in result.sql
|
||||
assert_valid_sql(result.sql)
|
||||
|
||||
|
||||
def test_include_empty_anchors_the_dimension_side():
|
||||
engine = make_engine(_include_empty_sources())
|
||||
result = engine.query(
|
||||
{
|
||||
"measures": ["sum(orders.amount)"],
|
||||
"dimensions": ["customers.segment"],
|
||||
"include_empty": True,
|
||||
}
|
||||
)
|
||||
|
||||
assert result.resolved_plan.anchor_source == "customers"
|
||||
assert "FROM public.customers AS customers" in result.sql
|
||||
assert "LEFT JOIN public.orders AS orders" in result.sql
|
||||
assert_valid_sql(result.sql)
|
||||
|
||||
|
||||
def test_cross_grain_measures_on_same_chain_use_aggregate_locality():
|
||||
engine = make_engine(
|
||||
{
|
||||
"customers": {
|
||||
"name": "customers",
|
||||
"table": "public.customers",
|
||||
"grain": ["id"],
|
||||
"columns": [
|
||||
{"name": "id", "type": "number"},
|
||||
{"name": "segment", "type": "string"},
|
||||
{"name": "credit_limit", "type": "number"},
|
||||
],
|
||||
},
|
||||
"orders": {
|
||||
"name": "orders",
|
||||
"table": "public.orders",
|
||||
"grain": ["id"],
|
||||
"columns": [
|
||||
{"name": "id", "type": "number"},
|
||||
{"name": "customer_id", "type": "number"},
|
||||
{"name": "amount", "type": "number"},
|
||||
],
|
||||
"joins": [
|
||||
{
|
||||
"to": "customers",
|
||||
"on": "customer_id = customers.id",
|
||||
"relationship": "many_to_one",
|
||||
}
|
||||
],
|
||||
},
|
||||
}
|
||||
)
|
||||
result = engine.query(
|
||||
{
|
||||
"measures": ["sum(orders.amount)", "sum(customers.credit_limit)"],
|
||||
"dimensions": ["customers.segment"],
|
||||
}
|
||||
)
|
||||
|
||||
assert result.resolved_plan.has_fan_out
|
||||
assert "orders_agg" in result.sql
|
||||
assert "customers_agg" in result.sql
|
||||
assert_valid_sql(result.sql)
|
||||
|
||||
|
||||
def test_filtered_count_distinct_keeps_distinct_inside_count():
|
||||
engine = make_engine(
|
||||
{
|
||||
"orders": {
|
||||
"name": "orders",
|
||||
"table": "public.orders",
|
||||
"grain": ["id"],
|
||||
"columns": [
|
||||
{"name": "id", "type": "number"},
|
||||
{"name": "customer_id", "type": "number"},
|
||||
{"name": "status", "type": "string"},
|
||||
],
|
||||
"measures": [
|
||||
{
|
||||
"name": "paid_customers",
|
||||
"expr": "count_distinct(customer_id)",
|
||||
"filter": "status = 'paid'",
|
||||
}
|
||||
],
|
||||
}
|
||||
}
|
||||
)
|
||||
result = engine.query(
|
||||
{"measures": ["orders.paid_customers"], "dimensions": ["orders.status"]}
|
||||
)
|
||||
|
||||
assert "COUNT(DISTINCT CASE WHEN orders.status = 'paid'" in result.sql
|
||||
assert_valid_sql(result.sql)
|
||||
|
||||
|
||||
def test_predefined_measure_via_alias_uses_real_table_and_alias_qualification():
|
||||
engine = make_engine(_alias_measure_sources())
|
||||
result = engine.query(
|
||||
{
|
||||
"measures": ["billing_customer.total_ltv"],
|
||||
"dimensions": ["billing_customer.id"],
|
||||
}
|
||||
)
|
||||
|
||||
assert "FROM public.customers AS billing_customer" in result.sql
|
||||
assert "SUM(billing_customer.lifetime_value)" in result.sql
|
||||
assert_valid_sql(result.sql)
|
||||
|
||||
|
||||
def test_runtime_case_measure_gets_a_safe_auto_alias():
|
||||
engine = make_engine(
|
||||
{
|
||||
"orders": {
|
||||
"name": "orders",
|
||||
"table": "public.orders",
|
||||
"grain": ["id"],
|
||||
"columns": [
|
||||
{"name": "id", "type": "number"},
|
||||
{"name": "amount", "type": "number"},
|
||||
{"name": "status", "type": "string"},
|
||||
],
|
||||
}
|
||||
}
|
||||
)
|
||||
result = engine.query(
|
||||
{
|
||||
"measures": [
|
||||
"sum(CASE WHEN orders.status = 'paid' THEN orders.amount ELSE 0 END)"
|
||||
],
|
||||
"dimensions": ["orders.status"],
|
||||
}
|
||||
)
|
||||
|
||||
assert (
|
||||
"sum_case_when_orders_status_paid_then_orders_amount_else_0_end" in result.sql
|
||||
)
|
||||
assert "=" not in result.resolved_plan.measures[0].name
|
||||
assert_valid_sql(result.sql)
|
||||
740
python/klo-sl/tests/test_coverage_gaps.py
Normal file
740
python/klo-sl/tests/test_coverage_gaps.py
Normal file
|
|
@ -0,0 +1,740 @@
|
|||
"""Tests targeting specific coverage gaps in planner.py, generator.py, models.py, engine.py."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from semantic_layer.generator import SqlGenerator
|
||||
from semantic_layer.graph import JoinGraph
|
||||
from semantic_layer.models import (
|
||||
JoinDeclaration,
|
||||
MeasureDefinition,
|
||||
SemanticQuery,
|
||||
SourceColumn,
|
||||
SourceDefinition,
|
||||
)
|
||||
from semantic_layer.planner import QueryPlanner
|
||||
|
||||
from conftest import assert_valid_sql, make_engine
|
||||
|
||||
|
||||
# ── Helpers ──────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _make_planner(sources: dict[str, SourceDefinition]) -> QueryPlanner:
|
||||
graph = JoinGraph(sources)
|
||||
graph.build()
|
||||
return QueryPlanner(sources, graph)
|
||||
|
||||
|
||||
def _plan_and_generate(sources: dict[str, SourceDefinition], query_dict: dict) -> str:
|
||||
planner = _make_planner(sources)
|
||||
generator = SqlGenerator(dialect="postgres")
|
||||
query = SemanticQuery(**query_dict)
|
||||
plan = planner.plan(query)
|
||||
sql = generator.generate(plan, sources)
|
||||
assert_valid_sql(sql)
|
||||
return sql
|
||||
|
||||
|
||||
# ── Source fixtures ──────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _simple_sources() -> dict[str, SourceDefinition]:
|
||||
"""orders -> customers (m2o)."""
|
||||
customers = SourceDefinition(
|
||||
name="customers",
|
||||
table="public.customers",
|
||||
grain=["id"],
|
||||
columns=[
|
||||
SourceColumn(name="id", type="number"),
|
||||
SourceColumn(name="segment", type="string"),
|
||||
],
|
||||
)
|
||||
orders = SourceDefinition(
|
||||
name="orders",
|
||||
table="public.orders",
|
||||
grain=["id"],
|
||||
columns=[
|
||||
SourceColumn(name="id", type="number"),
|
||||
SourceColumn(name="customer_id", type="number"),
|
||||
SourceColumn(name="amount", type="number"),
|
||||
SourceColumn(name="status", type="string"),
|
||||
],
|
||||
joins=[
|
||||
JoinDeclaration(
|
||||
to="customers",
|
||||
on="customer_id = customers.id",
|
||||
relationship="many_to_one",
|
||||
)
|
||||
],
|
||||
measures=[
|
||||
MeasureDefinition(
|
||||
name="revenue", expr="sum(amount)", filter="status != 'refunded'"
|
||||
),
|
||||
MeasureDefinition(name="order_count", expr="count(id)"),
|
||||
],
|
||||
)
|
||||
return {"customers": customers, "orders": orders}
|
||||
|
||||
|
||||
def _chasm_sources() -> dict[str, SourceDefinition]:
|
||||
"""Two fact tables (orders, tickets) -> hub (customers). Classic chasm trap."""
|
||||
customers = SourceDefinition(
|
||||
name="customers",
|
||||
table="public.customers",
|
||||
grain=["id"],
|
||||
columns=[
|
||||
SourceColumn(name="id", type="number"),
|
||||
SourceColumn(name="segment", type="string"),
|
||||
],
|
||||
)
|
||||
orders = SourceDefinition(
|
||||
name="orders",
|
||||
table="public.orders",
|
||||
grain=["id"],
|
||||
columns=[
|
||||
SourceColumn(name="id", type="number"),
|
||||
SourceColumn(name="customer_id", type="number"),
|
||||
SourceColumn(name="amount", type="number"),
|
||||
],
|
||||
joins=[
|
||||
JoinDeclaration(
|
||||
to="customers",
|
||||
on="customer_id = customers.id",
|
||||
relationship="many_to_one",
|
||||
)
|
||||
],
|
||||
measures=[MeasureDefinition(name="revenue", expr="sum(amount)")],
|
||||
)
|
||||
tickets = SourceDefinition(
|
||||
name="tickets",
|
||||
table="public.tickets",
|
||||
grain=["id"],
|
||||
columns=[
|
||||
SourceColumn(name="id", type="number"),
|
||||
SourceColumn(name="customer_id", type="number"),
|
||||
SourceColumn(name="priority", type="string"),
|
||||
],
|
||||
joins=[
|
||||
JoinDeclaration(
|
||||
to="customers",
|
||||
on="customer_id = customers.id",
|
||||
relationship="many_to_one",
|
||||
)
|
||||
],
|
||||
measures=[MeasureDefinition(name="ticket_count", expr="count(id)")],
|
||||
)
|
||||
return {"customers": customers, "orders": orders, "tickets": tickets}
|
||||
|
||||
|
||||
def _chain_sources_with_derived() -> dict[str, SourceDefinition]:
|
||||
"""orders -> customers -> tiers (m2o chain) with derived measures."""
|
||||
tiers = SourceDefinition(
|
||||
name="tiers",
|
||||
table="public.tiers",
|
||||
grain=["id"],
|
||||
columns=[
|
||||
SourceColumn(name="id", type="number"),
|
||||
SourceColumn(name="level", type="string"),
|
||||
],
|
||||
)
|
||||
customers = SourceDefinition(
|
||||
name="customers",
|
||||
table="public.customers",
|
||||
grain=["id"],
|
||||
columns=[
|
||||
SourceColumn(name="id", type="number"),
|
||||
SourceColumn(name="tier_id", type="number"),
|
||||
SourceColumn(name="segment", type="string"),
|
||||
],
|
||||
joins=[
|
||||
JoinDeclaration(
|
||||
to="tiers", on="tier_id = tiers.id", relationship="many_to_one"
|
||||
)
|
||||
],
|
||||
)
|
||||
orders = SourceDefinition(
|
||||
name="orders",
|
||||
table="public.orders",
|
||||
grain=["id"],
|
||||
columns=[
|
||||
SourceColumn(name="id", type="number"),
|
||||
SourceColumn(name="customer_id", type="number"),
|
||||
SourceColumn(name="amount", type="number"),
|
||||
SourceColumn(name="status", type="string"),
|
||||
],
|
||||
joins=[
|
||||
JoinDeclaration(
|
||||
to="customers",
|
||||
on="customer_id = customers.id",
|
||||
relationship="many_to_one",
|
||||
)
|
||||
],
|
||||
measures=[
|
||||
MeasureDefinition(
|
||||
name="revenue", expr="sum(amount)", filter="status != 'refunded'"
|
||||
),
|
||||
MeasureDefinition(name="order_count", expr="count(id)"),
|
||||
MeasureDefinition(name="avg_order", expr="revenue / order_count"),
|
||||
],
|
||||
)
|
||||
return {"tiers": tiers, "customers": customers, "orders": orders}
|
||||
|
||||
|
||||
# ── Planner: nested aggregation (lines 432-440) ─────────────────────
|
||||
|
||||
|
||||
class TestNestedAggregation:
|
||||
def test_nested_aggregation_raises(self):
|
||||
"""avg(sum(orders.amount)) should be rejected."""
|
||||
sources = _simple_sources()
|
||||
planner = _make_planner(sources)
|
||||
with pytest.raises(ValueError, match="Nested aggregation is not supported"):
|
||||
planner.plan(
|
||||
SemanticQuery(
|
||||
measures=["avg(sum(orders.amount))"],
|
||||
dimensions=["orders.status"],
|
||||
)
|
||||
)
|
||||
|
||||
def test_nested_max_count_raises(self):
|
||||
"""max(count(orders.id)) should be rejected."""
|
||||
sources = _simple_sources()
|
||||
planner = _make_planner(sources)
|
||||
with pytest.raises(ValueError, match="Nested aggregation is not supported"):
|
||||
planner.plan(
|
||||
SemanticQuery(
|
||||
measures=["max(count(orders.id))"],
|
||||
dimensions=["orders.status"],
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
# ── Planner: OR filter mixing (lines 810-833) ───────────────────────
|
||||
|
||||
|
||||
class TestOrFilterMixing:
|
||||
def test_or_mixing_agg_and_nonagg_raises(self):
|
||||
"""OR that mixes aggregate and non-aggregate conditions should raise."""
|
||||
sources = _simple_sources()
|
||||
planner = _make_planner(sources)
|
||||
with pytest.raises(ValueError, match="mixes aggregate and non-aggregate"):
|
||||
planner.plan(
|
||||
SemanticQuery(
|
||||
measures=["sum(orders.amount)"],
|
||||
dimensions=["orders.status"],
|
||||
filters=["orders.amount > 100 OR sum(orders.amount) > 5000"],
|
||||
)
|
||||
)
|
||||
|
||||
def test_or_pure_where_ok(self):
|
||||
"""OR with all non-aggregate conditions should be fine."""
|
||||
sources = _simple_sources()
|
||||
sql = _plan_and_generate(
|
||||
sources,
|
||||
{
|
||||
"measures": ["sum(orders.amount)"],
|
||||
"dimensions": ["orders.status"],
|
||||
"filters": ["orders.amount > 100 OR orders.amount < 10"],
|
||||
},
|
||||
)
|
||||
assert "OR" in sql.upper()
|
||||
|
||||
def test_or_pure_having_ok(self):
|
||||
"""OR with all aggregate conditions should be fine."""
|
||||
sources = _simple_sources()
|
||||
sql = _plan_and_generate(
|
||||
sources,
|
||||
{
|
||||
"measures": ["sum(orders.amount)"],
|
||||
"dimensions": ["orders.status"],
|
||||
"filters": ["sum(orders.amount) > 1000 OR count(orders.id) > 5"],
|
||||
},
|
||||
)
|
||||
assert "HAVING" in sql.upper()
|
||||
|
||||
|
||||
# ── Planner: empty source refs (line 62) ─────────────────────────────
|
||||
|
||||
|
||||
class TestEmptySourceRef:
|
||||
def test_no_source_refs_raises(self):
|
||||
"""Query that references no sources should raise."""
|
||||
sources = _simple_sources()
|
||||
planner = _make_planner(sources)
|
||||
with pytest.raises(ValueError, match="does not reference any source"):
|
||||
planner.plan(
|
||||
SemanticQuery(
|
||||
measures=["sum(1)"],
|
||||
dimensions=[],
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
# ── Planner: predefined measure dependency chains (lines 189-194, 237, 281-282) ──
|
||||
|
||||
|
||||
class TestPredefinedMeasureDeps:
|
||||
def test_derived_measure_resolves_dependencies(self):
|
||||
"""avg_order depends on revenue and order_count — both should appear in plan."""
|
||||
sources = _chain_sources_with_derived()
|
||||
planner = _make_planner(sources)
|
||||
plan = planner.plan(
|
||||
SemanticQuery(
|
||||
measures=["orders.avg_order"],
|
||||
dimensions=["orders.status"],
|
||||
)
|
||||
)
|
||||
measure_names = {m.name for m in plan.measures}
|
||||
assert "avg_order" in measure_names
|
||||
assert "revenue" in measure_names
|
||||
assert "order_count" in measure_names
|
||||
|
||||
def test_derived_measure_generates_valid_sql(self):
|
||||
"""Derived measures should produce valid SQL."""
|
||||
sources = _chain_sources_with_derived()
|
||||
sql = _plan_and_generate(
|
||||
sources,
|
||||
{
|
||||
"measures": ["orders.avg_order"],
|
||||
"dimensions": ["customers.segment"],
|
||||
},
|
||||
)
|
||||
assert "GROUP BY" in sql.upper()
|
||||
|
||||
|
||||
# ── Planner: fan-out with one_to_many to dimension sources (lines 595-643) ──
|
||||
|
||||
|
||||
class TestFanOutEdgeCases:
|
||||
def test_single_source_fan_out_to_dimension(self):
|
||||
"""Measure source with one_to_many to dimension should trigger fan-out."""
|
||||
hub = SourceDefinition(
|
||||
name="hub",
|
||||
table="public.hub",
|
||||
grain=["id"],
|
||||
columns=[
|
||||
SourceColumn(name="id", type="number"),
|
||||
SourceColumn(name="name", type="string"),
|
||||
],
|
||||
joins=[
|
||||
JoinDeclaration(
|
||||
to="detail", on="id = detail.hub_id", relationship="one_to_many"
|
||||
)
|
||||
],
|
||||
)
|
||||
detail = SourceDefinition(
|
||||
name="detail",
|
||||
table="public.detail",
|
||||
grain=["id"],
|
||||
columns=[
|
||||
SourceColumn(name="id", type="number"),
|
||||
SourceColumn(name="hub_id", type="number"),
|
||||
SourceColumn(name="category", type="string"),
|
||||
],
|
||||
)
|
||||
sources = {"hub": hub, "detail": detail}
|
||||
planner = _make_planner(sources)
|
||||
plan = planner.plan(
|
||||
SemanticQuery(
|
||||
measures=["sum(hub.id)"],
|
||||
dimensions=["detail.category"],
|
||||
)
|
||||
)
|
||||
assert plan.has_fan_out
|
||||
|
||||
def test_merged_groups_fan_out_to_dimension(self):
|
||||
"""Two measure sources on the same m2o chain, but with o2m to dimension source."""
|
||||
dim = SourceDefinition(
|
||||
name="dim",
|
||||
table="public.dim",
|
||||
grain=["id"],
|
||||
columns=[
|
||||
SourceColumn(name="id", type="number"),
|
||||
SourceColumn(name="label", type="string"),
|
||||
],
|
||||
)
|
||||
parent = SourceDefinition(
|
||||
name="parent",
|
||||
table="public.parent",
|
||||
grain=["id"],
|
||||
columns=[
|
||||
SourceColumn(name="id", type="number"),
|
||||
SourceColumn(name="val", type="number"),
|
||||
],
|
||||
joins=[
|
||||
JoinDeclaration(to="dim", on="id = dim.id", relationship="one_to_many")
|
||||
],
|
||||
)
|
||||
child = SourceDefinition(
|
||||
name="child",
|
||||
table="public.child",
|
||||
grain=["id"],
|
||||
columns=[
|
||||
SourceColumn(name="id", type="number"),
|
||||
SourceColumn(name="parent_id", type="number"),
|
||||
SourceColumn(name="amount", type="number"),
|
||||
],
|
||||
joins=[
|
||||
JoinDeclaration(
|
||||
to="parent", on="parent_id = parent.id", relationship="many_to_one"
|
||||
)
|
||||
],
|
||||
)
|
||||
sources = {"dim": dim, "parent": parent, "child": child}
|
||||
planner = _make_planner(sources)
|
||||
plan = planner.plan(
|
||||
SemanticQuery(
|
||||
measures=["sum(child.amount)"],
|
||||
dimensions=["dim.label"],
|
||||
)
|
||||
)
|
||||
assert plan.has_fan_out
|
||||
|
||||
def test_filter_fan_out_one_to_many_raises(self):
|
||||
"""Filter on source reachable only via one_to_many from measure source should raise."""
|
||||
parent = SourceDefinition(
|
||||
name="parent",
|
||||
table="public.parent",
|
||||
grain=["id"],
|
||||
columns=[
|
||||
SourceColumn(name="id", type="number"),
|
||||
SourceColumn(name="val", type="number"),
|
||||
],
|
||||
joins=[
|
||||
JoinDeclaration(
|
||||
to="child", on="id = child.parent_id", relationship="one_to_many"
|
||||
)
|
||||
],
|
||||
)
|
||||
child = SourceDefinition(
|
||||
name="child",
|
||||
table="public.child",
|
||||
grain=["id"],
|
||||
columns=[
|
||||
SourceColumn(name="id", type="number"),
|
||||
SourceColumn(name="parent_id", type="number"),
|
||||
SourceColumn(name="category", type="string"),
|
||||
],
|
||||
)
|
||||
sources = {"parent": parent, "child": child}
|
||||
planner = _make_planner(sources)
|
||||
with pytest.raises(ValueError, match="one_to_many join"):
|
||||
planner.plan(
|
||||
SemanticQuery(
|
||||
measures=["sum(parent.val)"],
|
||||
dimensions=[],
|
||||
filters=["child.category = 'A'"],
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
# ── Generator: NULL dimension in multi-CTE (lines 385-388) ──────────
|
||||
|
||||
|
||||
class TestNullDimensionInCTE:
|
||||
def test_dimension_not_in_any_cte_gets_null(self):
|
||||
"""When a dimension is from a source not reachable by any CTE, generate NULL."""
|
||||
# Use a 3-fact chasm topology where one dimension is only reachable by one fact
|
||||
hub = SourceDefinition(
|
||||
name="hub",
|
||||
table="public.hub",
|
||||
grain=["id"],
|
||||
columns=[
|
||||
SourceColumn(name="id", type="number"),
|
||||
SourceColumn(name="name", type="string"),
|
||||
],
|
||||
)
|
||||
fact_a = SourceDefinition(
|
||||
name="fact_a",
|
||||
table="public.fact_a",
|
||||
grain=["id"],
|
||||
columns=[
|
||||
SourceColumn(name="id", type="number"),
|
||||
SourceColumn(name="hub_id", type="number"),
|
||||
SourceColumn(name="val", type="number"),
|
||||
SourceColumn(name="extra", type="string"),
|
||||
],
|
||||
joins=[
|
||||
JoinDeclaration(
|
||||
to="hub", on="hub_id = hub.id", relationship="many_to_one"
|
||||
)
|
||||
],
|
||||
)
|
||||
fact_b = SourceDefinition(
|
||||
name="fact_b",
|
||||
table="public.fact_b",
|
||||
grain=["id"],
|
||||
columns=[
|
||||
SourceColumn(name="id", type="number"),
|
||||
SourceColumn(name="hub_id", type="number"),
|
||||
SourceColumn(name="val", type="number"),
|
||||
],
|
||||
joins=[
|
||||
JoinDeclaration(
|
||||
to="hub", on="hub_id = hub.id", relationship="many_to_one"
|
||||
)
|
||||
],
|
||||
)
|
||||
sources = {"hub": hub, "fact_a": fact_a, "fact_b": fact_b}
|
||||
sql = _plan_and_generate(
|
||||
sources,
|
||||
{
|
||||
"measures": ["sum(fact_a.val)", "sum(fact_b.val)"],
|
||||
"dimensions": ["hub.name"],
|
||||
},
|
||||
)
|
||||
# Should produce aggregate locality CTEs with FULL JOIN
|
||||
assert "FULL" in sql.upper() or "WITH" in sql.upper()
|
||||
|
||||
|
||||
# ── Generator: CTE alias collision (lines 202-206) ──────────────────
|
||||
|
||||
|
||||
class TestCTEAliasCollision:
|
||||
def test_alias_collision_resolved(self):
|
||||
"""When a source name matches a potential CTE alias, suffix should be used."""
|
||||
# Create a source named "orders_agg" to collide with the CTE alias
|
||||
orders_agg = SourceDefinition(
|
||||
name="orders_agg",
|
||||
table="public.orders_agg",
|
||||
grain=["id"],
|
||||
columns=[
|
||||
SourceColumn(name="id", type="number"),
|
||||
SourceColumn(name="segment", type="string"),
|
||||
],
|
||||
)
|
||||
orders = SourceDefinition(
|
||||
name="orders",
|
||||
table="public.orders",
|
||||
grain=["id"],
|
||||
columns=[
|
||||
SourceColumn(name="id", type="number"),
|
||||
SourceColumn(name="customer_id", type="number"),
|
||||
SourceColumn(name="amount", type="number"),
|
||||
],
|
||||
joins=[
|
||||
JoinDeclaration(
|
||||
to="orders_agg",
|
||||
on="customer_id = orders_agg.id",
|
||||
relationship="many_to_one",
|
||||
)
|
||||
],
|
||||
)
|
||||
tickets = SourceDefinition(
|
||||
name="tickets",
|
||||
table="public.tickets",
|
||||
grain=["id"],
|
||||
columns=[
|
||||
SourceColumn(name="id", type="number"),
|
||||
SourceColumn(name="customer_id", type="number"),
|
||||
SourceColumn(name="priority", type="string"),
|
||||
],
|
||||
joins=[
|
||||
JoinDeclaration(
|
||||
to="orders_agg",
|
||||
on="customer_id = orders_agg.id",
|
||||
relationship="many_to_one",
|
||||
)
|
||||
],
|
||||
)
|
||||
sources = {"orders_agg": orders_agg, "orders": orders, "tickets": tickets}
|
||||
sql = _plan_and_generate(
|
||||
sources,
|
||||
{
|
||||
"measures": ["sum(orders.amount)", "count(tickets.id)"],
|
||||
"dimensions": ["orders_agg.segment"],
|
||||
},
|
||||
)
|
||||
# Should still produce valid SQL even with the collision
|
||||
assert_valid_sql(sql)
|
||||
|
||||
|
||||
# ── Models: negative limit (line 95) ────────────────────────────────
|
||||
|
||||
|
||||
class TestNegativeLimit:
|
||||
def test_negative_limit_raises(self):
|
||||
with pytest.raises(ValidationError, match="limit"):
|
||||
SemanticQuery(
|
||||
measures=["sum(orders.amount)"],
|
||||
limit=-1,
|
||||
)
|
||||
|
||||
def test_zero_limit_allowed(self):
|
||||
q = SemanticQuery(measures=["sum(orders.amount)"], limit=0)
|
||||
assert q.limit == 0
|
||||
|
||||
|
||||
# ── Engine: suggest with missing sources (lines 100-106, 127) ────────
|
||||
|
||||
|
||||
class TestEngineSuggest:
|
||||
def test_suggest_with_missing_source(self):
|
||||
"""Suggest should return suggestions for missing sources."""
|
||||
engine = make_engine(
|
||||
{
|
||||
"orders": {
|
||||
"name": "orders",
|
||||
"table": "public.orders",
|
||||
"grain": ["id"],
|
||||
"columns": [
|
||||
{"name": "id", "type": "number"},
|
||||
{"name": "amount", "type": "number"},
|
||||
],
|
||||
},
|
||||
}
|
||||
)
|
||||
result = engine.suggest(
|
||||
{
|
||||
"measures": ["sum(unknown_source.val)"],
|
||||
"dimensions": ["orders.id"],
|
||||
}
|
||||
)
|
||||
assert not result["success"]
|
||||
assert any(
|
||||
"missing" in s["description"].lower()
|
||||
or "unknown_source" in s["description"]
|
||||
for s in result.get("suggestions", [])
|
||||
)
|
||||
|
||||
def test_suggest_with_dict_measure_and_dimension(self):
|
||||
"""Suggest handles dict-format measures and dimensions in failure path."""
|
||||
engine = make_engine(
|
||||
{
|
||||
"orders": {
|
||||
"name": "orders",
|
||||
"table": "public.orders",
|
||||
"grain": ["id"],
|
||||
"columns": [
|
||||
{"name": "id", "type": "number"},
|
||||
{"name": "amount", "type": "number"},
|
||||
],
|
||||
},
|
||||
}
|
||||
)
|
||||
# Use a nested aggregate to trigger a planner error that hits the dict-handling code
|
||||
result = engine.suggest(
|
||||
{
|
||||
"measures": [{"expr": "avg(sum(missing.val))", "name": "total"}],
|
||||
"dimensions": [{"field": "missing.category"}],
|
||||
}
|
||||
)
|
||||
assert not result["success"]
|
||||
|
||||
|
||||
# ── Planner: order_by resolution formats (lines 113-116) ────────────
|
||||
|
||||
|
||||
class TestOrderByResolution:
|
||||
def test_order_by_as_dict(self):
|
||||
sources = _simple_sources()
|
||||
planner = _make_planner(sources)
|
||||
plan = planner.plan(
|
||||
SemanticQuery(
|
||||
measures=["sum(orders.amount)"],
|
||||
dimensions=["orders.status"],
|
||||
order_by=[{"field": "orders.status", "direction": "desc"}],
|
||||
)
|
||||
)
|
||||
assert len(plan.order_by) == 1
|
||||
assert plan.order_by[0].direction == "desc"
|
||||
|
||||
def test_order_by_as_string(self):
|
||||
sources = _simple_sources()
|
||||
planner = _make_planner(sources)
|
||||
plan = planner.plan(
|
||||
SemanticQuery(
|
||||
measures=["sum(orders.amount)"],
|
||||
dimensions=["orders.status"],
|
||||
order_by=["orders.status"],
|
||||
)
|
||||
)
|
||||
assert len(plan.order_by) == 1
|
||||
|
||||
|
||||
# ── Planner: measure with no source refs (line 343) ─────────────────
|
||||
|
||||
|
||||
class TestMeasureNoSourceRef:
|
||||
def test_bare_column_no_aggregate_raises(self):
|
||||
"""A measure like 'orders.nonexistent' that isn't predefined should raise."""
|
||||
sources = _simple_sources()
|
||||
planner = _make_planner(sources)
|
||||
with pytest.raises(
|
||||
ValueError, match="does not reference any source|not a pre-defined measure"
|
||||
):
|
||||
planner.plan(
|
||||
SemanticQuery(
|
||||
measures=["sum(1)"],
|
||||
dimensions=["orders.status"],
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
# ── Generator: custom aggregate parsing (lines 614-617) ─────────────
|
||||
|
||||
|
||||
class TestCustomAggregates:
|
||||
def test_count_distinct_generates_valid_sql(self):
|
||||
sources = _simple_sources()
|
||||
sql = _plan_and_generate(
|
||||
sources,
|
||||
{
|
||||
"measures": ["count(distinct orders.id)"],
|
||||
"dimensions": ["orders.status"],
|
||||
},
|
||||
)
|
||||
upper = sql.upper()
|
||||
assert "COUNT(DISTINCT" in upper or "COUNT (DISTINCT" in upper
|
||||
|
||||
|
||||
# ── Generator: qualified predefined expressions via multi-hop joins (lines 925-931) ──
|
||||
|
||||
|
||||
class TestQualifiedPredefinedExpr:
|
||||
def test_predefined_filter_with_joined_column(self):
|
||||
"""Predefined measure with a filter referencing a column from a joined table."""
|
||||
sources = _chain_sources_with_derived()
|
||||
sql = _plan_and_generate(
|
||||
sources,
|
||||
{
|
||||
"measures": ["orders.revenue"],
|
||||
"dimensions": ["tiers.level"],
|
||||
},
|
||||
)
|
||||
assert_valid_sql(sql)
|
||||
assert "CASE WHEN" in sql.upper()
|
||||
|
||||
|
||||
# ── End-to-end: chasm trap with aggregate locality ───────────────────
|
||||
|
||||
|
||||
class TestChasmTrapEndToEnd:
|
||||
def test_two_fact_tables_produce_valid_sql(self):
|
||||
sources = _chasm_sources()
|
||||
sql = _plan_and_generate(
|
||||
sources,
|
||||
{
|
||||
"measures": ["sum(orders.amount)", "count(tickets.id)"],
|
||||
"dimensions": ["customers.segment"],
|
||||
},
|
||||
)
|
||||
upper = sql.upper()
|
||||
assert "WITH" in upper
|
||||
assert "FULL" in upper or "JOIN" in upper
|
||||
|
||||
def test_chasm_with_filter_on_hub(self):
|
||||
sources = _chasm_sources()
|
||||
sql = _plan_and_generate(
|
||||
sources,
|
||||
{
|
||||
"measures": ["sum(orders.amount)", "count(tickets.id)"],
|
||||
"dimensions": ["customers.segment"],
|
||||
"filters": ["customers.segment = 'enterprise'"],
|
||||
},
|
||||
)
|
||||
assert "enterprise" in sql
|
||||
assert_valid_sql(sql)
|
||||
220
python/klo-sl/tests/test_duplicate_check.py
Normal file
220
python/klo-sl/tests/test_duplicate_check.py
Normal file
|
|
@ -0,0 +1,220 @@
|
|||
"""Tests for semantic_layer.duplicate_check.validate_measure_duplicates."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from semantic_layer.duplicate_check import validate_measure_duplicates
|
||||
from semantic_layer.models import (
|
||||
MeasureDefinition,
|
||||
SourceColumn,
|
||||
SourceDefinition,
|
||||
)
|
||||
|
||||
|
||||
def _make_source(name: str, measures: list[MeasureDefinition]) -> SourceDefinition:
|
||||
return SourceDefinition(
|
||||
name=name,
|
||||
table=f"public.{name}",
|
||||
grain=["id"],
|
||||
columns=[SourceColumn(name="id", type="number")],
|
||||
measures=measures,
|
||||
)
|
||||
|
||||
|
||||
def test_same_expr_different_filter_is_flagged() -> None:
|
||||
"""The replay-trimmed case: count(*) twice, one with is_active filter."""
|
||||
source = _make_source(
|
||||
"fct_subscriptions",
|
||||
[
|
||||
MeasureDefinition(
|
||||
name="active_subscription_count",
|
||||
expr="count(*)",
|
||||
filter="is_active = true",
|
||||
),
|
||||
MeasureDefinition(
|
||||
name="new_subscription_count",
|
||||
expr="count(*)",
|
||||
),
|
||||
],
|
||||
)
|
||||
errors = validate_measure_duplicates({"fct_subscriptions": source})
|
||||
assert len(errors) == 1
|
||||
assert "new_subscription_count" in errors[0]
|
||||
assert "active_subscription_count" in errors[0]
|
||||
assert "differs only by `filter`" in errors[0]
|
||||
|
||||
|
||||
def test_same_expr_same_filter_is_flagged() -> None:
|
||||
"""Two measures with identical expr and filter — flagged as duplicate pair."""
|
||||
source = _make_source(
|
||||
"fct_orders",
|
||||
[
|
||||
MeasureDefinition(
|
||||
name="order_count_a", expr="count(*)", filter="is_paid = true"
|
||||
),
|
||||
MeasureDefinition(
|
||||
name="order_count_b", expr="count(*)", filter="is_paid = true"
|
||||
),
|
||||
],
|
||||
)
|
||||
errors = validate_measure_duplicates({"fct_orders": source})
|
||||
assert len(errors) == 1
|
||||
assert "same expression and filter" in errors[0]
|
||||
|
||||
|
||||
def test_different_expr_is_not_flagged() -> None:
|
||||
"""count(*) vs sum(amount) on same source — legitimately distinct measures."""
|
||||
source = _make_source(
|
||||
"fct_orders",
|
||||
[
|
||||
MeasureDefinition(name="order_count", expr="count(*)"),
|
||||
MeasureDefinition(name="total_revenue", expr="sum(amount)"),
|
||||
MeasureDefinition(name="avg_revenue", expr="avg(amount)"),
|
||||
],
|
||||
)
|
||||
errors = validate_measure_duplicates({"fct_orders": source})
|
||||
assert errors == []
|
||||
|
||||
|
||||
def test_measures_on_different_sources_not_compared() -> None:
|
||||
"""Same expr on two different sources is not a duplicate."""
|
||||
a = _make_source("fct_a", [MeasureDefinition(name="total", expr="count(*)")])
|
||||
b = _make_source("fct_b", [MeasureDefinition(name="total", expr="count(*)")])
|
||||
errors = validate_measure_duplicates({"fct_a": a, "fct_b": b})
|
||||
assert errors == []
|
||||
|
||||
|
||||
def test_whitespace_and_case_are_normalized() -> None:
|
||||
"""COUNT(*) and count(*) and count( * ) all compare equal."""
|
||||
source = _make_source(
|
||||
"fct_orders",
|
||||
[
|
||||
MeasureDefinition(name="a", expr="count(*)"),
|
||||
MeasureDefinition(name="b", expr="COUNT(*)"),
|
||||
MeasureDefinition(name="c", expr=" count( * ) "),
|
||||
],
|
||||
)
|
||||
errors = validate_measure_duplicates({"fct_orders": source})
|
||||
# Three measures pairwise — should yield 3 errors (a vs b, a vs c, b vs c)
|
||||
assert len(errors) == 3
|
||||
|
||||
|
||||
def test_unparseable_expr_is_skipped_not_errored() -> None:
|
||||
"""A measure whose expr can't be parsed is ignored — don't block commit."""
|
||||
source = _make_source(
|
||||
"fct_orders",
|
||||
[
|
||||
MeasureDefinition(name="bad", expr="!!! not SQL !!!"),
|
||||
MeasureDefinition(name="good", expr="count(*)"),
|
||||
],
|
||||
)
|
||||
# Should not raise, should not flag — the parser validator will catch the bad one elsewhere
|
||||
errors = validate_measure_duplicates({"fct_orders": source})
|
||||
assert errors == []
|
||||
|
||||
|
||||
def test_non_commutative_args_not_treated_as_equivalent() -> None:
|
||||
"""safe_divide(a, b) is NOT equivalent to safe_divide(b, a)."""
|
||||
source = _make_source(
|
||||
"fct_orders",
|
||||
[
|
||||
MeasureDefinition(
|
||||
name="ratio_ab", expr="safe_divide(count(*), sum(amount))"
|
||||
),
|
||||
MeasureDefinition(
|
||||
name="ratio_ba", expr="safe_divide(sum(amount), count(*))"
|
||||
),
|
||||
],
|
||||
)
|
||||
errors = validate_measure_duplicates({"fct_orders": source})
|
||||
assert errors == []
|
||||
|
||||
|
||||
def test_single_measure_source_no_comparison() -> None:
|
||||
source = _make_source(
|
||||
"fct_orders", [MeasureDefinition(name="total", expr="count(*)")]
|
||||
)
|
||||
errors = validate_measure_duplicates({"fct_orders": source})
|
||||
assert errors == []
|
||||
|
||||
|
||||
def test_same_expr_different_segments_is_not_flagged() -> None:
|
||||
"""Two measures with same expr but different named segments are by-design distinct."""
|
||||
source = _make_source(
|
||||
"fct_subscriptions",
|
||||
[
|
||||
MeasureDefinition(
|
||||
name="active_count", expr="count(*)", segments=["active"]
|
||||
),
|
||||
MeasureDefinition(
|
||||
name="inactive_count", expr="count(*)", segments=["inactive"]
|
||||
),
|
||||
],
|
||||
)
|
||||
errors = validate_measure_duplicates({"fct_subscriptions": source})
|
||||
assert errors == []
|
||||
|
||||
|
||||
def test_same_expr_same_segments_is_flagged() -> None:
|
||||
"""Same expr + same segment set = a true duplicate."""
|
||||
source = _make_source(
|
||||
"fct_subscriptions",
|
||||
[
|
||||
MeasureDefinition(name="a_count", expr="count(*)", segments=["active"]),
|
||||
MeasureDefinition(name="b_count", expr="count(*)", segments=["active"]),
|
||||
],
|
||||
)
|
||||
errors = validate_measure_duplicates({"fct_subscriptions": source})
|
||||
assert len(errors) == 1
|
||||
assert "same expression and filter" in errors[0]
|
||||
|
||||
|
||||
def test_segment_difference_with_filter_difference_not_flagged() -> None:
|
||||
"""Segments differ → distinct measures even if filter also differs."""
|
||||
source = _make_source(
|
||||
"fct_subscriptions",
|
||||
[
|
||||
MeasureDefinition(
|
||||
name="m1",
|
||||
expr="count(*)",
|
||||
segments=["active"],
|
||||
filter="protocol = 'TRT'",
|
||||
),
|
||||
MeasureDefinition(name="m2", expr="count(*)", segments=["inactive"]),
|
||||
],
|
||||
)
|
||||
errors = validate_measure_duplicates({"fct_subscriptions": source})
|
||||
assert errors == []
|
||||
|
||||
|
||||
def test_bigquery_native_exprs_compared_correctly():
|
||||
"""Two measures with identical BigQuery-native exprs must be flagged as duplicates."""
|
||||
from semantic_layer.duplicate_check import validate_measure_duplicates
|
||||
from semantic_layer.models import (
|
||||
MeasureDefinition,
|
||||
SourceColumn,
|
||||
SourceDefinition,
|
||||
)
|
||||
|
||||
source = SourceDefinition(
|
||||
name="fct_orders",
|
||||
table="fct_orders",
|
||||
grain=["id"],
|
||||
columns=[
|
||||
SourceColumn(name="id", type="number"),
|
||||
SourceColumn(name="amount", type="number"),
|
||||
],
|
||||
measures=[
|
||||
MeasureDefinition(
|
||||
name="safe_ratio_a",
|
||||
expr="SAFE_DIVIDE(sum(amount), count(*))",
|
||||
),
|
||||
MeasureDefinition(
|
||||
name="safe_ratio_b",
|
||||
expr="SAFE_DIVIDE(sum(amount), count(*))",
|
||||
),
|
||||
],
|
||||
)
|
||||
errors = validate_measure_duplicates({"fct_orders": source}, dialect="bigquery")
|
||||
assert any("safe_ratio_a" in e and "safe_ratio_b" in e for e in errors), (
|
||||
f"Duplicate detection missed identical BigQuery-native exprs: {errors}"
|
||||
)
|
||||
1380
python/klo-sl/tests/test_engine.py
Normal file
1380
python/klo-sl/tests/test_engine.py
Normal file
File diff suppressed because it is too large
Load diff
2302
python/klo-sl/tests/test_generator.py
Normal file
2302
python/klo-sl/tests/test_generator.py
Normal file
File diff suppressed because it is too large
Load diff
731
python/klo-sl/tests/test_graph.py
Normal file
731
python/klo-sl/tests/test_graph.py
Normal file
|
|
@ -0,0 +1,731 @@
|
|||
import pytest
|
||||
|
||||
from semantic_layer.graph import JoinGraph
|
||||
from semantic_layer.models import SourceDefinition, SourceColumn, JoinDeclaration
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def graph(ecommerce_sources):
|
||||
g = JoinGraph(ecommerce_sources)
|
||||
g.build()
|
||||
return g
|
||||
|
||||
|
||||
class TestJoinGraphBuild:
|
||||
def test_all_sources_in_adjacency(self, graph, ecommerce_sources):
|
||||
assert set(graph.adjacency.keys()) == set(ecommerce_sources.keys())
|
||||
|
||||
def test_bidirectional_edges(self, graph):
|
||||
# orders declares join to customers → both directions exist
|
||||
orders_edges = graph.adjacency["orders"]
|
||||
assert any(e.to_source == "customers" for e in orders_edges)
|
||||
|
||||
customers_edges = graph.adjacency["customers"]
|
||||
assert any(e.to_source == "orders" for e in customers_edges)
|
||||
|
||||
def test_relationship_inversion(self, graph):
|
||||
# orders → customers is many_to_one
|
||||
fwd = next(e for e in graph.adjacency["orders"] if e.to_source == "customers")
|
||||
assert fwd.relationship == "many_to_one"
|
||||
|
||||
# customers → orders is one_to_many (reverse)
|
||||
rev = next(e for e in graph.adjacency["customers"] if e.to_source == "orders")
|
||||
assert rev.relationship == "one_to_many"
|
||||
|
||||
def test_on_parsing(self, graph):
|
||||
fwd = next(e for e in graph.adjacency["orders"] if e.to_source == "customers")
|
||||
assert fwd.from_column == "customer_id"
|
||||
assert fwd.to_column == "id"
|
||||
|
||||
|
||||
class TestFindPath:
|
||||
def test_direct_join(self, graph):
|
||||
path = graph.find_path("orders", "customers")
|
||||
assert path is not None
|
||||
assert len(path.edges) == 1
|
||||
assert path.edges[0].from_source == "orders"
|
||||
assert path.edges[0].to_source == "customers"
|
||||
assert not path.has_one_to_many
|
||||
|
||||
def test_two_hop_m2o(self, graph):
|
||||
# orders → customers → regions (all m2o)
|
||||
path = graph.find_path("orders", "regions")
|
||||
assert path is not None
|
||||
assert len(path.edges) == 2
|
||||
assert path.source_names == ["orders", "customers", "regions"]
|
||||
assert not path.has_one_to_many
|
||||
|
||||
def test_reverse_path_flagged(self, graph):
|
||||
# regions → customers (o2m) → orders (o2m)
|
||||
path = graph.find_path("regions", "orders")
|
||||
assert path is not None
|
||||
assert len(path.edges) == 2
|
||||
assert path.has_one_to_many
|
||||
|
||||
def test_through_bridge(self, graph):
|
||||
# orders → order_items is reverse (o2m), order_items → products is m2o
|
||||
# But shortest may be: orders ← order_items → products
|
||||
path = graph.find_path("orders", "products")
|
||||
assert path is not None
|
||||
assert "order_items" in path.source_names
|
||||
|
||||
def test_churn_risk_to_regions(self, graph):
|
||||
path = graph.find_path("churn_risk", "regions")
|
||||
assert path is not None
|
||||
assert "customers" in path.source_names
|
||||
|
||||
def test_same_source(self, graph):
|
||||
path = graph.find_path("orders", "orders")
|
||||
assert path is not None
|
||||
assert len(path.edges) == 0
|
||||
assert not path.has_one_to_many
|
||||
|
||||
def test_source_names_property(self, graph):
|
||||
path = graph.find_path("orders", "regions")
|
||||
assert path.source_names == ["orders", "customers", "regions"]
|
||||
|
||||
def test_empty_path_source_names(self, graph):
|
||||
path = graph.find_path("orders", "orders")
|
||||
assert path.source_names == []
|
||||
|
||||
|
||||
class TestResolveJoinTree:
|
||||
def test_single_source(self, graph):
|
||||
tree = graph.resolve_join_tree({"orders"})
|
||||
assert tree.sources == {"orders"}
|
||||
assert tree.edges == []
|
||||
|
||||
def test_two_sources(self, graph):
|
||||
tree = graph.resolve_join_tree({"orders", "customers"})
|
||||
assert "orders" in tree.sources
|
||||
assert "customers" in tree.sources
|
||||
assert len(tree.edges) >= 1
|
||||
|
||||
def test_three_sources_via_customers(self, graph):
|
||||
tree = graph.resolve_join_tree({"churn_risk", "regions", "orders"})
|
||||
assert "customers" in tree.sources # intermediate node added
|
||||
assert len(tree.sources) >= 4
|
||||
|
||||
def test_disconnected_raises(self):
|
||||
from semantic_layer.models import SourceDefinition, SourceColumn
|
||||
|
||||
src_a = SourceDefinition(
|
||||
name="a",
|
||||
table="t",
|
||||
grain=["id"],
|
||||
columns=[SourceColumn(name="id", type="number")],
|
||||
)
|
||||
src_b = SourceDefinition(
|
||||
name="b",
|
||||
table="t2",
|
||||
grain=["id"],
|
||||
columns=[SourceColumn(name="id", type="number")],
|
||||
)
|
||||
g = JoinGraph({"a": src_a, "b": src_b})
|
||||
g.build()
|
||||
with pytest.raises(ValueError, match="No join path"):
|
||||
g.resolve_join_tree({"a", "b"})
|
||||
|
||||
|
||||
class TestOneToOneRelationship:
|
||||
def test_one_to_one_no_fan_out(self):
|
||||
"""one_to_one joins should not flag has_one_to_many."""
|
||||
from semantic_layer.models import (
|
||||
SourceDefinition,
|
||||
SourceColumn,
|
||||
JoinDeclaration,
|
||||
)
|
||||
|
||||
users = SourceDefinition(
|
||||
name="users",
|
||||
table="t",
|
||||
grain=["id"],
|
||||
columns=[SourceColumn(name="id", type="number")],
|
||||
)
|
||||
profiles = SourceDefinition(
|
||||
name="profiles",
|
||||
table="t2",
|
||||
grain=["user_id"],
|
||||
columns=[SourceColumn(name="user_id", type="number")],
|
||||
joins=[
|
||||
JoinDeclaration(
|
||||
to="users", on="user_id = users.id", relationship="one_to_one"
|
||||
)
|
||||
],
|
||||
)
|
||||
g = JoinGraph({"users": users, "profiles": profiles})
|
||||
g.build()
|
||||
|
||||
path = g.find_path("profiles", "users")
|
||||
assert path is not None
|
||||
assert not path.has_one_to_many
|
||||
|
||||
# Reverse should also be one_to_one
|
||||
rev_path = g.find_path("users", "profiles")
|
||||
assert rev_path is not None
|
||||
assert not rev_path.has_one_to_many
|
||||
|
||||
def test_one_to_one_inverse(self):
|
||||
"""one_to_one inverted should stay one_to_one."""
|
||||
from semantic_layer.models import (
|
||||
SourceDefinition,
|
||||
SourceColumn,
|
||||
JoinDeclaration,
|
||||
)
|
||||
|
||||
a = SourceDefinition(
|
||||
name="a",
|
||||
table="t",
|
||||
grain=["id"],
|
||||
columns=[SourceColumn(name="id", type="number")],
|
||||
)
|
||||
b = SourceDefinition(
|
||||
name="b",
|
||||
table="t2",
|
||||
grain=["a_id"],
|
||||
columns=[SourceColumn(name="a_id", type="number")],
|
||||
joins=[
|
||||
JoinDeclaration(to="a", on="a_id = a.id", relationship="one_to_one")
|
||||
],
|
||||
)
|
||||
g = JoinGraph({"a": a, "b": b})
|
||||
g.build()
|
||||
|
||||
fwd = next(e for e in g.adjacency["b"] if e.to_source == "a")
|
||||
assert fwd.relationship == "one_to_one"
|
||||
rev = next(e for e in g.adjacency["a"] if e.to_source == "b")
|
||||
assert rev.relationship == "one_to_one"
|
||||
|
||||
|
||||
class TestMultipleJoinsFromSource:
|
||||
def test_order_items_two_joins(self, graph):
|
||||
"""order_items has joins to both orders and products."""
|
||||
oi_edges = graph.adjacency["order_items"]
|
||||
targets = {e.to_source for e in oi_edges}
|
||||
assert "orders" in targets
|
||||
assert "products" in targets
|
||||
|
||||
def test_path_through_bridge(self, graph):
|
||||
"""Can find path from orders to products through order_items."""
|
||||
path = graph.find_path("orders", "products")
|
||||
assert path is not None
|
||||
assert "order_items" in path.source_names
|
||||
|
||||
|
||||
class TestResolveJoinTreeRoot:
|
||||
def test_root_is_respected(self, graph):
|
||||
"""When root is specified, it should be the anchor of the tree."""
|
||||
tree = graph.resolve_join_tree({"orders", "regions"}, root="orders")
|
||||
assert "orders" in tree.sources
|
||||
assert "regions" in tree.sources
|
||||
assert "customers" in tree.sources # intermediate
|
||||
|
||||
def test_root_not_in_sources_uses_default(self, graph):
|
||||
"""When root is not in source_names, falls back to sorted order."""
|
||||
tree = graph.resolve_join_tree({"orders", "customers"}, root="nonexistent")
|
||||
assert "orders" in tree.sources
|
||||
assert "customers" in tree.sources
|
||||
|
||||
|
||||
class TestFindComponents:
|
||||
def test_connected_graph(self, graph):
|
||||
components = graph.find_components()
|
||||
assert len(components) == 1
|
||||
assert components[0] == set(graph.adjacency.keys())
|
||||
|
||||
def test_disconnected_graph(self):
|
||||
from semantic_layer.models import SourceDefinition, SourceColumn
|
||||
|
||||
src_a = SourceDefinition(
|
||||
name="a",
|
||||
table="t",
|
||||
grain=["id"],
|
||||
columns=[SourceColumn(name="id", type="number")],
|
||||
)
|
||||
src_b = SourceDefinition(
|
||||
name="b",
|
||||
table="t2",
|
||||
grain=["id"],
|
||||
columns=[SourceColumn(name="id", type="number")],
|
||||
)
|
||||
g = JoinGraph({"a": src_a, "b": src_b})
|
||||
g.build()
|
||||
components = g.find_components()
|
||||
assert len(components) == 2
|
||||
assert {frozenset(c) for c in components} == {
|
||||
frozenset({"a"}),
|
||||
frozenset({"b"}),
|
||||
}
|
||||
|
||||
|
||||
# ── From test_edge_cases.py ──────────────────────────────────────────
|
||||
|
||||
|
||||
class TestGraphEdgeCases:
|
||||
def test_self_referencing_join(self):
|
||||
emp_with_join = SourceDefinition(
|
||||
name="employees",
|
||||
table="t",
|
||||
grain=["id"],
|
||||
columns=[
|
||||
SourceColumn(name="id", type="number"),
|
||||
SourceColumn(name="manager_id", type="number"),
|
||||
SourceColumn(name="salary", type="number"),
|
||||
],
|
||||
joins=[
|
||||
JoinDeclaration(
|
||||
to="employees",
|
||||
on="manager_id = employees.id",
|
||||
relationship="many_to_one",
|
||||
)
|
||||
],
|
||||
)
|
||||
sources = {"employees": emp_with_join}
|
||||
graph = JoinGraph(sources)
|
||||
graph.build()
|
||||
path = graph.find_path("employees", "employees")
|
||||
assert path is not None
|
||||
assert len(path.edges) == 0
|
||||
|
||||
def test_no_sources(self):
|
||||
graph = JoinGraph({})
|
||||
graph.build()
|
||||
components = graph.find_components()
|
||||
assert components == []
|
||||
|
||||
def test_single_source_no_joins(self):
|
||||
src = SourceDefinition(
|
||||
name="a",
|
||||
table="t",
|
||||
grain=["id"],
|
||||
columns=[SourceColumn(name="id", type="number")],
|
||||
)
|
||||
graph = JoinGraph({"a": src})
|
||||
graph.build()
|
||||
assert graph.find_path("a", "a") is not None
|
||||
assert graph.find_path("a", "nonexistent") is None
|
||||
|
||||
def test_two_disconnected_sources(self):
|
||||
a = SourceDefinition(
|
||||
name="a",
|
||||
table="t",
|
||||
grain=["id"],
|
||||
columns=[SourceColumn(name="id", type="number")],
|
||||
)
|
||||
b = SourceDefinition(
|
||||
name="b",
|
||||
table="t2",
|
||||
grain=["id"],
|
||||
columns=[SourceColumn(name="id", type="number")],
|
||||
)
|
||||
graph = JoinGraph({"a": a, "b": b})
|
||||
graph.build()
|
||||
assert graph.find_path("a", "b") is None
|
||||
|
||||
def test_on_clause_with_spaces(self):
|
||||
g = JoinGraph({})
|
||||
result = g._parse_on(" customer_id = customers.id ", "customers")
|
||||
assert result == ("customer_id", "id")
|
||||
|
||||
def test_on_clause_without_prefix(self):
|
||||
g = JoinGraph({})
|
||||
result = g._parse_on("customer_id = id", "customers")
|
||||
assert result == ("customer_id", "id")
|
||||
|
||||
def test_on_clause_invalid(self):
|
||||
g = JoinGraph({})
|
||||
with pytest.raises(ValueError, match="Invalid join condition"):
|
||||
g._parse_on("customer_id", "customers")
|
||||
|
||||
def test_on_clause_three_parts(self):
|
||||
g = JoinGraph({})
|
||||
with pytest.raises(ValueError, match="Invalid join condition"):
|
||||
g._parse_on("a = b = c", "target")
|
||||
|
||||
def test_composite_join_key(self):
|
||||
"""Composite join: 'a = t.x AND b = t.y' → comma-separated columns."""
|
||||
g = JoinGraph({})
|
||||
from_col, to_col = g._parse_on(
|
||||
"product_id = inventory.product_id AND warehouse_id = inventory.warehouse_id",
|
||||
"inventory",
|
||||
)
|
||||
assert from_col == "product_id,warehouse_id"
|
||||
assert to_col == "product_id,warehouse_id"
|
||||
|
||||
def test_composite_join_key_with_source_prefix(self):
|
||||
"""Composite join with source prefix on left side."""
|
||||
g = JoinGraph({})
|
||||
from_col, to_col = g._parse_on(
|
||||
"items.product_id = inventory.product_id AND items.warehouse_id = inventory.warehouse_id",
|
||||
"inventory",
|
||||
)
|
||||
assert from_col == "product_id,warehouse_id"
|
||||
assert to_col == "product_id,warehouse_id"
|
||||
|
||||
def test_composite_join_generates_correct_sql(self):
|
||||
"""End-to-end: composite join keys produce multi-condition ON clause."""
|
||||
items = SourceDefinition(
|
||||
name="items",
|
||||
table="public.items",
|
||||
grain=["order_id", "product_id"],
|
||||
columns=[
|
||||
SourceColumn(name="order_id", type="number"),
|
||||
SourceColumn(name="product_id", type="number"),
|
||||
SourceColumn(name="warehouse_id", type="number"),
|
||||
SourceColumn(name="qty", type="number"),
|
||||
],
|
||||
joins=[
|
||||
JoinDeclaration(
|
||||
to="inventory",
|
||||
on="product_id = inventory.product_id AND warehouse_id = inventory.warehouse_id",
|
||||
relationship="many_to_one",
|
||||
)
|
||||
],
|
||||
)
|
||||
inv = SourceDefinition(
|
||||
name="inventory",
|
||||
table="public.inventory",
|
||||
grain=["product_id", "warehouse_id"],
|
||||
columns=[
|
||||
SourceColumn(name="product_id", type="number"),
|
||||
SourceColumn(name="warehouse_id", type="number"),
|
||||
SourceColumn(name="stock", type="number"),
|
||||
],
|
||||
)
|
||||
graph = JoinGraph({"items": items, "inventory": inv})
|
||||
graph.build()
|
||||
path = graph.find_path("items", "inventory")
|
||||
assert path is not None
|
||||
assert len(path.edges) == 1
|
||||
assert path.edges[0].from_column == "product_id,warehouse_id"
|
||||
assert path.edges[0].to_column == "product_id,warehouse_id"
|
||||
|
||||
def test_resolve_join_tree_empty_set(self):
|
||||
graph = JoinGraph({})
|
||||
graph.build()
|
||||
tree = graph.resolve_join_tree(set())
|
||||
assert tree.sources == set()
|
||||
assert tree.edges == []
|
||||
|
||||
|
||||
# ── From test_brainstorm_cases.py ────────────────────────────────────
|
||||
|
||||
|
||||
class TestJoinTreeReusesIntermediates:
|
||||
def test_resolve_join_tree_reuses_intermediate_sources(self):
|
||||
a = SourceDefinition(
|
||||
name="a",
|
||||
table="public.a",
|
||||
grain=["id"],
|
||||
columns=[SourceColumn(name="id", type="number")],
|
||||
joins=[
|
||||
JoinDeclaration(to="z", on="z_id = z.id", relationship="many_to_one")
|
||||
],
|
||||
)
|
||||
z = SourceDefinition(
|
||||
name="z",
|
||||
table="public.z",
|
||||
grain=["id"],
|
||||
columns=[
|
||||
SourceColumn(name="id", type="number"),
|
||||
SourceColumn(name="m_id", type="number"),
|
||||
],
|
||||
joins=[
|
||||
JoinDeclaration(to="m", on="m_id = m.id", relationship="many_to_one")
|
||||
],
|
||||
)
|
||||
m = SourceDefinition(
|
||||
name="m",
|
||||
table="public.m",
|
||||
grain=["id"],
|
||||
columns=[SourceColumn(name="id", type="number")],
|
||||
)
|
||||
|
||||
graph = JoinGraph({"a": a, "z": z, "m": m})
|
||||
graph.build()
|
||||
|
||||
tree = graph.resolve_join_tree({"a", "m", "z"}, root="a")
|
||||
|
||||
assert tree.sources == {"a", "z", "m"}
|
||||
assert len(tree.edges) == 2
|
||||
assert {(edge.from_source, edge.to_source) for edge in tree.edges} == {
|
||||
("a", "z"),
|
||||
("z", "m"),
|
||||
}
|
||||
|
||||
|
||||
class TestDijkstraEdgeWeightPreference:
|
||||
"""LIMIT 2: Dijkstra prefers safe (m2o) paths over one_to_many paths."""
|
||||
|
||||
def test_dijkstra_prefers_safe_path(self):
|
||||
"""1-hop o2m path vs 2-hop all-m2o path: Dijkstra should pick the 2-hop m2o path."""
|
||||
# A --o2m--> C (direct, 1-hop, but unsafe)
|
||||
# A --m2o--> B --m2o--> C (2-hop, all safe)
|
||||
a = SourceDefinition(
|
||||
name="a",
|
||||
table="t",
|
||||
grain=["id"],
|
||||
columns=[SourceColumn(name="id", type="number")],
|
||||
joins=[
|
||||
JoinDeclaration(to="c", on="c_id = c.id", relationship="one_to_many"),
|
||||
JoinDeclaration(to="b", on="b_id = b.id", relationship="many_to_one"),
|
||||
],
|
||||
)
|
||||
b = SourceDefinition(
|
||||
name="b",
|
||||
table="t2",
|
||||
grain=["id"],
|
||||
columns=[
|
||||
SourceColumn(name="id", type="number"),
|
||||
SourceColumn(name="c_id", type="number"),
|
||||
],
|
||||
joins=[
|
||||
JoinDeclaration(to="c", on="c_id = c.id", relationship="many_to_one"),
|
||||
],
|
||||
)
|
||||
c = SourceDefinition(
|
||||
name="c",
|
||||
table="t3",
|
||||
grain=["id"],
|
||||
columns=[
|
||||
SourceColumn(name="id", type="number"),
|
||||
SourceColumn(name="c_id", type="number"),
|
||||
],
|
||||
)
|
||||
g = JoinGraph({"a": a, "b": b, "c": c})
|
||||
g.build()
|
||||
|
||||
path = g.find_path("a", "c")
|
||||
assert path is not None
|
||||
# Should pick the 2-hop safe path (a -> b -> c) over the 1-hop o2m (a -> c)
|
||||
assert len(path.edges) == 2
|
||||
assert path.source_names == ["a", "b", "c"]
|
||||
assert not path.has_one_to_many
|
||||
|
||||
def test_dijkstra_uses_unsafe_when_only_option(self):
|
||||
"""When only an o2m path exists, it should still be returned."""
|
||||
a = SourceDefinition(
|
||||
name="a",
|
||||
table="t",
|
||||
grain=["id"],
|
||||
columns=[SourceColumn(name="id", type="number")],
|
||||
joins=[
|
||||
JoinDeclaration(to="b", on="b_id = b.id", relationship="one_to_many"),
|
||||
],
|
||||
)
|
||||
b = SourceDefinition(
|
||||
name="b",
|
||||
table="t2",
|
||||
grain=["id"],
|
||||
columns=[SourceColumn(name="id", type="number")],
|
||||
)
|
||||
g = JoinGraph({"a": a, "b": b})
|
||||
g.build()
|
||||
|
||||
path = g.find_path("a", "b")
|
||||
assert path is not None
|
||||
assert len(path.edges) == 1
|
||||
assert path.has_one_to_many
|
||||
|
||||
|
||||
class TestAmbiguousPathDetection:
|
||||
"""Tests for 12.1 fix: diamond graph ambiguity detection."""
|
||||
|
||||
@staticmethod
|
||||
def _diamond_sources():
|
||||
"""Diamond: A →(m2o) B →(m2o) D, A →(m2o) C →(m2o) D. Two equal-cost paths."""
|
||||
return {
|
||||
"a": SourceDefinition(
|
||||
name="a",
|
||||
table="t_a",
|
||||
grain=["id"],
|
||||
columns=[SourceColumn(name="id", type="number")],
|
||||
joins=[
|
||||
JoinDeclaration(
|
||||
to="b", on="b_id = b.id", relationship="many_to_one"
|
||||
),
|
||||
JoinDeclaration(
|
||||
to="c", on="c_id = c.id", relationship="many_to_one"
|
||||
),
|
||||
],
|
||||
),
|
||||
"b": SourceDefinition(
|
||||
name="b",
|
||||
table="t_b",
|
||||
grain=["id"],
|
||||
columns=[SourceColumn(name="id", type="number")],
|
||||
joins=[
|
||||
JoinDeclaration(
|
||||
to="d", on="d_id = d.id", relationship="many_to_one"
|
||||
)
|
||||
],
|
||||
),
|
||||
"c": SourceDefinition(
|
||||
name="c",
|
||||
table="t_c",
|
||||
grain=["id"],
|
||||
columns=[SourceColumn(name="id", type="number")],
|
||||
joins=[
|
||||
JoinDeclaration(
|
||||
to="d", on="d_id = d.id", relationship="many_to_one"
|
||||
)
|
||||
],
|
||||
),
|
||||
"d": SourceDefinition(
|
||||
name="d",
|
||||
table="t_d",
|
||||
grain=["id"],
|
||||
columns=[SourceColumn(name="id", type="number")],
|
||||
),
|
||||
}
|
||||
|
||||
def test_diamond_graph_is_ambiguous(self):
|
||||
g = JoinGraph(self._diamond_sources())
|
||||
g.build()
|
||||
path = g.find_path("a", "d")
|
||||
assert path is not None
|
||||
assert path.is_ambiguous is True
|
||||
|
||||
def test_linear_graph_not_ambiguous(self):
|
||||
"""A → B → C: single path, no ambiguity."""
|
||||
sources = {
|
||||
"a": SourceDefinition(
|
||||
name="a",
|
||||
table="t_a",
|
||||
grain=["id"],
|
||||
columns=[SourceColumn(name="id", type="number")],
|
||||
joins=[
|
||||
JoinDeclaration(
|
||||
to="b", on="b_id = b.id", relationship="many_to_one"
|
||||
)
|
||||
],
|
||||
),
|
||||
"b": SourceDefinition(
|
||||
name="b",
|
||||
table="t_b",
|
||||
grain=["id"],
|
||||
columns=[SourceColumn(name="id", type="number")],
|
||||
joins=[
|
||||
JoinDeclaration(
|
||||
to="c", on="c_id = c.id", relationship="many_to_one"
|
||||
)
|
||||
],
|
||||
),
|
||||
"c": SourceDefinition(
|
||||
name="c",
|
||||
table="t_c",
|
||||
grain=["id"],
|
||||
columns=[SourceColumn(name="id", type="number")],
|
||||
),
|
||||
}
|
||||
g = JoinGraph(sources)
|
||||
g.build()
|
||||
path = g.find_path("a", "c")
|
||||
assert path is not None
|
||||
assert path.is_ambiguous is False
|
||||
|
||||
def test_different_cost_paths_not_ambiguous(self):
|
||||
"""A →(m2o) B →(m2o) D and A →(o2m) C →(m2o) D: costs differ."""
|
||||
sources = {
|
||||
"a": SourceDefinition(
|
||||
name="a",
|
||||
table="t_a",
|
||||
grain=["id"],
|
||||
columns=[SourceColumn(name="id", type="number")],
|
||||
joins=[
|
||||
JoinDeclaration(
|
||||
to="b", on="b_id = b.id", relationship="many_to_one"
|
||||
),
|
||||
JoinDeclaration(
|
||||
to="c", on="id = c.a_id", relationship="one_to_many"
|
||||
),
|
||||
],
|
||||
),
|
||||
"b": SourceDefinition(
|
||||
name="b",
|
||||
table="t_b",
|
||||
grain=["id"],
|
||||
columns=[SourceColumn(name="id", type="number")],
|
||||
joins=[
|
||||
JoinDeclaration(
|
||||
to="d", on="d_id = d.id", relationship="many_to_one"
|
||||
)
|
||||
],
|
||||
),
|
||||
"c": SourceDefinition(
|
||||
name="c",
|
||||
table="t_c",
|
||||
grain=["id"],
|
||||
columns=[
|
||||
SourceColumn(name="id", type="number"),
|
||||
SourceColumn(name="a_id", type="number"),
|
||||
],
|
||||
joins=[
|
||||
JoinDeclaration(
|
||||
to="d", on="d_id = d.id", relationship="many_to_one"
|
||||
)
|
||||
],
|
||||
),
|
||||
"d": SourceDefinition(
|
||||
name="d",
|
||||
table="t_d",
|
||||
grain=["id"],
|
||||
columns=[SourceColumn(name="id", type="number")],
|
||||
),
|
||||
}
|
||||
g = JoinGraph(sources)
|
||||
g.build()
|
||||
path = g.find_path("a", "d")
|
||||
assert path is not None
|
||||
# Safe path (cost 2) vs unsafe path (cost 11) — not ambiguous
|
||||
assert path.is_ambiguous is False
|
||||
assert path.has_one_to_many is False
|
||||
|
||||
def test_ambiguous_path_warning_in_resolve_join_tree(self, caplog):
|
||||
"""resolve_join_tree logs a warning for ambiguous paths."""
|
||||
import logging
|
||||
|
||||
g = JoinGraph(self._diamond_sources())
|
||||
g.build()
|
||||
with caplog.at_level(logging.WARNING, logger="semantic_layer.graph"):
|
||||
g.resolve_join_tree({"a", "d"}, root="a")
|
||||
assert any("Ambiguous join path" in r.message for r in caplog.records)
|
||||
|
||||
|
||||
def test_bigquery_native_on_clause_extracts_column_pair():
|
||||
"""Join on: with BigQuery-specific casts must parse and yield column pairs."""
|
||||
orders = SourceDefinition(
|
||||
name="orders",
|
||||
table="orders",
|
||||
grain=["id"],
|
||||
columns=[
|
||||
SourceColumn(name="id", type="number"),
|
||||
SourceColumn(name="user_id", type="number"),
|
||||
],
|
||||
joins=[
|
||||
JoinDeclaration(
|
||||
to="users",
|
||||
on="user_id = SAFE_CAST(users.id AS INT64)",
|
||||
relationship="many_to_one",
|
||||
)
|
||||
],
|
||||
)
|
||||
users = SourceDefinition(
|
||||
name="users",
|
||||
table="users",
|
||||
grain=["id"],
|
||||
columns=[SourceColumn(name="id", type="number")],
|
||||
)
|
||||
graph = JoinGraph({"orders": orders, "users": users}, dialect="bigquery")
|
||||
graph.build()
|
||||
# The graph must have recorded the compatibility edge
|
||||
orders_edges = graph.adjacency.get("orders", [])
|
||||
assert any(e.to_source == "users" for e in orders_edges), (
|
||||
f"orders → users edge missing after BigQuery-native on: parse:\n{orders_edges}"
|
||||
)
|
||||
|
||||
|
||||
def test_joingraph_dialect_defaults_to_postgres():
|
||||
"""Default keeps existing test ergonomics unchanged."""
|
||||
g = JoinGraph({})
|
||||
assert g.dialect == "postgres"
|
||||
171
python/klo-sl/tests/test_loader.py
Normal file
171
python/klo-sl/tests/test_loader.py
Normal file
|
|
@ -0,0 +1,171 @@
|
|||
import pytest
|
||||
from pathlib import Path
|
||||
import tempfile
|
||||
|
||||
import yaml
|
||||
|
||||
from semantic_layer.loader import SourceLoader
|
||||
from semantic_layer.models import SourceDefinition
|
||||
|
||||
SOURCES_DIR = Path(__file__).parent.parent / "sources" / "ecommerce"
|
||||
|
||||
|
||||
class TestSourceLoader:
|
||||
def test_load_all_ecommerce(self, ecommerce_sources):
|
||||
assert len(ecommerce_sources) == 6
|
||||
assert set(ecommerce_sources.keys()) == {
|
||||
"customers",
|
||||
"orders",
|
||||
"regions",
|
||||
"products",
|
||||
"order_items",
|
||||
"churn_risk",
|
||||
}
|
||||
|
||||
def test_orders_source(self, ecommerce_sources):
|
||||
orders = ecommerce_sources["orders"]
|
||||
assert orders.is_table_source
|
||||
assert orders.table == "public.orders"
|
||||
assert orders.grain == ["id"]
|
||||
assert len(orders.columns) == 6
|
||||
assert len(orders.measures) == 5
|
||||
assert len(orders.joins) == 1
|
||||
assert orders.joins[0].to == "customers"
|
||||
assert orders.joins[0].relationship == "many_to_one"
|
||||
|
||||
def test_churn_risk_sql_source(self, ecommerce_sources):
|
||||
churn = ecommerce_sources["churn_risk"]
|
||||
assert churn.is_sql_source
|
||||
assert churn.sql is not None
|
||||
assert "calculate_churn_score" in churn.sql
|
||||
assert churn.grain == ["customer_id"]
|
||||
assert len(churn.measures) == 1
|
||||
assert churn.measures[0].name == "avg_risk"
|
||||
|
||||
def test_regions_no_joins(self, ecommerce_sources):
|
||||
regions = ecommerce_sources["regions"]
|
||||
assert regions.joins == []
|
||||
assert regions.measures == []
|
||||
|
||||
def test_order_items_bridge(self, ecommerce_sources):
|
||||
oi = ecommerce_sources["order_items"]
|
||||
assert len(oi.joins) == 2
|
||||
targets = {j.to for j in oi.joins}
|
||||
assert targets == {"orders", "products"}
|
||||
|
||||
def test_revenue_measure_has_filter(self, ecommerce_sources):
|
||||
orders = ecommerce_sources["orders"]
|
||||
revenue = next(m for m in orders.measures if m.name == "revenue")
|
||||
assert revenue.filter == "status != 'refunded'"
|
||||
assert revenue.expr == "sum(amount)"
|
||||
|
||||
def test_load_single_file(self):
|
||||
loader = SourceLoader(SOURCES_DIR)
|
||||
src = loader.load_file(SOURCES_DIR / "regions.yaml")
|
||||
assert src.name == "regions"
|
||||
assert isinstance(src, SourceDefinition)
|
||||
|
||||
def test_invalid_join_target(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
data = {
|
||||
"name": "bad_source",
|
||||
"table": "t",
|
||||
"grain": ["id"],
|
||||
"columns": [{"name": "id", "type": "number"}],
|
||||
"joins": [
|
||||
{
|
||||
"to": "nonexistent",
|
||||
"on": "id = nonexistent.id",
|
||||
"relationship": "many_to_one",
|
||||
}
|
||||
],
|
||||
}
|
||||
path = Path(tmpdir) / "bad.yaml"
|
||||
with open(path, "w") as f:
|
||||
yaml.dump(data, f)
|
||||
|
||||
loader = SourceLoader(tmpdir)
|
||||
with pytest.raises(ValueError, match="nonexistent"):
|
||||
loader.load_all()
|
||||
|
||||
def test_duplicate_source_name(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
data = {
|
||||
"name": "dupe",
|
||||
"table": "t",
|
||||
"grain": ["id"],
|
||||
"columns": [{"name": "id", "type": "number"}],
|
||||
}
|
||||
for fname in ["a.yaml", "b.yaml"]:
|
||||
with open(Path(tmpdir) / fname, "w") as f:
|
||||
yaml.dump(data, f)
|
||||
|
||||
loader = SourceLoader(tmpdir)
|
||||
with pytest.raises(ValueError, match="Duplicate source name"):
|
||||
loader.load_all()
|
||||
|
||||
def test_source_description_loads(self, ecommerce_sources):
|
||||
churn = ecommerce_sources["churn_risk"]
|
||||
assert churn.description is not None
|
||||
assert "churn" in churn.description.lower()
|
||||
|
||||
def test_column_role_loads(self, ecommerce_sources):
|
||||
orders = ecommerce_sources["orders"]
|
||||
time_col = next(c for c in orders.columns if c.name == "created_at")
|
||||
assert time_col.role == "time"
|
||||
|
||||
def test_source_without_description(self, ecommerce_sources):
|
||||
regions = ecommerce_sources["regions"]
|
||||
assert regions.description is None
|
||||
|
||||
|
||||
# ── From test_edge_cases.py ──────────────────────────────────────────
|
||||
|
||||
|
||||
class TestLoaderEdgeCases:
|
||||
def test_empty_directory(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
loader = SourceLoader(tmpdir)
|
||||
sources = loader.load_all()
|
||||
assert sources == {}
|
||||
|
||||
def test_non_yaml_files_ignored(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
(Path(tmpdir) / "readme.txt").write_text("not a yaml file")
|
||||
loader = SourceLoader(tmpdir)
|
||||
sources = loader.load_all()
|
||||
assert sources == {}
|
||||
|
||||
def test_yaml_with_extra_fields(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
data = {
|
||||
"name": "test",
|
||||
"table": "t",
|
||||
"grain": ["id"],
|
||||
"columns": [{"name": "id", "type": "number"}],
|
||||
"unknown_field": "should be rejected",
|
||||
}
|
||||
with open(Path(tmpdir) / "test.yaml", "w") as f:
|
||||
yaml.dump(data, f)
|
||||
loader = SourceLoader(tmpdir)
|
||||
try:
|
||||
sources = loader.load_all()
|
||||
assert "test" in sources
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def test_subdirectory_sources(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
subdir = Path(tmpdir) / "sub"
|
||||
subdir.mkdir()
|
||||
data = {
|
||||
"name": "nested",
|
||||
"table": "t",
|
||||
"grain": ["id"],
|
||||
"columns": [{"name": "id", "type": "number"}],
|
||||
}
|
||||
with open(subdir / "nested.yaml", "w") as f:
|
||||
yaml.dump(data, f)
|
||||
loader = SourceLoader(tmpdir)
|
||||
sources = loader.load_all()
|
||||
assert "nested" in sources
|
||||
619
python/klo-sl/tests/test_manifest.py
Normal file
619
python/klo-sl/tests/test_manifest.py
Normal file
|
|
@ -0,0 +1,619 @@
|
|||
"""Tests for manifest models, projection, overlay validation, and two-tier loading."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
|
||||
from semantic_layer.loader import SourceLoader
|
||||
from semantic_layer.manifest import (
|
||||
ManifestColumn,
|
||||
ManifestEntry,
|
||||
ManifestJoin,
|
||||
map_column_type,
|
||||
project_manifest_entry,
|
||||
validate_overlay,
|
||||
)
|
||||
from semantic_layer.models import ColumnRole
|
||||
|
||||
|
||||
# ── Type Mapping Tests ──────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestMapColumnType:
|
||||
def test_map_column_type_numbers(self):
|
||||
number_types = [
|
||||
"integer",
|
||||
"bigint",
|
||||
"smallint",
|
||||
"numeric",
|
||||
"decimal",
|
||||
"float",
|
||||
"double",
|
||||
"real",
|
||||
"int",
|
||||
"int2",
|
||||
"int4",
|
||||
"int8",
|
||||
"float4",
|
||||
"float8",
|
||||
"double precision",
|
||||
"number",
|
||||
"tinyint",
|
||||
"mediumint",
|
||||
]
|
||||
for db_type in number_types:
|
||||
assert map_column_type(db_type) == "number", (
|
||||
f"{db_type} should map to 'number'"
|
||||
)
|
||||
|
||||
def test_map_column_type_time(self):
|
||||
time_types = [
|
||||
"timestamp",
|
||||
"timestamptz",
|
||||
"timestamp with time zone",
|
||||
"timestamp without time zone",
|
||||
"TIMESTAMP_NTZ",
|
||||
"TIMESTAMP_LTZ",
|
||||
"TIMESTAMP_TZ",
|
||||
"datetime",
|
||||
"date",
|
||||
"time",
|
||||
"timetz",
|
||||
]
|
||||
for db_type in time_types:
|
||||
assert map_column_type(db_type) == "time", f"{db_type} should map to 'time'"
|
||||
|
||||
def test_map_column_type_boolean(self):
|
||||
for db_type in ["boolean", "bool"]:
|
||||
assert map_column_type(db_type) == "boolean", (
|
||||
f"{db_type} should map to 'boolean'"
|
||||
)
|
||||
|
||||
def test_map_column_type_string_fallback(self):
|
||||
string_types = ["varchar", "text", "char", "unknown", "jsonb", "xml"]
|
||||
for db_type in string_types:
|
||||
assert map_column_type(db_type) == "string", (
|
||||
f"{db_type} should map to 'string'"
|
||||
)
|
||||
|
||||
def test_map_column_type_strips_precision(self):
|
||||
assert map_column_type("numeric(10,2)") == "number"
|
||||
assert map_column_type("varchar(255)") == "string"
|
||||
assert map_column_type("decimal(18,4)") == "number"
|
||||
assert map_column_type("timestamp(6)") == "time"
|
||||
assert map_column_type("char(1)") == "string"
|
||||
|
||||
|
||||
# ── Manifest Projection Tests ──────────────────────────────────────
|
||||
|
||||
|
||||
class TestProjectManifestEntry:
|
||||
@pytest.fixture()
|
||||
def orders_entry(self) -> ManifestEntry:
|
||||
return ManifestEntry(
|
||||
table="public.orders",
|
||||
description="Customer orders",
|
||||
columns=[
|
||||
ManifestColumn(name="id", type="integer", pk=True),
|
||||
ManifestColumn(name="customer_id", type="integer"),
|
||||
ManifestColumn(name="total", type="numeric"),
|
||||
ManifestColumn(name="status", type="varchar"),
|
||||
ManifestColumn(name="created_at", type="timestamp"),
|
||||
],
|
||||
joins=[
|
||||
ManifestJoin(
|
||||
to="customers",
|
||||
on="orders.customer_id = customers.id",
|
||||
relationship="many_to_one",
|
||||
source="formal",
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
def test_project_manifest_entry_basic(self, orders_entry: ManifestEntry):
|
||||
src = project_manifest_entry("orders", orders_entry)
|
||||
assert src.name == "orders"
|
||||
assert src.table == "public.orders"
|
||||
assert src.description == "Customer orders"
|
||||
assert len(src.columns) == 5
|
||||
assert src.measures == []
|
||||
col_names = [c.name for c in src.columns]
|
||||
assert col_names == ["id", "customer_id", "total", "status", "created_at"]
|
||||
|
||||
def test_project_manifest_entry_type_mapping(self, orders_entry: ManifestEntry):
|
||||
src = project_manifest_entry("orders", orders_entry)
|
||||
col_types = {c.name: c.type for c in src.columns}
|
||||
assert col_types["id"] == "number"
|
||||
assert col_types["customer_id"] == "number"
|
||||
assert col_types["total"] == "number"
|
||||
assert col_types["status"] == "string"
|
||||
assert col_types["created_at"] == "time"
|
||||
|
||||
def test_project_manifest_entry_grain_from_pk(self, orders_entry: ManifestEntry):
|
||||
src = project_manifest_entry("orders", orders_entry)
|
||||
assert src.grain == ["id"]
|
||||
|
||||
def test_project_manifest_entry_grain_all_columns_no_pk(self):
|
||||
entry = ManifestEntry(
|
||||
table="public.events",
|
||||
columns=[
|
||||
ManifestColumn(name="user_id", type="integer"),
|
||||
ManifestColumn(name="event_type", type="varchar"),
|
||||
ManifestColumn(name="ts", type="timestamp"),
|
||||
],
|
||||
)
|
||||
src = project_manifest_entry("events", entry)
|
||||
assert src.grain == ["user_id", "event_type", "ts"]
|
||||
|
||||
def test_project_manifest_entry_joins_stripped(self, orders_entry: ManifestEntry):
|
||||
src = project_manifest_entry("orders", orders_entry)
|
||||
assert len(src.joins) == 1
|
||||
join = src.joins[0]
|
||||
assert join.to == "customers"
|
||||
assert join.on == "orders.customer_id = customers.id"
|
||||
assert join.relationship == "many_to_one"
|
||||
assert not hasattr(join, "source") or getattr(join, "source", None) is None
|
||||
|
||||
def test_project_manifest_entry_time_role(self, orders_entry: ManifestEntry):
|
||||
src = project_manifest_entry("orders", orders_entry)
|
||||
time_cols = [c for c in src.columns if c.role == ColumnRole.TIME]
|
||||
assert len(time_cols) == 1
|
||||
assert time_cols[0].name == "created_at"
|
||||
non_time = [c for c in src.columns if c.role == ColumnRole.DEFAULT]
|
||||
assert len(non_time) == 4
|
||||
|
||||
def test_project_manifest_entry_preserves_dbt_metadata(self):
|
||||
entry = ManifestEntry(
|
||||
table="public.orders",
|
||||
columns=[
|
||||
ManifestColumn(
|
||||
name="status",
|
||||
type="varchar",
|
||||
constraints={"dbt": {"not_null": True}},
|
||||
enum_values={"dbt": ["placed", "shipped"]},
|
||||
tests={"dbt": [{"name": "accepted_values", "package": "dbt"}]},
|
||||
)
|
||||
],
|
||||
tags={"dbt": ["mart"]},
|
||||
freshness={"dbt": {"loaded_at_field": "updated_at"}},
|
||||
)
|
||||
|
||||
src = project_manifest_entry("orders", entry)
|
||||
|
||||
assert src.columns[0].constraints is not None
|
||||
assert src.columns[0].constraints["dbt"].not_null is True
|
||||
assert src.columns[0].enum_values == {"dbt": ["placed", "shipped"]}
|
||||
assert src.columns[0].tests is not None
|
||||
assert src.columns[0].tests.model_dump(mode="python", exclude_none=True) == {
|
||||
"dbt": [{"name": "accepted_values", "package": "dbt"}]
|
||||
}
|
||||
assert src.tags == {"dbt": ["mart"]}
|
||||
assert src.freshness is not None
|
||||
assert src.freshness["dbt"].loaded_at_field == "updated_at"
|
||||
|
||||
|
||||
# ── Overlay Validation Tests ───────────────────────────────────────
|
||||
|
||||
|
||||
class TestValidateOverlay:
|
||||
def test_validate_overlay_valid(self):
|
||||
data = {
|
||||
"name": "orders",
|
||||
"description": "Revenue-bearing orders",
|
||||
"grain": ["id"],
|
||||
"measures": [{"name": "revenue", "expr": "sum(total)"}],
|
||||
"columns": [
|
||||
{"name": "is_high_value", "expr": "total > 1000", "type": "boolean"}
|
||||
],
|
||||
"exclude_columns": ["status"],
|
||||
}
|
||||
errors = validate_overlay(data)
|
||||
assert errors == []
|
||||
|
||||
def test_validate_overlay_rejects_table(self):
|
||||
data = {"name": "orders", "table": "public.orders"}
|
||||
errors = validate_overlay(data)
|
||||
assert len(errors) == 1
|
||||
assert "table" in errors[0].lower()
|
||||
|
||||
def test_validate_overlay_rejects_sql(self):
|
||||
data = {"name": "orders", "sql": "SELECT * FROM orders"}
|
||||
errors = validate_overlay(data)
|
||||
assert len(errors) == 1
|
||||
assert "sql" in errors[0].lower()
|
||||
|
||||
def test_validate_overlay_rejects_type_without_expr(self):
|
||||
data = {
|
||||
"name": "orders",
|
||||
"columns": [{"name": "status", "type": "string"}],
|
||||
}
|
||||
errors = validate_overlay(data)
|
||||
assert len(errors) == 1
|
||||
assert "type" in errors[0].lower()
|
||||
assert "expr" in errors[0].lower()
|
||||
|
||||
def test_validate_overlay_allows_type_with_expr(self):
|
||||
data = {
|
||||
"name": "orders",
|
||||
"columns": [{"name": "is_big", "type": "boolean", "expr": "total > 1000"}],
|
||||
}
|
||||
errors = validate_overlay(data)
|
||||
assert errors == []
|
||||
|
||||
|
||||
# ── Two-Tier Loading Tests ─────────────────────────────────────────
|
||||
|
||||
|
||||
def _write_yaml(path: Path, data: dict | list) -> None:
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(path, "w") as f:
|
||||
yaml.dump(data, f, default_flow_style=False)
|
||||
|
||||
|
||||
def _manifest_tables() -> dict:
|
||||
"""Manifest shard with orders + customers tables."""
|
||||
return {
|
||||
"tables": {
|
||||
"orders": {
|
||||
"table": "public.orders",
|
||||
"description": "Customer orders",
|
||||
"columns": [
|
||||
{"name": "id", "type": "integer", "pk": True},
|
||||
{"name": "customer_id", "type": "integer"},
|
||||
{"name": "total", "type": "numeric"},
|
||||
{"name": "status", "type": "varchar"},
|
||||
{"name": "created_at", "type": "timestamp"},
|
||||
],
|
||||
"joins": [
|
||||
{
|
||||
"to": "customers",
|
||||
"on": "orders.customer_id = customers.id",
|
||||
"relationship": "many_to_one",
|
||||
"source": "formal",
|
||||
},
|
||||
],
|
||||
},
|
||||
"customers": {
|
||||
"table": "public.customers",
|
||||
"description": "Customer accounts",
|
||||
"columns": [
|
||||
{"name": "id", "type": "integer", "pk": True},
|
||||
{"name": "name", "type": "varchar"},
|
||||
],
|
||||
"joins": [
|
||||
{
|
||||
"to": "orders",
|
||||
"on": "customers.id = orders.customer_id",
|
||||
"relationship": "one_to_many",
|
||||
"source": "formal",
|
||||
},
|
||||
],
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class TestTwoTierLoading:
|
||||
def test_load_manifest_shard(self, tmp_path: Path):
|
||||
schema_dir = tmp_path / "_schema"
|
||||
_write_yaml(schema_dir / "public.yaml", _manifest_tables())
|
||||
|
||||
loader = SourceLoader(tmp_path)
|
||||
sources = loader.load_all()
|
||||
|
||||
assert "orders" in sources
|
||||
assert "customers" in sources
|
||||
assert sources["orders"].table == "public.orders"
|
||||
assert sources["orders"].grain == ["id"]
|
||||
assert sources["customers"].table == "public.customers"
|
||||
|
||||
def test_load_standalone_source(self, tmp_path: Path):
|
||||
standalone = {
|
||||
"name": "regions",
|
||||
"table": "public.regions",
|
||||
"grain": ["id"],
|
||||
"columns": [
|
||||
{"name": "id", "type": "number"},
|
||||
{"name": "name", "type": "string"},
|
||||
],
|
||||
}
|
||||
_write_yaml(tmp_path / "regions.yaml", standalone)
|
||||
|
||||
loader = SourceLoader(tmp_path)
|
||||
sources = loader.load_all()
|
||||
|
||||
assert "regions" in sources
|
||||
assert sources["regions"].table == "public.regions"
|
||||
assert sources["regions"].is_table_source
|
||||
|
||||
def test_overlay_descriptions_do_not_promote_base_description_to_user_source(
|
||||
self, tmp_path: Path
|
||||
):
|
||||
standalone = {
|
||||
"name": "regions",
|
||||
"description": "Standalone description",
|
||||
"table": "public.regions",
|
||||
"grain": ["id"],
|
||||
"columns": [
|
||||
{"name": "id", "type": "number"},
|
||||
],
|
||||
}
|
||||
_write_yaml(tmp_path / "a_regions.yaml", standalone)
|
||||
|
||||
overlay = {"name": "regions", "descriptions": {"dbt": "dbt description"}}
|
||||
_write_yaml(tmp_path / "z_regions_overlay.yaml", overlay)
|
||||
|
||||
loader = SourceLoader(tmp_path)
|
||||
sources = loader.load_all()
|
||||
|
||||
assert sources["regions"].description == "dbt description"
|
||||
|
||||
def test_load_sql_source(self, tmp_path: Path):
|
||||
sql_source = {
|
||||
"name": "active_users",
|
||||
"sql": "SELECT id, email FROM users WHERE active = true",
|
||||
"grain": ["id"],
|
||||
"columns": [
|
||||
{"name": "id", "type": "number"},
|
||||
{"name": "email", "type": "string"},
|
||||
],
|
||||
}
|
||||
_write_yaml(tmp_path / "active_users.yaml", sql_source)
|
||||
|
||||
loader = SourceLoader(tmp_path)
|
||||
sources = loader.load_all()
|
||||
|
||||
assert "active_users" in sources
|
||||
assert sources["active_users"].is_sql_source
|
||||
assert "SELECT" in sources["active_users"].sql
|
||||
|
||||
def test_load_overlay_composition(self, tmp_path: Path):
|
||||
schema_dir = tmp_path / "_schema"
|
||||
_write_yaml(schema_dir / "public.yaml", _manifest_tables())
|
||||
|
||||
overlay = {
|
||||
"name": "orders",
|
||||
"description": "Revenue-bearing orders",
|
||||
"grain": ["id"],
|
||||
"measures": [{"name": "revenue", "expr": "sum(total)"}],
|
||||
}
|
||||
_write_yaml(tmp_path / "orders.yaml", overlay)
|
||||
|
||||
# Customers overlay (empty, just name match) to avoid cross-ref error
|
||||
_write_yaml(tmp_path / "customers.yaml", {"name": "customers"})
|
||||
|
||||
loader = SourceLoader(tmp_path)
|
||||
sources = loader.load_all()
|
||||
|
||||
orders = sources["orders"]
|
||||
assert orders.table == "public.orders"
|
||||
assert orders.description == "Revenue-bearing orders"
|
||||
assert len(orders.measures) == 1
|
||||
assert orders.measures[0].name == "revenue"
|
||||
|
||||
def test_overlay_description_override(self, tmp_path: Path):
|
||||
schema_dir = tmp_path / "_schema"
|
||||
_write_yaml(schema_dir / "public.yaml", _manifest_tables())
|
||||
|
||||
overlay = {"name": "orders", "description": "Overridden description"}
|
||||
_write_yaml(tmp_path / "orders.yaml", overlay)
|
||||
_write_yaml(tmp_path / "customers.yaml", {"name": "customers"})
|
||||
|
||||
loader = SourceLoader(tmp_path)
|
||||
sources = loader.load_all()
|
||||
assert sources["orders"].description == "Overridden description"
|
||||
|
||||
def test_overlay_descriptions_map_preserves_higher_priority_manifest_description(
|
||||
self, tmp_path: Path
|
||||
):
|
||||
schema_dir = tmp_path / "_schema"
|
||||
_write_yaml(schema_dir / "public.yaml", _manifest_tables())
|
||||
|
||||
overlay = {
|
||||
"name": "orders",
|
||||
"descriptions": {
|
||||
"db": "DB description",
|
||||
"dbt": "dbt description",
|
||||
},
|
||||
}
|
||||
_write_yaml(tmp_path / "orders.yaml", overlay)
|
||||
_write_yaml(tmp_path / "customers.yaml", {"name": "customers"})
|
||||
|
||||
loader = SourceLoader(tmp_path)
|
||||
sources = loader.load_all()
|
||||
assert sources["orders"].description == "Customer orders"
|
||||
|
||||
def test_overlay_descriptions_map_overrides_lower_priority_db_description(
|
||||
self, tmp_path: Path
|
||||
):
|
||||
schema_dir = tmp_path / "_schema"
|
||||
_write_yaml(
|
||||
schema_dir / "public.yaml",
|
||||
{
|
||||
"tables": {
|
||||
"orders": {
|
||||
"table": "public.orders",
|
||||
"descriptions": {"db": "DB description"},
|
||||
"columns": [{"name": "id", "type": "integer", "pk": True}],
|
||||
},
|
||||
"customers": {
|
||||
"table": "public.customers",
|
||||
"columns": [{"name": "id", "type": "integer", "pk": True}],
|
||||
},
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
overlay = {
|
||||
"name": "orders",
|
||||
"descriptions": {
|
||||
"dbt": "dbt description",
|
||||
},
|
||||
}
|
||||
_write_yaml(tmp_path / "orders.yaml", overlay)
|
||||
_write_yaml(tmp_path / "customers.yaml", {"name": "customers"})
|
||||
|
||||
loader = SourceLoader(tmp_path)
|
||||
sources = loader.load_all()
|
||||
assert sources["orders"].description == "dbt description"
|
||||
|
||||
def test_overlay_exclude_columns(self, tmp_path: Path):
|
||||
schema_dir = tmp_path / "_schema"
|
||||
_write_yaml(schema_dir / "public.yaml", _manifest_tables())
|
||||
|
||||
overlay = {"name": "orders", "exclude_columns": ["status"]}
|
||||
_write_yaml(tmp_path / "orders.yaml", overlay)
|
||||
_write_yaml(tmp_path / "customers.yaml", {"name": "customers"})
|
||||
|
||||
loader = SourceLoader(tmp_path)
|
||||
sources = loader.load_all()
|
||||
|
||||
col_names = [c.name for c in sources["orders"].columns]
|
||||
assert "status" not in col_names
|
||||
assert "id" in col_names
|
||||
assert "total" in col_names
|
||||
|
||||
def test_overlay_computed_columns_appended(self, tmp_path: Path):
|
||||
schema_dir = tmp_path / "_schema"
|
||||
_write_yaml(schema_dir / "public.yaml", _manifest_tables())
|
||||
|
||||
overlay = {
|
||||
"name": "orders",
|
||||
"columns": [
|
||||
{"name": "is_high_value", "expr": "total > 1000", "type": "boolean"},
|
||||
],
|
||||
}
|
||||
_write_yaml(tmp_path / "orders.yaml", overlay)
|
||||
_write_yaml(tmp_path / "customers.yaml", {"name": "customers"})
|
||||
|
||||
loader = SourceLoader(tmp_path)
|
||||
sources = loader.load_all()
|
||||
|
||||
col_names = [c.name for c in sources["orders"].columns]
|
||||
assert "is_high_value" in col_names
|
||||
# Original columns still present
|
||||
assert "id" in col_names
|
||||
assert "total" in col_names
|
||||
# Computed column is at end
|
||||
hv = next(c for c in sources["orders"].columns if c.name == "is_high_value")
|
||||
assert hv.expr == "total > 1000"
|
||||
assert hv.type == "boolean"
|
||||
|
||||
def test_overlay_measures_set(self, tmp_path: Path):
|
||||
schema_dir = tmp_path / "_schema"
|
||||
_write_yaml(schema_dir / "public.yaml", _manifest_tables())
|
||||
|
||||
overlay = {
|
||||
"name": "orders",
|
||||
"measures": [
|
||||
{"name": "revenue", "expr": "sum(total)"},
|
||||
{"name": "order_count", "expr": "count(id)"},
|
||||
],
|
||||
}
|
||||
_write_yaml(tmp_path / "orders.yaml", overlay)
|
||||
_write_yaml(tmp_path / "customers.yaml", {"name": "customers"})
|
||||
|
||||
loader = SourceLoader(tmp_path)
|
||||
sources = loader.load_all()
|
||||
|
||||
assert len(sources["orders"].measures) == 2
|
||||
measure_names = {m.name for m in sources["orders"].measures}
|
||||
assert measure_names == {"revenue", "order_count"}
|
||||
|
||||
def test_overlay_grain_override(self, tmp_path: Path):
|
||||
schema_dir = tmp_path / "_schema"
|
||||
_write_yaml(schema_dir / "public.yaml", _manifest_tables())
|
||||
|
||||
overlay = {"name": "orders", "grain": ["id", "customer_id"]}
|
||||
_write_yaml(tmp_path / "orders.yaml", overlay)
|
||||
_write_yaml(tmp_path / "customers.yaml", {"name": "customers"})
|
||||
|
||||
loader = SourceLoader(tmp_path)
|
||||
sources = loader.load_all()
|
||||
assert sources["orders"].grain == ["id", "customer_id"]
|
||||
|
||||
def test_overlay_join_union_and_dedupe(self, tmp_path: Path):
|
||||
schema_dir = tmp_path / "_schema"
|
||||
_write_yaml(schema_dir / "public.yaml", _manifest_tables())
|
||||
|
||||
# Add a "regions" standalone so the join target exists
|
||||
_write_yaml(
|
||||
tmp_path / "regions.yaml",
|
||||
{
|
||||
"name": "regions",
|
||||
"table": "public.regions",
|
||||
"grain": ["id"],
|
||||
"columns": [
|
||||
{"name": "id", "type": "number"},
|
||||
{"name": "name", "type": "string"},
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
overlay = {
|
||||
"name": "orders",
|
||||
"joins": [
|
||||
# Duplicate of manifest join (should be deduped)
|
||||
{
|
||||
"to": "customers",
|
||||
"on": "orders.customer_id = customers.id",
|
||||
"relationship": "many_to_one",
|
||||
},
|
||||
# New join
|
||||
{
|
||||
"to": "regions",
|
||||
"on": "orders.region_id = regions.id",
|
||||
"relationship": "many_to_one",
|
||||
},
|
||||
],
|
||||
}
|
||||
_write_yaml(tmp_path / "orders.yaml", overlay)
|
||||
_write_yaml(tmp_path / "customers.yaml", {"name": "customers"})
|
||||
|
||||
loader = SourceLoader(tmp_path)
|
||||
sources = loader.load_all()
|
||||
|
||||
joins = sources["orders"].joins
|
||||
# Manifest had 1 join to customers, overlay adds 1 new (regions), duplicate deduped
|
||||
assert len(joins) == 2
|
||||
join_targets = [j.to for j in joins]
|
||||
assert "customers" in join_targets
|
||||
assert "regions" in join_targets
|
||||
|
||||
def test_overlay_disable_joins(self, tmp_path: Path):
|
||||
schema_dir = tmp_path / "_schema"
|
||||
_write_yaml(schema_dir / "public.yaml", _manifest_tables())
|
||||
|
||||
overlay = {
|
||||
"name": "orders",
|
||||
"disable_joins": ["orders.customer_id = customers.id"],
|
||||
}
|
||||
_write_yaml(tmp_path / "orders.yaml", overlay)
|
||||
|
||||
# Customers still needs to exist since the customers manifest entry has
|
||||
# a join back to orders that is NOT disabled
|
||||
_write_yaml(tmp_path / "customers.yaml", {"name": "customers"})
|
||||
|
||||
loader = SourceLoader(tmp_path)
|
||||
sources = loader.load_all()
|
||||
|
||||
assert len(sources["orders"].joins) == 0
|
||||
|
||||
def test_overlay_rejects_invalid(self, tmp_path: Path):
|
||||
schema_dir = tmp_path / "_schema"
|
||||
_write_yaml(schema_dir / "public.yaml", _manifest_tables())
|
||||
|
||||
# An overlay with a column that has type but no expr is invalid
|
||||
overlay = {
|
||||
"name": "orders",
|
||||
"columns": [{"name": "status", "type": "string"}],
|
||||
}
|
||||
_write_yaml(tmp_path / "orders.yaml", overlay)
|
||||
_write_yaml(tmp_path / "customers.yaml", {"name": "customers"})
|
||||
|
||||
loader = SourceLoader(tmp_path)
|
||||
with pytest.raises(ValueError, match="Invalid overlay"):
|
||||
loader.load_all()
|
||||
373
python/klo-sl/tests/test_models.py
Normal file
373
python/klo-sl/tests/test_models.py
Normal file
|
|
@ -0,0 +1,373 @@
|
|||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from semantic_layer.models import (
|
||||
ColumnRole,
|
||||
ColumnVisibility,
|
||||
ColumnDbtConstraints,
|
||||
DefaultTimeDimensionDbt,
|
||||
FreshnessDbt,
|
||||
MeasureGroup,
|
||||
Provenance,
|
||||
QueryResult,
|
||||
ResolvedColumn,
|
||||
ResolvedMeasure,
|
||||
ResolvedPlan,
|
||||
SemanticQuery,
|
||||
SourceColumn,
|
||||
SourceDefinition,
|
||||
)
|
||||
|
||||
|
||||
class TestSourceColumn:
|
||||
def test_defaults(self):
|
||||
col = SourceColumn(name="id", type="number")
|
||||
assert col.visibility == ColumnVisibility.PUBLIC
|
||||
assert col.role == ColumnRole.DEFAULT
|
||||
assert col.description is None
|
||||
|
||||
def test_all_fields(self):
|
||||
col = SourceColumn(
|
||||
name="id", type="number", visibility="hidden", role="time", description="PK"
|
||||
)
|
||||
assert col.visibility == ColumnVisibility.HIDDEN
|
||||
assert col.role == ColumnRole.TIME
|
||||
|
||||
def test_invalid_type(self):
|
||||
with pytest.raises(ValidationError):
|
||||
SourceColumn(name="id", type="integer")
|
||||
|
||||
|
||||
class TestSourceDefinition:
|
||||
def test_table_source(self):
|
||||
src = SourceDefinition(
|
||||
name="orders",
|
||||
table="public.orders",
|
||||
grain=["id"],
|
||||
columns=[SourceColumn(name="id", type="number")],
|
||||
)
|
||||
assert src.table == "public.orders"
|
||||
assert src.sql is None
|
||||
assert src.is_table_source
|
||||
assert not src.is_sql_source
|
||||
|
||||
def test_sql_source(self):
|
||||
src = SourceDefinition(
|
||||
name="churn",
|
||||
sql="SELECT * FROM x",
|
||||
grain=["customer_id"],
|
||||
columns=[SourceColumn(name="customer_id", type="number")],
|
||||
)
|
||||
assert src.sql == "SELECT * FROM x"
|
||||
assert src.table is None
|
||||
assert src.is_sql_source
|
||||
assert not src.is_table_source
|
||||
|
||||
def test_table_and_sql_mutually_exclusive(self):
|
||||
with pytest.raises(ValidationError, match="mutually exclusive"):
|
||||
SourceDefinition(
|
||||
name="bad",
|
||||
table="t",
|
||||
sql="SELECT 1",
|
||||
grain=["id"],
|
||||
columns=[SourceColumn(name="id", type="number")],
|
||||
)
|
||||
|
||||
def test_empty_grain_rejected(self):
|
||||
with pytest.raises(ValidationError, match="grain must be non-empty"):
|
||||
SourceDefinition(
|
||||
name="bad",
|
||||
table="t",
|
||||
grain=[],
|
||||
columns=[SourceColumn(name="id", type="number")],
|
||||
)
|
||||
|
||||
def test_measures_and_joins(self):
|
||||
src = SourceDefinition(
|
||||
name="orders",
|
||||
table="public.orders",
|
||||
grain=["id"],
|
||||
columns=[SourceColumn(name="id", type="number")],
|
||||
joins=[
|
||||
{
|
||||
"to": "customers",
|
||||
"on": "cid = customers.id",
|
||||
"relationship": "many_to_one",
|
||||
}
|
||||
],
|
||||
measures=[{"name": "revenue", "expr": "sum(amount)"}],
|
||||
)
|
||||
assert len(src.joins) == 1
|
||||
assert src.joins[0].to == "customers"
|
||||
assert len(src.measures) == 1
|
||||
assert src.measures[0].name == "revenue"
|
||||
|
||||
def test_default_time_dimension_optional_and_dump(self):
|
||||
minimal = SourceDefinition(
|
||||
name="orders",
|
||||
table="t",
|
||||
grain=["id"],
|
||||
columns=[SourceColumn(name="id", type="number")],
|
||||
)
|
||||
assert minimal.default_time_dimension is None
|
||||
|
||||
src = SourceDefinition(
|
||||
name="orders",
|
||||
table="t",
|
||||
grain=["id"],
|
||||
columns=[SourceColumn(name="id", type="number")],
|
||||
default_time_dimension=DefaultTimeDimensionDbt(dbt="order_date"),
|
||||
)
|
||||
dumped = src.model_dump(mode="python", exclude_none=True)
|
||||
assert dumped["default_time_dimension"] == {"dbt": "order_date"}
|
||||
|
||||
round_tripped = SourceDefinition.model_validate(dumped)
|
||||
assert round_tripped.default_time_dimension == DefaultTimeDimensionDbt(
|
||||
dbt="order_date"
|
||||
)
|
||||
|
||||
def test_dbt_structural_metadata_round_trips(self):
|
||||
src = SourceDefinition(
|
||||
name="orders",
|
||||
table="public.orders",
|
||||
grain=["id"],
|
||||
columns=[
|
||||
SourceColumn(
|
||||
name="status",
|
||||
type="string",
|
||||
constraints={"dbt": {"not_null": True, "unique": True}},
|
||||
enum_values={"dbt": ["placed", "shipped"]},
|
||||
tests={
|
||||
"dbt": [{"name": "accepted_values", "package": "dbt"}],
|
||||
"dbt_by_package": {"dbt": ["accepted_values"]},
|
||||
},
|
||||
)
|
||||
],
|
||||
tags={"dbt": ["mart", "finance"]},
|
||||
freshness={
|
||||
"dbt": {
|
||||
"loaded_at_field": "updated_at",
|
||||
"raw": {"warn_after": {"count": 12, "period": "hour"}},
|
||||
}
|
||||
},
|
||||
default_time_dimension=DefaultTimeDimensionDbt(dbt="updated_at"),
|
||||
)
|
||||
|
||||
assert src.columns[0].constraints == {
|
||||
"dbt": ColumnDbtConstraints(not_null=True, unique=True)
|
||||
}
|
||||
assert src.columns[0].enum_values == {"dbt": ["placed", "shipped"]}
|
||||
assert src.columns[0].tests is not None
|
||||
assert src.columns[0].tests.model_dump(mode="python", exclude_none=True) == {
|
||||
"dbt": [{"name": "accepted_values", "package": "dbt"}],
|
||||
"dbt_by_package": {"dbt": ["accepted_values"]},
|
||||
}
|
||||
assert src.tags == {"dbt": ["mart", "finance"]}
|
||||
assert src.freshness == {
|
||||
"dbt": FreshnessDbt(
|
||||
loaded_at_field="updated_at",
|
||||
raw={"warn_after": {"count": 12, "period": "hour"}},
|
||||
)
|
||||
}
|
||||
|
||||
dumped = src.model_dump(mode="python", exclude_none=True)
|
||||
round_tripped = SourceDefinition.model_validate(dumped)
|
||||
assert round_tripped.columns[0].constraints == src.columns[0].constraints
|
||||
assert round_tripped.columns[0].enum_values == src.columns[0].enum_values
|
||||
assert round_tripped.columns[0].tests == src.columns[0].tests
|
||||
assert round_tripped.tags == src.tags
|
||||
assert round_tripped.freshness == src.freshness
|
||||
|
||||
|
||||
class TestSemanticQuery:
|
||||
def test_minimal(self):
|
||||
q = SemanticQuery(measures=["sum(orders.amount)"])
|
||||
assert q.dimensions == []
|
||||
assert q.filters == []
|
||||
assert q.limit == 1000
|
||||
|
||||
def test_mixed_measures(self):
|
||||
q = SemanticQuery(
|
||||
measures=[
|
||||
"orders.revenue",
|
||||
{"expr": "sum(orders.amount)", "name": "total"},
|
||||
]
|
||||
)
|
||||
assert isinstance(q.measures[0], str)
|
||||
assert isinstance(q.measures[1], dict)
|
||||
|
||||
def test_with_dimensions(self):
|
||||
q = SemanticQuery(
|
||||
measures=["sum(orders.amount)"],
|
||||
dimensions=[
|
||||
"orders.status",
|
||||
{"field": "orders.created_at", "granularity": "month"},
|
||||
],
|
||||
)
|
||||
assert len(q.dimensions) == 2
|
||||
|
||||
|
||||
class TestResolvedModels:
|
||||
def test_resolved_column(self):
|
||||
col = ResolvedColumn(
|
||||
name="revenue", provenance=Provenance.VERIFIED, expr="sum(amount)"
|
||||
)
|
||||
assert col.provenance == Provenance.VERIFIED
|
||||
|
||||
def test_resolved_measure(self):
|
||||
m = ResolvedMeasure(name="revenue", expr="sum(amount)", source_name="orders")
|
||||
assert m.provenance == Provenance.COMPOSED
|
||||
assert not m.is_derived
|
||||
|
||||
def test_measure_group(self):
|
||||
m = ResolvedMeasure(name="rev", expr="sum(amount)", source_name="orders")
|
||||
g = MeasureGroup(source_name="orders", measures=[m])
|
||||
assert g.source_name == "orders"
|
||||
|
||||
def test_resolved_plan(self):
|
||||
plan = ResolvedPlan(
|
||||
sources_used=["orders"],
|
||||
join_paths=[],
|
||||
anchor_grain=["id"],
|
||||
fan_out_description="none",
|
||||
aggregate_locality=[],
|
||||
where_filters=[],
|
||||
having_filters=[],
|
||||
columns=[ResolvedColumn(name="revenue", provenance=Provenance.COMPOSED)],
|
||||
)
|
||||
assert plan.has_fan_out is False
|
||||
assert plan.measure_groups == []
|
||||
|
||||
def test_query_result(self):
|
||||
plan = ResolvedPlan(
|
||||
sources_used=["orders"],
|
||||
join_paths=[],
|
||||
anchor_grain=["id"],
|
||||
fan_out_description="none",
|
||||
aggregate_locality=[],
|
||||
where_filters=[],
|
||||
having_filters=[],
|
||||
columns=[],
|
||||
)
|
||||
result = QueryResult(
|
||||
resolved_plan=plan, sql="SELECT 1", dialect="postgres", columns=[]
|
||||
)
|
||||
assert result.dialect == "postgres"
|
||||
|
||||
|
||||
class TestJoinDeclaration:
|
||||
def test_with_alias(self):
|
||||
from semantic_layer.models import JoinDeclaration
|
||||
|
||||
j = JoinDeclaration(
|
||||
to="customers",
|
||||
on="billing_customer_id = customers.id",
|
||||
relationship="many_to_one",
|
||||
alias="billing_customer",
|
||||
)
|
||||
assert j.alias == "billing_customer"
|
||||
assert j.to == "customers"
|
||||
|
||||
def test_without_alias(self):
|
||||
from semantic_layer.models import JoinDeclaration
|
||||
|
||||
j = JoinDeclaration(
|
||||
to="customers",
|
||||
on="customer_id = customers.id",
|
||||
relationship="many_to_one",
|
||||
)
|
||||
assert j.alias is None
|
||||
|
||||
|
||||
class TestMeasureDefinition:
|
||||
def test_with_filter_and_description(self):
|
||||
from semantic_layer.models import MeasureDefinition
|
||||
|
||||
m = MeasureDefinition(
|
||||
name="revenue",
|
||||
expr="sum(amount)",
|
||||
filter="status != 'refunded'",
|
||||
description="Net revenue excluding refunds",
|
||||
)
|
||||
assert m.filter == "status != 'refunded'"
|
||||
assert m.description == "Net revenue excluding refunds"
|
||||
|
||||
def test_minimal(self):
|
||||
from semantic_layer.models import MeasureDefinition
|
||||
|
||||
m = MeasureDefinition(name="total", expr="count(id)")
|
||||
assert m.filter is None
|
||||
assert m.description is None
|
||||
|
||||
|
||||
class TestSemanticQueryExtended:
|
||||
def test_include_empty_default(self):
|
||||
q = SemanticQuery(measures=["sum(orders.amount)"])
|
||||
assert q.include_empty is True
|
||||
|
||||
def test_include_empty_false(self):
|
||||
q = SemanticQuery(measures=["sum(orders.amount)"], include_empty=False)
|
||||
assert q.include_empty is False
|
||||
|
||||
def test_with_order_by(self):
|
||||
q = SemanticQuery(
|
||||
measures=["sum(orders.amount)"],
|
||||
order_by=[{"field": "orders.amount", "direction": "desc"}],
|
||||
)
|
||||
assert len(q.order_by) == 1
|
||||
assert q.order_by[0]["direction"] == "desc"
|
||||
|
||||
def test_custom_limit(self):
|
||||
q = SemanticQuery(measures=["sum(orders.amount)"], limit=50)
|
||||
assert q.limit == 50
|
||||
|
||||
|
||||
# ── From test_edge_cases.py ──────────────────────────────────────────
|
||||
|
||||
|
||||
class TestModelEdgeCases:
|
||||
def test_semantic_query_empty_measures(self):
|
||||
q = SemanticQuery(measures=[])
|
||||
assert q.measures == []
|
||||
|
||||
def test_semantic_query_defaults(self):
|
||||
q = SemanticQuery(measures=["sum(x.y)"])
|
||||
assert q.dimensions == []
|
||||
assert q.filters == []
|
||||
assert q.order_by == []
|
||||
assert q.limit == 1000
|
||||
assert q.include_empty is True
|
||||
|
||||
def test_semantic_query_with_order_by(self):
|
||||
q = SemanticQuery(
|
||||
measures=["sum(orders.amount)"],
|
||||
order_by=[{"field": "orders.status", "direction": "desc"}],
|
||||
)
|
||||
assert len(q.order_by) == 1
|
||||
|
||||
def test_table_and_sql_mutually_exclusive(self):
|
||||
with pytest.raises(ValidationError, match="mutually exclusive"):
|
||||
SourceDefinition(
|
||||
name="bad",
|
||||
table="t",
|
||||
sql="SELECT 1",
|
||||
grain=["id"],
|
||||
columns=[SourceColumn(name="id", type="number")],
|
||||
)
|
||||
|
||||
def test_empty_grain_rejected(self):
|
||||
with pytest.raises(ValidationError, match="grain must be non-empty"):
|
||||
SourceDefinition(
|
||||
name="bad",
|
||||
table="t",
|
||||
grain=[],
|
||||
columns=[SourceColumn(name="id", type="number")],
|
||||
)
|
||||
|
||||
def test_measure_definition_with_filter(self):
|
||||
from semantic_layer.models import MeasureDefinition
|
||||
|
||||
m = MeasureDefinition(
|
||||
name="rev", expr="sum(amount)", filter="status != 'refunded'"
|
||||
)
|
||||
assert m.filter == "status != 'refunded'"
|
||||
279
python/klo-sl/tests/test_parser.py
Normal file
279
python/klo-sl/tests/test_parser.py
Normal file
|
|
@ -0,0 +1,279 @@
|
|||
from semantic_layer.parser import ExpressionParser
|
||||
|
||||
|
||||
parser = ExpressionParser()
|
||||
|
||||
|
||||
class TestAggregateDetection:
|
||||
def test_sum(self):
|
||||
r = parser.parse("sum(orders.amount)")
|
||||
assert r.is_aggregate
|
||||
assert r.aggregate_function == "sum"
|
||||
|
||||
def test_avg(self):
|
||||
r = parser.parse("avg(score)")
|
||||
assert r.is_aggregate
|
||||
assert r.aggregate_function == "avg"
|
||||
|
||||
def test_count(self):
|
||||
r = parser.parse("count(orders.id)")
|
||||
assert r.is_aggregate
|
||||
assert r.aggregate_function == "count"
|
||||
|
||||
def test_count_distinct(self):
|
||||
r = parser.parse("count_distinct(orders.customer_id)")
|
||||
assert r.is_aggregate
|
||||
assert r.aggregate_function == "count_distinct"
|
||||
|
||||
def test_non_aggregate(self):
|
||||
r = parser.parse("orders.revenue")
|
||||
assert not r.is_aggregate
|
||||
assert r.aggregate_function is None
|
||||
|
||||
def test_multiple_aggregates(self):
|
||||
r = parser.parse("sum(orders.amount) / count(orders.id)")
|
||||
assert r.is_aggregate
|
||||
# first aggregate found
|
||||
assert r.aggregate_function == "sum"
|
||||
|
||||
def test_aggregate_in_scalar_subquery_not_aggregate(self):
|
||||
# `col = (SELECT MAX(col) FROM t)` is a plain column predicate, not HAVING-bound
|
||||
r = parser.parse("orders.created_at = (SELECT MAX(created_at) FROM orders)")
|
||||
assert not r.is_aggregate
|
||||
assert r.aggregate_function is None
|
||||
|
||||
def test_aggregate_in_in_subquery_not_aggregate(self):
|
||||
r = parser.parse("orders.id IN (SELECT COUNT(id) FROM orders)")
|
||||
assert not r.is_aggregate
|
||||
|
||||
def test_custom_agg_in_subquery_not_aggregate(self):
|
||||
r = parser.parse(
|
||||
"orders.customer_id = (SELECT count_distinct(customer_id) FROM orders)"
|
||||
)
|
||||
assert not r.is_aggregate
|
||||
|
||||
def test_outer_aggregate_with_inner_subquery_still_aggregate(self):
|
||||
# Outer SUM on a plain column, even if subquery appears elsewhere
|
||||
r = parser.parse("sum(orders.amount) > (SELECT AVG(amount) FROM orders)")
|
||||
assert r.is_aggregate
|
||||
assert r.aggregate_function == "sum"
|
||||
|
||||
|
||||
class TestSourceRefs:
|
||||
def test_single_ref(self):
|
||||
r = parser.parse("sum(orders.amount)")
|
||||
assert r.source_refs == {"orders"}
|
||||
assert r.column_refs == {"orders.amount"}
|
||||
|
||||
def test_multiple_refs(self):
|
||||
r = parser.parse("sum(orders.revenue) / count(customers.id)")
|
||||
assert r.source_refs == {"orders", "customers"}
|
||||
assert r.column_refs == {"orders.revenue", "customers.id"}
|
||||
|
||||
def test_pre_defined_ref(self):
|
||||
r = parser.parse("orders.revenue")
|
||||
assert r.source_refs == {"orders"}
|
||||
assert r.column_refs == {"orders.revenue"}
|
||||
|
||||
def test_no_refs(self):
|
||||
r = parser.parse("total_rev - total_cost")
|
||||
assert r.source_refs == set()
|
||||
assert r.column_refs == set()
|
||||
|
||||
def test_mixed_refs(self):
|
||||
r = parser.parse("sum(orders.amount) + churn_risk.score")
|
||||
assert r.source_refs == {"orders", "churn_risk"}
|
||||
|
||||
|
||||
class TestDerivedMeasures:
|
||||
def test_depends_on_known_measures(self):
|
||||
r = parser.parse(
|
||||
"total_rev - total_cost",
|
||||
known_measure_names={"total_rev", "total_cost"},
|
||||
)
|
||||
assert r.depends_on_measures == {"total_rev", "total_cost"}
|
||||
assert not r.is_aggregate
|
||||
|
||||
def test_no_false_positives(self):
|
||||
# "sum" should not be detected as a measure dependency
|
||||
r = parser.parse(
|
||||
"sum(orders.amount)",
|
||||
known_measure_names={"sum"},
|
||||
)
|
||||
assert r.depends_on_measures == set()
|
||||
|
||||
def test_mixed_ref_and_derived(self):
|
||||
r = parser.parse(
|
||||
"total_rev / count(orders.id)",
|
||||
known_measure_names={"total_rev"},
|
||||
)
|
||||
assert r.depends_on_measures == {"total_rev"}
|
||||
assert r.is_aggregate
|
||||
|
||||
def test_empty_known_measures(self):
|
||||
r = parser.parse("total_rev - total_cost")
|
||||
assert r.depends_on_measures == set()
|
||||
|
||||
|
||||
class TestExtractSourceRefs:
|
||||
def test_basic(self):
|
||||
refs = parser.extract_source_refs("sum(orders.amount)")
|
||||
assert refs == {"orders"}
|
||||
|
||||
def test_multiple(self):
|
||||
refs = parser.extract_source_refs("orders.amount + customers.score")
|
||||
assert refs == {"orders", "customers"}
|
||||
|
||||
def test_no_refs(self):
|
||||
refs = parser.extract_source_refs("count(*)")
|
||||
assert refs == set()
|
||||
|
||||
|
||||
class TestEdgeCases:
|
||||
def test_percentile(self):
|
||||
r = parser.parse("percentile(churn_risk.score, 0.9)")
|
||||
assert r.is_aggregate
|
||||
assert r.aggregate_function == "percentile"
|
||||
assert r.source_refs == {"churn_risk"}
|
||||
|
||||
def test_string_literal_not_detected(self):
|
||||
# "status != 'refunded'" — 'refunded' should not be a source ref
|
||||
r = parser.parse("status != 'refunded'")
|
||||
assert r.source_refs == set()
|
||||
|
||||
def test_complex_expression(self):
|
||||
r = parser.parse("sum(orders.amount) / count(orders.id) * 100")
|
||||
assert r.is_aggregate
|
||||
assert r.source_refs == {"orders"}
|
||||
assert r.column_refs == {"orders.amount", "orders.id"}
|
||||
|
||||
|
||||
class TestAdditionalAggregates:
|
||||
def test_min(self):
|
||||
r = parser.parse("min(orders.amount)")
|
||||
assert r.is_aggregate
|
||||
assert r.aggregate_function == "min"
|
||||
|
||||
def test_max(self):
|
||||
r = parser.parse("max(orders.amount)")
|
||||
assert r.is_aggregate
|
||||
assert r.aggregate_function == "max"
|
||||
|
||||
def test_median(self):
|
||||
r = parser.parse("median(orders.amount)")
|
||||
assert r.is_aggregate
|
||||
assert r.aggregate_function == "median"
|
||||
|
||||
def test_nested_function_not_aggregate(self):
|
||||
"""abs() is not an aggregate function, but sum() wrapping it is."""
|
||||
r = parser.parse("sum(orders.amount)")
|
||||
assert r.is_aggregate
|
||||
assert r.source_refs == {"orders"}
|
||||
|
||||
def test_comparison_operators(self):
|
||||
"""Filter-like expression with comparison."""
|
||||
r = parser.parse("orders.status = 'completed'")
|
||||
assert not r.is_aggregate
|
||||
assert r.source_refs == {"orders"}
|
||||
|
||||
def test_multiple_source_column_refs(self):
|
||||
"""Expression referencing columns from 3 different sources."""
|
||||
r = parser.parse(
|
||||
"sum(orders.amount) + count(customers.id) - avg(tickets.score)"
|
||||
)
|
||||
assert r.is_aggregate
|
||||
assert r.source_refs == {"orders", "customers", "tickets"}
|
||||
|
||||
|
||||
# ── From test_edge_cases.py: TestExpressionParserEdgeCases ───────────
|
||||
|
||||
|
||||
class TestExpressionParserEdgeCases:
|
||||
def test_empty_string(self):
|
||||
result = parser.parse("")
|
||||
assert result.source_refs == set()
|
||||
assert result.column_refs == set()
|
||||
assert not result.is_aggregate
|
||||
|
||||
def test_count_star(self):
|
||||
result = parser.parse("count(*)")
|
||||
assert result.is_aggregate
|
||||
assert result.aggregate_function == "count"
|
||||
assert result.source_refs == set()
|
||||
|
||||
def test_multiple_aggregate_functions(self):
|
||||
result = parser.parse("sum(orders.amount) + avg(orders.cost)")
|
||||
assert result.is_aggregate
|
||||
assert result.aggregate_function == "sum"
|
||||
assert result.source_refs == {"orders"}
|
||||
|
||||
def test_nested_function_not_aggregate(self):
|
||||
result = parser.parse("lower(orders.status)")
|
||||
assert not result.is_aggregate
|
||||
assert result.source_refs == {"orders"}
|
||||
|
||||
def test_source_ref_in_string_literal(self):
|
||||
result = parser.parse("'orders.amount'")
|
||||
assert "orders" not in result.source_refs
|
||||
assert len(result.column_refs) == 0
|
||||
|
||||
def test_underscore_names(self):
|
||||
result = parser.parse("sum(order_items.unit_price)")
|
||||
assert "order_items" in result.source_refs
|
||||
assert "order_items.unit_price" in result.column_refs
|
||||
|
||||
def test_extract_source_refs_multi(self):
|
||||
refs = parser.extract_source_refs("orders.amount + customers.score")
|
||||
assert refs == {"orders", "customers"}
|
||||
|
||||
|
||||
class TestReservedWordHandling:
|
||||
"""LIMIT 4: Reserved SQL keywords as source or column names."""
|
||||
|
||||
def test_reserved_word_source_name(self):
|
||||
"""Parse 'sum(where.value)' where 'where' is a source name."""
|
||||
r = parser.parse("sum(where.value)")
|
||||
assert r.source_refs == {"where"}
|
||||
assert r.column_refs == {"where.value"}
|
||||
assert r.is_aggregate
|
||||
|
||||
def test_reserved_word_column_name(self):
|
||||
"""Parse 'select.from' where both are reserved words."""
|
||||
r = parser.parse("select.from")
|
||||
assert r.source_refs == {"select"}
|
||||
assert r.column_refs == {"select.from"}
|
||||
|
||||
def test_reserved_word_in_extract_source_refs(self):
|
||||
"""extract_source_refs should handle reserved words in expressions."""
|
||||
refs = parser.extract_source_refs("where.value > 10")
|
||||
assert refs == {"where"}
|
||||
|
||||
|
||||
def test_extract_source_refs_bigquery_native():
|
||||
"""BigQuery-native filter must not drop source refs due to mis-parse."""
|
||||
from semantic_layer.parser import ExpressionParser
|
||||
|
||||
parser = ExpressionParser(dialect="bigquery")
|
||||
refs = parser.extract_source_refs(
|
||||
"SAFE_DIVIDE(orders.revenue, customers.count) > 0"
|
||||
)
|
||||
assert refs == {"orders", "customers"}
|
||||
|
||||
|
||||
def test_expression_parser_dialect_defaults_to_postgres():
|
||||
"""Constructor default is postgres — keeps existing tests working."""
|
||||
from semantic_layer.parser import ExpressionParser
|
||||
|
||||
parser = ExpressionParser()
|
||||
assert parser.dialect == "postgres"
|
||||
|
||||
|
||||
def test_extract_source_refs_postgres_baseline():
|
||||
"""Postgres-dialect parser continues to work on postgres syntax."""
|
||||
from semantic_layer.parser import ExpressionParser
|
||||
|
||||
parser = ExpressionParser(dialect="postgres")
|
||||
refs = parser.extract_source_refs(
|
||||
"orders.created_at >= current_date - interval '30 days'"
|
||||
)
|
||||
assert refs == {"orders"}
|
||||
1509
python/klo-sl/tests/test_planner.py
Normal file
1509
python/klo-sl/tests/test_planner.py
Normal file
File diff suppressed because it is too large
Load diff
293
python/klo-sl/tests/test_segments.py
Normal file
293
python/klo-sl/tests/test_segments.py
Normal file
|
|
@ -0,0 +1,293 @@
|
|||
"""Tests for named segments — reusable boolean predicates on a source.
|
||||
|
||||
Segments are AND-ed into the measure's effective filter via the same CASE WHEN
|
||||
pathway used by `measure.filter`. They never become a global WHERE clause.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from .conftest import assert_valid_sql, make_engine
|
||||
|
||||
|
||||
def _orders_source(**overrides):
|
||||
base = {
|
||||
"name": "orders",
|
||||
"table": "public.orders",
|
||||
"grain": ["id"],
|
||||
"columns": [
|
||||
{"name": "id", "type": "number"},
|
||||
{"name": "amount", "type": "number"},
|
||||
{"name": "is_paid", "type": "boolean"},
|
||||
{"name": "is_refunded", "type": "string"},
|
||||
{"name": "customer_id", "type": "number"},
|
||||
],
|
||||
"segments": [
|
||||
{
|
||||
"name": "paid_non_refunded",
|
||||
"expr": "is_paid = true and is_refunded = '0'",
|
||||
"description": "Settled, not reversed.",
|
||||
},
|
||||
],
|
||||
"measures": [
|
||||
{
|
||||
"name": "total_revenue",
|
||||
"expr": "sum(amount)",
|
||||
"segments": ["paid_non_refunded"],
|
||||
},
|
||||
],
|
||||
}
|
||||
base.update(overrides)
|
||||
return base
|
||||
|
||||
|
||||
def _customers_source(**overrides):
|
||||
base = {
|
||||
"name": "customers",
|
||||
"table": "public.customers",
|
||||
"grain": ["id"],
|
||||
"columns": [
|
||||
{"name": "id", "type": "number"},
|
||||
{"name": "is_vip", "type": "boolean"},
|
||||
],
|
||||
"measures": [
|
||||
{"name": "customer_count", "expr": "count(distinct id)"},
|
||||
],
|
||||
}
|
||||
base.update(overrides)
|
||||
return base
|
||||
|
||||
|
||||
# ── Composition + golden SQL shape ───────────────────────────────────
|
||||
|
||||
|
||||
class TestSegmentComposition:
|
||||
def test_measure_segment_lands_in_case_when_wrap(self):
|
||||
engine = make_engine({"orders": _orders_source()})
|
||||
result = engine.query({"measures": ["orders.total_revenue"]})
|
||||
assert_valid_sql(result.sql)
|
||||
sql_upper = result.sql.upper()
|
||||
# Filter must be inside CASE WHEN (the measure-filter pathway)
|
||||
assert "CASE WHEN" in sql_upper
|
||||
assert "is_paid" in result.sql.lower()
|
||||
assert "is_refunded" in result.sql.lower()
|
||||
# Should NOT show up as a global WHERE
|
||||
# (a WHERE clause may exist for other reasons — assert no segment expr in it)
|
||||
# Easiest: assert WHERE doesn't contain the segment's exact predicate.
|
||||
# Split before/after first WHERE keyword if any.
|
||||
assert "WHERE IS_PAID" not in sql_upper.replace(" = ", " = ")
|
||||
|
||||
def test_measure_filter_and_segment_both_applied(self):
|
||||
src = _orders_source()
|
||||
src["measures"][0]["filter"] = "amount > 0"
|
||||
engine = make_engine({"orders": src})
|
||||
result = engine.query({"measures": ["orders.total_revenue"]})
|
||||
assert_valid_sql(result.sql)
|
||||
sql_lower = result.sql.lower()
|
||||
# Both predicates appear inside the measure's CASE WHEN wrap
|
||||
assert "amount > 0" in sql_lower
|
||||
assert "is_paid" in sql_lower
|
||||
assert "is_refunded" in sql_lower
|
||||
# AND composition: ensure both halves are joined
|
||||
assert " and " in sql_lower
|
||||
|
||||
def test_query_time_segment_applies_to_measure(self):
|
||||
# Measure has no measure-bound segment; segment is applied at query time.
|
||||
src = _orders_source()
|
||||
src["measures"] = [{"name": "raw_revenue", "expr": "sum(amount)"}]
|
||||
engine = make_engine({"orders": src})
|
||||
result = engine.query(
|
||||
{
|
||||
"measures": ["orders.raw_revenue"],
|
||||
"segments": ["orders.paid_non_refunded"],
|
||||
}
|
||||
)
|
||||
assert_valid_sql(result.sql)
|
||||
sql_lower = result.sql.lower()
|
||||
assert "case when" in sql_lower
|
||||
assert "is_paid" in sql_lower
|
||||
assert "is_refunded" in sql_lower
|
||||
|
||||
def test_measure_and_query_segments_compose(self):
|
||||
# Measure has paid_non_refunded; query adds 'high_value'.
|
||||
src = _orders_source()
|
||||
src["segments"].append(
|
||||
{"name": "high_value", "expr": "amount >= 100"},
|
||||
)
|
||||
engine = make_engine({"orders": src})
|
||||
result = engine.query(
|
||||
{
|
||||
"measures": ["orders.total_revenue"],
|
||||
"segments": ["orders.high_value"],
|
||||
}
|
||||
)
|
||||
assert_valid_sql(result.sql)
|
||||
sql_lower = result.sql.lower()
|
||||
# All three predicates present
|
||||
assert "is_paid" in sql_lower
|
||||
assert "is_refunded" in sql_lower
|
||||
assert "amount >= 100" in sql_lower
|
||||
|
||||
|
||||
# ── Multi-source query: scope is per-measure, not global ─────────────
|
||||
|
||||
|
||||
class TestSegmentMultiSourceScope:
|
||||
def test_segment_does_not_apply_to_other_source_measures(self):
|
||||
# Query touches both orders and customers; segment is on orders only.
|
||||
# Assert that the segment predicate does NOT show up in the
|
||||
# customers CTE / WHERE on customers.
|
||||
engine = make_engine(
|
||||
{
|
||||
"orders": _orders_source(
|
||||
joins=[
|
||||
{
|
||||
"to": "customers",
|
||||
"on": "customer_id = customers.id",
|
||||
"relationship": "many_to_one",
|
||||
}
|
||||
],
|
||||
measures=[
|
||||
{"name": "raw_revenue", "expr": "sum(amount)"},
|
||||
],
|
||||
),
|
||||
"customers": _customers_source(),
|
||||
}
|
||||
)
|
||||
result = engine.query(
|
||||
{
|
||||
"measures": [
|
||||
"orders.raw_revenue",
|
||||
"customers.customer_count",
|
||||
],
|
||||
"segments": ["orders.paid_non_refunded"],
|
||||
}
|
||||
)
|
||||
assert_valid_sql(result.sql)
|
||||
sql_lower = result.sql.lower()
|
||||
# Segment predicate appears (it landed on orders)
|
||||
assert "is_paid" in sql_lower
|
||||
# The customers measure's pre-aggregation CTE / clause must not be filtered by the segment.
|
||||
# Heuristic: find each line that references count(distinct ... id) and assert no
|
||||
# "is_paid" or "is_refunded" in the same CASE WHEN block. The simpler assertion
|
||||
# is that there's no global WHERE applying the segment.
|
||||
# We assert the segment doesn't appear inside an aggregate against the customers source.
|
||||
# Concretely: count(...customers...) should not contain is_paid/is_refunded.
|
||||
# Walk the SQL and find COUNT(DISTINCT ... ID) — that aggregate must be unfiltered.
|
||||
import re
|
||||
|
||||
count_aggs = re.findall(
|
||||
r"COUNT\s*\(\s*DISTINCT[^()]*\)", result.sql, flags=re.IGNORECASE
|
||||
)
|
||||
assert count_aggs, "expected at least one COUNT(DISTINCT ...) aggregate"
|
||||
for agg in count_aggs:
|
||||
assert "is_paid" not in agg.lower(), (
|
||||
f"customer_count aggregate must not be filtered by segment: {agg}"
|
||||
)
|
||||
|
||||
|
||||
# ── Error cases ──────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestSegmentErrors:
|
||||
def test_unknown_bare_name_in_measure_segments(self):
|
||||
src = _orders_source()
|
||||
src["measures"][0]["segments"] = ["does_not_exist"]
|
||||
engine = make_engine({"orders": src})
|
||||
with pytest.raises(ValueError, match="unknown segment 'does_not_exist'"):
|
||||
engine.query({"measures": ["orders.total_revenue"]})
|
||||
|
||||
def test_unknown_query_time_segment_name(self):
|
||||
engine = make_engine({"orders": _orders_source()})
|
||||
with pytest.raises(ValueError, match="unknown segment 'does_not_exist'"):
|
||||
engine.query(
|
||||
{
|
||||
"measures": ["orders.total_revenue"],
|
||||
"segments": ["orders.does_not_exist"],
|
||||
}
|
||||
)
|
||||
|
||||
def test_unknown_query_time_segment_source(self):
|
||||
engine = make_engine({"orders": _orders_source()})
|
||||
with pytest.raises(ValueError, match="unknown source 'no_such_source'"):
|
||||
engine.query(
|
||||
{
|
||||
"measures": ["orders.total_revenue"],
|
||||
"segments": ["no_such_source.foo"],
|
||||
}
|
||||
)
|
||||
|
||||
def test_query_time_segment_must_be_dotted(self):
|
||||
engine = make_engine({"orders": _orders_source()})
|
||||
with pytest.raises(ValueError, match="dotted"):
|
||||
engine.query(
|
||||
{
|
||||
"measures": ["orders.total_revenue"],
|
||||
"segments": ["paid_non_refunded"], # missing source prefix
|
||||
}
|
||||
)
|
||||
|
||||
def test_no_op_query_time_segment_errors(self):
|
||||
# Segment on customers, but no customers measure in the query.
|
||||
engine = make_engine(
|
||||
{
|
||||
"orders": _orders_source(
|
||||
joins=[
|
||||
{
|
||||
"to": "customers",
|
||||
"on": "customer_id = customers.id",
|
||||
"relationship": "many_to_one",
|
||||
}
|
||||
],
|
||||
measures=[{"name": "raw_revenue", "expr": "sum(amount)"}],
|
||||
),
|
||||
"customers": _customers_source(
|
||||
segments=[{"name": "vips", "expr": "is_vip = true"}]
|
||||
),
|
||||
}
|
||||
)
|
||||
with pytest.raises(ValueError, match="no matching"):
|
||||
engine.query(
|
||||
{
|
||||
"measures": ["orders.raw_revenue"],
|
||||
"segments": ["customers.vips"],
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def test_bigquery_native_segment_referenced_by_measure(make_engine_factory):
|
||||
"""Segment authored in BigQuery dialect, referenced by a measure,
|
||||
must not degrade the segment's native syntax when composed."""
|
||||
source = {
|
||||
"name": "fct_orders",
|
||||
"table": "fct_orders",
|
||||
"grain": ["id"],
|
||||
"columns": [
|
||||
{"name": "id", "type": "number"},
|
||||
{"name": "status", "type": "string"},
|
||||
{"name": "ts", "type": "time"},
|
||||
],
|
||||
"segments": [
|
||||
{"name": "non_cancelled", "expr": "status != 'cancelled'"},
|
||||
{
|
||||
"name": "last_30",
|
||||
"expr": "ts >= timestamp(date_sub(current_date(), interval 30 day))",
|
||||
},
|
||||
],
|
||||
"measures": [
|
||||
{
|
||||
"name": "dau",
|
||||
"expr": "count(distinct id)",
|
||||
"segments": ["non_cancelled", "last_30"],
|
||||
}
|
||||
],
|
||||
}
|
||||
engine = make_engine_factory({"fct_orders": source}, dialect="bigquery")
|
||||
result = engine.query(
|
||||
{"measures": ["fct_orders.dau"], "dimensions": [], "filters": []}
|
||||
)
|
||||
sql = result.sql
|
||||
assert "INTERVAL '30'" not in sql or "DAY" in sql.upper(), (
|
||||
f"INTERVAL unit lost in segment reference:\n{sql}"
|
||||
)
|
||||
470
python/klo-sl/tests/test_snowflake.py
Normal file
470
python/klo-sl/tests/test_snowflake.py
Normal file
|
|
@ -0,0 +1,470 @@
|
|||
"""Comprehensive Snowflake dialect tests covering all major SQL generation code paths."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import sqlglot
|
||||
|
||||
from semantic_layer.engine import SemanticEngine
|
||||
from semantic_layer.models import SourceColumn, SourceDefinition
|
||||
|
||||
SOURCES_DIR = str(Path(__file__).parent.parent / "sources" / "ecommerce")
|
||||
|
||||
|
||||
def assert_valid_snowflake_sql(sql: str):
|
||||
"""Assert SQL parses as valid Snowflake SQL."""
|
||||
try:
|
||||
result = sqlglot.parse(sql, read="snowflake")
|
||||
assert result and all(r is not None for r in result)
|
||||
except Exception as e:
|
||||
pytest.fail(f"SQL is not valid Snowflake: {e}\n\nSQL:\n{sql}")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sf_engine():
|
||||
return SemanticEngine(SOURCES_DIR, dialect="snowflake")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def chasm_engine():
|
||||
"""Engine with hub + two fact tables for chasm trap / aggregate locality tests."""
|
||||
sources = {
|
||||
"hub": SourceDefinition(
|
||||
name="hub",
|
||||
table="public.hub",
|
||||
grain=["id"],
|
||||
columns=[
|
||||
SourceColumn(name="id", type="number"),
|
||||
SourceColumn(name="segment", type="string"),
|
||||
],
|
||||
),
|
||||
"fact_a": SourceDefinition(
|
||||
name="fact_a",
|
||||
table="public.fact_a",
|
||||
grain=["id"],
|
||||
columns=[
|
||||
SourceColumn(name="id", type="number"),
|
||||
SourceColumn(name="hub_id", type="number"),
|
||||
SourceColumn(name="val", type="number"),
|
||||
SourceColumn(name="created_at", type="time"),
|
||||
],
|
||||
joins=[
|
||||
{"to": "hub", "on": "hub_id = hub.id", "relationship": "many_to_one"}
|
||||
],
|
||||
),
|
||||
"fact_b": SourceDefinition(
|
||||
name="fact_b",
|
||||
table="public.fact_b",
|
||||
grain=["id"],
|
||||
columns=[
|
||||
SourceColumn(name="id", type="number"),
|
||||
SourceColumn(name="hub_id", type="number"),
|
||||
SourceColumn(name="val", type="number"),
|
||||
],
|
||||
joins=[
|
||||
{"to": "hub", "on": "hub_id = hub.id", "relationship": "many_to_one"}
|
||||
],
|
||||
measures=[{"name": "total_val", "expr": "sum(val)", "filter": "val > 0"}],
|
||||
),
|
||||
}
|
||||
return SemanticEngine.from_sources(sources, dialect="snowflake")
|
||||
|
||||
|
||||
# ── Basic query patterns ─────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestSnowflakeBasic:
|
||||
def test_simple_single_source(self, sf_engine):
|
||||
result = sf_engine.query(
|
||||
{
|
||||
"measures": ["sum(orders.amount)"],
|
||||
"dimensions": ["orders.status"],
|
||||
}
|
||||
)
|
||||
sql = result.sql
|
||||
assert result.dialect == "snowflake"
|
||||
assert_valid_snowflake_sql(sql)
|
||||
assert "GROUP BY" in sql
|
||||
|
||||
def test_cross_source_m2o(self, sf_engine):
|
||||
result = sf_engine.query(
|
||||
{
|
||||
"measures": ["sum(orders.amount)"],
|
||||
"dimensions": ["customers.segment", "regions.name"],
|
||||
}
|
||||
)
|
||||
sql = result.sql
|
||||
assert_valid_snowflake_sql(sql)
|
||||
assert "JOIN" in sql
|
||||
|
||||
def test_predefined_measure_with_filter(self, sf_engine):
|
||||
result = sf_engine.query(
|
||||
{
|
||||
"measures": ["orders.revenue"],
|
||||
"dimensions": ["orders.status"],
|
||||
}
|
||||
)
|
||||
sql = result.sql
|
||||
assert_valid_snowflake_sql(sql)
|
||||
assert "CASE WHEN" in sql
|
||||
assert "<>" in sql # sqlglot transpiles != to <>
|
||||
assert "'refunded'" in sql
|
||||
|
||||
def test_derived_measures(self, sf_engine):
|
||||
result = sf_engine.query(
|
||||
{
|
||||
"measures": [
|
||||
{"expr": "sum(orders.amount)", "name": "total_rev"},
|
||||
{"expr": "sum(orders.cost)", "name": "total_cost"},
|
||||
{"expr": "total_rev - total_cost", "name": "profit"},
|
||||
],
|
||||
"dimensions": ["customers.segment"],
|
||||
}
|
||||
)
|
||||
sql = result.sql
|
||||
assert_valid_snowflake_sql(sql)
|
||||
assert "profit" in sql.lower()
|
||||
assert "total_rev" in sql
|
||||
assert "total_cost" in sql
|
||||
|
||||
def test_include_empty_false(self, sf_engine):
|
||||
result_left = sf_engine.query(
|
||||
{
|
||||
"measures": ["sum(orders.amount)"],
|
||||
"dimensions": ["customers.segment"],
|
||||
"include_empty": True,
|
||||
}
|
||||
)
|
||||
result_inner = sf_engine.query(
|
||||
{
|
||||
"measures": ["sum(orders.amount)"],
|
||||
"dimensions": ["customers.segment"],
|
||||
"include_empty": False,
|
||||
}
|
||||
)
|
||||
assert_valid_snowflake_sql(result_left.sql)
|
||||
assert_valid_snowflake_sql(result_inner.sql)
|
||||
assert "LEFT JOIN" in result_left.sql.upper()
|
||||
assert "LEFT JOIN" not in result_inner.sql.upper()
|
||||
|
||||
|
||||
# ── Time granularity ─────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestSnowflakeTimeGranularity:
|
||||
@pytest.mark.parametrize("granularity", ["day", "week", "month", "quarter", "year"])
|
||||
def test_date_trunc_uppercase(self, sf_engine, granularity):
|
||||
result = sf_engine.query(
|
||||
{
|
||||
"measures": ["sum(orders.amount)"],
|
||||
"dimensions": [
|
||||
{"field": "orders.created_at", "granularity": granularity}
|
||||
],
|
||||
}
|
||||
)
|
||||
sql = result.sql
|
||||
assert_valid_snowflake_sql(sql)
|
||||
# Snowflake DATE_TRUNC uses uppercase granularity
|
||||
assert f"DATE_TRUNC('{granularity.upper()}'" in sql
|
||||
|
||||
|
||||
# ── Filters ──────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestSnowflakeFilters:
|
||||
def test_having_filter(self, sf_engine):
|
||||
result = sf_engine.query(
|
||||
{
|
||||
"measures": ["sum(orders.amount)"],
|
||||
"dimensions": ["orders.status"],
|
||||
"filters": ["sum(orders.amount) > 10000"],
|
||||
}
|
||||
)
|
||||
sql = result.sql
|
||||
assert_valid_snowflake_sql(sql)
|
||||
assert "HAVING" in sql
|
||||
assert "10000" in sql
|
||||
|
||||
def test_where_and_having(self, sf_engine):
|
||||
result = sf_engine.query(
|
||||
{
|
||||
"measures": ["sum(orders.amount)"],
|
||||
"dimensions": ["orders.status"],
|
||||
"filters": [
|
||||
"orders.status != 'cancelled'",
|
||||
"sum(orders.amount) > 1000",
|
||||
],
|
||||
}
|
||||
)
|
||||
sql = result.sql
|
||||
assert_valid_snowflake_sql(sql)
|
||||
assert "WHERE" in sql
|
||||
assert "HAVING" in sql
|
||||
|
||||
|
||||
# ── SQL sources / CTEs ───────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestSnowflakeCTE:
|
||||
def test_sql_source_as_cte(self, sf_engine):
|
||||
result = sf_engine.query(
|
||||
{
|
||||
"measures": ["avg(churn_risk.score)"],
|
||||
"dimensions": ["churn_risk.customer_type"],
|
||||
}
|
||||
)
|
||||
sql = result.sql
|
||||
assert_valid_snowflake_sql(sql)
|
||||
assert "WITH" in sql
|
||||
assert "churn_risk" in sql
|
||||
|
||||
def test_cross_source_with_sql_source(self, sf_engine):
|
||||
result = sf_engine.query(
|
||||
{
|
||||
"measures": ["avg(churn_risk.score)"],
|
||||
"dimensions": ["regions.name"],
|
||||
}
|
||||
)
|
||||
sql = result.sql
|
||||
assert_valid_snowflake_sql(sql)
|
||||
assert "WITH" in sql
|
||||
assert "JOIN" in sql
|
||||
|
||||
def test_sql_source_with_datediff(self):
|
||||
"""DATEDIFF in SQL source must survive transpilation (not become AGE)."""
|
||||
sources = {
|
||||
"cohorts": SourceDefinition(
|
||||
name="cohorts",
|
||||
sql="SELECT id, DATEDIFF(WEEK, start_date, end_date) AS n FROM t",
|
||||
grain=["id"],
|
||||
columns=[
|
||||
SourceColumn(name="id", type="number"),
|
||||
SourceColumn(name="n", type="number"),
|
||||
],
|
||||
),
|
||||
}
|
||||
engine = SemanticEngine.from_sources(sources, dialect="snowflake")
|
||||
result = engine.query({"measures": ["sum(cohorts.n)"], "dimensions": []})
|
||||
assert_valid_snowflake_sql(result.sql)
|
||||
assert "DATEDIFF" in result.sql.upper()
|
||||
assert "AGE" not in result.sql.upper()
|
||||
|
||||
def test_sql_source_with_datediff_in_ctes(self):
|
||||
"""DATEDIFF inside inner CTEs must survive CTE promotion."""
|
||||
sources = {
|
||||
"retention": SourceDefinition(
|
||||
name="retention",
|
||||
sql=(
|
||||
"WITH spine AS ("
|
||||
" SELECT DISTINCT cohort_week,"
|
||||
" DATEDIFF(WEEK, cohort_week, period_week) AS n"
|
||||
" FROM adopters"
|
||||
") SELECT cohort_week, n, COUNT(*) AS cnt FROM spine GROUP BY 1, 2"
|
||||
),
|
||||
grain=["cohort_week", "n"],
|
||||
columns=[
|
||||
SourceColumn(name="cohort_week", type="time"),
|
||||
SourceColumn(name="n", type="number"),
|
||||
SourceColumn(name="cnt", type="number"),
|
||||
],
|
||||
),
|
||||
}
|
||||
engine = SemanticEngine.from_sources(sources, dialect="snowflake")
|
||||
result = engine.query(
|
||||
{"measures": ["sum(retention.cnt)"], "dimensions": ["retention.n"]}
|
||||
)
|
||||
assert_valid_snowflake_sql(result.sql)
|
||||
assert "DATEDIFF" in result.sql.upper()
|
||||
# Inner CTE should be promoted with prefix
|
||||
assert "retention__spine" in result.sql
|
||||
|
||||
|
||||
# ── Aggregate functions ──────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestSnowflakeAggregateFunctions:
|
||||
def test_median_percentile_cont(self, sf_engine):
|
||||
result = sf_engine.query(
|
||||
{
|
||||
"measures": [{"expr": "median(orders.amount)", "name": "median_order"}],
|
||||
"dimensions": ["orders.status"],
|
||||
}
|
||||
)
|
||||
sql = result.sql
|
||||
assert_valid_snowflake_sql(sql)
|
||||
assert "PERCENTILE_CONT" in sql
|
||||
assert "WITHIN GROUP" in sql
|
||||
|
||||
def test_percentile(self, sf_engine):
|
||||
result = sf_engine.query(
|
||||
{
|
||||
"measures": [{"expr": "percentile(orders.amount, 0.9)", "name": "p90"}],
|
||||
"dimensions": ["orders.status"],
|
||||
}
|
||||
)
|
||||
sql = result.sql
|
||||
assert_valid_snowflake_sql(sql)
|
||||
assert "PERCENTILE_CONT" in sql
|
||||
assert "0.9" in sql
|
||||
|
||||
def test_count_distinct(self, sf_engine):
|
||||
result = sf_engine.query(
|
||||
{
|
||||
"measures": ["count_distinct(orders.customer_id)"],
|
||||
"dimensions": ["orders.status"],
|
||||
}
|
||||
)
|
||||
sql = result.sql
|
||||
assert_valid_snowflake_sql(sql)
|
||||
assert "COUNT(DISTINCT" in sql
|
||||
|
||||
|
||||
# ── Aggregate locality / chasm traps ─────────────────────────────────
|
||||
|
||||
|
||||
class TestSnowflakeAggregateLocality:
|
||||
def test_chasm_trap_full_join(self, chasm_engine):
|
||||
result = chasm_engine.query(
|
||||
{
|
||||
"measures": ["sum(fact_a.val)", "sum(fact_b.val)"],
|
||||
"dimensions": ["hub.segment"],
|
||||
}
|
||||
)
|
||||
sql = result.sql
|
||||
assert_valid_snowflake_sql(sql)
|
||||
assert "FULL JOIN" in sql.upper()
|
||||
assert "COALESCE" in sql.upper()
|
||||
assert "fact_a_agg" in sql
|
||||
assert "fact_b_agg" in sql
|
||||
|
||||
def test_chasm_trap_predefined_filtered_measure(self, chasm_engine):
|
||||
result = chasm_engine.query(
|
||||
{
|
||||
"measures": ["sum(fact_a.val)", "fact_b.total_val"],
|
||||
"dimensions": ["hub.segment"],
|
||||
}
|
||||
)
|
||||
sql = result.sql
|
||||
assert_valid_snowflake_sql(sql)
|
||||
assert "CASE WHEN" in sql
|
||||
assert "total_val" in sql
|
||||
|
||||
def test_chasm_trap_derived_measure(self, chasm_engine):
|
||||
result = chasm_engine.query(
|
||||
{
|
||||
"measures": [
|
||||
{"expr": "sum(fact_a.val)", "name": "total_a"},
|
||||
{"expr": "sum(fact_b.val)", "name": "total_b"},
|
||||
{"expr": "total_a + total_b", "name": "grand_total"},
|
||||
],
|
||||
"dimensions": ["hub.segment"],
|
||||
}
|
||||
)
|
||||
sql = result.sql
|
||||
assert_valid_snowflake_sql(sql)
|
||||
assert "grand_total" in sql
|
||||
assert "COALESCE" in sql.upper()
|
||||
|
||||
def test_chasm_trap_derived_ratio_nullif(self, chasm_engine):
|
||||
result = chasm_engine.query(
|
||||
{
|
||||
"measures": [
|
||||
{"expr": "sum(fact_a.val)", "name": "total_a"},
|
||||
{"expr": "sum(fact_b.val)", "name": "total_b"},
|
||||
{"expr": "total_a / total_b", "name": "ratio"},
|
||||
],
|
||||
"dimensions": ["hub.segment"],
|
||||
}
|
||||
)
|
||||
sql = result.sql
|
||||
assert_valid_snowflake_sql(sql)
|
||||
assert "NULLIF" in sql.upper()
|
||||
assert "ratio" in sql
|
||||
|
||||
def test_chasm_trap_having(self, chasm_engine):
|
||||
result = chasm_engine.query(
|
||||
{
|
||||
"measures": ["sum(fact_a.val)", "sum(fact_b.val)"],
|
||||
"dimensions": ["hub.segment"],
|
||||
"filters": ["sum(fact_a.val) > 100"],
|
||||
}
|
||||
)
|
||||
sql = result.sql
|
||||
assert_valid_snowflake_sql(sql)
|
||||
assert "100" in sql
|
||||
# HAVING in locality mode becomes WHERE on outer query
|
||||
assert "WHERE" in sql
|
||||
|
||||
|
||||
# ── ORDER BY + LIMIT ─────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestSnowflakeOrderByLimit:
|
||||
def test_order_by_desc_with_limit(self, sf_engine):
|
||||
result = sf_engine.query(
|
||||
{
|
||||
"measures": ["sum(orders.amount)"],
|
||||
"dimensions": ["orders.status"],
|
||||
"order_by": [{"field": "sum(orders.amount)", "direction": "desc"}],
|
||||
"limit": 10,
|
||||
}
|
||||
)
|
||||
sql = result.sql
|
||||
assert_valid_snowflake_sql(sql)
|
||||
assert "DESC" in sql.upper()
|
||||
assert "LIMIT 10" in sql
|
||||
|
||||
|
||||
# ── Snowflake reserved words as identifiers ──────────────────────────
|
||||
|
||||
|
||||
class TestSnowflakeReservedWords:
|
||||
"""Snowflake-specific reserved words (sample, qualify) must be quoted."""
|
||||
|
||||
@pytest.mark.parametrize("source_name", ["sample", "qualify"])
|
||||
def test_snowflake_reserved_word_as_source_name(self, source_name):
|
||||
sources = {
|
||||
source_name: SourceDefinition(
|
||||
name=source_name,
|
||||
table=f"public.{source_name}s",
|
||||
grain=["id"],
|
||||
columns=[
|
||||
SourceColumn(name="id", type="number"),
|
||||
SourceColumn(name="val", type="number"),
|
||||
],
|
||||
),
|
||||
}
|
||||
engine = SemanticEngine.from_sources(sources, dialect="snowflake")
|
||||
result = engine.query(
|
||||
{
|
||||
"measures": [f"sum({source_name}.val)"],
|
||||
"dimensions": [],
|
||||
}
|
||||
)
|
||||
assert_valid_snowflake_sql(result.sql)
|
||||
assert "SUM" in result.sql.upper()
|
||||
|
||||
@pytest.mark.parametrize("col_name", ["sample", "qualify"])
|
||||
def test_snowflake_reserved_word_as_column_name(self, col_name):
|
||||
sources = {
|
||||
"orders": SourceDefinition(
|
||||
name="orders",
|
||||
table="public.orders",
|
||||
grain=["id"],
|
||||
columns=[
|
||||
SourceColumn(name="id", type="number"),
|
||||
SourceColumn(name=col_name, type="number"),
|
||||
],
|
||||
),
|
||||
}
|
||||
engine = SemanticEngine.from_sources(sources, dialect="snowflake")
|
||||
result = engine.query(
|
||||
{
|
||||
"measures": [f"sum(orders.{col_name})"],
|
||||
"dimensions": [],
|
||||
}
|
||||
)
|
||||
assert_valid_snowflake_sql(result.sql)
|
||||
assert "SUM" in result.sql.upper()
|
||||
296
python/klo-sl/tests/test_sql_join_coverage.py
Normal file
296
python/klo-sl/tests/test_sql_join_coverage.py
Normal file
|
|
@ -0,0 +1,296 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from semantic_layer.engine import SemanticEngine
|
||||
from semantic_layer.models import (
|
||||
JoinDeclaration,
|
||||
SourceColumn,
|
||||
SourceDefinition,
|
||||
)
|
||||
from semantic_layer.sql_table_extractor import (
|
||||
extract_table_refs,
|
||||
normalize_table,
|
||||
ref_matches_source_table,
|
||||
)
|
||||
|
||||
|
||||
def _table_src(
|
||||
name: str, table: str, columns: list[str] | None = None
|
||||
) -> SourceDefinition:
|
||||
cols = columns or ["id"]
|
||||
return SourceDefinition(
|
||||
name=name,
|
||||
table=table,
|
||||
grain=["id"],
|
||||
columns=[SourceColumn(name=c, type="number") for c in cols],
|
||||
)
|
||||
|
||||
|
||||
def _sql_src(
|
||||
name: str,
|
||||
sql: str,
|
||||
columns: list[str] | None = None,
|
||||
joins: list[JoinDeclaration] | None = None,
|
||||
) -> SourceDefinition:
|
||||
cols = columns or ["id"]
|
||||
return SourceDefinition(
|
||||
name=name,
|
||||
sql=sql,
|
||||
grain=["id"],
|
||||
columns=[SourceColumn(name=c, type="number") for c in cols],
|
||||
joins=joins or [],
|
||||
)
|
||||
|
||||
|
||||
class TestExtractTableRefs:
|
||||
def test_simple_select(self):
|
||||
refs = extract_table_refs("select id from analytics.marts.listings")
|
||||
assert refs == [("analytics", "marts", "listings")]
|
||||
|
||||
def test_join_clause(self):
|
||||
sql = """
|
||||
select l.id from analytics.marts.listings l
|
||||
join analytics.marts.accounts a on l.account_id = a.id
|
||||
"""
|
||||
assert extract_table_refs(sql) == [
|
||||
("analytics", "marts", "listings"),
|
||||
("analytics", "marts", "accounts"),
|
||||
]
|
||||
|
||||
def test_cte_alias_skipped(self):
|
||||
sql = """
|
||||
with d as (select id from staging.shipments)
|
||||
select * from d join staging.items_shipments i on d.id = i.shipment_id
|
||||
"""
|
||||
# `d` is a CTE — must not appear. `staging.shipments` and
|
||||
# `staging.items_shipments` both should.
|
||||
refs = extract_table_refs(sql)
|
||||
assert ("staging", "shipments") in refs
|
||||
assert ("staging", "items_shipments") in refs
|
||||
assert all(ref != ("d",) for ref in refs)
|
||||
|
||||
def test_dedup(self):
|
||||
sql = """
|
||||
select * from analytics.marts.listings l1
|
||||
join analytics.marts.listings l2 on l1.id = l2.id
|
||||
"""
|
||||
assert extract_table_refs(sql) == [("analytics", "marts", "listings")]
|
||||
|
||||
def test_unparseable_returns_empty(self):
|
||||
assert extract_table_refs("not valid sql !!!") == []
|
||||
|
||||
|
||||
class TestRefMatching:
|
||||
def test_normalize_strips_quotes_and_lowercases(self):
|
||||
assert normalize_table('"ANALYTICS"."MARTS"."LISTINGS"') == (
|
||||
"analytics",
|
||||
"marts",
|
||||
"listings",
|
||||
)
|
||||
|
||||
def test_full_match(self):
|
||||
assert ref_matches_source_table(
|
||||
("analytics", "marts", "listings"), "ANALYTICS.MARTS.LISTINGS"
|
||||
)
|
||||
|
||||
def test_two_part_suffix_matches_three_part_table(self):
|
||||
assert ref_matches_source_table(
|
||||
("marts", "listings"), "ANALYTICS.MARTS.LISTINGS"
|
||||
)
|
||||
|
||||
def test_bare_name_matches_three_part_table(self):
|
||||
assert ref_matches_source_table(("listings",), "ANALYTICS.MARTS.LISTINGS")
|
||||
|
||||
def test_db_mismatch_blocks_match(self):
|
||||
assert not ref_matches_source_table(
|
||||
("staging", "listings"), "ANALYTICS.MARTS.LISTINGS"
|
||||
)
|
||||
|
||||
def test_longer_ref_does_not_match_shorter_table(self):
|
||||
assert not ref_matches_source_table(
|
||||
("analytics", "marts", "listings"), "marts.listings"
|
||||
)
|
||||
|
||||
|
||||
class TestSqlJoinCoverage:
|
||||
def _build_engine(
|
||||
self,
|
||||
listings_table: str = "ANALYTICS.MARTS.LISTINGS",
|
||||
accounts_table: str = "ANALYTICS.MARTS.ACCOUNTS",
|
||||
new_source_sql: str | None = None,
|
||||
new_source_joins: list[JoinDeclaration] | None = None,
|
||||
) -> SemanticEngine:
|
||||
listings = _table_src("LISTINGS", listings_table)
|
||||
accounts = _table_src("ACCOUNTS", accounts_table)
|
||||
sources = {"LISTINGS": listings, "ACCOUNTS": accounts}
|
||||
if new_source_sql is not None:
|
||||
sources["my_source"] = _sql_src(
|
||||
"my_source",
|
||||
sql=new_source_sql,
|
||||
joins=new_source_joins,
|
||||
)
|
||||
return SemanticEngine.from_sources(sources)
|
||||
|
||||
def test_coverage_gap_emitted_as_error(self):
|
||||
sql = """
|
||||
select l.id, a.name
|
||||
from ANALYTICS.MARTS.LISTINGS l
|
||||
join ANALYTICS.MARTS.ACCOUNTS a on l.account_id = a.id
|
||||
"""
|
||||
engine = self._build_engine(new_source_sql=sql, new_source_joins=[])
|
||||
|
||||
report = engine.validate(recently_touched={"my_source"})
|
||||
|
||||
assert not report.valid
|
||||
coverage_errors = [e for e in report.errors if "my_source" in e]
|
||||
assert any("LISTINGS" in e and "ACCOUNTS" in e for e in coverage_errors), (
|
||||
f"Expected coverage error mentioning LISTINGS and ACCOUNTS, got: {report.errors}"
|
||||
)
|
||||
|
||||
def test_declared_join_satisfies_coverage(self):
|
||||
sql = """
|
||||
select l.id, a.name
|
||||
from ANALYTICS.MARTS.LISTINGS l
|
||||
join ANALYTICS.MARTS.ACCOUNTS a on l.account_id = a.id
|
||||
"""
|
||||
joins = [
|
||||
JoinDeclaration(
|
||||
to="LISTINGS",
|
||||
on="my_source.listing_id = LISTINGS.id",
|
||||
relationship="many_to_one",
|
||||
),
|
||||
JoinDeclaration(
|
||||
to="ACCOUNTS",
|
||||
on="my_source.account_id = ACCOUNTS.id",
|
||||
relationship="many_to_one",
|
||||
),
|
||||
]
|
||||
engine = self._build_engine(new_source_sql=sql, new_source_joins=joins)
|
||||
|
||||
report = engine.validate(recently_touched={"my_source"})
|
||||
|
||||
coverage_errors = [
|
||||
e for e in report.errors if "my_source" in e and "joins[]" in e
|
||||
]
|
||||
assert coverage_errors == []
|
||||
|
||||
def test_partial_coverage_lists_only_missing(self):
|
||||
sql = """
|
||||
select l.id, a.name
|
||||
from ANALYTICS.MARTS.LISTINGS l
|
||||
join ANALYTICS.MARTS.ACCOUNTS a on l.account_id = a.id
|
||||
"""
|
||||
joins = [
|
||||
JoinDeclaration(
|
||||
to="LISTINGS",
|
||||
on="my_source.listing_id = LISTINGS.id",
|
||||
relationship="many_to_one",
|
||||
),
|
||||
]
|
||||
engine = self._build_engine(new_source_sql=sql, new_source_joins=joins)
|
||||
|
||||
report = engine.validate(recently_touched={"my_source"})
|
||||
|
||||
coverage_errors = [
|
||||
e for e in report.errors if "my_source" in e and "ACCOUNTS" in e
|
||||
]
|
||||
assert coverage_errors, f"Expected ACCOUNTS gap, got: {report.errors}"
|
||||
assert all("LISTINGS]" not in e for e in coverage_errors), (
|
||||
f"LISTINGS should be satisfied: {report.errors}"
|
||||
)
|
||||
|
||||
def test_unmapped_table_does_not_trigger_coverage_error(self):
|
||||
# SQL references staging.foo which has no manifest entry — the
|
||||
# check is silent. (The agent is still expected to write a wiki
|
||||
# note, but that's outside the validator's scope.)
|
||||
sql = "select id from staging.foo"
|
||||
engine = self._build_engine(new_source_sql=sql)
|
||||
|
||||
report = engine.validate(recently_touched={"my_source"})
|
||||
|
||||
assert not any("my_source" in e and "joins[]" in e for e in report.errors), (
|
||||
f"Unmapped table must not be flagged: {report.errors}"
|
||||
)
|
||||
|
||||
def test_quoted_identifiers_match(self):
|
||||
sql = (
|
||||
'select * from "ANALYTICS"."MARTS"."LISTINGS" l '
|
||||
'join "ANALYTICS"."MARTS"."ACCOUNTS" a on l.account_id = a.id'
|
||||
)
|
||||
engine = self._build_engine(new_source_sql=sql, new_source_joins=[])
|
||||
|
||||
report = engine.validate(recently_touched={"my_source"})
|
||||
|
||||
assert any(
|
||||
"my_source" in e and "LISTINGS" in e and "ACCOUNTS" in e
|
||||
for e in report.errors
|
||||
), f"Quoted identifiers should match: {report.errors}"
|
||||
|
||||
def test_cte_self_reference_not_flagged(self):
|
||||
sql = """
|
||||
with d as (select id from ANALYTICS.MARTS.LISTINGS)
|
||||
select * from d
|
||||
"""
|
||||
# LISTINGS is referenced inside the CTE — that still counts and
|
||||
# must be flagged (the manifest entry exists). `d` itself must
|
||||
# NOT be flagged as missing.
|
||||
engine = self._build_engine(new_source_sql=sql, new_source_joins=[])
|
||||
|
||||
report = engine.validate(recently_touched={"my_source"})
|
||||
|
||||
coverage_errors = [e for e in report.errors if "my_source" in e]
|
||||
assert any("LISTINGS" in e for e in coverage_errors)
|
||||
assert not any("'d'" in e or " d " in e for e in coverage_errors), (
|
||||
f"CTE alias 'd' must not be flagged: {coverage_errors}"
|
||||
)
|
||||
|
||||
def test_two_part_suffix_match(self):
|
||||
# Source's SQL references `MARTS.LISTINGS` (2-part) — should match
|
||||
# the 3-part manifest entry `ANALYTICS.MARTS.LISTINGS`.
|
||||
sql = "select id from MARTS.LISTINGS"
|
||||
engine = self._build_engine(new_source_sql=sql, new_source_joins=[])
|
||||
|
||||
report = engine.validate(recently_touched={"my_source"})
|
||||
|
||||
assert any("my_source" in e and "LISTINGS" in e for e in report.errors), (
|
||||
f"Two-part suffix should match: {report.errors}"
|
||||
)
|
||||
|
||||
def test_not_recently_touched_means_no_check(self):
|
||||
# Same buggy SQL as above, but the source isn't in
|
||||
# `recently_touched` — coverage check skipped.
|
||||
sql = """
|
||||
select l.id from ANALYTICS.MARTS.LISTINGS l
|
||||
join ANALYTICS.MARTS.ACCOUNTS a on l.account_id = a.id
|
||||
"""
|
||||
engine = self._build_engine(new_source_sql=sql, new_source_joins=[])
|
||||
|
||||
report = engine.validate(recently_touched=None)
|
||||
|
||||
coverage_errors = [
|
||||
e for e in report.errors if "my_source" in e and "joins[]" in e
|
||||
]
|
||||
assert coverage_errors == []
|
||||
|
||||
def test_table_only_source_skipped(self):
|
||||
# A source with `table:` (no SQL) cannot be coverage-checked.
|
||||
listings = _table_src("LISTINGS", "ANALYTICS.MARTS.LISTINGS")
|
||||
bare = _table_src("bare", "public.bare", columns=["id"])
|
||||
engine = SemanticEngine.from_sources({"LISTINGS": listings, "bare": bare})
|
||||
|
||||
report = engine.validate(recently_touched={"bare"})
|
||||
|
||||
assert not any("bare" in e and "joins[]" in e for e in report.errors), (
|
||||
f"Table-only source must not be flagged: {report.errors}"
|
||||
)
|
||||
|
||||
def test_self_reference_not_flagged(self):
|
||||
# If `my_source` somehow names its own table in the manifest, we
|
||||
# shouldn't flag itself.
|
||||
my_source = _sql_src("my_source", sql="select id from public.my_source")
|
||||
# Not realistic for SQL sources, but make sure self-refs are
|
||||
# filtered defensively.
|
||||
engine = SemanticEngine.from_sources({"my_source": my_source})
|
||||
|
||||
report = engine.validate(recently_touched={"my_source"})
|
||||
|
||||
assert not any("my_source" in e and "joins[]" in e for e in report.errors)
|
||||
77
python/klo-sl/tests/test_table_identifier_parser.py
Normal file
77
python/klo-sl/tests/test_table_identifier_parser.py
Normal file
|
|
@ -0,0 +1,77 @@
|
|||
from semantic_layer.table_identifier_parser import (
|
||||
ParseTableIdentifierItem,
|
||||
parse_table_identifier_batch,
|
||||
parse_table_identifier_one,
|
||||
)
|
||||
|
||||
|
||||
def test_parse_table_identifier_supported_dialects_and_aliases() -> None:
|
||||
response = parse_table_identifier_batch(
|
||||
[
|
||||
ParseTableIdentifierItem(
|
||||
key="pg",
|
||||
sql_table_name="public.orders AS o",
|
||||
dialect="postgres",
|
||||
),
|
||||
ParseTableIdentifierItem(
|
||||
key="bq",
|
||||
sql_table_name="analytics.orders",
|
||||
dialect="bigquery",
|
||||
),
|
||||
ParseTableIdentifierItem(
|
||||
key="sf",
|
||||
sql_table_name="RAW.PUBLIC.ORDERS",
|
||||
dialect="snowflake",
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
assert response["pg"].ok is True
|
||||
assert response["pg"].schema_ == "public"
|
||||
assert response["pg"].name == "orders"
|
||||
assert response["pg"].canonical_table == "public.orders"
|
||||
assert response["bq"].ok is True
|
||||
assert response["bq"].schema_ == "analytics"
|
||||
assert response["bq"].name == "orders"
|
||||
assert response["sf"].ok is True
|
||||
assert response["sf"].catalog == "RAW"
|
||||
assert response["sf"].schema_ == "PUBLIC"
|
||||
assert response["sf"].name == "ORDERS"
|
||||
|
||||
|
||||
def test_parse_table_identifier_rejects_non_physical_inputs() -> None:
|
||||
assert (
|
||||
parse_table_identifier_one("${orders.SQL_TABLE_NAME}", "postgres").reason
|
||||
== "looker_template_unresolved"
|
||||
)
|
||||
assert (
|
||||
parse_table_identifier_one("(select * from public.orders)", "postgres").reason
|
||||
== "derived_table_not_supported"
|
||||
)
|
||||
assert (
|
||||
parse_table_identifier_one(
|
||||
"public.orders join public.users on true", "postgres"
|
||||
).reason
|
||||
== "multiple_table_references"
|
||||
)
|
||||
assert (
|
||||
parse_table_identifier_one("public.orders", "not-a-dialect").reason
|
||||
== "unsupported_dialect"
|
||||
)
|
||||
|
||||
|
||||
def test_parse_table_identifier_preserves_batch_keys() -> None:
|
||||
response = parse_table_identifier_batch(
|
||||
[
|
||||
ParseTableIdentifierItem(
|
||||
key="z", sql_table_name="public.z", dialect="postgres"
|
||||
),
|
||||
ParseTableIdentifierItem(
|
||||
key="a", sql_table_name="public.a", dialect="postgres"
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
assert list(response) == ["z", "a"]
|
||||
assert response["z"].name == "z"
|
||||
assert response["a"].name == "a"
|
||||
360
python/klo-sl/tests/test_tpch.py
Normal file
360
python/klo-sl/tests/test_tpch.py
Normal file
|
|
@ -0,0 +1,360 @@
|
|||
"""TPC-H schema tests: loading, graph, planning, and SQL execution against DuckDB."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from semantic_layer.engine import SemanticEngine
|
||||
from semantic_layer.graph import JoinGraph
|
||||
from semantic_layer.loader import SourceLoader
|
||||
from semantic_layer.models import SourceDefinition
|
||||
|
||||
TPCH_DIR = Path(__file__).parent.parent / "sources" / "tpch"
|
||||
TPCH_TABLES = [
|
||||
"region",
|
||||
"nation",
|
||||
"supplier",
|
||||
"customer",
|
||||
"part",
|
||||
"partsupp",
|
||||
"orders",
|
||||
"lineitem",
|
||||
]
|
||||
|
||||
try:
|
||||
import duckdb
|
||||
|
||||
HAS_DUCKDB = True
|
||||
except ImportError:
|
||||
HAS_DUCKDB = False
|
||||
|
||||
|
||||
# ── Fixtures ─────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def sources() -> dict[str, SourceDefinition]:
|
||||
return SourceLoader(TPCH_DIR).load_all()
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def graph(sources: dict[str, SourceDefinition]) -> JoinGraph:
|
||||
g = JoinGraph(sources)
|
||||
g.build()
|
||||
return g
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def engine() -> SemanticEngine:
|
||||
return SemanticEngine(str(TPCH_DIR), dialect="duckdb")
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def tpch_conn():
|
||||
if not HAS_DUCKDB:
|
||||
pytest.skip("duckdb not installed")
|
||||
conn = duckdb.connect()
|
||||
conn.execute("INSTALL tpch; LOAD tpch")
|
||||
conn.execute("CALL dbgen(sf=0.01)")
|
||||
conn.execute("CREATE SCHEMA IF NOT EXISTS public")
|
||||
for t in TPCH_TABLES:
|
||||
conn.execute(f"CREATE VIEW public.{t} AS SELECT * FROM main.{t}")
|
||||
return conn
|
||||
|
||||
|
||||
# ── Loader Tests ─────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestTpchLoader:
|
||||
def test_all_sources_loaded(self, sources):
|
||||
assert set(sources.keys()) == set(TPCH_TABLES)
|
||||
|
||||
def test_lineitem_columns(self, sources):
|
||||
li = sources["lineitem"]
|
||||
col_names = {c.name for c in li.columns}
|
||||
assert "l_orderkey" in col_names
|
||||
assert "l_extendedprice" in col_names
|
||||
assert "l_shipdate" in col_names
|
||||
assert len(li.columns) == 16
|
||||
|
||||
def test_lineitem_composite_grain(self, sources):
|
||||
assert sources["lineitem"].grain == ["l_orderkey", "l_linenumber"]
|
||||
|
||||
def test_partsupp_composite_grain(self, sources):
|
||||
assert sources["partsupp"].grain == ["ps_partkey", "ps_suppkey"]
|
||||
|
||||
def test_lineitem_measures(self, sources):
|
||||
measure_names = {m.name for m in sources["lineitem"].measures}
|
||||
assert "revenue" in measure_names
|
||||
assert "returned_revenue" in measure_names
|
||||
assert "charge" in measure_names
|
||||
assert len(sources["lineitem"].measures) == 8
|
||||
|
||||
def test_returned_revenue_has_filter(self, sources):
|
||||
m = next(
|
||||
m for m in sources["lineitem"].measures if m.name == "returned_revenue"
|
||||
)
|
||||
assert m.filter == "l_returnflag = 'R'"
|
||||
|
||||
def test_lineitem_joins(self, sources):
|
||||
join_targets = {j.to for j in sources["lineitem"].joins}
|
||||
assert join_targets == {"orders", "part", "supplier"}
|
||||
|
||||
def test_region_is_leaf(self, sources):
|
||||
assert sources["region"].joins == []
|
||||
assert sources["region"].measures == []
|
||||
|
||||
def test_orders_measures(self, sources):
|
||||
measure_names = {m.name for m in sources["orders"].measures}
|
||||
assert measure_names == {"order_count", "total_price", "avg_order_value"}
|
||||
|
||||
|
||||
# ── Graph Tests ──────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestTpchGraph:
|
||||
def test_all_sources_in_graph(self, graph):
|
||||
assert set(graph.adjacency.keys()) >= set(TPCH_TABLES)
|
||||
|
||||
def test_lineitem_to_region_path(self, graph):
|
||||
"""Shortest path: lineitem → supplier → nation → region (3 hops)."""
|
||||
path = graph.find_path("lineitem", "region")
|
||||
assert path is not None
|
||||
source_chain = [path.edges[0].from_source] + [e.to_source for e in path.edges]
|
||||
assert "lineitem" in source_chain
|
||||
assert "region" in source_chain
|
||||
assert len(path.edges) == 3
|
||||
|
||||
def test_lineitem_to_part_direct(self, graph):
|
||||
path = graph.find_path("lineitem", "part")
|
||||
assert path is not None
|
||||
assert len(path.edges) == 1
|
||||
|
||||
def test_part_to_supplier_via_lineitem(self, graph):
|
||||
"""Shortest path: part → lineitem → supplier (2 hops, shorter than via partsupp)."""
|
||||
path = graph.find_path("part", "supplier")
|
||||
assert path is not None
|
||||
assert len(path.edges) == 2
|
||||
|
||||
def test_partsupp_bridges_part_and_supplier(self, graph):
|
||||
"""partsupp has direct edges to both part and supplier."""
|
||||
path_to_part = graph.find_path("partsupp", "part")
|
||||
path_to_supplier = graph.find_path("partsupp", "supplier")
|
||||
assert path_to_part is not None and len(path_to_part.edges) == 1
|
||||
assert path_to_supplier is not None and len(path_to_supplier.edges) == 1
|
||||
|
||||
def test_graph_is_single_component(self, graph):
|
||||
components = graph.find_components()
|
||||
assert len(components) == 1
|
||||
|
||||
|
||||
# ── Plan-only Tests (no DuckDB needed) ───────────────────────────────
|
||||
|
||||
|
||||
class TestTpchPlanning:
|
||||
def test_q1_plan(self, engine):
|
||||
plan = engine.plan_only(
|
||||
{
|
||||
"measures": ["lineitem.revenue"],
|
||||
"dimensions": ["lineitem.l_returnflag", "lineitem.l_linestatus"],
|
||||
}
|
||||
)
|
||||
assert plan.anchor_source == "lineitem"
|
||||
assert len(plan.sources_used) == 1
|
||||
|
||||
def test_q5_plan_multi_hop(self, engine):
|
||||
plan = engine.plan_only(
|
||||
{
|
||||
"measures": ["lineitem.revenue"],
|
||||
"dimensions": ["nation.n_name"],
|
||||
"filters": ["region.r_name = 'ASIA'"],
|
||||
}
|
||||
)
|
||||
assert "lineitem" in plan.sources_used
|
||||
assert "nation" in plan.sources_used
|
||||
assert "region" in plan.sources_used
|
||||
|
||||
def test_filtered_measure_plan(self, engine):
|
||||
plan = engine.plan_only(
|
||||
{
|
||||
"measures": ["lineitem.returned_revenue"],
|
||||
"dimensions": ["customer.c_name"],
|
||||
}
|
||||
)
|
||||
assert any(m.filter for m in plan.measures)
|
||||
|
||||
def test_time_granularity_plan(self, engine):
|
||||
plan = engine.plan_only(
|
||||
{
|
||||
"measures": ["lineitem.revenue"],
|
||||
"dimensions": [{"field": "orders.o_orderdate", "granularity": "month"}],
|
||||
}
|
||||
)
|
||||
col_names = [c.name for c in plan.columns]
|
||||
# Column may be named "o_orderdate" with granularity metadata
|
||||
assert "o_orderdate" in col_names
|
||||
dim_col = next(c for c in plan.columns if c.name == "o_orderdate")
|
||||
assert dim_col.granularity == "month"
|
||||
|
||||
def test_suggest_valid_query(self, engine):
|
||||
result = engine.suggest(
|
||||
{
|
||||
"measures": ["lineitem.revenue"],
|
||||
"dimensions": ["lineitem.l_returnflag"],
|
||||
}
|
||||
)
|
||||
assert result["success"] is True
|
||||
|
||||
def test_suggest_missing_source(self, engine):
|
||||
result = engine.suggest(
|
||||
{
|
||||
"measures": ["sum(lineitem.l_quantity)"],
|
||||
"dimensions": ["nonexistent.col"],
|
||||
}
|
||||
)
|
||||
assert result["success"] is False
|
||||
|
||||
|
||||
# ── Execution Tests (require DuckDB) ────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.skipif(not HAS_DUCKDB, reason="duckdb not installed")
|
||||
class TestTpchExecution:
|
||||
def test_q1_pricing_summary(self, tpch_conn, engine):
|
||||
result = engine.query(
|
||||
{
|
||||
"measures": [
|
||||
"lineitem.revenue",
|
||||
"lineitem.total_quantity",
|
||||
"lineitem.avg_discount",
|
||||
"lineitem.line_count",
|
||||
],
|
||||
"dimensions": ["lineitem.l_returnflag", "lineitem.l_linestatus"],
|
||||
}
|
||||
)
|
||||
rows = tpch_conn.execute(result.sql).fetchall()
|
||||
assert len(rows) > 0
|
||||
# TPC-H has exactly 4 combinations: A/F, N/F, N/O, R/F
|
||||
assert len(rows) <= 4
|
||||
|
||||
def test_q5_revenue_by_nation_asia(self, tpch_conn, engine):
|
||||
"""4-hop join with filter: lineitem→supplier→nation→region."""
|
||||
result = engine.query(
|
||||
{
|
||||
"measures": ["lineitem.revenue"],
|
||||
"dimensions": ["nation.n_name"],
|
||||
"filters": ["region.r_name = 'ASIA'"],
|
||||
}
|
||||
)
|
||||
rows = tpch_conn.execute(result.sql).fetchall()
|
||||
assert len(rows) > 0
|
||||
# ASIA has 5 nations
|
||||
assert len(rows) <= 5
|
||||
|
||||
def test_q3_revenue_by_month(self, tpch_conn, engine):
|
||||
"""DATE_TRUNC + multi-table filter."""
|
||||
result = engine.query(
|
||||
{
|
||||
"measures": ["lineitem.revenue"],
|
||||
"dimensions": [{"field": "orders.o_orderdate", "granularity": "month"}],
|
||||
"filters": ["customer.c_mktsegment = 'BUILDING'"],
|
||||
"limit": 12,
|
||||
}
|
||||
)
|
||||
rows = tpch_conn.execute(result.sql).fetchall()
|
||||
assert len(rows) > 0
|
||||
assert len(rows) <= 12
|
||||
|
||||
def test_q10_returned_revenue(self, tpch_conn, engine):
|
||||
"""Filtered measure with CASE WHEN."""
|
||||
result = engine.query(
|
||||
{
|
||||
"measures": ["lineitem.returned_revenue"],
|
||||
"dimensions": ["customer.c_name"],
|
||||
"limit": 10,
|
||||
}
|
||||
)
|
||||
rows = tpch_conn.execute(result.sql).fetchall()
|
||||
assert len(rows) > 0
|
||||
assert len(rows) <= 10
|
||||
|
||||
def test_order_count(self, tpch_conn, engine):
|
||||
result = engine.query(
|
||||
{
|
||||
"measures": ["orders.order_count"],
|
||||
"dimensions": ["orders.o_orderstatus"],
|
||||
}
|
||||
)
|
||||
rows = tpch_conn.execute(result.sql).fetchall()
|
||||
assert len(rows) > 0
|
||||
# Sum of counts should equal total orders at SF=0.01
|
||||
total = sum(r[1] for r in rows)
|
||||
assert total == 15000 # SF=0.01 → 15000 orders
|
||||
|
||||
def test_supply_cost_by_nation(self, tpch_conn, engine):
|
||||
"""Bridge table path: partsupp → supplier → nation."""
|
||||
result = engine.query(
|
||||
{
|
||||
"measures": ["partsupp.total_supply_cost"],
|
||||
"dimensions": ["nation.n_name"],
|
||||
}
|
||||
)
|
||||
rows = tpch_conn.execute(result.sql).fetchall()
|
||||
assert len(rows) == 25 # 25 nations
|
||||
|
||||
def test_avg_order_value(self, tpch_conn, engine):
|
||||
result = engine.query(
|
||||
{
|
||||
"measures": ["orders.avg_order_value"],
|
||||
"dimensions": ["customer.c_mktsegment"],
|
||||
}
|
||||
)
|
||||
rows = tpch_conn.execute(result.sql).fetchall()
|
||||
assert len(rows) == 5 # 5 market segments
|
||||
# avg values should be positive
|
||||
for row in rows:
|
||||
assert row[1] > 0
|
||||
|
||||
def test_lineitem_charge(self, tpch_conn, engine):
|
||||
"""Complex expression: sum(price * (1 - discount) * (1 + tax))."""
|
||||
result = engine.query(
|
||||
{
|
||||
"measures": ["lineitem.charge"],
|
||||
"dimensions": ["lineitem.l_returnflag"],
|
||||
}
|
||||
)
|
||||
rows = tpch_conn.execute(result.sql).fetchall()
|
||||
assert len(rows) > 0
|
||||
for row in rows:
|
||||
assert row[1] > 0
|
||||
|
||||
def test_order_by_desc(self, tpch_conn, engine):
|
||||
result = engine.query(
|
||||
{
|
||||
"measures": ["lineitem.revenue"],
|
||||
"dimensions": ["nation.n_name"],
|
||||
"order_by": [{"field": "lineitem.revenue", "direction": "desc"}],
|
||||
"limit": 5,
|
||||
}
|
||||
)
|
||||
rows = tpch_conn.execute(result.sql).fetchall()
|
||||
assert len(rows) == 5
|
||||
# Revenue should be descending
|
||||
revenues = [r[1] for r in rows]
|
||||
assert revenues == sorted(revenues, reverse=True)
|
||||
|
||||
def test_multiple_filters(self, tpch_conn, engine):
|
||||
result = engine.query(
|
||||
{
|
||||
"measures": ["lineitem.revenue"],
|
||||
"dimensions": ["orders.o_orderpriority"],
|
||||
"filters": [
|
||||
"customer.c_mktsegment = 'BUILDING'",
|
||||
"nation.n_name = 'FRANCE'",
|
||||
],
|
||||
}
|
||||
)
|
||||
rows = tpch_conn.execute(result.sql).fetchall()
|
||||
assert len(rows) > 0
|
||||
299
python/klo-sl/tests/test_validator.py
Normal file
299
python/klo-sl/tests/test_validator.py
Normal file
|
|
@ -0,0 +1,299 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from semantic_layer.engine import SemanticEngine
|
||||
from semantic_layer.models import (
|
||||
JoinDeclaration,
|
||||
SourceColumn,
|
||||
SourceDefinition,
|
||||
)
|
||||
|
||||
|
||||
def _src(
|
||||
name: str,
|
||||
columns: list[str] | None = None,
|
||||
grain: list[str] | None = None,
|
||||
joins: list[JoinDeclaration] | None = None,
|
||||
) -> SourceDefinition:
|
||||
"""Minimal-boilerplate source factory for validator tests."""
|
||||
columns = columns or ["id"]
|
||||
grain = grain or ["id"]
|
||||
return SourceDefinition(
|
||||
name=name,
|
||||
table=f"public.{name}",
|
||||
grain=grain,
|
||||
columns=[SourceColumn(name=c, type="number") for c in columns],
|
||||
joins=joins or [],
|
||||
)
|
||||
|
||||
|
||||
class TestValidatorValid:
|
||||
def test_valid_connected_model(self):
|
||||
orders = _src(
|
||||
"orders",
|
||||
columns=["id", "customer_id"],
|
||||
joins=[
|
||||
JoinDeclaration(
|
||||
to="customers",
|
||||
on="customer_id = customers.id",
|
||||
relationship="many_to_one",
|
||||
)
|
||||
],
|
||||
)
|
||||
customers = _src("customers")
|
||||
engine = SemanticEngine.from_sources({"orders": orders, "customers": customers})
|
||||
|
||||
report = engine.validate()
|
||||
|
||||
assert report.valid
|
||||
assert report.errors == []
|
||||
assert report.warnings == []
|
||||
|
||||
|
||||
class TestOrphanJoinTarget:
|
||||
def test_orphan_join_target_is_error(self):
|
||||
orders = _src(
|
||||
"orders",
|
||||
columns=["id", "customer_id"],
|
||||
joins=[
|
||||
JoinDeclaration(
|
||||
to="customers",
|
||||
on="customer_id = customers.id",
|
||||
relationship="many_to_one",
|
||||
)
|
||||
],
|
||||
)
|
||||
# `customers` deliberately not defined
|
||||
engine = SemanticEngine.from_sources({"orders": orders})
|
||||
|
||||
report = engine.validate()
|
||||
|
||||
assert not report.valid
|
||||
assert any(
|
||||
"orders" in e and "customers" in e and "not defined" in e
|
||||
for e in report.errors
|
||||
)
|
||||
|
||||
def test_query_with_orphan_target_raises_before_sql(self):
|
||||
"""Query path must reject orphan targets, not silently emit SQL
|
||||
that references the undefined table name (which could read a real
|
||||
unmodeled table sharing that name)."""
|
||||
orders = _src(
|
||||
"orders",
|
||||
columns=["id", "amount", "customer_id"],
|
||||
joins=[
|
||||
JoinDeclaration(
|
||||
to="customers",
|
||||
on="customer_id = customers.id",
|
||||
relationship="many_to_one",
|
||||
)
|
||||
],
|
||||
)
|
||||
engine = SemanticEngine.from_sources({"orders": orders})
|
||||
|
||||
with pytest.raises(ValueError) as exc:
|
||||
engine.query(
|
||||
{
|
||||
"measures": ["sum(orders.amount)"],
|
||||
"dimensions": ["customers.id"],
|
||||
}
|
||||
)
|
||||
msg = str(exc.value)
|
||||
assert "orders" in msg
|
||||
assert "customers" in msg
|
||||
assert "not defined" in msg
|
||||
|
||||
|
||||
class TestInvalidGrain:
|
||||
def test_grain_column_missing_from_columns(self):
|
||||
bad = _src(
|
||||
"bad",
|
||||
columns=["id"],
|
||||
grain=["nonexistent_col"],
|
||||
)
|
||||
engine = SemanticEngine.from_sources({"bad": bad})
|
||||
|
||||
report = engine.validate()
|
||||
|
||||
assert not report.valid
|
||||
assert any("bad" in e and "nonexistent_col" in e for e in report.errors)
|
||||
|
||||
|
||||
class TestDisconnectedComponents:
|
||||
def test_two_components_produce_warning_not_error(self):
|
||||
a = _src("a")
|
||||
b = _src("b")
|
||||
engine = SemanticEngine.from_sources({"a": a, "b": b})
|
||||
|
||||
report = engine.validate()
|
||||
|
||||
assert report.valid
|
||||
assert report.errors == []
|
||||
assert len(report.warnings) >= 1
|
||||
disconnection = next(
|
||||
(w for w in report.warnings if "disconnected components" in w), None
|
||||
)
|
||||
assert disconnection is not None
|
||||
assert "2 disconnected components" in disconnection
|
||||
assert "Component 1" in disconnection
|
||||
assert "Component 2" in disconnection
|
||||
|
||||
def test_aliases_do_not_create_false_disconnection(self):
|
||||
"""Two aliases of the same base source must count as one component
|
||||
with the base, not as separate islands."""
|
||||
orders = SourceDefinition(
|
||||
name="orders",
|
||||
table="public.orders",
|
||||
grain=["id"],
|
||||
columns=[
|
||||
SourceColumn(name="id", type="number"),
|
||||
SourceColumn(name="amount", type="number"),
|
||||
SourceColumn(name="billing_customer_id", type="number"),
|
||||
SourceColumn(name="shipping_customer_id", type="number"),
|
||||
],
|
||||
joins=[
|
||||
JoinDeclaration(
|
||||
to="customers",
|
||||
alias="billing_customer",
|
||||
on="billing_customer_id = billing_customer.id",
|
||||
relationship="many_to_one",
|
||||
),
|
||||
JoinDeclaration(
|
||||
to="customers",
|
||||
alias="shipping_customer",
|
||||
on="shipping_customer_id = shipping_customer.id",
|
||||
relationship="many_to_one",
|
||||
),
|
||||
],
|
||||
)
|
||||
customers = _src("customers", columns=["id", "segment"])
|
||||
engine = SemanticEngine.from_sources({"orders": orders, "customers": customers})
|
||||
|
||||
report = engine.validate()
|
||||
|
||||
assert report.valid
|
||||
assert not any("disconnected components" in w for w in report.warnings)
|
||||
|
||||
def test_large_component_is_truncated(self):
|
||||
many = {f"s{i}": _src(f"s{i}") for i in range(10)}
|
||||
# Join them sequentially so they form one big component
|
||||
for i in range(9):
|
||||
many[f"s{i}"].joins.append(
|
||||
JoinDeclaration(
|
||||
to=f"s{i + 1}",
|
||||
on=f"id = s{i + 1}.id",
|
||||
relationship="many_to_one",
|
||||
)
|
||||
)
|
||||
many["island"] = _src("island")
|
||||
engine = SemanticEngine.from_sources(many)
|
||||
|
||||
report = engine.validate()
|
||||
|
||||
disconnection = next(
|
||||
w for w in report.warnings if "disconnected components" in w
|
||||
)
|
||||
assert "(10 sources)" in disconnection
|
||||
assert "... (+8 more)" in disconnection
|
||||
assert "(1 sources): island" in disconnection
|
||||
|
||||
def test_singleton_component_warning_names_recently_touched_source(self):
|
||||
orders = _src(
|
||||
"orders",
|
||||
columns=["id", "customer_id"],
|
||||
joins=[
|
||||
JoinDeclaration(
|
||||
to="customers",
|
||||
on="customer_id = customers.id",
|
||||
relationship="many_to_one",
|
||||
)
|
||||
],
|
||||
)
|
||||
customers = _src("customers")
|
||||
lonely_source = _src("lonely_source")
|
||||
engine = SemanticEngine.from_sources(
|
||||
{
|
||||
"orders": orders,
|
||||
"customers": customers,
|
||||
"lonely_source": lonely_source,
|
||||
}
|
||||
)
|
||||
|
||||
report = engine.validate(recently_touched={"lonely_source"})
|
||||
|
||||
assert report.per_source_warnings["lonely_source"]
|
||||
msg = report.per_source_warnings["lonely_source"][0]
|
||||
assert "lonely_source" in msg
|
||||
assert "singleton" in msg.lower() or "no joins" in msg.lower()
|
||||
|
||||
def test_no_per_source_warning_for_connected_recently_touched_source(self):
|
||||
orders = _src(
|
||||
"orders",
|
||||
columns=["id", "customer_id"],
|
||||
joins=[
|
||||
JoinDeclaration(
|
||||
to="customers",
|
||||
on="customer_id = customers.id",
|
||||
relationship="many_to_one",
|
||||
)
|
||||
],
|
||||
)
|
||||
customers = _src("customers")
|
||||
engine = SemanticEngine.from_sources({"orders": orders, "customers": customers})
|
||||
|
||||
report = engine.validate(recently_touched={"orders"})
|
||||
|
||||
assert report.per_source_warnings.get("orders", []) == []
|
||||
|
||||
def test_recently_touched_default_none_preserves_existing_behavior(self):
|
||||
lonely = _src("lonely")
|
||||
other = _src("other")
|
||||
engine = SemanticEngine.from_sources({"lonely": lonely, "other": other})
|
||||
|
||||
report = engine.validate()
|
||||
|
||||
assert any("disconnected components" in w for w in report.warnings)
|
||||
assert report.per_source_warnings == {}
|
||||
|
||||
|
||||
class TestEcommerceSmoke:
|
||||
def test_ecommerce_fixtures_validate_cleanly(self, ecommerce_sources):
|
||||
engine = SemanticEngine.from_sources(ecommerce_sources)
|
||||
|
||||
report = engine.validate()
|
||||
|
||||
assert report.valid, f"Expected clean report, got errors: {report.errors}"
|
||||
assert report.warnings == [], f"Expected no warnings, got: {report.warnings}"
|
||||
|
||||
|
||||
class TestMultipleIssuesCollected:
|
||||
def test_errors_and_warnings_coexist(self):
|
||||
bad_grain = _src("bad_grain", columns=["id"], grain=["missing"])
|
||||
orphan_target = _src(
|
||||
"with_orphan",
|
||||
columns=["id", "fk"],
|
||||
joins=[
|
||||
JoinDeclaration(
|
||||
to="doesnt_exist",
|
||||
on="fk = doesnt_exist.id",
|
||||
relationship="many_to_one",
|
||||
)
|
||||
],
|
||||
)
|
||||
isolated = _src("isolated")
|
||||
engine = SemanticEngine.from_sources(
|
||||
{
|
||||
"bad_grain": bad_grain,
|
||||
"with_orphan": orphan_target,
|
||||
"isolated": isolated,
|
||||
}
|
||||
)
|
||||
|
||||
report = engine.validate()
|
||||
|
||||
assert not report.valid
|
||||
assert len(report.errors) >= 2
|
||||
assert any("missing" in e for e in report.errors)
|
||||
assert any("doesnt_exist" in e for e in report.errors)
|
||||
assert len(report.warnings) >= 1
|
||||
Loading…
Add table
Add a link
Reference in a new issue