diff --git a/packages/cli/src/connectors/snowflake/connector.test.ts b/packages/cli/src/connectors/snowflake/connector.test.ts index 7b2e600f..7e17117e 100644 --- a/packages/cli/src/connectors/snowflake/connector.test.ts +++ b/packages/cli/src/connectors/snowflake/connector.test.ts @@ -215,6 +215,75 @@ describe('KtxSnowflakeScanConnector', () => { expect(driver.cleanup).toHaveBeenCalledTimes(1); }); + it('lists tables across schemas with one information schema query', async () => { + const queries: Array<{ sql: string; params?: unknown }> = []; + const driverFactory: KtxSnowflakeDriverFactory = { + createDriver: vi.fn(() => ({ + test: vi.fn(async () => ({ success: true })), + query: vi.fn(async (sql: string, params?: unknown) => { + queries.push({ sql, params }); + return { + headers: ['TABLE_SCHEMA', 'TABLE_NAME', 'TABLE_TYPE'], + rows: [ + ['MART', 'ORDERS', 'BASE TABLE'], + ['PUBLIC', 'ORDER_SUMMARY', 'VIEW'], + ], + totalRows: 2, + rowCount: 2, + }; + }), + getSchemaMetadata: vi.fn(async () => []), + listSchemas: vi.fn(async () => []), + listTables: vi.fn(async () => []), + cleanup: vi.fn(async () => undefined), + })), + }; + const connector = new KtxSnowflakeScanConnector({ + connectionId: 'warehouse', + connection: { + driver: 'snowflake', + authMethod: 'password', + account: 'acct', + warehouse: 'WH', + database: 'ANALYTICS', + schema_name: 'PUBLIC', + username: 'reader', + password: 'fixture-pass', // pragma: allowlist secret + }, + driverFactory, + }); + + await expect(connector.listTables(['MART', 'PUBLIC'])).resolves.toEqual([ + { schema: 'MART', name: 'ORDERS', kind: 'table' }, + { schema: 'PUBLIC', name: 'ORDER_SUMMARY', kind: 'view' }, + ]); + + expect(queries).toHaveLength(1); + expect(queries[0]?.sql).toContain('FROM "ANALYTICS".INFORMATION_SCHEMA.TABLES'); + expect(queries[0]?.sql).toContain('AND TABLE_SCHEMA IN (?, ?)'); + expect(queries[0]?.params).toEqual(['ANALYTICS', 'MART', 'PUBLIC']); + }); + + it('rejects unsafe Snowflake identifiers before driver creation', () => { + expect( + () => + new KtxSnowflakeScanConnector({ + connectionId: 'warehouse', + connection: { + driver: 'snowflake', + authMethod: 'password', + account: 'acct', + warehouse: 'WH;DROP', + database: 'ANALYTICS', + schema_name: 'PUBLIC', + username: 'reader', + password: 'fixture-pass', // pragma: allowlist secret + }, + driverFactory: fakeDriverFactory(), + }), + ).toThrow('Invalid Snowflake warehouse identifier "WH;DROP"'); + }); + it('converts a native snapshot into a live-database introspection snapshot', async () => { const introspection = createSnowflakeLiveDatabaseIntrospection({ connections: { diff --git a/packages/cli/src/connectors/snowflake/connector.ts b/packages/cli/src/connectors/snowflake/connector.ts index 5b0dfcca..41263d4b 100644 --- a/packages/cli/src/connectors/snowflake/connector.ts +++ b/packages/cli/src/connectors/snowflake/connector.ts @@ -6,6 +6,7 @@ import { assertReadOnlySql, limitSqlForExecution } from '../../context/connectio import { createKtxConnectorCapabilities, type KtxColumnSampleInput, type KtxColumnSampleResult, type KtxColumnStatsInput, type KtxColumnStatsResult, type KtxQueryResult, type KtxReadOnlyQueryInput, type KtxScanConnector, type KtxScanContext, type KtxScanInput, type KtxSchemaColumn, type KtxSchemaSnapshot, type KtxSchemaTable, type KtxTableRef, type KtxTableSampleInput, type KtxTableListEntry, type KtxTableSampleResult } from '../../context/scan/types.js'; import * as snowflake from 'snowflake-sdk'; import { KtxSnowflakeDialect } from './dialect.js'; +import { assertSafeSnowflakeIdentifier, quoteSnowflakeIdentifier } from './identifiers.js'; export interface KtxSnowflakeConnectionConfig { driver?: string; @@ -206,16 +207,23 @@ export function snowflakeConnectionConfigFromConfig(input: { if (!username) { throw new Error(`Native Snowflake connector requires connections.${input.connectionId}.username`); } + assertSafeSnowflakeIdentifier(warehouse, 'warehouse'); + assertSafeSnowflakeIdentifier(database, 'database'); + const resolvedSchemas = schemaNames(input.connection!, env); + for (const schema of resolvedSchemas) { + assertSafeSnowflakeIdentifier(schema, 'schema'); + } const resolved: KtxSnowflakeResolvedConnectionConfig = { authMethod, account, warehouse, database, - schemas: schemaNames(input.connection!, env), + schemas: resolvedSchemas, username, }; const role = stringConfigValue(input.connection, 'role', env); if (role) { + assertSafeSnowflakeIdentifier(role, 'role'); resolved.role = role; } if (authMethod === 'rsa') { @@ -322,33 +330,30 @@ class SnowflakeSdkDriver implements KtxSnowflakeDriver { } async listSchemas(): Promise { - const result = await this.query(`SHOW SCHEMAS IN DATABASE "${this.resolved.database}"`); + const result = await this.query( + `SHOW SCHEMAS IN DATABASE ${quoteSnowflakeIdentifier(this.resolved.database, 'database')}`, + ); return result.rows.map((row) => String(row[1])).filter((name) => name !== 'INFORMATION_SCHEMA'); } async listTables(schemas?: string[]): Promise { - const filterSchemas = schemas ?? (await this.listSchemas()); - if (filterSchemas.length === 0) return []; - const entries: KtxTableListEntry[] = []; - for (const schemaName of filterSchemas) { - const result = await this.query( - ` - SELECT TABLE_NAME, TABLE_TYPE - FROM INFORMATION_SCHEMA.TABLES - WHERE TABLE_SCHEMA = ? AND TABLE_CATALOG = ? - ORDER BY TABLE_NAME - `, - [schemaName, this.resolved.database], - ); - for (const row of result.rows) { - entries.push({ - schema: schemaName, - name: String(row[0]), - kind: String(row[1]) === 'VIEW' ? 'view' : 'table', - }); - } - } - return entries; + const filters = schemas && schemas.length > 0 ? schemas.map(() => '?').join(', ') : null; + const result = await this.query( + ` + SELECT TABLE_SCHEMA, TABLE_NAME, TABLE_TYPE + FROM ${quoteSnowflakeIdentifier(this.resolved.database, 'database')}.INFORMATION_SCHEMA.TABLES + WHERE TABLE_CATALOG = ? + AND TABLE_SCHEMA <> 'INFORMATION_SCHEMA' + ${filters ? `AND TABLE_SCHEMA IN (${filters})` : ''} + ORDER BY TABLE_SCHEMA, TABLE_NAME + `, + [this.resolved.database, ...(schemas ?? [])], + ); + return result.rows.map((row) => ({ + schema: String(row[0]), + name: String(row[1]), + kind: String(row[2]) === 'VIEW' ? ('view' as const) : ('table' as const), + })); } async cleanup(): Promise { @@ -414,11 +419,20 @@ class SnowflakeSdkDriver implements KtxSnowflakeDriver { private async setConnectionContext(connection: snowflake.Connection): Promise { if (this.resolved.role) { - await this.executeSnowflakeQuery(connection, `USE ROLE "${this.resolved.role}"`); + await this.executeSnowflakeQuery(connection, `USE ROLE ${quoteSnowflakeIdentifier(this.resolved.role, 'role')}`); } - await this.executeSnowflakeQuery(connection, `USE WAREHOUSE "${this.resolved.warehouse}"`); - await this.executeSnowflakeQuery(connection, `USE DATABASE "${this.resolved.database}"`); - await this.executeSnowflakeQuery(connection, `USE SCHEMA "${this.resolved.schemas[0] ?? 'PUBLIC'}"`); + await this.executeSnowflakeQuery( + connection, + `USE WAREHOUSE ${quoteSnowflakeIdentifier(this.resolved.warehouse, 'warehouse')}`, + ); + await this.executeSnowflakeQuery( + connection, + `USE DATABASE ${quoteSnowflakeIdentifier(this.resolved.database, 'database')}`, + ); + await this.executeSnowflakeQuery( + connection, + `USE SCHEMA ${quoteSnowflakeIdentifier(this.resolved.schemas[0] ?? 'PUBLIC', 'schema')}`, + ); } private async executeSnowflakeQuery( @@ -601,8 +615,24 @@ export class KtxSnowflakeScanConnector implements KtxScanConnector { return this.getDriver().listSchemas(); } - listTables(schemas?: string[]): Promise { - return this.getDriver().listTables(schemas); + async listTables(schemas?: string[]): Promise { + const filters = schemas && schemas.length > 0 ? schemas.map(() => '?').join(', ') : null; + const result = await this.getDriver().query( + ` + SELECT TABLE_SCHEMA, TABLE_NAME, TABLE_TYPE + FROM ${quoteSnowflakeIdentifier(this.resolved.database, 'database')}.INFORMATION_SCHEMA.TABLES + WHERE TABLE_CATALOG = ? + AND TABLE_SCHEMA <> 'INFORMATION_SCHEMA' + ${filters ? `AND TABLE_SCHEMA IN (${filters})` : ''} + ORDER BY TABLE_SCHEMA, TABLE_NAME + `, + [this.resolved.database, ...(schemas ?? [])], + ); + return result.rows.map((row) => ({ + schema: String(row[0]), + name: String(row[1]), + kind: String(row[2]) === 'VIEW' ? ('view' as const) : ('table' as const), + })); } async cleanup(): Promise { diff --git a/packages/cli/src/connectors/snowflake/identifiers.test.ts b/packages/cli/src/connectors/snowflake/identifiers.test.ts new file mode 100644 index 00000000..d2c3e448 --- /dev/null +++ b/packages/cli/src/connectors/snowflake/identifiers.test.ts @@ -0,0 +1,18 @@ +import { describe, expect, it } from 'vitest'; +import { assertSafeSnowflakeIdentifier, quoteSnowflakeIdentifier } from './identifiers.js'; + +describe('Snowflake identifier guards', () => { + it('quotes simple Snowflake identifiers', () => { + expect(quoteSnowflakeIdentifier('ANALYTICS_DB', 'database')).toBe('"ANALYTICS_DB"'); + expect(quoteSnowflakeIdentifier('ROLE_1$', 'role')).toBe('"ROLE_1$"'); + }); + + it('rejects configured identifiers with field and value in the error', () => { + expect(() => assertSafeSnowflakeIdentifier('bad.db', 'database')).toThrow( + 'Invalid Snowflake database identifier "bad.db"; use a simple unquoted identifier matching /^[A-Za-z_][A-Za-z0-9_$]*$/', + ); + expect(() => assertSafeSnowflakeIdentifier('WH"DROP', 'warehouse')).toThrow( + 'Invalid Snowflake warehouse identifier "WH\\"DROP"; use a simple unquoted identifier matching /^[A-Za-z_][A-Za-z0-9_$]*$/', + ); + }); +}); diff --git a/packages/cli/src/connectors/snowflake/identifiers.ts b/packages/cli/src/connectors/snowflake/identifiers.ts new file mode 100644 index 00000000..2d1c7106 --- /dev/null +++ b/packages/cli/src/connectors/snowflake/identifiers.ts @@ -0,0 +1,14 @@ +const SNOWFLAKE_SIMPLE_IDENTIFIER = /^[A-Za-z_][A-Za-z0-9_$]*$/; + +export function assertSafeSnowflakeIdentifier(value: string, field: string): string { + if (!SNOWFLAKE_SIMPLE_IDENTIFIER.test(value)) { + throw new Error( + `Invalid Snowflake ${field} identifier ${JSON.stringify(value)}; use a simple unquoted identifier matching ${SNOWFLAKE_SIMPLE_IDENTIFIER}`, + ); + } + return value; +} + +export function quoteSnowflakeIdentifier(value: string, field: string): string { + return `"${assertSafeSnowflakeIdentifier(value, field)}"`; +}