mirror of
https://github.com/Kaelio/ktx.git
synced 2026-06-07 07:55:13 +02:00
285 lines
9.8 KiB
Python
285 lines
9.8 KiB
Python
from __future__ import annotations
|
|
|
|
import heapq
|
|
import logging
|
|
from dataclasses import dataclass, field
|
|
|
|
from semantic_layer.models import SourceDefinition
|
|
|
|
# DIALECT CONVENTION:
|
|
# YAML-authored join `on:` clauses may contain dialect-specific casts
|
|
# (e.g. BigQuery `SAFE_CAST(x AS INT64)`). `_parse_on` parses them with
|
|
# `read=self.dialect` so the AST reflects the author's intent.
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
RELATIONSHIP_INVERSE = {
|
|
"many_to_one": "one_to_many",
|
|
"one_to_many": "many_to_one",
|
|
"one_to_one": "one_to_one",
|
|
}
|
|
|
|
|
|
@dataclass
|
|
class JoinEdge:
|
|
from_source: str
|
|
to_source: str
|
|
from_column: str
|
|
to_column: str
|
|
relationship: str
|
|
alias: str | None = None
|
|
|
|
|
|
@dataclass
|
|
class JoinPath:
|
|
edges: list[JoinEdge]
|
|
has_one_to_many: bool = False
|
|
is_ambiguous: bool = False
|
|
|
|
@property
|
|
def source_names(self) -> list[str]:
|
|
if not self.edges:
|
|
return []
|
|
names = [self.edges[0].from_source]
|
|
for e in self.edges:
|
|
names.append(e.to_source)
|
|
return names
|
|
|
|
|
|
@dataclass
|
|
class JoinTree:
|
|
edges: list[JoinEdge] = field(default_factory=list)
|
|
sources: set[str] = field(default_factory=set)
|
|
has_one_to_many: bool = False
|
|
|
|
|
|
class JoinGraph:
|
|
def __init__(
|
|
self,
|
|
sources: dict[str, SourceDefinition],
|
|
*,
|
|
dialect: str = "postgres",
|
|
):
|
|
self.sources = sources
|
|
self.dialect = dialect
|
|
self.adjacency: dict[str, list[JoinEdge]] = {}
|
|
|
|
def build(self) -> None:
|
|
# alias_name → actual source name
|
|
self.alias_map: dict[str, str] = {}
|
|
|
|
for name in self.sources:
|
|
self.adjacency.setdefault(name, [])
|
|
|
|
for source in self.sources.values():
|
|
for join in source.joins:
|
|
from_col, to_col = self._parse_on(join.on, join.to)
|
|
target_name = join.alias if join.alias else join.to
|
|
|
|
if join.alias:
|
|
self.alias_map[join.alias] = join.to
|
|
|
|
# Forward edge: source → alias (or target)
|
|
fwd = JoinEdge(
|
|
from_source=source.name,
|
|
to_source=target_name,
|
|
from_column=from_col,
|
|
to_column=to_col,
|
|
relationship=join.relationship,
|
|
alias=join.alias,
|
|
)
|
|
self.adjacency.setdefault(target_name, [])
|
|
self.adjacency[source.name].append(fwd)
|
|
|
|
# Reverse edge: alias (or target) → source
|
|
rev = JoinEdge(
|
|
from_source=target_name,
|
|
to_source=source.name,
|
|
from_column=to_col,
|
|
to_column=from_col,
|
|
relationship=RELATIONSHIP_INVERSE[join.relationship],
|
|
alias=join.alias,
|
|
)
|
|
self.adjacency[target_name].append(rev)
|
|
|
|
def find_path(self, from_source: str, to_source: str) -> JoinPath | None:
|
|
"""Dijkstra shortest path between two sources.
|
|
|
|
Also detects ambiguity: if multiple equal-cost paths exist to the
|
|
destination, the returned ``JoinPath`` has ``is_ambiguous=True``.
|
|
"""
|
|
if from_source == to_source:
|
|
return JoinPath(edges=[], has_one_to_many=False)
|
|
if from_source not in self.adjacency or to_source not in self.adjacency:
|
|
return None
|
|
|
|
# (cost, counter, current_node, path_edges)
|
|
counter = 0
|
|
heap: list[tuple[int, int, str, list[JoinEdge]]] = [
|
|
(0, counter, from_source, [])
|
|
]
|
|
visited: set[str] = set()
|
|
first_path: JoinPath | None = None
|
|
first_cost: int | None = None
|
|
|
|
while heap:
|
|
cost, _, current, path = heapq.heappop(heap)
|
|
|
|
# All equal-cost alternatives exhausted — stop.
|
|
if first_cost is not None and cost > first_cost:
|
|
break
|
|
|
|
if current == to_source:
|
|
has_o2m = any(e.relationship == "one_to_many" for e in path)
|
|
if first_path is None:
|
|
first_path = JoinPath(edges=path, has_one_to_many=has_o2m)
|
|
first_cost = cost
|
|
continue # don't visit dest — keep looking for alternatives
|
|
else:
|
|
first_path.is_ambiguous = True
|
|
return first_path
|
|
|
|
if current in visited:
|
|
continue
|
|
visited.add(current)
|
|
|
|
for edge in self.adjacency.get(current, []):
|
|
if edge.to_source not in visited:
|
|
counter += 1
|
|
# Prefer safe (many_to_one / one_to_one) paths over one_to_many
|
|
edge_cost = (
|
|
1 if edge.relationship in ("many_to_one", "one_to_one") else 10
|
|
)
|
|
heapq.heappush(
|
|
heap, (cost + edge_cost, counter, edge.to_source, path + [edge])
|
|
)
|
|
|
|
return first_path
|
|
|
|
def resolve_join_tree(
|
|
self, source_names: set[str], root: str | None = None
|
|
) -> JoinTree:
|
|
"""
|
|
Steiner tree approximation: pick root source,
|
|
find shortest path to each other source, merge paths.
|
|
"""
|
|
if len(source_names) <= 1:
|
|
return JoinTree(sources=source_names)
|
|
|
|
if root is not None and root in source_names:
|
|
names = [root] + sorted(source_names - {root})
|
|
else:
|
|
names = sorted(source_names)
|
|
root = names[0]
|
|
tree = JoinTree(sources={root})
|
|
|
|
for target in names[1:]:
|
|
if target in tree.sources:
|
|
continue
|
|
path = self.find_path(root, target)
|
|
if path is not None and path.is_ambiguous:
|
|
logger.warning(
|
|
"Ambiguous join path from '%s' to '%s': multiple equal-cost "
|
|
"paths exist. The engine picked one arbitrarily. Use join "
|
|
"aliases to disambiguate.",
|
|
root,
|
|
target,
|
|
)
|
|
if path is None:
|
|
raise ValueError(
|
|
f"No join path from '{root}' to '{target}'. "
|
|
f"These sources are not connected in the join graph."
|
|
)
|
|
for edge in path.edges:
|
|
if not any(
|
|
e.from_source == edge.from_source and e.to_source == edge.to_source
|
|
for e in tree.edges
|
|
):
|
|
tree.edges.append(edge)
|
|
if edge.relationship == "one_to_many":
|
|
tree.has_one_to_many = True
|
|
tree.sources.add(edge.from_source)
|
|
tree.sources.add(edge.to_source)
|
|
|
|
return tree
|
|
|
|
def find_components(self) -> list[set[str]]:
|
|
"""Partition the graph into connected components.
|
|
|
|
Returns one set per component. For an empty graph, returns []. For a
|
|
fully connected graph, returns a single-element list. Used both for
|
|
validation (multi-component → warning) and for suggest().
|
|
|
|
Aliases and their base source are treated as belonging to the same
|
|
component, since alias-scoped queries resolve back to the base table.
|
|
"""
|
|
# Bidirectional alias↔base adjacency so BFS treats them as one node
|
|
alias_neighbors: dict[str, list[str]] = {}
|
|
for alias, base in self.alias_map.items():
|
|
alias_neighbors.setdefault(alias, []).append(base)
|
|
alias_neighbors.setdefault(base, []).append(alias)
|
|
|
|
components: list[set[str]] = []
|
|
unvisited = set(self.adjacency)
|
|
while unvisited:
|
|
start = next(iter(unvisited))
|
|
component: set[str] = set()
|
|
queue = [start]
|
|
while queue:
|
|
node = queue.pop()
|
|
if node in component:
|
|
continue
|
|
component.add(node)
|
|
for edge in self.adjacency.get(node, []):
|
|
if edge.to_source not in component:
|
|
queue.append(edge.to_source)
|
|
for neighbor in alias_neighbors.get(node, []):
|
|
if neighbor not in component:
|
|
queue.append(neighbor)
|
|
components.append(component)
|
|
unvisited -= component
|
|
return components
|
|
|
|
def _parse_on(self, on_clause: str, target_source: str) -> tuple[str, str]:
|
|
"""
|
|
Parse join conditions into (from_columns, to_columns) using sqlglot AST.
|
|
|
|
Single key: "customer_id = customers.id" → ("customer_id", "id")
|
|
Composite: "a = t.x AND b = t.y" → ("a,b", "x,y")
|
|
|
|
Composite keys are stored as comma-separated strings.
|
|
"""
|
|
import sqlglot
|
|
from sqlglot import exp as _exp
|
|
from semantic_layer.parser import quote_reserved_identifiers
|
|
|
|
quoted = quote_reserved_identifiers(on_clause)
|
|
tree = sqlglot.parse_one(
|
|
f"SELECT 1 FROM _a JOIN _b ON {quoted}", read=self.dialect
|
|
)
|
|
|
|
from_cols: list[str] = []
|
|
to_cols: list[str] = []
|
|
|
|
for eq_node in tree.find_all(_exp.EQ):
|
|
left = eq_node.left
|
|
right = eq_node.right
|
|
|
|
# Reject nested equality (e.g., "a = b = c")
|
|
if isinstance(left, _exp.EQ) or isinstance(right, _exp.EQ):
|
|
raise ValueError(f"Invalid join condition: '{on_clause}'")
|
|
|
|
# Extract column name, stripping any source qualifier
|
|
def _col_name(node: _exp.Expression) -> str:
|
|
if isinstance(node, _exp.Column):
|
|
return node.name
|
|
return node.sql(dialect="postgres")
|
|
|
|
from_cols.append(_col_name(left))
|
|
to_cols.append(_col_name(right))
|
|
|
|
if not from_cols:
|
|
raise ValueError(f"Invalid join condition: '{on_clause}'")
|
|
|
|
return ",".join(from_cols), ",".join(to_cols)
|