2026-05-11 16:56:50 +02:00
|
|
|
from __future__ import annotations
|
|
|
|
|
|
|
|
|
|
from ktx_daemon.sql_analysis import (
|
|
|
|
|
AnalyzeSqlBatchItem,
|
|
|
|
|
AnalyzeSqlBatchRequest,
|
2026-05-15 02:35:09 +02:00
|
|
|
ValidateReadOnlySqlRequest,
|
2026-05-11 22:35:07 +02:00
|
|
|
_columns_from_nodes,
|
2026-05-11 16:56:50 +02:00
|
|
|
analyze_sql_batch_response,
|
2026-05-15 02:35:09 +02:00
|
|
|
validate_read_only_sql_response,
|
2026-05-11 16:56:50 +02:00
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_analyze_sql_batch_extracts_tables_and_clause_columns() -> None:
|
|
|
|
|
response = analyze_sql_batch_response(
|
|
|
|
|
AnalyzeSqlBatchRequest(
|
|
|
|
|
dialect="postgres",
|
|
|
|
|
items=[
|
|
|
|
|
AnalyzeSqlBatchItem(
|
|
|
|
|
id="orders_by_customer",
|
|
|
|
|
sql=(
|
|
|
|
|
"select o.status, count(*) "
|
|
|
|
|
"from public.orders o "
|
|
|
|
|
"join public.customers c on o.customer_id = c.id "
|
|
|
|
|
"where o.created_at >= current_date - interval '30 day' "
|
|
|
|
|
"group by o.status"
|
|
|
|
|
),
|
|
|
|
|
)
|
|
|
|
|
],
|
|
|
|
|
max_workers=1,
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
result = response.results["orders_by_customer"]
|
|
|
|
|
assert result.error is None
|
2026-06-03 17:19:42 +02:00
|
|
|
assert [item.model_dump() for item in result.tables_touched] == [
|
|
|
|
|
{"catalog": None, "db": "public", "name": "orders"},
|
|
|
|
|
{"catalog": None, "db": "public", "name": "customers"},
|
|
|
|
|
]
|
2026-05-11 16:56:50 +02:00
|
|
|
assert result.columns_by_clause == {
|
|
|
|
|
"select": ["status"],
|
|
|
|
|
"where": ["created_at"],
|
|
|
|
|
"join": ["customer_id", "id"],
|
|
|
|
|
"groupBy": ["status"],
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_analyze_sql_batch_returns_per_item_parse_errors() -> None:
|
|
|
|
|
response = analyze_sql_batch_response(
|
|
|
|
|
AnalyzeSqlBatchRequest(
|
|
|
|
|
dialect="postgres",
|
|
|
|
|
items=[AnalyzeSqlBatchItem(id="broken", sql="select * from where")],
|
|
|
|
|
max_workers=1,
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
result = response.results["broken"]
|
|
|
|
|
assert result.tables_touched == []
|
|
|
|
|
assert result.columns_by_clause == {}
|
|
|
|
|
assert result.error is not None
|
2026-05-11 22:35:07 +02:00
|
|
|
|
|
|
|
|
|
2026-06-03 17:19:42 +02:00
|
|
|
def test_analyze_sql_batch_qualifies_bare_table_from_catalog() -> None:
|
|
|
|
|
response = analyze_sql_batch_response(
|
|
|
|
|
AnalyzeSqlBatchRequest(
|
|
|
|
|
dialect="postgres",
|
|
|
|
|
catalog={
|
|
|
|
|
"tables": [
|
|
|
|
|
{
|
|
|
|
|
"catalog": None,
|
|
|
|
|
"db": "orbit_raw",
|
|
|
|
|
"name": "accounts",
|
|
|
|
|
"columns": ["id"],
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"catalog": None,
|
|
|
|
|
"db": "orbit_analytics",
|
|
|
|
|
"name": "orders",
|
|
|
|
|
"columns": ["id"],
|
|
|
|
|
},
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
items=[AnalyzeSqlBatchItem(id="bare", sql="select id from accounts")],
|
|
|
|
|
max_workers=1,
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
assert [item.model_dump() for item in response.results["bare"].tables_touched] == [
|
|
|
|
|
{"catalog": None, "db": "orbit_raw", "name": "accounts"}
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_analyze_sql_batch_returns_all_ambiguous_modeled_matches() -> None:
|
|
|
|
|
response = analyze_sql_batch_response(
|
|
|
|
|
AnalyzeSqlBatchRequest(
|
|
|
|
|
dialect="postgres",
|
|
|
|
|
catalog={
|
|
|
|
|
"tables": [
|
|
|
|
|
{
|
|
|
|
|
"catalog": None,
|
|
|
|
|
"db": "orbit_raw",
|
|
|
|
|
"name": "events",
|
|
|
|
|
"columns": ["id"],
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"catalog": None,
|
|
|
|
|
"db": "orbit_analytics",
|
|
|
|
|
"name": "events",
|
|
|
|
|
"columns": ["id"],
|
|
|
|
|
},
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
items=[AnalyzeSqlBatchItem(id="ambiguous", sql="select id from events")],
|
|
|
|
|
max_workers=1,
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
assert [
|
|
|
|
|
item.model_dump() for item in response.results["ambiguous"].tables_touched
|
|
|
|
|
] == [
|
|
|
|
|
{"catalog": None, "db": "orbit_raw", "name": "events"},
|
|
|
|
|
{"catalog": None, "db": "orbit_analytics", "name": "events"},
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_analyze_sql_batch_leaves_unresolved_bare_refs_unqualified() -> None:
|
|
|
|
|
response = analyze_sql_batch_response(
|
|
|
|
|
AnalyzeSqlBatchRequest(
|
|
|
|
|
dialect="postgres",
|
|
|
|
|
catalog={
|
|
|
|
|
"tables": [{"catalog": None, "db": "orbit_raw", "name": "accounts"}]
|
|
|
|
|
},
|
|
|
|
|
items=[AnalyzeSqlBatchItem(id="missing", sql="select * from invoices")],
|
|
|
|
|
max_workers=1,
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
assert [
|
|
|
|
|
item.model_dump() for item in response.results["missing"].tables_touched
|
|
|
|
|
] == [{"catalog": None, "db": None, "name": "invoices"}]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_analyze_sql_batch_returns_bigquery_project_dataset_table_refs() -> None:
|
|
|
|
|
response = analyze_sql_batch_response(
|
|
|
|
|
AnalyzeSqlBatchRequest(
|
|
|
|
|
dialect="bigquery",
|
|
|
|
|
catalog={
|
|
|
|
|
"tables": [
|
|
|
|
|
{
|
|
|
|
|
"catalog": "demo-project",
|
|
|
|
|
"db": "orbit_analytics",
|
|
|
|
|
"name": "orders",
|
|
|
|
|
}
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
items=[
|
|
|
|
|
AnalyzeSqlBatchItem(
|
|
|
|
|
id="bq",
|
|
|
|
|
sql="select * from `demo-project.orbit_analytics.orders`",
|
|
|
|
|
)
|
|
|
|
|
],
|
|
|
|
|
max_workers=1,
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
assert [item.model_dump() for item in response.results["bq"].tables_touched] == [
|
|
|
|
|
{"catalog": "demo-project", "db": "orbit_analytics", "name": "orders"}
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
2026-05-11 22:35:07 +02:00
|
|
|
def test_columns_from_nodes_ignores_non_expression_clause_values() -> None:
|
|
|
|
|
assert _columns_from_nodes([True, False, None]) == []
|
2026-05-15 02:35:09 +02:00
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_validate_read_only_sql_accepts_select_and_with_queries() -> None:
|
|
|
|
|
select_response = validate_read_only_sql_response(
|
|
|
|
|
ValidateReadOnlySqlRequest(
|
|
|
|
|
dialect="postgres",
|
|
|
|
|
sql="select id, status from public.orders where status = 'paid'",
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
with_response = validate_read_only_sql_response(
|
|
|
|
|
ValidateReadOnlySqlRequest(
|
|
|
|
|
dialect="postgres",
|
|
|
|
|
sql=(
|
|
|
|
|
"with paid as (select * from public.orders where status = 'paid') "
|
|
|
|
|
"select count(*) from paid"
|
|
|
|
|
),
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
assert select_response.ok is True
|
|
|
|
|
assert select_response.error is None
|
|
|
|
|
assert with_response.ok is True
|
|
|
|
|
assert with_response.error is None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_validate_read_only_sql_rejects_cte_dml() -> None:
|
|
|
|
|
response = validate_read_only_sql_response(
|
|
|
|
|
ValidateReadOnlySqlRequest(
|
|
|
|
|
dialect="postgres",
|
|
|
|
|
sql="with x as (insert into audit.events values (1) returning *) select * from x",
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
assert response.ok is False
|
|
|
|
|
assert response.error == "SQL contains read/write operation: Insert"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_validate_read_only_sql_rejects_multi_statement_payloads() -> None:
|
|
|
|
|
response = validate_read_only_sql_response(
|
|
|
|
|
ValidateReadOnlySqlRequest(
|
|
|
|
|
dialect="postgres",
|
|
|
|
|
sql="select * from public.orders; delete from public.orders",
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
assert response.ok is False
|
|
|
|
|
assert response.error == "Only one SQL statement can be executed."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_validate_read_only_sql_rejects_commands_and_pragmas() -> None:
|
|
|
|
|
command_response = validate_read_only_sql_response(
|
|
|
|
|
ValidateReadOnlySqlRequest(dialect="postgres", sql="call refresh_stats()")
|
|
|
|
|
)
|
|
|
|
|
pragma_response = validate_read_only_sql_response(
|
|
|
|
|
ValidateReadOnlySqlRequest(dialect="sqlite", sql="pragma table_info(users)")
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
assert command_response.ok is False
|
|
|
|
|
assert command_response.error == "SQL contains read/write operation: Command"
|
|
|
|
|
assert pragma_response.ok is False
|
|
|
|
|
assert pragma_response.error == "SQL contains read/write operation: Pragma"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_validate_read_only_sql_reports_parse_errors() -> None:
|
|
|
|
|
response = validate_read_only_sql_response(
|
|
|
|
|
ValidateReadOnlySqlRequest(dialect="postgres", sql="select * from where")
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
assert response.ok is False
|
|
|
|
|
assert response.error is not None
|
|
|
|
|
assert "Invalid expression" in response.error
|