mirror of
https://github.com/Kaelio/ktx.git
synced 2026-06-07 07:55:13 +02:00
254 lines
7.3 KiB
Python
254 lines
7.3 KiB
Python
"""Generate ktx-sl YAML source definitions from database schema scan data."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
import re
|
|
from typing import Any
|
|
|
|
from pydantic import BaseModel
|
|
from semantic_layer.models import (
|
|
ColumnRole,
|
|
JoinDeclaration,
|
|
MeasureDefinition,
|
|
SourceColumn,
|
|
SourceDefinition,
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
_NUMBER_PATTERN = re.compile(
|
|
r"int|integer|bigint|smallint|tinyint|numeric|decimal|float|double|real|number|money",
|
|
re.IGNORECASE,
|
|
)
|
|
_TIME_PATTERN = re.compile(
|
|
r"timestamp|datetime|date|time(?!stamp)",
|
|
re.IGNORECASE,
|
|
)
|
|
_BOOLEAN_PATTERN = re.compile(r"bool|boolean|bit", re.IGNORECASE)
|
|
_ID_PATTERN = re.compile(
|
|
r"^id$|_id$|^uuid$|_uuid$|_key$|_pk$|identifier$",
|
|
re.IGNORECASE,
|
|
)
|
|
|
|
_RELATIONSHIP_MAP = {
|
|
"MANY_TO_ONE": "many_to_one",
|
|
"ONE_TO_MANY": "one_to_many",
|
|
"ONE_TO_ONE": "one_to_one",
|
|
"many_to_one": "many_to_one",
|
|
"one_to_many": "one_to_many",
|
|
"one_to_one": "one_to_one",
|
|
}
|
|
|
|
_RELATIONSHIP_INVERSE = {
|
|
"many_to_one": "one_to_many",
|
|
"one_to_many": "many_to_one",
|
|
"one_to_one": "one_to_one",
|
|
}
|
|
|
|
|
|
class ColumnInput(BaseModel):
|
|
name: str
|
|
type: str
|
|
primary_key: bool = False
|
|
nullable: bool = True
|
|
comment: str | None = None
|
|
|
|
|
|
class TableInput(BaseModel):
|
|
name: str
|
|
catalog: str | None = None
|
|
db: str | None = None
|
|
comment: str | None = None
|
|
columns: list[ColumnInput]
|
|
|
|
|
|
class LinkInput(BaseModel):
|
|
from_table: str
|
|
from_column: str
|
|
to_table: str
|
|
to_column: str
|
|
relationship_type: str
|
|
|
|
|
|
class GenerateSourcesRequest(BaseModel):
|
|
tables: list[TableInput]
|
|
links: list[LinkInput]
|
|
dialect: str = "postgres"
|
|
|
|
|
|
class GenerateSourcesResponse(BaseModel):
|
|
sources: list[dict[str, Any]]
|
|
source_count: int
|
|
|
|
|
|
def _map_column_type(db_type: str) -> str:
|
|
if _BOOLEAN_PATTERN.search(db_type):
|
|
return "boolean"
|
|
if _TIME_PATTERN.search(db_type):
|
|
return "time"
|
|
if _NUMBER_PATTERN.search(db_type):
|
|
return "number"
|
|
return "string"
|
|
|
|
|
|
def _build_table_ref(table: TableInput) -> str:
|
|
parts = []
|
|
if table.catalog:
|
|
parts.append(table.catalog)
|
|
if table.db:
|
|
parts.append(table.db)
|
|
parts.append(table.name)
|
|
return ".".join(parts)
|
|
|
|
|
|
def _generate_measures(
|
|
table_name: str,
|
|
columns: list[ColumnInput],
|
|
pk_columns: list[str],
|
|
) -> list[MeasureDefinition]:
|
|
measures: list[MeasureDefinition] = []
|
|
|
|
if pk_columns:
|
|
pk = pk_columns[0]
|
|
measures.append(
|
|
MeasureDefinition(
|
|
name="record_count",
|
|
expr=f"count({pk})",
|
|
description=f"Count of {table_name} records",
|
|
)
|
|
)
|
|
|
|
for col in columns:
|
|
if _map_column_type(col.type) != "number":
|
|
continue
|
|
if _ID_PATTERN.search(col.name):
|
|
continue
|
|
measures.append(
|
|
MeasureDefinition(
|
|
name=f"total_{col.name}",
|
|
expr=f"sum({col.name})",
|
|
description=f"Sum of {col.name}"
|
|
+ (f" \u2014 {col.comment}" if col.comment else ""),
|
|
)
|
|
)
|
|
measures.append(
|
|
MeasureDefinition(
|
|
name=f"avg_{col.name}",
|
|
expr=f"avg({col.name})",
|
|
description=f"Average of {col.name}"
|
|
+ (f" \u2014 {col.comment}" if col.comment else ""),
|
|
)
|
|
)
|
|
|
|
return measures
|
|
|
|
|
|
def generate_sources(request: GenerateSourcesRequest) -> list[dict[str, Any]]:
|
|
links_by_from: dict[str, list[LinkInput]] = {}
|
|
links_by_to: dict[str, list[LinkInput]] = {}
|
|
for link in request.links:
|
|
links_by_from.setdefault(link.from_table, []).append(link)
|
|
links_by_to.setdefault(link.to_table, []).append(link)
|
|
|
|
table_names = {table.name for table in request.tables}
|
|
sources: list[dict[str, Any]] = []
|
|
|
|
for table in request.tables:
|
|
pk_columns = [column.name for column in table.columns if column.primary_key]
|
|
grain = (
|
|
pk_columns
|
|
if pk_columns
|
|
else [table.columns[0].name]
|
|
if table.columns
|
|
else ["id"]
|
|
)
|
|
|
|
sl_columns: list[SourceColumn] = []
|
|
for column in table.columns:
|
|
sl_type = _map_column_type(column.type)
|
|
role = ColumnRole.TIME if sl_type == "time" else ColumnRole.DEFAULT
|
|
sl_columns.append(
|
|
SourceColumn(
|
|
name=column.name,
|
|
type=sl_type,
|
|
role=role,
|
|
description=column.comment,
|
|
)
|
|
)
|
|
|
|
joins: list[JoinDeclaration] = []
|
|
for link in links_by_from.get(table.name, []):
|
|
if link.to_table not in table_names:
|
|
logger.warning(
|
|
"Skipping link from %s.%s to %s.%s: target table not in scan",
|
|
link.from_table,
|
|
link.from_column,
|
|
link.to_table,
|
|
link.to_column,
|
|
)
|
|
continue
|
|
|
|
relationship = _RELATIONSHIP_MAP.get(link.relationship_type, "many_to_one")
|
|
joins.append(
|
|
JoinDeclaration(
|
|
to=link.to_table,
|
|
on=f"{link.from_column} = {link.to_table}.{link.to_column}",
|
|
relationship=relationship,
|
|
)
|
|
)
|
|
|
|
for link in links_by_to.get(table.name, []):
|
|
if link.from_table not in table_names:
|
|
logger.warning(
|
|
"Skipping reverse link from %s.%s to %s.%s: source table not in scan",
|
|
link.from_table,
|
|
link.from_column,
|
|
link.to_table,
|
|
link.to_column,
|
|
)
|
|
continue
|
|
|
|
forward_relationship = _RELATIONSHIP_MAP.get(
|
|
link.relationship_type, "many_to_one"
|
|
)
|
|
reverse_relationship = _RELATIONSHIP_INVERSE.get(
|
|
forward_relationship, "one_to_many"
|
|
)
|
|
joins.append(
|
|
JoinDeclaration(
|
|
to=link.from_table,
|
|
on=f"{link.to_column} = {link.from_table}.{link.from_column}",
|
|
relationship=reverse_relationship,
|
|
)
|
|
)
|
|
|
|
to_counts: dict[str, int] = {}
|
|
for join in joins:
|
|
to_counts[join.to] = to_counts.get(join.to, 0) + 1
|
|
if any(count > 1 for count in to_counts.values()):
|
|
for join in joins:
|
|
if to_counts[join.to] > 1:
|
|
fk_col = join.on.split(" = ")[0].strip().lower()
|
|
join.alias = f"{join.to}_{fk_col}"
|
|
|
|
source = SourceDefinition(
|
|
name=table.name,
|
|
description=table.comment,
|
|
table=_build_table_ref(table),
|
|
grain=grain,
|
|
columns=sl_columns,
|
|
joins=joins,
|
|
measures=_generate_measures(table.name, table.columns, pk_columns),
|
|
)
|
|
sources.append(source.model_dump(exclude_none=True))
|
|
|
|
logger.info("Generated %d ktx-sl source definitions", len(sources))
|
|
return sources
|
|
|
|
|
|
def generate_sources_response(
|
|
request: GenerateSourcesRequest,
|
|
) -> GenerateSourcesResponse:
|
|
sources = generate_sources(request)
|
|
return GenerateSourcesResponse(sources=sources, source_count=len(sources))
|