mirror of
https://github.com/Kaelio/ktx.git
synced 2026-06-16 08:25:14 +02:00
feat(daemon): validate read-only SQL with sqlglot
This commit is contained in:
parent
de9f4d97e7
commit
aa4431b295
4 changed files with 189 additions and 0 deletions
|
|
@ -51,7 +51,10 @@ from ktx_daemon.source_generation import (
|
|||
from ktx_daemon.sql_analysis import (
|
||||
AnalyzeSqlBatchRequest,
|
||||
AnalyzeSqlBatchResponse,
|
||||
ValidateReadOnlySqlRequest,
|
||||
ValidateReadOnlySqlResponse,
|
||||
analyze_sql_batch_response,
|
||||
validate_read_only_sql_response,
|
||||
)
|
||||
from ktx_daemon.table_identifier import (
|
||||
ParseTableIdentifierBatchRequest,
|
||||
|
|
@ -198,6 +201,19 @@ def create_app(
|
|||
detail=f"Table identifier parsing failed: {error}",
|
||||
) from error
|
||||
|
||||
@app.post("/sql/validate-read-only", response_model=ValidateReadOnlySqlResponse)
|
||||
async def sql_validate_read_only(
|
||||
request: ValidateReadOnlySqlRequest,
|
||||
) -> ValidateReadOnlySqlResponse:
|
||||
try:
|
||||
return validate_read_only_sql_response(request)
|
||||
except Exception as error:
|
||||
logger.exception("SQL read-only validation failed: %s", error)
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"SQL read-only validation failed: {error}",
|
||||
) from error
|
||||
|
||||
@app.post("/sql/analyze-batch", response_model=AnalyzeSqlBatchResponse)
|
||||
async def sql_analyze_batch(
|
||||
request: AnalyzeSqlBatchRequest,
|
||||
|
|
|
|||
|
|
@ -34,6 +34,46 @@ class AnalyzeSqlBatchResponse(BaseModel):
|
|||
results: dict[str, AnalyzeSqlBatchResult]
|
||||
|
||||
|
||||
class ValidateReadOnlySqlRequest(BaseModel):
|
||||
dialect: str
|
||||
sql: str
|
||||
|
||||
|
||||
class ValidateReadOnlySqlResponse(BaseModel):
|
||||
ok: bool
|
||||
error: str | None = None
|
||||
|
||||
|
||||
_READ_ONLY_ROOT_TYPES = (exp.Select, exp.Union)
|
||||
_READ_WRITE_NODE_TYPES = (
|
||||
exp.Alter,
|
||||
exp.Analyze,
|
||||
exp.Cache,
|
||||
exp.Command,
|
||||
exp.Commit,
|
||||
exp.Copy,
|
||||
exp.Create,
|
||||
exp.Delete,
|
||||
exp.Describe,
|
||||
exp.Drop,
|
||||
exp.Execute,
|
||||
exp.Grant,
|
||||
exp.Insert,
|
||||
exp.Merge,
|
||||
exp.Pragma,
|
||||
exp.Refresh,
|
||||
exp.Revoke,
|
||||
exp.Rollback,
|
||||
exp.Set,
|
||||
exp.Show,
|
||||
exp.Transaction,
|
||||
exp.TruncateTable,
|
||||
exp.Uncache,
|
||||
exp.Update,
|
||||
exp.Use,
|
||||
)
|
||||
|
||||
|
||||
def _ordered_unique(values: list[str]) -> list[str]:
|
||||
seen: set[str] = set()
|
||||
result: list[str] = []
|
||||
|
|
@ -137,6 +177,42 @@ def _analyze_payload(
|
|||
return _analyze_one(item_id, sql, dialect)
|
||||
|
||||
|
||||
def validate_read_only_sql_response(
|
||||
request: ValidateReadOnlySqlRequest,
|
||||
) -> ValidateReadOnlySqlResponse:
|
||||
try:
|
||||
statements = sqlglot.parse(request.sql, read=request.dialect)
|
||||
except sqlglot.errors.SqlglotError as exc:
|
||||
return ValidateReadOnlySqlResponse(ok=False, error=f"Invalid expression: {exc}")
|
||||
|
||||
if len(statements) != 1:
|
||||
return ValidateReadOnlySqlResponse(
|
||||
ok=False,
|
||||
error="Only one SQL statement can be executed.",
|
||||
)
|
||||
|
||||
tree = statements[0]
|
||||
if tree is None:
|
||||
return ValidateReadOnlySqlResponse(
|
||||
ok=False,
|
||||
error="SQL did not parse to a statement.",
|
||||
)
|
||||
if not isinstance(tree, _READ_ONLY_ROOT_TYPES):
|
||||
return ValidateReadOnlySqlResponse(
|
||||
ok=False,
|
||||
error=f"SQL contains read/write operation: {type(tree).__name__}",
|
||||
)
|
||||
|
||||
for node in tree.walk():
|
||||
if isinstance(node, _READ_WRITE_NODE_TYPES):
|
||||
return ValidateReadOnlySqlResponse(
|
||||
ok=False,
|
||||
error=f"SQL contains read/write operation: {type(node).__name__}",
|
||||
)
|
||||
|
||||
return ValidateReadOnlySqlResponse(ok=True, error=None)
|
||||
|
||||
|
||||
def _worker_count(request: AnalyzeSqlBatchRequest) -> int:
|
||||
if len(request.items) <= 1:
|
||||
return 1
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue