diff --git a/python/ktx-daemon/src/ktx_daemon/app.py b/python/ktx-daemon/src/ktx_daemon/app.py index 76325719..0d7016dd 100644 --- a/python/ktx-daemon/src/ktx_daemon/app.py +++ b/python/ktx-daemon/src/ktx_daemon/app.py @@ -51,7 +51,10 @@ from ktx_daemon.source_generation import ( from ktx_daemon.sql_analysis import ( AnalyzeSqlBatchRequest, AnalyzeSqlBatchResponse, + ValidateReadOnlySqlRequest, + ValidateReadOnlySqlResponse, analyze_sql_batch_response, + validate_read_only_sql_response, ) from ktx_daemon.table_identifier import ( ParseTableIdentifierBatchRequest, @@ -198,6 +201,19 @@ def create_app( detail=f"Table identifier parsing failed: {error}", ) from error + @app.post("/sql/validate-read-only", response_model=ValidateReadOnlySqlResponse) + async def sql_validate_read_only( + request: ValidateReadOnlySqlRequest, + ) -> ValidateReadOnlySqlResponse: + try: + return validate_read_only_sql_response(request) + except Exception as error: + logger.exception("SQL read-only validation failed: %s", error) + raise HTTPException( + status_code=500, + detail=f"SQL read-only validation failed: {error}", + ) from error + @app.post("/sql/analyze-batch", response_model=AnalyzeSqlBatchResponse) async def sql_analyze_batch( request: AnalyzeSqlBatchRequest, diff --git a/python/ktx-daemon/src/ktx_daemon/sql_analysis.py b/python/ktx-daemon/src/ktx_daemon/sql_analysis.py index d5deb240..ebecf83c 100644 --- a/python/ktx-daemon/src/ktx_daemon/sql_analysis.py +++ b/python/ktx-daemon/src/ktx_daemon/sql_analysis.py @@ -34,6 +34,46 @@ class AnalyzeSqlBatchResponse(BaseModel): results: dict[str, AnalyzeSqlBatchResult] +class ValidateReadOnlySqlRequest(BaseModel): + dialect: str + sql: str + + +class ValidateReadOnlySqlResponse(BaseModel): + ok: bool + error: str | None = None + + +_READ_ONLY_ROOT_TYPES = (exp.Select, exp.Union) +_READ_WRITE_NODE_TYPES = ( + exp.Alter, + exp.Analyze, + exp.Cache, + exp.Command, + exp.Commit, + exp.Copy, + exp.Create, + exp.Delete, + exp.Describe, + exp.Drop, + exp.Execute, + exp.Grant, + exp.Insert, + exp.Merge, + exp.Pragma, + exp.Refresh, + exp.Revoke, + exp.Rollback, + exp.Set, + exp.Show, + exp.Transaction, + exp.TruncateTable, + exp.Uncache, + exp.Update, + exp.Use, +) + + def _ordered_unique(values: list[str]) -> list[str]: seen: set[str] = set() result: list[str] = [] @@ -137,6 +177,42 @@ def _analyze_payload( return _analyze_one(item_id, sql, dialect) +def validate_read_only_sql_response( + request: ValidateReadOnlySqlRequest, +) -> ValidateReadOnlySqlResponse: + try: + statements = sqlglot.parse(request.sql, read=request.dialect) + except sqlglot.errors.SqlglotError as exc: + return ValidateReadOnlySqlResponse(ok=False, error=f"Invalid expression: {exc}") + + if len(statements) != 1: + return ValidateReadOnlySqlResponse( + ok=False, + error="Only one SQL statement can be executed.", + ) + + tree = statements[0] + if tree is None: + return ValidateReadOnlySqlResponse( + ok=False, + error="SQL did not parse to a statement.", + ) + if not isinstance(tree, _READ_ONLY_ROOT_TYPES): + return ValidateReadOnlySqlResponse( + ok=False, + error=f"SQL contains read/write operation: {type(tree).__name__}", + ) + + for node in tree.walk(): + if isinstance(node, _READ_WRITE_NODE_TYPES): + return ValidateReadOnlySqlResponse( + ok=False, + error=f"SQL contains read/write operation: {type(node).__name__}", + ) + + return ValidateReadOnlySqlResponse(ok=True, error=None) + + def _worker_count(request: AnalyzeSqlBatchRequest) -> int: if len(request.items) <= 1: return 1 diff --git a/python/ktx-daemon/tests/test_app.py b/python/ktx-daemon/tests/test_app.py index eb2c3d68..3c1ce18d 100644 --- a/python/ktx-daemon/tests/test_app.py +++ b/python/ktx-daemon/tests/test_app.py @@ -280,6 +280,30 @@ def test_sql_parse_table_identifier_endpoint() -> None: assert body["results"]["template"]["reason"] == "looker_template_unresolved" +def test_sql_validate_read_only_endpoint() -> None: + client = TestClient(create_app()) + + ok_response = client.post( + "/sql/validate-read-only", + json={"dialect": "postgres", "sql": "select * from public.orders"}, + ) + bad_response = client.post( + "/sql/validate-read-only", + json={ + "dialect": "postgres", + "sql": "with x as (insert into audit.events values (1) returning *) select * from x", + }, + ) + + assert ok_response.status_code == 200 + assert ok_response.json() == {"ok": True, "error": None} + assert bad_response.status_code == 200 + assert bad_response.json() == { + "ok": False, + "error": "SQL contains read/write operation: Insert", + } + + def test_sql_analyze_batch_endpoint_returns_per_item_results() -> None: client = TestClient(create_app()) diff --git a/python/ktx-daemon/tests/test_sql_analysis.py b/python/ktx-daemon/tests/test_sql_analysis.py index c1fc35f8..855d16fd 100644 --- a/python/ktx-daemon/tests/test_sql_analysis.py +++ b/python/ktx-daemon/tests/test_sql_analysis.py @@ -3,8 +3,10 @@ from __future__ import annotations from ktx_daemon.sql_analysis import ( AnalyzeSqlBatchItem, AnalyzeSqlBatchRequest, + ValidateReadOnlySqlRequest, _columns_from_nodes, analyze_sql_batch_response, + validate_read_only_sql_response, ) @@ -56,3 +58,74 @@ def test_analyze_sql_batch_returns_per_item_parse_errors() -> None: def test_columns_from_nodes_ignores_non_expression_clause_values() -> None: assert _columns_from_nodes([True, False, None]) == [] + + +def test_validate_read_only_sql_accepts_select_and_with_queries() -> None: + select_response = validate_read_only_sql_response( + ValidateReadOnlySqlRequest( + dialect="postgres", + sql="select id, status from public.orders where status = 'paid'", + ) + ) + with_response = validate_read_only_sql_response( + ValidateReadOnlySqlRequest( + dialect="postgres", + sql=( + "with paid as (select * from public.orders where status = 'paid') " + "select count(*) from paid" + ), + ) + ) + + assert select_response.ok is True + assert select_response.error is None + assert with_response.ok is True + assert with_response.error is None + + +def test_validate_read_only_sql_rejects_cte_dml() -> None: + response = validate_read_only_sql_response( + ValidateReadOnlySqlRequest( + dialect="postgres", + sql="with x as (insert into audit.events values (1) returning *) select * from x", + ) + ) + + assert response.ok is False + assert response.error == "SQL contains read/write operation: Insert" + + +def test_validate_read_only_sql_rejects_multi_statement_payloads() -> None: + response = validate_read_only_sql_response( + ValidateReadOnlySqlRequest( + dialect="postgres", + sql="select * from public.orders; delete from public.orders", + ) + ) + + assert response.ok is False + assert response.error == "Only one SQL statement can be executed." + + +def test_validate_read_only_sql_rejects_commands_and_pragmas() -> None: + command_response = validate_read_only_sql_response( + ValidateReadOnlySqlRequest(dialect="postgres", sql="call refresh_stats()") + ) + pragma_response = validate_read_only_sql_response( + ValidateReadOnlySqlRequest(dialect="sqlite", sql="pragma table_info(users)") + ) + + assert command_response.ok is False + assert command_response.error == "SQL contains read/write operation: Command" + assert pragma_response.ok is False + assert pragma_response.error == "SQL contains read/write operation: Pragma" + + +def test_validate_read_only_sql_reports_parse_errors() -> None: + response = validate_read_only_sql_response( + ValidateReadOnlySqlRequest(dialect="postgres", sql="select * from where") + ) + + assert response.ok is False + assert response.error is not None + assert "Invalid expression" in response.error