diff --git a/packages/cli/src/connectors/sqlserver/connector.ts b/packages/cli/src/connectors/sqlserver/connector.ts index 116fdea7..b70d6a8c 100644 --- a/packages/cli/src/connectors/sqlserver/connector.ts +++ b/packages/cli/src/connectors/sqlserver/connector.ts @@ -1,4 +1,4 @@ -import { assertReadOnlySql, stripTrailingSqlNoise } from '../../context/connections/read-only-sql.js'; +import { assertReadOnlySql, hoistLeadingCte, stripTrailingSqlNoise } from '../../context/connections/read-only-sql.js'; import { getDialectForDriver } from '../../context/connections/dialects.js'; import { tryConstraintQuery } from '../../context/scan/constraint-discovery.js'; import { scopedTableNames } from '../../context/scan/table-ref.js'; @@ -277,7 +277,8 @@ function limitSqlForSqlServerExecution(sqlText: string, maxRows: number | undefi if (!Number.isInteger(maxRows) || maxRows <= 0) { throw new Error('maxRows must be a positive integer.'); } - return `SELECT TOP ${maxRows} * FROM (${trimmed}) AS ktx_query_result`; + const { withPrefix, body } = hoistLeadingCte(trimmed); + return `${withPrefix}SELECT TOP ${maxRows} * FROM (${body}) AS ktx_query_result`; } export function isKtxSqlServerConnectionConfig( diff --git a/packages/cli/src/context/connections/read-only-sql.ts b/packages/cli/src/context/connections/read-only-sql.ts index 1bde80b1..d9b5c6df 100644 --- a/packages/cli/src/context/connections/read-only-sql.ts +++ b/packages/cli/src/context/connections/read-only-sql.ts @@ -97,6 +97,161 @@ export function assertReadOnlySql(sql: string): string { return trimmed; } +function isSqlIdentifierPart(char: string | undefined): boolean { + return char !== undefined && /[A-Za-z0-9_$]/.test(char); +} + +function keywordAt(sql: string, index: number, keyword: string): boolean { + if (sql.slice(index, index + keyword.length).toLowerCase() !== keyword.toLowerCase()) { + return false; + } + return !isSqlIdentifierPart(sql[index - 1]) && !isSqlIdentifierPart(sql[index + keyword.length]); +} + +function skipWhitespaceAndComments(sql: string, index: number): number { + let current = index; + while (current < sql.length) { + while (/\s/.test(sql[current] ?? '')) { + current += 1; + } + if (sql.startsWith('--', current) || sql.startsWith('/*', current)) { + current = skipQuotedOrComment(sql, current); + continue; + } + return current; + } + return current; +} + +function skipBracketIdentifier(sql: string, index: number): number { + let current = index + 1; + while (current < sql.length) { + if (sql[current] === ']') { + if (sql[current + 1] === ']') { + current += 2; + continue; + } + return current + 1; + } + current += 1; + } + return -1; +} + +function skipBacktickIdentifier(sql: string, index: number): number { + let current = index + 1; + while (current < sql.length) { + if (sql[current] === '`') { + if (sql[current + 1] === '`') { + current += 2; + continue; + } + return current + 1; + } + current += 1; + } + return -1; +} + +function skipIdentifier(sql: string, index: number): number { + if (sql[index] === '"') { + const skipped = skipQuotedOrComment(sql, index); + return skipped > index ? skipped : -1; + } + if (sql[index] === '[') { + return skipBracketIdentifier(sql, index); + } + if (sql[index] === '`') { + return skipBacktickIdentifier(sql, index); + } + let current = index; + while (isSqlIdentifierPart(sql[current])) { + current += 1; + } + return current > index ? current : -1; +} + +function skipBalancedParentheses(sql: string, index: number): number { + if (sql[index] !== '(') { + return -1; + } + + let current = index; + let depth = 0; + while (current < sql.length) { + const skipped = skipQuotedOrComment(sql, current); + if (skipped > current) { + current = skipped; + continue; + } + + if (sql[current] === '(') { + depth += 1; + } else if (sql[current] === ')') { + depth -= 1; + if (depth === 0) { + return current + 1; + } + } + current += 1; + } + + return -1; +} + +/** @internal */ +export function hoistLeadingCte(sql: string): { withPrefix: string; body: string } { + const trimmed = sql.trim(); + if (!keywordAt(trimmed, 0, 'with')) { + return { withPrefix: '', body: sql }; + } + + let current = skipWhitespaceAndComments(trimmed, 4); + if (keywordAt(trimmed, current, 'recursive')) { + current = skipWhitespaceAndComments(trimmed, current + 'recursive'.length); + } + + while (current < trimmed.length) { + current = skipIdentifier(trimmed, current); + if (current < 0) { + return { withPrefix: '', body: trimmed }; + } + + current = skipWhitespaceAndComments(trimmed, current); + if (trimmed[current] === '(') { + current = skipBalancedParentheses(trimmed, current); + if (current < 0) { + return { withPrefix: '', body: trimmed }; + } + current = skipWhitespaceAndComments(trimmed, current); + } + + if (!keywordAt(trimmed, current, 'as')) { + return { withPrefix: '', body: trimmed }; + } + + current = skipWhitespaceAndComments(trimmed, current + 2); + current = skipBalancedParentheses(trimmed, current); + if (current < 0) { + return { withPrefix: '', body: trimmed }; + } + + current = skipWhitespaceAndComments(trimmed, current); + if (trimmed[current] === ',') { + current = skipWhitespaceAndComments(trimmed, current + 1); + continue; + } + + const body = trimmed.slice(current).trimStart(); + if (!body) { + return { withPrefix: '', body: trimmed }; + } + return { withPrefix: `${trimmed.slice(0, current).trimEnd()} `, body }; + } + + return { withPrefix: '', body: trimmed }; +} + // `assertReadOnlySql` deliberately keeps trailing semicolons, comments, and // whitespace (e.g. `select 1; -- done`) — harmless for direct single-statement // execution. A row-limit subquery wrapper needs a bare expression instead: a @@ -137,5 +292,6 @@ export function limitSqlForExecution(sql: string, maxRows: number | undefined): if (!Number.isInteger(maxRows) || maxRows <= 0) { throw new KtxQueryError('maxRows must be a positive integer.'); } - return `select * from (${trimmed}) as ktx_query_result limit ${maxRows}`; + const { withPrefix, body } = hoistLeadingCte(trimmed); + return `${withPrefix}select * from (${body}) as ktx_query_result limit ${maxRows}`; } diff --git a/packages/cli/test/connectors/sqlserver/connector.test.ts b/packages/cli/test/connectors/sqlserver/connector.test.ts index b7318ab5..25184a5a 100644 --- a/packages/cli/test/connectors/sqlserver/connector.test.ts +++ b/packages/cli/test/connectors/sqlserver/connector.test.ts @@ -404,6 +404,50 @@ describe('KtxSqlServerScanConnector', () => { await connector.cleanup(); }); + it('hoists leading CTEs before applying the SQL Server TOP wrapper', async () => { + const queries: string[] = []; + const request = { + input: vi.fn((_name: string, _value: unknown) => request), + query: vi.fn(async (sql: string): Promise => { + queries.push(sql); + return result([{ value: 1 }], ['value']); + }), + }; + const poolFactory: KtxSqlServerPoolFactory = { + createPool: vi.fn(async () => ({ + request: () => request, + close: vi.fn(async () => undefined), + })), + }; + const connector = new KtxSqlServerScanConnector({ + connectionId: 'warehouse', + connection: { + driver: 'sqlserver', + host: 'db.example.test', + database: 'analytics', + username: 'reader', + schema: 'dbo', + }, + poolFactory, + }); + + await expect( + connector.executeReadOnly( + { + connectionId: 'warehouse', + sql: 'WITH child_values AS (SELECT 1 AS value) SELECT value FROM child_values', + maxRows: 1, + }, + { runId: 'scan-run-sqlserver-cte-limit' }, + ), + ).resolves.toMatchObject({ headers: ['value'], rows: [[1]], rowCount: 1 }); + + expect(queries).toEqual([ + 'WITH child_values AS (SELECT 1 AS value) SELECT TOP 1 * FROM (SELECT value FROM child_values) AS ktx_query_result', + ]); + expect(queries[0]).not.toContain('FROM (WITH'); + }); + it('limits introspection to tables in tableScope', async () => { const queries: string[] = []; const inputs: Array<{ name: string; value: unknown }> = []; diff --git a/packages/cli/test/context/connections/read-only-sql.test.ts b/packages/cli/test/context/connections/read-only-sql.test.ts index 8a090c28..51eff123 100644 --- a/packages/cli/test/context/connections/read-only-sql.test.ts +++ b/packages/cli/test/context/connections/read-only-sql.test.ts @@ -2,6 +2,7 @@ import { describe, expect, it } from 'vitest'; import { KtxExpectedError, KtxQueryError } from '../../../src/errors.js'; import { assertReadOnlySql, + hoistLeadingCte, limitSqlForExecution, stripTrailingSqlNoise, } from '../../../src/context/connections/read-only-sql.js'; @@ -81,6 +82,60 @@ describe('assertReadOnlySql', () => { }); }); +describe('hoistLeadingCte', () => { + it('leaves non-CTE SQL untouched', () => { + expect(hoistLeadingCte('select * from orders')).toEqual({ + withPrefix: '', + body: 'select * from orders', + }); + }); + + it('splits a single leading CTE from the main query', () => { + expect(hoistLeadingCte('WITH paid AS (SELECT * FROM orders) SELECT * FROM paid')).toEqual({ + withPrefix: 'WITH paid AS (SELECT * FROM orders) ', + body: 'SELECT * FROM paid', + }); + }); + + it('splits multiple CTEs, recursive CTEs, column lists, and UNION bodies', () => { + expect( + hoistLeadingCte( + 'WITH RECURSIVE nodes(id, parent_id) AS (SELECT id, parent_id FROM roots UNION ALL SELECT child.id, child.parent_id FROM child JOIN nodes ON child.parent_id = nodes.id), totals AS (SELECT count(*) AS total FROM nodes) SELECT total FROM totals', + ), + ).toEqual({ + withPrefix: + 'WITH RECURSIVE nodes(id, parent_id) AS (SELECT id, parent_id FROM roots UNION ALL SELECT child.id, child.parent_id FROM child JOIN nodes ON child.parent_id = nodes.id), totals AS (SELECT count(*) AS total FROM nodes) ', + body: 'SELECT total FROM totals', + }); + }); + + it('ignores WITH, commas, and closing parens inside literals, identifiers, and comments', () => { + expect( + hoistLeadingCte( + `WITH tricky AS ( + SELECT 'WITH x AS (SELECT 1), nope' AS "value, )" + FROM orders + WHERE note = ')' + /* comment ), next AS (SELECT 1) */ + ) SELECT * FROM tricky`, + ), + ).toEqual({ + withPrefix: `WITH tricky AS ( + SELECT 'WITH x AS (SELECT 1), nope' AS "value, )" + FROM orders + WHERE note = ')' + /* comment ), next AS (SELECT 1) */ + ) `, + body: 'SELECT * FROM tricky', + }); + }); + + it('falls back to the legacy whole-query body when the CTE is malformed', () => { + const malformed = 'WITH broken AS (SELECT * FROM orders SELECT * FROM broken'; + expect(hoistLeadingCte(malformed)).toEqual({ withPrefix: '', body: malformed }); + }); +}); + describe('limitSqlForExecution', () => { it('wraps compiled SQL and strips trailing semicolons', () => { expect(limitSqlForExecution('select * from public.orders; ', 25)).toBe( @@ -98,6 +153,18 @@ describe('limitSqlForExecution', () => { ); }); + it('hoists leading CTEs before applying the generic LIMIT wrapper', () => { + expect(limitSqlForExecution('WITH paid AS (SELECT * FROM orders) SELECT * FROM paid', 25)).toBe( + 'WITH paid AS (SELECT * FROM orders) select * from (SELECT * FROM paid) as ktx_query_result limit 25', + ); + }); + + it('keeps the generic wrapper byte-identical for non-CTE SQL', () => { + expect(limitSqlForExecution('select id, status from public.orders', 25)).toBe( + 'select * from (select id, status from public.orders) as ktx_query_result limit 25', + ); + }); + it('drops a trailing semicolon followed by a comment so the subquery stays valid', () => { // The single-statement gate accepts `select 1; -- done`; without stripping // the terminator the wrapper would embed `select 1; -- done` and comment out diff --git a/packages/cli/test/context/scan/relationship-validation.test.ts b/packages/cli/test/context/scan/relationship-validation.test.ts index cd7771d9..7e876421 100644 --- a/packages/cli/test/context/scan/relationship-validation.test.ts +++ b/packages/cli/test/context/scan/relationship-validation.test.ts @@ -8,6 +8,8 @@ import { profileKtxRelationshipSchema } from '../../../src/context/scan/relation import { validateKtxRelationshipDiscoveryCandidates } from '../../../src/context/scan/relationship-validation.js'; import type { KtxQueryResult, KtxReadOnlyQueryInput, KtxScanContext } from '../../../src/context/scan/types.js'; +// This harness runs SQL directly through SQLite; row-limit wrapper coverage lives +// in read-only-sql.test.ts and the SQL Server connector test. class InMemorySqliteExecutor { readonly db = new Database(':memory:'); queryCount = 0; diff --git a/uv.lock b/uv.lock index e66330e3..e831ec69 100644 --- a/uv.lock +++ b/uv.lock @@ -466,7 +466,7 @@ wheels = [ [[package]] name = "ktx-daemon" -version = "0.12.0" +version = "0.13.0" source = { editable = "python/ktx-daemon" } dependencies = [ { name = "fastapi" }, @@ -523,7 +523,7 @@ dev = [ [[package]] name = "ktx-sl" -version = "0.12.0" +version = "0.13.0" source = { editable = "python/ktx-sl" } dependencies = [ { name = "pydantic" },