From e56eabb22b19b08293d8f24cc31e2acbe1fa7c80 Mon Sep 17 00:00:00 2001 From: Andrey Avtomonov Date: Fri, 22 May 2026 16:53:31 +0200 Subject: [PATCH] feat(scan): pool Snowflake sessions --- .../connectors/snowflake/connector.test.ts | 132 +++++++++++++++++ .../cli/src/connectors/snowflake/connector.ts | 133 +++++++++--------- 2 files changed, 198 insertions(+), 67 deletions(-) diff --git a/packages/cli/src/connectors/snowflake/connector.test.ts b/packages/cli/src/connectors/snowflake/connector.test.ts index 7e17117e..fed02a75 100644 --- a/packages/cli/src/connectors/snowflake/connector.test.ts +++ b/packages/cli/src/connectors/snowflake/connector.test.ts @@ -1,4 +1,11 @@ import { describe, expect, it, vi } from 'vitest'; + +const createPool = vi.hoisted(() => vi.fn()); + +vi.mock('snowflake-sdk', () => ({ + createPool, +})); + import { createSnowflakeLiveDatabaseIntrospection } from '../../connectors/snowflake/live-database-introspection.js'; import { isKtxSnowflakeConnectionConfig, KtxSnowflakeScanConnector, snowflakeConnectionConfigFromConfig, type KtxSnowflakeDriver, type KtxSnowflakeDriverFactory } from '../../connectors/snowflake/connector.js'; @@ -63,6 +70,38 @@ function fakeDriverFactory(): KtxSnowflakeDriverFactory { return { createDriver: vi.fn(() => driver) }; } +function fakeSnowflakeStatement(headers: string[] = ['ONE']) { + return { + getColumns: () => headers.map((header) => ({ getName: () => header, getType: () => 'TEXT' })), + }; +} + +function installSnowflakePoolMock() { + const executedSql: string[] = []; + const connection = { + execute: vi.fn( + (input: { + sqlText: string; + complete: ( + error: Error | null, + statement: ReturnType, + rows: Array>, + ) => void; + }) => { + executedSql.push(input.sqlText); + input.complete(null, fakeSnowflakeStatement(), [{ ONE: 1 }]); + }, + ), + }; + const pool = { + use: vi.fn(async (fn: (conn: typeof connection) => Promise) => fn(connection)), + drain: vi.fn(async () => undefined), + clear: vi.fn(async () => undefined), + }; + createPool.mockReturnValue(pool); + return { connection, pool, executedSql }; +} + describe('KtxSnowflakeScanConnector', () => { it('resolves Snowflake connection configuration safely', () => { expect( @@ -99,6 +138,99 @@ describe('KtxSnowflakeScanConnector', () => { }); }); + it('defaults and validates Snowflake maxSessions', () => { + const baseConnection = { + driver: 'snowflake', + authMethod: 'password', + account: 'acct', + warehouse: 'WH', + database: 'ANALYTICS', + schema_name: 'PUBLIC', + username: 'reader', + password: 'fixture-pass', // pragma: allowlist secret + } as const; + + expect( + snowflakeConnectionConfigFromConfig({ + connectionId: 'warehouse', + connection: baseConnection, + }), + ).toMatchObject({ maxSessions: 4 }); + + expect( + snowflakeConnectionConfigFromConfig({ + connectionId: 'warehouse', + connection: { ...baseConnection, maxSessions: 8 }, + }), + ).toMatchObject({ maxSessions: 8 }); + + for (const maxSessions of [0, -1, 1.5, Number.NaN]) { + expect(() => + snowflakeConnectionConfigFromConfig({ + connectionId: 'warehouse', + connection: { ...baseConnection, maxSessions }, + }), + ).toThrow('connections.warehouse.maxSessions must be a positive integer'); + } + }); + + it('uses one lazy Snowflake pool and drains it during cleanup', async () => { + const { pool, executedSql } = installSnowflakePoolMock(); + const close = vi.fn(async () => undefined); + const connector = new KtxSnowflakeScanConnector({ + connectionId: 'warehouse', + connection: { + driver: 'snowflake', + authMethod: 'password', + account: 'acct', + warehouse: 'WH', + database: 'ANALYTICS', + schema_name: 'PUBLIC', + username: 'reader', + password: 'fixture-pass', // pragma: allowlist secret + role: 'ANALYST', + maxSessions: 3, + }, + sdkOptionsProvider: { + resolve: vi.fn(async () => ({ sdkOptions: { application: 'ktx-test' }, close })), + }, + }); + + expect(createPool).not.toHaveBeenCalled(); + + await connector.executeReadOnly({ connectionId: 'warehouse', sql: 'select 1', maxRows: 1 }, { runId: 'run-1' }); + await connector.executeReadOnly({ connectionId: 'warehouse', sql: 'select 1', maxRows: 1 }, { runId: 'run-1' }); + + expect(createPool).toHaveBeenCalledTimes(1); + expect(createPool).toHaveBeenCalledWith( + expect.objectContaining({ + account: 'acct', + username: 'reader', + warehouse: 'WH', + database: 'ANALYTICS', + schema: 'PUBLIC', + role: 'ANALYST', + password: 'fixture-pass', // pragma: allowlist secret + clientSessionKeepAlive: true, + clientSessionKeepAliveHeartbeatFrequency: 900, + application: 'ktx-test', + }), + expect.objectContaining({ + min: 0, + max: 3, + evictionRunIntervalMillis: 30_000, + acquireTimeoutMillis: 60_000, + }), + ); + expect(pool.use).toHaveBeenCalledTimes(2); + expect(executedSql.some((sql) => /^USE\s+/i.test(sql.trim()))).toBe(false); + + await connector.cleanup(); + expect(pool.drain).toHaveBeenCalledBefore(pool.clear); + expect(pool.clear).toHaveBeenCalledTimes(1); + expect(close).toHaveBeenCalledTimes(1); + }); + it('introspects schema, primary keys, comments, row counts, and dimensions', async () => { const connector = new KtxSnowflakeScanConnector({ connectionId: 'warehouse', diff --git a/packages/cli/src/connectors/snowflake/connector.ts b/packages/cli/src/connectors/snowflake/connector.ts index 41263d4b..81455af6 100644 --- a/packages/cli/src/connectors/snowflake/connector.ts +++ b/packages/cli/src/connectors/snowflake/connector.ts @@ -21,6 +21,7 @@ export interface KtxSnowflakeConnectionConfig { privateKey?: string; passphrase?: string; role?: string; + maxSessions?: number; [key: string]: unknown; } @@ -35,6 +36,7 @@ export interface KtxSnowflakeResolvedConnectionConfig { privateKey?: string; passphrase?: string; role?: string; + maxSessions: number; } export interface KtxSnowflakeRawColumnMetadata { @@ -123,6 +125,23 @@ function stringConfigValue( return typeof value === 'string' && value.trim().length > 0 ? resolveStringReference(value.trim(), env) : undefined; } +function positiveIntegerConfigValue(input: { + connection: KtxSnowflakeConnectionConfig; + key: keyof KtxSnowflakeConnectionConfig; + connectionId: string; + defaultValue: number; +}): number { + const value = input.connection[input.key]; + if (value === undefined) { + return input.defaultValue; + } + const numberValue = Number(value); + if (!Number.isInteger(numberValue) || numberValue < 1) { + throw new Error(`connections.${input.connectionId}.${String(input.key)} must be a positive integer`); + } + return numberValue; +} + function schemaNames(connection: KtxSnowflakeConnectionConfig, env: NodeJS.ProcessEnv): string[] { if (Array.isArray(connection.schema_names) && connection.schema_names.length > 0) { return connection.schema_names @@ -220,6 +239,12 @@ export function snowflakeConnectionConfigFromConfig(input: { database, schemas: resolvedSchemas, username, + maxSessions: positiveIntegerConfigValue({ + connection: input.connection, + key: 'maxSessions', + connectionId: input.connectionId, + defaultValue: 4, + }), }; const role = stringConfigValue(input.connection, 'role', env); if (role) { @@ -255,6 +280,7 @@ class DefaultSnowflakeDriverFactory implements KtxSnowflakeDriverFactory { class SnowflakeSdkDriver implements KtxSnowflakeDriver { private closeSdkOptions: Array<() => Promise> = []; + private pool: ReturnType | null = null; constructor( private readonly resolved: KtxSnowflakeResolvedConnectionConfig, @@ -275,16 +301,21 @@ class SnowflakeSdkDriver implements KtxSnowflakeDriver { } async query(sql: string, params?: unknown): Promise { - let connection: snowflake.Connection | null = null; + const binds = Array.isArray(params) ? toSnowflakeBinds(params) : undefined; try { - connection = await this.createConnection(); - const binds = Array.isArray(params) ? toSnowflakeBinds(params) : undefined; - const result = await this.executeSnowflakeQuery(connection, sql, binds); + const pool = await this.getPool(); + const result = await pool.use(async (connection: snowflake.Connection) => + this.executeSnowflakeQuery(connection, sql, binds), + ); return { ...result, totalRows: result.rows.length, rowCount: result.rows.length }; - } finally { - if (connection) { - await this.destroyConnection(connection); + } catch (error) { + const message = error instanceof Error ? error.message : String(error); + if (/timeout/i.test(message) && /pool|acquire/i.test(message)) { + throw new Error( + "Snowflake session pool exhausted after 60s - consider lowering maxSessions or increasing your account's concurrent-statement limit.", + ); } + throw error; } } @@ -357,27 +388,41 @@ class SnowflakeSdkDriver implements KtxSnowflakeDriver { } async cleanup(): Promise { + const pool = this.pool; + this.pool = null; + if (pool) { + // Drain before clear so in-flight Snowflake statements finish before idle + // sessions are closed. + await pool.drain(); + await pool.clear(); + } const closers = this.closeSdkOptions; this.closeSdkOptions = []; - await Promise.all(closers.map((close) => close())); + await Promise.all(closers.map((close) => Promise.resolve(close()))); } private async runTest(): Promise<{ success: boolean; error?: string }> { - let connection: snowflake.Connection | null = null; try { - connection = await this.createConnection(); - await this.executeSnowflakeQuery(connection, 'SELECT 1'); + await this.query('SELECT 1'); return { success: true }; } catch (error) { return { success: false, error: error instanceof Error ? error.message : String(error) }; - } finally { - if (connection) { - await this.destroyConnection(connection); - } } } - private async createConnection(): Promise { + private async getPool(): Promise> { + if (!this.pool) { + this.pool = snowflake.createPool(await this.resolveConnectionOptions(), { + min: 0, + max: this.resolved.maxSessions, + evictionRunIntervalMillis: 30_000, + acquireTimeoutMillis: 60_000, + }); + } + return this.pool; + } + + private async resolveConnectionOptions(): Promise { const patch = await this.sdkOptionsProvider?.resolve({ account: this.resolved.account, connection: { ...this.resolved, driver: 'snowflake' }, @@ -392,47 +437,13 @@ class SnowflakeSdkDriver implements KtxSnowflakeDriver { database: this.resolved.database, schema: this.resolved.schemas[0] ?? 'PUBLIC', role: this.resolved.role, + clientSessionKeepAlive: true, + clientSessionKeepAliveHeartbeatFrequency: 900, ...patch?.sdkOptions, }; - const connectionConfig: snowflake.ConnectionOptions = - this.resolved.authMethod === 'rsa' - ? { ...baseConfig, authenticator: 'SNOWFLAKE_JWT', privateKey: this.decryptPrivateKey() } - : { ...baseConfig, password: this.resolved.password }; - const connection = snowflake.createConnection(connectionConfig); - return new Promise((resolveConnection, rejectConnection) => { - connection.connect((error, connected) => { - if (error) { - rejectConnection(error); - return; - } - const resolvedConnection = connected ?? connection; - this.setConnectionContext(resolvedConnection).then( - () => resolveConnection(resolvedConnection), - (contextError) => { - resolvedConnection.destroy(() => undefined); - rejectConnection(contextError); - }, - ); - }); - }); - } - - private async setConnectionContext(connection: snowflake.Connection): Promise { - if (this.resolved.role) { - await this.executeSnowflakeQuery(connection, `USE ROLE ${quoteSnowflakeIdentifier(this.resolved.role, 'role')}`); - } - await this.executeSnowflakeQuery( - connection, - `USE WAREHOUSE ${quoteSnowflakeIdentifier(this.resolved.warehouse, 'warehouse')}`, - ); - await this.executeSnowflakeQuery( - connection, - `USE DATABASE ${quoteSnowflakeIdentifier(this.resolved.database, 'database')}`, - ); - await this.executeSnowflakeQuery( - connection, - `USE SCHEMA ${quoteSnowflakeIdentifier(this.resolved.schemas[0] ?? 'PUBLIC', 'schema')}`, - ); + return this.resolved.authMethod === 'rsa' + ? { ...baseConfig, authenticator: 'SNOWFLAKE_JWT', privateKey: this.decryptPrivateKey() } + : { ...baseConfig, password: this.resolved.password }; } private async executeSnowflakeQuery( @@ -461,18 +472,6 @@ class SnowflakeSdkDriver implements KtxSnowflakeDriver { }); } - private destroyConnection(connection: snowflake.Connection): Promise { - return new Promise((resolveDestroy, rejectDestroy) => { - connection.destroy((error) => { - if (error) { - rejectDestroy(error); - return; - } - resolveDestroy(); - }); - }); - } - private decryptPrivateKey(): string { if (!this.resolved.privateKey) { throw new Error('Private key is required for RSA authentication');