feat(daemon): validate read-only SQL with sqlglot

This commit is contained in:
Andrey Avtomonov 2026-05-14 17:54:36 +02:00
parent de9f4d97e7
commit aa4431b295
4 changed files with 189 additions and 0 deletions

View file

@ -34,6 +34,46 @@ class AnalyzeSqlBatchResponse(BaseModel):
results: dict[str, AnalyzeSqlBatchResult]
class ValidateReadOnlySqlRequest(BaseModel):
dialect: str
sql: str
class ValidateReadOnlySqlResponse(BaseModel):
ok: bool
error: str | None = None
_READ_ONLY_ROOT_TYPES = (exp.Select, exp.Union)
_READ_WRITE_NODE_TYPES = (
exp.Alter,
exp.Analyze,
exp.Cache,
exp.Command,
exp.Commit,
exp.Copy,
exp.Create,
exp.Delete,
exp.Describe,
exp.Drop,
exp.Execute,
exp.Grant,
exp.Insert,
exp.Merge,
exp.Pragma,
exp.Refresh,
exp.Revoke,
exp.Rollback,
exp.Set,
exp.Show,
exp.Transaction,
exp.TruncateTable,
exp.Uncache,
exp.Update,
exp.Use,
)
def _ordered_unique(values: list[str]) -> list[str]:
seen: set[str] = set()
result: list[str] = []
@ -137,6 +177,42 @@ def _analyze_payload(
return _analyze_one(item_id, sql, dialect)
def validate_read_only_sql_response(
request: ValidateReadOnlySqlRequest,
) -> ValidateReadOnlySqlResponse:
try:
statements = sqlglot.parse(request.sql, read=request.dialect)
except sqlglot.errors.SqlglotError as exc:
return ValidateReadOnlySqlResponse(ok=False, error=f"Invalid expression: {exc}")
if len(statements) != 1:
return ValidateReadOnlySqlResponse(
ok=False,
error="Only one SQL statement can be executed.",
)
tree = statements[0]
if tree is None:
return ValidateReadOnlySqlResponse(
ok=False,
error="SQL did not parse to a statement.",
)
if not isinstance(tree, _READ_ONLY_ROOT_TYPES):
return ValidateReadOnlySqlResponse(
ok=False,
error=f"SQL contains read/write operation: {type(tree).__name__}",
)
for node in tree.walk():
if isinstance(node, _READ_WRITE_NODE_TYPES):
return ValidateReadOnlySqlResponse(
ok=False,
error=f"SQL contains read/write operation: {type(node).__name__}",
)
return ValidateReadOnlySqlResponse(ok=True, error=None)
def _worker_count(request: AnalyzeSqlBatchRequest) -> int:
if len(request.items) <= 1:
return 1