mirror of
https://github.com/Kaelio/ktx.git
synced 2026-06-07 07:55:13 +02:00
feat: add daemon sql batch analysis
This commit is contained in:
parent
c45d131a1f
commit
ffbbaf417a
4 changed files with 260 additions and 0 deletions
|
|
@ -48,6 +48,11 @@ from ktx_daemon.source_generation import (
|
|||
GenerateSourcesResponse,
|
||||
generate_sources_response,
|
||||
)
|
||||
from ktx_daemon.sql_analysis import (
|
||||
AnalyzeSqlBatchRequest,
|
||||
AnalyzeSqlBatchResponse,
|
||||
analyze_sql_batch_response,
|
||||
)
|
||||
from ktx_daemon.table_identifier import (
|
||||
ParseTableIdentifierBatchRequest,
|
||||
ParseTableIdentifierBatchResponse,
|
||||
|
|
@ -193,6 +198,19 @@ def create_app(
|
|||
detail=f"Table identifier parsing failed: {error}",
|
||||
) from error
|
||||
|
||||
@app.post("/sql/analyze-batch", response_model=AnalyzeSqlBatchResponse)
|
||||
async def sql_analyze_batch(
|
||||
request: AnalyzeSqlBatchRequest,
|
||||
) -> AnalyzeSqlBatchResponse:
|
||||
try:
|
||||
return analyze_sql_batch_response(request)
|
||||
except Exception as error:
|
||||
logger.exception("SQL batch analysis failed: %s", error)
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"SQL batch analysis failed: {error}",
|
||||
) from error
|
||||
|
||||
@app.post(
|
||||
"/semantic-layer/generate-sources", response_model=GenerateSourcesResponse
|
||||
)
|
||||
|
|
|
|||
158
python/ktx-daemon/src/ktx_daemon/sql_analysis.py
Normal file
158
python/ktx-daemon/src/ktx_daemon/sql_analysis.py
Normal file
|
|
@ -0,0 +1,158 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from concurrent.futures import ProcessPoolExecutor
|
||||
from typing import Literal
|
||||
|
||||
import sqlglot
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from sqlglot import exp
|
||||
|
||||
SqlAnalysisClause = Literal["select", "where", "join", "groupBy", "having", "orderBy"]
|
||||
|
||||
|
||||
class AnalyzeSqlBatchItem(BaseModel):
|
||||
id: str
|
||||
sql: str
|
||||
|
||||
|
||||
class AnalyzeSqlBatchRequest(BaseModel):
|
||||
dialect: str
|
||||
items: list[AnalyzeSqlBatchItem]
|
||||
max_workers: int | None = Field(default=None, ge=1, le=32)
|
||||
|
||||
|
||||
class AnalyzeSqlBatchResult(BaseModel):
|
||||
model_config = ConfigDict(populate_by_name=True)
|
||||
|
||||
tables_touched: list[str] = Field(default_factory=list)
|
||||
columns_by_clause: dict[SqlAnalysisClause, list[str]] = Field(default_factory=dict)
|
||||
error: str | None = None
|
||||
|
||||
|
||||
class AnalyzeSqlBatchResponse(BaseModel):
|
||||
results: dict[str, AnalyzeSqlBatchResult]
|
||||
|
||||
|
||||
def _ordered_unique(values: list[str]) -> list[str]:
|
||||
seen: set[str] = set()
|
||||
result: list[str] = []
|
||||
for value in values:
|
||||
if value and value not in seen:
|
||||
seen.add(value)
|
||||
result.append(value)
|
||||
return result
|
||||
|
||||
|
||||
def _table_ref(table: exp.Table) -> str:
|
||||
parts: list[str] = []
|
||||
catalog = table.args.get("catalog")
|
||||
db = table.args.get("db")
|
||||
if catalog is not None and getattr(catalog, "name", None):
|
||||
parts.append(str(catalog.name))
|
||||
if db is not None and getattr(db, "name", None):
|
||||
parts.append(str(db.name))
|
||||
if table.name:
|
||||
parts.append(str(table.name))
|
||||
return ".".join(parts)
|
||||
|
||||
|
||||
def _column_name(column: exp.Column) -> str:
|
||||
return str(column.name)
|
||||
|
||||
|
||||
def _columns_from_nodes(nodes: list[exp.Expression | None]) -> list[str]:
|
||||
names: list[str] = []
|
||||
for node in nodes:
|
||||
if node is None:
|
||||
continue
|
||||
names.extend(_column_name(column) for column in node.find_all(exp.Column))
|
||||
return _ordered_unique(names)
|
||||
|
||||
|
||||
def _columns_by_clause(tree: exp.Expression) -> dict[SqlAnalysisClause, list[str]]:
|
||||
result: dict[SqlAnalysisClause, list[str]] = {}
|
||||
|
||||
select_columns = _columns_from_nodes(list(tree.expressions))
|
||||
if select_columns:
|
||||
result["select"] = select_columns
|
||||
|
||||
where_columns = _columns_from_nodes([tree.args.get("where")])
|
||||
if where_columns:
|
||||
result["where"] = where_columns
|
||||
|
||||
join_columns = _columns_from_nodes(
|
||||
[join.args.get("on") for join in tree.args.get("joins") or []]
|
||||
)
|
||||
if join_columns:
|
||||
result["join"] = join_columns
|
||||
|
||||
group = tree.args.get("group")
|
||||
group_columns = _columns_from_nodes(
|
||||
list(group.expressions) if group is not None else []
|
||||
)
|
||||
if group_columns:
|
||||
result["groupBy"] = group_columns
|
||||
|
||||
having_columns = _columns_from_nodes([tree.args.get("having")])
|
||||
if having_columns:
|
||||
result["having"] = having_columns
|
||||
|
||||
order = tree.args.get("order")
|
||||
order_columns = _columns_from_nodes(
|
||||
list(order.expressions) if order is not None else []
|
||||
)
|
||||
if order_columns:
|
||||
result["orderBy"] = order_columns
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def _analyze_one(
|
||||
item_id: str, sql: str, dialect: str
|
||||
) -> tuple[str, AnalyzeSqlBatchResult]:
|
||||
try:
|
||||
tree = sqlglot.parse_one(sql, read=dialect)
|
||||
except sqlglot.errors.SqlglotError as exc:
|
||||
return item_id, AnalyzeSqlBatchResult(error=str(exc))
|
||||
|
||||
cte_names = {cte.alias_or_name.lower() for cte in tree.find_all(exp.CTE)}
|
||||
table_refs = [
|
||||
table_ref
|
||||
for table_ref in (_table_ref(table) for table in tree.find_all(exp.Table))
|
||||
if table_ref and table_ref.split(".")[-1].lower() not in cte_names
|
||||
]
|
||||
|
||||
return item_id, AnalyzeSqlBatchResult(
|
||||
tables_touched=_ordered_unique(table_refs),
|
||||
columns_by_clause=_columns_by_clause(tree),
|
||||
error=None,
|
||||
)
|
||||
|
||||
|
||||
def _analyze_payload(payload: tuple[str, str, str]) -> tuple[str, AnalyzeSqlBatchResult]:
|
||||
item_id, sql, dialect = payload
|
||||
return _analyze_one(item_id, sql, dialect)
|
||||
|
||||
|
||||
def _worker_count(request: AnalyzeSqlBatchRequest) -> int:
|
||||
if len(request.items) <= 1:
|
||||
return 1
|
||||
if request.max_workers is not None:
|
||||
return min(request.max_workers, len(request.items))
|
||||
return min(os.cpu_count() or 1, len(request.items), 8)
|
||||
|
||||
|
||||
def analyze_sql_batch_response(
|
||||
request: AnalyzeSqlBatchRequest,
|
||||
) -> AnalyzeSqlBatchResponse:
|
||||
payloads = [(item.id, item.sql, request.dialect) for item in request.items]
|
||||
if _worker_count(request) == 1:
|
||||
analyzed = [_analyze_payload(payload) for payload in payloads]
|
||||
else:
|
||||
with ProcessPoolExecutor(max_workers=_worker_count(request)) as executor:
|
||||
analyzed = list(executor.map(_analyze_payload, payloads))
|
||||
|
||||
return AnalyzeSqlBatchResponse(
|
||||
results={item_id: result for item_id, result in analyzed}
|
||||
)
|
||||
|
|
@ -280,6 +280,37 @@ def test_sql_parse_table_identifier_endpoint() -> None:
|
|||
assert body["results"]["template"]["reason"] == "looker_template_unresolved"
|
||||
|
||||
|
||||
def test_sql_analyze_batch_endpoint_returns_per_item_results() -> None:
|
||||
client = TestClient(create_app())
|
||||
|
||||
response = client.post(
|
||||
"/sql/analyze-batch",
|
||||
json={
|
||||
"dialect": "postgres",
|
||||
"max_workers": 1,
|
||||
"items": [
|
||||
{
|
||||
"id": "orders",
|
||||
"sql": "select status from public.orders where created_at is not null",
|
||||
},
|
||||
{"id": "broken", "sql": "select * from where"},
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
body = response.json()
|
||||
assert body["results"]["orders"]["tables_touched"] == ["public.orders"]
|
||||
assert body["results"]["orders"]["columns_by_clause"] == {
|
||||
"select": ["status"],
|
||||
"where": ["created_at"],
|
||||
}
|
||||
assert body["results"]["orders"]["error"] is None
|
||||
assert body["results"]["broken"]["tables_touched"] == []
|
||||
assert body["results"]["broken"]["columns_by_clause"] == {}
|
||||
assert body["results"]["broken"]["error"] is not None
|
||||
|
||||
|
||||
def test_semantic_query_endpoint_returns_sql() -> None:
|
||||
client = TestClient(create_app())
|
||||
|
||||
|
|
|
|||
53
python/ktx-daemon/tests/test_sql_analysis.py
Normal file
53
python/ktx-daemon/tests/test_sql_analysis.py
Normal file
|
|
@ -0,0 +1,53 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from ktx_daemon.sql_analysis import (
|
||||
AnalyzeSqlBatchItem,
|
||||
AnalyzeSqlBatchRequest,
|
||||
analyze_sql_batch_response,
|
||||
)
|
||||
|
||||
|
||||
def test_analyze_sql_batch_extracts_tables_and_clause_columns() -> None:
|
||||
response = analyze_sql_batch_response(
|
||||
AnalyzeSqlBatchRequest(
|
||||
dialect="postgres",
|
||||
items=[
|
||||
AnalyzeSqlBatchItem(
|
||||
id="orders_by_customer",
|
||||
sql=(
|
||||
"select o.status, count(*) "
|
||||
"from public.orders o "
|
||||
"join public.customers c on o.customer_id = c.id "
|
||||
"where o.created_at >= current_date - interval '30 day' "
|
||||
"group by o.status"
|
||||
),
|
||||
)
|
||||
],
|
||||
max_workers=1,
|
||||
)
|
||||
)
|
||||
|
||||
result = response.results["orders_by_customer"]
|
||||
assert result.error is None
|
||||
assert result.tables_touched == ["public.orders", "public.customers"]
|
||||
assert result.columns_by_clause == {
|
||||
"select": ["status"],
|
||||
"where": ["created_at"],
|
||||
"join": ["customer_id", "id"],
|
||||
"groupBy": ["status"],
|
||||
}
|
||||
|
||||
|
||||
def test_analyze_sql_batch_returns_per_item_parse_errors() -> None:
|
||||
response = analyze_sql_batch_response(
|
||||
AnalyzeSqlBatchRequest(
|
||||
dialect="postgres",
|
||||
items=[AnalyzeSqlBatchItem(id="broken", sql="select * from where")],
|
||||
max_workers=1,
|
||||
)
|
||||
)
|
||||
|
||||
result = response.results["broken"]
|
||||
assert result.tables_touched == []
|
||||
assert result.columns_by_clause == {}
|
||||
assert result.error is not None
|
||||
Loading…
Add table
Add a link
Reference in a new issue