diff --git a/python/ktx-daemon/src/ktx_daemon/app.py b/python/ktx-daemon/src/ktx_daemon/app.py index 272b0c24..76325719 100644 --- a/python/ktx-daemon/src/ktx_daemon/app.py +++ b/python/ktx-daemon/src/ktx_daemon/app.py @@ -48,6 +48,11 @@ from ktx_daemon.source_generation import ( GenerateSourcesResponse, generate_sources_response, ) +from ktx_daemon.sql_analysis import ( + AnalyzeSqlBatchRequest, + AnalyzeSqlBatchResponse, + analyze_sql_batch_response, +) from ktx_daemon.table_identifier import ( ParseTableIdentifierBatchRequest, ParseTableIdentifierBatchResponse, @@ -193,6 +198,19 @@ def create_app( detail=f"Table identifier parsing failed: {error}", ) from error + @app.post("/sql/analyze-batch", response_model=AnalyzeSqlBatchResponse) + async def sql_analyze_batch( + request: AnalyzeSqlBatchRequest, + ) -> AnalyzeSqlBatchResponse: + try: + return analyze_sql_batch_response(request) + except Exception as error: + logger.exception("SQL batch analysis failed: %s", error) + raise HTTPException( + status_code=500, + detail=f"SQL batch analysis failed: {error}", + ) from error + @app.post( "/semantic-layer/generate-sources", response_model=GenerateSourcesResponse ) diff --git a/python/ktx-daemon/src/ktx_daemon/sql_analysis.py b/python/ktx-daemon/src/ktx_daemon/sql_analysis.py new file mode 100644 index 00000000..80e6d85b --- /dev/null +++ b/python/ktx-daemon/src/ktx_daemon/sql_analysis.py @@ -0,0 +1,158 @@ +from __future__ import annotations + +import os +from concurrent.futures import ProcessPoolExecutor +from typing import Literal + +import sqlglot +from pydantic import BaseModel, ConfigDict, Field +from sqlglot import exp + +SqlAnalysisClause = Literal["select", "where", "join", "groupBy", "having", "orderBy"] + + +class AnalyzeSqlBatchItem(BaseModel): + id: str + sql: str + + +class AnalyzeSqlBatchRequest(BaseModel): + dialect: str + items: list[AnalyzeSqlBatchItem] + max_workers: int | None = Field(default=None, ge=1, le=32) + + +class AnalyzeSqlBatchResult(BaseModel): + model_config = ConfigDict(populate_by_name=True) + + tables_touched: list[str] = Field(default_factory=list) + columns_by_clause: dict[SqlAnalysisClause, list[str]] = Field(default_factory=dict) + error: str | None = None + + +class AnalyzeSqlBatchResponse(BaseModel): + results: dict[str, AnalyzeSqlBatchResult] + + +def _ordered_unique(values: list[str]) -> list[str]: + seen: set[str] = set() + result: list[str] = [] + for value in values: + if value and value not in seen: + seen.add(value) + result.append(value) + return result + + +def _table_ref(table: exp.Table) -> str: + parts: list[str] = [] + catalog = table.args.get("catalog") + db = table.args.get("db") + if catalog is not None and getattr(catalog, "name", None): + parts.append(str(catalog.name)) + if db is not None and getattr(db, "name", None): + parts.append(str(db.name)) + if table.name: + parts.append(str(table.name)) + return ".".join(parts) + + +def _column_name(column: exp.Column) -> str: + return str(column.name) + + +def _columns_from_nodes(nodes: list[exp.Expression | None]) -> list[str]: + names: list[str] = [] + for node in nodes: + if node is None: + continue + names.extend(_column_name(column) for column in node.find_all(exp.Column)) + return _ordered_unique(names) + + +def _columns_by_clause(tree: exp.Expression) -> dict[SqlAnalysisClause, list[str]]: + result: dict[SqlAnalysisClause, list[str]] = {} + + select_columns = _columns_from_nodes(list(tree.expressions)) + if select_columns: + result["select"] = select_columns + + where_columns = _columns_from_nodes([tree.args.get("where")]) + if where_columns: + result["where"] = where_columns + + join_columns = _columns_from_nodes( + [join.args.get("on") for join in tree.args.get("joins") or []] + ) + if join_columns: + result["join"] = join_columns + + group = tree.args.get("group") + group_columns = _columns_from_nodes( + list(group.expressions) if group is not None else [] + ) + if group_columns: + result["groupBy"] = group_columns + + having_columns = _columns_from_nodes([tree.args.get("having")]) + if having_columns: + result["having"] = having_columns + + order = tree.args.get("order") + order_columns = _columns_from_nodes( + list(order.expressions) if order is not None else [] + ) + if order_columns: + result["orderBy"] = order_columns + + return result + + +def _analyze_one( + item_id: str, sql: str, dialect: str +) -> tuple[str, AnalyzeSqlBatchResult]: + try: + tree = sqlglot.parse_one(sql, read=dialect) + except sqlglot.errors.SqlglotError as exc: + return item_id, AnalyzeSqlBatchResult(error=str(exc)) + + cte_names = {cte.alias_or_name.lower() for cte in tree.find_all(exp.CTE)} + table_refs = [ + table_ref + for table_ref in (_table_ref(table) for table in tree.find_all(exp.Table)) + if table_ref and table_ref.split(".")[-1].lower() not in cte_names + ] + + return item_id, AnalyzeSqlBatchResult( + tables_touched=_ordered_unique(table_refs), + columns_by_clause=_columns_by_clause(tree), + error=None, + ) + + +def _analyze_payload(payload: tuple[str, str, str]) -> tuple[str, AnalyzeSqlBatchResult]: + item_id, sql, dialect = payload + return _analyze_one(item_id, sql, dialect) + + +def _worker_count(request: AnalyzeSqlBatchRequest) -> int: + if len(request.items) <= 1: + return 1 + if request.max_workers is not None: + return min(request.max_workers, len(request.items)) + return min(os.cpu_count() or 1, len(request.items), 8) + + +def analyze_sql_batch_response( + request: AnalyzeSqlBatchRequest, +) -> AnalyzeSqlBatchResponse: + payloads = [(item.id, item.sql, request.dialect) for item in request.items] + if _worker_count(request) == 1: + analyzed = [_analyze_payload(payload) for payload in payloads] + else: + with ProcessPoolExecutor(max_workers=_worker_count(request)) as executor: + analyzed = list(executor.map(_analyze_payload, payloads)) + + return AnalyzeSqlBatchResponse( + results={item_id: result for item_id, result in analyzed} + ) diff --git a/python/ktx-daemon/tests/test_app.py b/python/ktx-daemon/tests/test_app.py index cd5c4f16..eb2c3d68 100644 --- a/python/ktx-daemon/tests/test_app.py +++ b/python/ktx-daemon/tests/test_app.py @@ -280,6 +280,37 @@ def test_sql_parse_table_identifier_endpoint() -> None: assert body["results"]["template"]["reason"] == "looker_template_unresolved" +def test_sql_analyze_batch_endpoint_returns_per_item_results() -> None: + client = TestClient(create_app()) + + response = client.post( + "/sql/analyze-batch", + json={ + "dialect": "postgres", + "max_workers": 1, + "items": [ + { + "id": "orders", + "sql": "select status from public.orders where created_at is not null", + }, + {"id": "broken", "sql": "select * from where"}, + ], + }, + ) + + assert response.status_code == 200 + body = response.json() + assert body["results"]["orders"]["tables_touched"] == ["public.orders"] + assert body["results"]["orders"]["columns_by_clause"] == { + "select": ["status"], + "where": ["created_at"], + } + assert body["results"]["orders"]["error"] is None + assert body["results"]["broken"]["tables_touched"] == [] + assert body["results"]["broken"]["columns_by_clause"] == {} + assert body["results"]["broken"]["error"] is not None + + def test_semantic_query_endpoint_returns_sql() -> None: client = TestClient(create_app()) diff --git a/python/ktx-daemon/tests/test_sql_analysis.py b/python/ktx-daemon/tests/test_sql_analysis.py new file mode 100644 index 00000000..ac800a09 --- /dev/null +++ b/python/ktx-daemon/tests/test_sql_analysis.py @@ -0,0 +1,53 @@ +from __future__ import annotations + +from ktx_daemon.sql_analysis import ( + AnalyzeSqlBatchItem, + AnalyzeSqlBatchRequest, + analyze_sql_batch_response, +) + + +def test_analyze_sql_batch_extracts_tables_and_clause_columns() -> None: + response = analyze_sql_batch_response( + AnalyzeSqlBatchRequest( + dialect="postgres", + items=[ + AnalyzeSqlBatchItem( + id="orders_by_customer", + sql=( + "select o.status, count(*) " + "from public.orders o " + "join public.customers c on o.customer_id = c.id " + "where o.created_at >= current_date - interval '30 day' " + "group by o.status" + ), + ) + ], + max_workers=1, + ) + ) + + result = response.results["orders_by_customer"] + assert result.error is None + assert result.tables_touched == ["public.orders", "public.customers"] + assert result.columns_by_clause == { + "select": ["status"], + "where": ["created_at"], + "join": ["customer_id", "id"], + "groupBy": ["status"], + } + + +def test_analyze_sql_batch_returns_per_item_parse_errors() -> None: + response = analyze_sql_batch_response( + AnalyzeSqlBatchRequest( + dialect="postgres", + items=[AnalyzeSqlBatchItem(id="broken", sql="select * from where")], + max_workers=1, + ) + ) + + result = response.results["broken"] + assert result.tables_touched == [] + assert result.columns_by_clause == {} + assert result.error is not None