mirror of
https://github.com/Kaelio/ktx.git
synced 2026-06-07 07:55:13 +02:00
feat(query-history): scope mining to modeled schemas by default (#258)
* feat(query-history): structure SQL analysis table refs * feat(query-history): qualify SQL analysis table refs * feat(query-history): wire modeled scope floor through ingest * chore(query-history): verify scope floor * test(query-history): align daemon SQL batch endpoint contract * feat(query-history): build scope from same-run scan catalog * feat(query-history): fail open on scope-floor catalog failures * chore(query-history): verify scope-floor v1 closure * refactor(query-history): share scope membership * feat(setup): apply derived query history filters * docs: document derived query history filters * fix(query-history): redact filter picker LLM prompt SQL * fix(setup): run filter picker SQL analysis through managed daemon * chore(query-history): verify filter picker v1 closure * fix(query-history): fail open on partial service-account attribution * fix(query-history): aggregate BigQuery users by execution count * fix(query-history): aggregate Snowflake users by execution count * fix(query-history): use BigQuery query info hash
This commit is contained in:
parent
ce1516b357
commit
e70ae1e63b
42 changed files with 3090 additions and 274 deletions
|
|
@ -2,15 +2,32 @@ from __future__ import annotations
|
|||
|
||||
import os
|
||||
from concurrent.futures import ProcessPoolExecutor
|
||||
from dataclasses import dataclass
|
||||
from typing import Literal
|
||||
|
||||
import sqlglot
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlglot import exp
|
||||
from sqlglot.optimizer.normalize_identifiers import normalize_identifiers
|
||||
from sqlglot.optimizer.qualify_tables import qualify_tables
|
||||
|
||||
SqlAnalysisClause = Literal["select", "where", "join", "groupBy", "having", "orderBy"]
|
||||
|
||||
|
||||
class SqlAnalysisTableRef(BaseModel):
|
||||
catalog: str | None = None
|
||||
db: str | None = None
|
||||
name: str
|
||||
|
||||
|
||||
class SqlAnalysisCatalogTable(SqlAnalysisTableRef):
|
||||
columns: list[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
class AnalyzeSqlCatalog(BaseModel):
|
||||
tables: list[SqlAnalysisCatalogTable] = Field(default_factory=list)
|
||||
|
||||
|
||||
class AnalyzeSqlBatchItem(BaseModel):
|
||||
id: str
|
||||
sql: str
|
||||
|
|
@ -19,11 +36,12 @@ class AnalyzeSqlBatchItem(BaseModel):
|
|||
class AnalyzeSqlBatchRequest(BaseModel):
|
||||
dialect: str
|
||||
items: list[AnalyzeSqlBatchItem]
|
||||
catalog: AnalyzeSqlCatalog | None = None
|
||||
max_workers: int | None = Field(default=None, ge=1, le=32)
|
||||
|
||||
|
||||
class AnalyzeSqlBatchResult(BaseModel):
|
||||
tables_touched: list[str] = Field(default_factory=list)
|
||||
tables_touched: list[SqlAnalysisTableRef] = Field(default_factory=list)
|
||||
columns_by_clause: dict[SqlAnalysisClause, list[str]] = Field(default_factory=dict)
|
||||
error: str | None = None
|
||||
|
||||
|
|
@ -82,17 +100,76 @@ def _ordered_unique(values: list[str]) -> list[str]:
|
|||
return result
|
||||
|
||||
|
||||
def _table_ref(table: exp.Table) -> str:
|
||||
parts: list[str] = []
|
||||
def _normalize_identifier(value: str | None, dialect: str) -> str | None:
|
||||
if value is None:
|
||||
return None
|
||||
identifier = exp.to_identifier(value)
|
||||
identifier.meta["is_table"] = True
|
||||
normalized = normalize_identifiers(identifier, dialect=dialect)
|
||||
return str(normalized.name)
|
||||
|
||||
|
||||
def _normalized_ref(ref: SqlAnalysisTableRef, dialect: str) -> SqlAnalysisTableRef:
|
||||
return SqlAnalysisTableRef(
|
||||
catalog=_normalize_identifier(ref.catalog, dialect),
|
||||
db=_normalize_identifier(ref.db, dialect),
|
||||
name=_normalize_identifier(ref.name, dialect) or ref.name,
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class _CatalogIndex:
|
||||
by_full: dict[tuple[str | None, str | None, str], SqlAnalysisTableRef]
|
||||
by_name: dict[str, list[SqlAnalysisTableRef]]
|
||||
|
||||
|
||||
def _catalog_index(
|
||||
catalog: AnalyzeSqlCatalog | None, dialect: str
|
||||
) -> _CatalogIndex | None:
|
||||
if catalog is None or not catalog.tables:
|
||||
return None
|
||||
by_full: dict[tuple[str | None, str | None, str], SqlAnalysisTableRef] = {}
|
||||
by_name: dict[str, list[SqlAnalysisTableRef]] = {}
|
||||
for table in catalog.tables:
|
||||
ref = _normalized_ref(table, dialect)
|
||||
key = (ref.catalog, ref.db, ref.name)
|
||||
by_full[key] = ref
|
||||
by_name.setdefault(ref.name, []).append(ref)
|
||||
return _CatalogIndex(by_full=by_full, by_name=by_name)
|
||||
|
||||
|
||||
def _raw_table_ref(table: exp.Table, dialect: str) -> SqlAnalysisTableRef | None:
|
||||
if not table.name:
|
||||
return None
|
||||
catalog = table.args.get("catalog")
|
||||
db = table.args.get("db")
|
||||
if catalog is not None and getattr(catalog, "name", None):
|
||||
parts.append(str(catalog.name))
|
||||
if db is not None and getattr(db, "name", None):
|
||||
parts.append(str(db.name))
|
||||
if table.name:
|
||||
parts.append(str(table.name))
|
||||
return ".".join(parts)
|
||||
return _normalized_ref(
|
||||
SqlAnalysisTableRef(
|
||||
catalog=str(catalog.name)
|
||||
if catalog is not None and getattr(catalog, "name", None)
|
||||
else None,
|
||||
db=str(db.name) if db is not None and getattr(db, "name", None) else None,
|
||||
name=str(table.name),
|
||||
),
|
||||
dialect,
|
||||
)
|
||||
|
||||
|
||||
def _resolve_table_refs(
|
||||
raw: SqlAnalysisTableRef,
|
||||
catalog: _CatalogIndex | None,
|
||||
) -> list[SqlAnalysisTableRef]:
|
||||
if catalog is None:
|
||||
return [raw]
|
||||
exact = catalog.by_full.get((raw.catalog, raw.db, raw.name))
|
||||
if exact is not None:
|
||||
return [exact]
|
||||
if raw.db is not None:
|
||||
return [raw]
|
||||
matches = catalog.by_name.get(raw.name, [])
|
||||
if matches:
|
||||
return matches
|
||||
return [SqlAnalysisTableRef(catalog=None, db=None, name=raw.name)]
|
||||
|
||||
|
||||
def _column_name(column: exp.Column) -> str:
|
||||
|
|
@ -146,33 +223,48 @@ def _columns_by_clause(tree: exp.Expression) -> dict[SqlAnalysisClause, list[str
|
|||
return result
|
||||
|
||||
|
||||
def _table_refs(
|
||||
tree: exp.Expression, dialect: str, catalog: _CatalogIndex | None
|
||||
) -> list[SqlAnalysisTableRef]:
|
||||
normalized_tree = normalize_identifiers(tree, dialect=dialect)
|
||||
qualified_tree = qualify_tables(normalized_tree, dialect=dialect)
|
||||
cte_names = {cte.alias_or_name.lower() for cte in qualified_tree.find_all(exp.CTE)}
|
||||
refs: list[SqlAnalysisTableRef] = []
|
||||
seen: set[tuple[str | None, str | None, str]] = set()
|
||||
for table in qualified_tree.find_all(exp.Table):
|
||||
if table.name.lower() in cte_names:
|
||||
continue
|
||||
raw = _raw_table_ref(table, dialect)
|
||||
if raw is None:
|
||||
continue
|
||||
for ref in _resolve_table_refs(raw, catalog):
|
||||
key = (ref.catalog, ref.db, ref.name)
|
||||
if key not in seen:
|
||||
seen.add(key)
|
||||
refs.append(ref)
|
||||
return refs
|
||||
|
||||
|
||||
def _analyze_one(
|
||||
item_id: str, sql: str, dialect: str
|
||||
item_id: str, sql: str, dialect: str, catalog: _CatalogIndex | None
|
||||
) -> tuple[str, AnalyzeSqlBatchResult]:
|
||||
try:
|
||||
tree = sqlglot.parse_one(sql, read=dialect)
|
||||
except sqlglot.errors.SqlglotError as exc:
|
||||
return item_id, AnalyzeSqlBatchResult(error=str(exc))
|
||||
|
||||
cte_names = {cte.alias_or_name.lower() for cte in tree.find_all(exp.CTE)}
|
||||
table_refs = [
|
||||
table_ref
|
||||
for table_ref in (_table_ref(table) for table in tree.find_all(exp.Table))
|
||||
if table_ref and table_ref.split(".")[-1].lower() not in cte_names
|
||||
]
|
||||
|
||||
return item_id, AnalyzeSqlBatchResult(
|
||||
tables_touched=_ordered_unique(table_refs),
|
||||
tables_touched=_table_refs(tree, dialect, catalog),
|
||||
columns_by_clause=_columns_by_clause(tree),
|
||||
error=None,
|
||||
)
|
||||
|
||||
|
||||
def _analyze_payload(
|
||||
payload: tuple[str, str, str],
|
||||
payload: tuple[str, str, str, _CatalogIndex | None],
|
||||
) -> tuple[str, AnalyzeSqlBatchResult]:
|
||||
item_id, sql, dialect = payload
|
||||
return _analyze_one(item_id, sql, dialect)
|
||||
item_id, sql, dialect, catalog = payload
|
||||
return _analyze_one(item_id, sql, dialect, catalog)
|
||||
|
||||
|
||||
def validate_read_only_sql_response(
|
||||
|
|
@ -222,7 +314,8 @@ def _worker_count(request: AnalyzeSqlBatchRequest) -> int:
|
|||
def analyze_sql_batch_response(
|
||||
request: AnalyzeSqlBatchRequest,
|
||||
) -> AnalyzeSqlBatchResponse:
|
||||
payloads = [(item.id, item.sql, request.dialect) for item in request.items]
|
||||
catalog = _catalog_index(request.catalog, request.dialect)
|
||||
payloads = [(item.id, item.sql, request.dialect, catalog) for item in request.items]
|
||||
if _worker_count(request) == 1:
|
||||
analyzed = [_analyze_payload(payload) for payload in payloads]
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -368,7 +368,9 @@ def test_sql_analyze_batch_endpoint_returns_per_item_results() -> None:
|
|||
|
||||
assert response.status_code == 200
|
||||
body = response.json()
|
||||
assert body["results"]["orders"]["tables_touched"] == ["public.orders"]
|
||||
assert body["results"]["orders"]["tables_touched"] == [
|
||||
{"catalog": None, "db": "public", "name": "orders"}
|
||||
]
|
||||
assert body["results"]["orders"]["columns_by_clause"] == {
|
||||
"select": ["status"],
|
||||
"where": ["created_at"],
|
||||
|
|
|
|||
|
|
@ -32,7 +32,10 @@ def test_analyze_sql_batch_extracts_tables_and_clause_columns() -> None:
|
|||
|
||||
result = response.results["orders_by_customer"]
|
||||
assert result.error is None
|
||||
assert result.tables_touched == ["public.orders", "public.customers"]
|
||||
assert [item.model_dump() for item in result.tables_touched] == [
|
||||
{"catalog": None, "db": "public", "name": "orders"},
|
||||
{"catalog": None, "db": "public", "name": "customers"},
|
||||
]
|
||||
assert result.columns_by_clause == {
|
||||
"select": ["status"],
|
||||
"where": ["created_at"],
|
||||
|
|
@ -56,6 +59,114 @@ def test_analyze_sql_batch_returns_per_item_parse_errors() -> None:
|
|||
assert result.error is not None
|
||||
|
||||
|
||||
def test_analyze_sql_batch_qualifies_bare_table_from_catalog() -> None:
|
||||
response = analyze_sql_batch_response(
|
||||
AnalyzeSqlBatchRequest(
|
||||
dialect="postgres",
|
||||
catalog={
|
||||
"tables": [
|
||||
{
|
||||
"catalog": None,
|
||||
"db": "orbit_raw",
|
||||
"name": "accounts",
|
||||
"columns": ["id"],
|
||||
},
|
||||
{
|
||||
"catalog": None,
|
||||
"db": "orbit_analytics",
|
||||
"name": "orders",
|
||||
"columns": ["id"],
|
||||
},
|
||||
]
|
||||
},
|
||||
items=[AnalyzeSqlBatchItem(id="bare", sql="select id from accounts")],
|
||||
max_workers=1,
|
||||
)
|
||||
)
|
||||
|
||||
assert [item.model_dump() for item in response.results["bare"].tables_touched] == [
|
||||
{"catalog": None, "db": "orbit_raw", "name": "accounts"}
|
||||
]
|
||||
|
||||
|
||||
def test_analyze_sql_batch_returns_all_ambiguous_modeled_matches() -> None:
|
||||
response = analyze_sql_batch_response(
|
||||
AnalyzeSqlBatchRequest(
|
||||
dialect="postgres",
|
||||
catalog={
|
||||
"tables": [
|
||||
{
|
||||
"catalog": None,
|
||||
"db": "orbit_raw",
|
||||
"name": "events",
|
||||
"columns": ["id"],
|
||||
},
|
||||
{
|
||||
"catalog": None,
|
||||
"db": "orbit_analytics",
|
||||
"name": "events",
|
||||
"columns": ["id"],
|
||||
},
|
||||
]
|
||||
},
|
||||
items=[AnalyzeSqlBatchItem(id="ambiguous", sql="select id from events")],
|
||||
max_workers=1,
|
||||
)
|
||||
)
|
||||
|
||||
assert [
|
||||
item.model_dump() for item in response.results["ambiguous"].tables_touched
|
||||
] == [
|
||||
{"catalog": None, "db": "orbit_raw", "name": "events"},
|
||||
{"catalog": None, "db": "orbit_analytics", "name": "events"},
|
||||
]
|
||||
|
||||
|
||||
def test_analyze_sql_batch_leaves_unresolved_bare_refs_unqualified() -> None:
|
||||
response = analyze_sql_batch_response(
|
||||
AnalyzeSqlBatchRequest(
|
||||
dialect="postgres",
|
||||
catalog={
|
||||
"tables": [{"catalog": None, "db": "orbit_raw", "name": "accounts"}]
|
||||
},
|
||||
items=[AnalyzeSqlBatchItem(id="missing", sql="select * from invoices")],
|
||||
max_workers=1,
|
||||
)
|
||||
)
|
||||
|
||||
assert [
|
||||
item.model_dump() for item in response.results["missing"].tables_touched
|
||||
] == [{"catalog": None, "db": None, "name": "invoices"}]
|
||||
|
||||
|
||||
def test_analyze_sql_batch_returns_bigquery_project_dataset_table_refs() -> None:
|
||||
response = analyze_sql_batch_response(
|
||||
AnalyzeSqlBatchRequest(
|
||||
dialect="bigquery",
|
||||
catalog={
|
||||
"tables": [
|
||||
{
|
||||
"catalog": "demo-project",
|
||||
"db": "orbit_analytics",
|
||||
"name": "orders",
|
||||
}
|
||||
]
|
||||
},
|
||||
items=[
|
||||
AnalyzeSqlBatchItem(
|
||||
id="bq",
|
||||
sql="select * from `demo-project.orbit_analytics.orders`",
|
||||
)
|
||||
],
|
||||
max_workers=1,
|
||||
)
|
||||
)
|
||||
|
||||
assert [item.model_dump() for item in response.results["bq"].tables_touched] == [
|
||||
{"catalog": "demo-project", "db": "orbit_analytics", "name": "orders"}
|
||||
]
|
||||
|
||||
|
||||
def test_columns_from_nodes_ignores_non_expression_clause_values() -> None:
|
||||
assert _columns_from_nodes([True, False, None]) == []
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue