diff --git a/docs-site/content/docs/integrations/primary-sources.mdx b/docs-site/content/docs/integrations/primary-sources.mdx index da916339..028a8012 100644 --- a/docs-site/content/docs/integrations/primary-sources.mdx +++ b/docs-site/content/docs/integrations/primary-sources.mdx @@ -129,20 +129,18 @@ connections: account: xy12345 warehouse: ANALYTICS_WH database: PROD - schema_name: PUBLIC + schema_names: + - PUBLIC + - SALES + - MARKETING username: KTX_SERVICE password: env:SNOWFLAKE_PASSWORD role: ANALYST ``` -For multiple schemas: - -```yaml - schema_names: - - PUBLIC - - ANALYTICS - - STAGING -``` +`ktx setup` discovers schemas after the connection is verified and writes the +selected list to `schema_names`. You can also set this field manually. For a +single schema, `schema_name: PUBLIC` is accepted as an equivalent shorthand. ### Authentication diff --git a/packages/cli/src/connectors/snowflake/connector.ts b/packages/cli/src/connectors/snowflake/connector.ts index 5b0dfcca..3909367a 100644 --- a/packages/cli/src/connectors/snowflake/connector.ts +++ b/packages/cli/src/connectors/snowflake/connector.ts @@ -4,7 +4,8 @@ import { homedir } from 'node:os'; import { resolve } from 'node:path'; import { assertReadOnlySql, limitSqlForExecution } from '../../context/connections/read-only-sql.js'; import { createKtxConnectorCapabilities, type KtxColumnSampleInput, type KtxColumnSampleResult, type KtxColumnStatsInput, type KtxColumnStatsResult, type KtxQueryResult, type KtxReadOnlyQueryInput, type KtxScanConnector, type KtxScanContext, type KtxScanInput, type KtxSchemaColumn, type KtxSchemaSnapshot, type KtxSchemaTable, type KtxTableRef, type KtxTableSampleInput, type KtxTableListEntry, type KtxTableSampleResult } from '../../context/scan/types.js'; -import * as snowflake from 'snowflake-sdk'; +import snowflake from 'snowflake-sdk'; +import type { Bind, Binds, Connection, ConnectionOptions } from 'snowflake-sdk'; import { KtxSnowflakeDialect } from './dialect.js'; export interface KtxSnowflakeConnectionConfig { @@ -128,7 +129,8 @@ function schemaNames(connection: KtxSnowflakeConnectionConfig, env: NodeJS.Proce .filter((schema) => schema.trim().length > 0) .map((schema) => resolveStringReference(schema, env)); } - return [stringConfigValue(connection, 'schema_name', env) ?? 'PUBLIC']; + const single = stringConfigValue(connection, 'schema_name', env); + return single ? [single] : []; } function firstNumber(value: unknown): number | null { @@ -158,7 +160,7 @@ function normalizeSnowflakeValue(value: unknown, columnType?: string): unknown { return value; } -function toSnowflakeBind(value: unknown): snowflake.Bind { +function toSnowflakeBind(value: unknown): Bind { if (value === null || typeof value === 'string' || typeof value === 'number' || typeof value === 'boolean') { return value; } @@ -168,7 +170,7 @@ function toSnowflakeBind(value: unknown): snowflake.Bind { return String(value); } -function toSnowflakeBinds(params: unknown[] | undefined): snowflake.Binds | undefined { +function toSnowflakeBinds(params: unknown[] | undefined): Binds | undefined { return params?.map((value) => toSnowflakeBind(value)); } @@ -267,7 +269,7 @@ class SnowflakeSdkDriver implements KtxSnowflakeDriver { } async query(sql: string, params?: unknown): Promise { - let connection: snowflake.Connection | null = null; + let connection: Connection | null = null; try { connection = await this.createConnection(); const binds = Array.isArray(params) ? toSnowflakeBinds(params) : undefined; @@ -358,7 +360,7 @@ class SnowflakeSdkDriver implements KtxSnowflakeDriver { } private async runTest(): Promise<{ success: boolean; error?: string }> { - let connection: snowflake.Connection | null = null; + let connection: Connection | null = null; try { connection = await this.createConnection(); await this.executeSnowflakeQuery(connection, 'SELECT 1'); @@ -372,7 +374,7 @@ class SnowflakeSdkDriver implements KtxSnowflakeDriver { } } - private async createConnection(): Promise { + private async createConnection(): Promise { const patch = await this.sdkOptionsProvider?.resolve({ account: this.resolved.account, connection: { ...this.resolved, driver: 'snowflake' }, @@ -380,16 +382,17 @@ class SnowflakeSdkDriver implements KtxSnowflakeDriver { if (patch?.close) { this.closeSdkOptions.push(patch.close); } - const baseConfig: snowflake.ConnectionOptions = { + const sessionSchema = this.resolved.schemas[0]; + const baseConfig: ConnectionOptions = { account: this.resolved.account, username: this.resolved.username, warehouse: this.resolved.warehouse, database: this.resolved.database, - schema: this.resolved.schemas[0] ?? 'PUBLIC', + ...(sessionSchema ? { schema: sessionSchema } : {}), role: this.resolved.role, ...patch?.sdkOptions, }; - const connectionConfig: snowflake.ConnectionOptions = + const connectionConfig: ConnectionOptions = this.resolved.authMethod === 'rsa' ? { ...baseConfig, authenticator: 'SNOWFLAKE_JWT', privateKey: this.decryptPrivateKey() } : { ...baseConfig, password: this.resolved.password }; @@ -412,7 +415,7 @@ class SnowflakeSdkDriver implements KtxSnowflakeDriver { }); } - private async setConnectionContext(connection: snowflake.Connection): Promise { + private async setConnectionContext(connection: Connection): Promise { if (this.resolved.role) { await this.executeSnowflakeQuery(connection, `USE ROLE "${this.resolved.role}"`); } @@ -422,9 +425,9 @@ class SnowflakeSdkDriver implements KtxSnowflakeDriver { } private async executeSnowflakeQuery( - connection: snowflake.Connection, + connection: Connection, sqlText: string, - binds?: snowflake.Binds, + binds?: Binds, ): Promise<{ headers: string[]; headerTypes?: string[]; rows: unknown[][] }> { return new Promise((resolveQuery, rejectQuery) => { connection.execute({ @@ -447,7 +450,7 @@ class SnowflakeSdkDriver implements KtxSnowflakeDriver { }); } - private destroyConnection(connection: snowflake.Connection): Promise { + private destroyConnection(connection: Connection): Promise { return new Promise((resolveDestroy, rejectDestroy) => { connection.destroy((error) => { if (error) { diff --git a/packages/cli/src/setup-databases.test.ts b/packages/cli/src/setup-databases.test.ts index 178ed076..2a90f428 100644 --- a/packages/cli/src/setup-databases.test.ts +++ b/packages/cli/src/setup-databases.test.ts @@ -458,7 +458,7 @@ describe('setup databases step', () => { { driver: 'snowflake', selectValues: ['no'], - textValues: ['', 'env:SNOWFLAKE_ACCOUNT', 'ANALYTICS_WH', 'ANALYTICS', '', 'env:SNOWFLAKE_USER', ''], + textValues: ['', 'env:SNOWFLAKE_ACCOUNT', 'ANALYTICS_WH', 'ANALYTICS', 'env:SNOWFLAKE_USER', ''], passwordValues: ['env:SNOWFLAKE_PASSWORD'], expectedTextPrompts: [ { @@ -475,11 +475,6 @@ describe('setup databases step', () => { { message: 'Snowflake database name', }, - { - message: 'Snowflake schema\nPress Enter for PUBLIC, or enter a schema name.', - placeholder: 'PUBLIC', - initialValue: 'PUBLIC', - }, { message: 'Snowflake username', }, @@ -514,6 +509,8 @@ describe('setup databases step', () => { prompts, testConnection: vi.fn(async () => 0), scanConnection: vi.fn(async () => 0), + listSchemas: vi.fn(async () => []), + listTables: vi.fn(async () => []), }, ); @@ -687,6 +684,8 @@ describe('setup databases step', () => { }); const testConnection = vi.fn(async () => 0); const scanConnection = vi.fn(async () => 0); + const listSchemas = vi.fn(async () => []); + const listTables = vi.fn(async () => []); const result = await runKtxSetupDatabasesStep( { @@ -697,7 +696,7 @@ describe('setup databases step', () => { disableQueryHistory: true, }, makeIo().io, - { prompts, testConnection, scanConnection }, + { prompts, testConnection, scanConnection, listSchemas, listTables }, ); expect(result).toEqual({ @@ -1521,6 +1520,65 @@ describe('setup databases step', () => { expect(io.stdout()).toContain('✓ orbit_analytics, orbit_raw'); }); + it('falls back to comma-separated free-text when listSchemas fails interactively', async () => { + const io = makeIo(); + const prompts = makePromptAdapter({ + selectValues: ['url'], + textValues: ['', 'env:DATABASE_URL', 'orbit_analytics, orbit_raw'], + }); + const testConnection = vi.fn(async () => 0); + const scanConnection = vi.fn(async () => 0); + const listSchemas = vi.fn(async () => { + throw new Error('permission denied to list schemas'); + }); + const listTables = vi.fn(async () => [ + { schema: 'orbit_analytics', name: 'events', kind: 'table' as const }, + { schema: 'orbit_raw', name: 'inputs', kind: 'table' as const }, + ]); + const pickers = makePickerStubs({ + scopes: [ + { + schemas: ['orbit_analytics', 'orbit_raw'], + tables: ['orbit_analytics.events', 'orbit_raw.inputs'], + }, + ], + }); + + const result = await runKtxSetupDatabasesStep( + { + projectDir: tempDir, + inputMode: 'auto', + databaseDrivers: ['postgres'], + databaseSchemas: [], + skipDatabases: false, + }, + io.io, + { + prompts, + testConnection, + scanConnection, + listSchemas, + listTables, + pickDatabaseScope: pickers.pickDatabaseScope, + }, + ); + + expect(result.status).toBe('ready'); + expect(io.stderr()).toContain('Could not discover postgresql schemas'); + expect(vi.mocked(prompts.text).mock.calls.map(([options]) => options.message)).toContain( + textInputPrompt( + 'Enter schemas for postgres-warehouse as a comma-separated list (e.g. SALES, MARKETING).', + ), + ); + expect(listTables).toHaveBeenCalledWith(tempDir, 'postgres-warehouse', [ + 'orbit_analytics', + 'orbit_raw', + ]); + expect(pickers.scopeCalls[0]).toMatchObject({ + defaultSchemas: ['orbit_analytics', 'orbit_raw'], + }); + }); + it('auto-selects all discovered Postgres schemas in non-interactive setup', async () => { const io = makeIo(); const prompts = makePromptAdapter({}); @@ -1827,7 +1885,7 @@ describe('setup databases step', () => { testConnection: vi.fn(async () => 0), scanConnection: vi.fn(async () => 0), prompts: makePromptAdapter({ - textValues: ['env:SNOWFLAKE_ACCOUNT', 'WH', 'ANALYTICS', 'PUBLIC', 'reader', ''], + textValues: ['env:SNOWFLAKE_ACCOUNT', 'WH', 'ANALYTICS', 'reader', ''], passwordValues: ['env:SNOWFLAKE_PASSWORD'], }), }, diff --git a/packages/cli/src/setup-databases.ts b/packages/cli/src/setup-databases.ts index c8e735c5..a0ed60ba 100644 --- a/packages/cli/src/setup-databases.ts +++ b/packages/cli/src/setup-databases.ts @@ -867,12 +867,6 @@ async function buildConnectionConfig(input: { stringConfigField(input.existingConnection, 'database'), ); if (database === undefined) return 'back'; - const schemaName = await promptText( - prompts, - 'Snowflake schema\nPress Enter for PUBLIC, or enter a schema name.', - stringConfigField(input.existingConnection, 'schema_name') ?? 'PUBLIC', - ); - if (schemaName === undefined) return 'back'; const username = await promptText( prompts, 'Snowflake username', @@ -894,14 +888,13 @@ async function buildConnectionConfig(input: { ); if (role === undefined) return 'back'; const resolvedPasswordRef = passwordRef ?? stringConfigField(input.existingConnection, 'password'); - if (!account || !warehouse || !database || !schemaName || !username || !resolvedPasswordRef) return null; + if (!account || !warehouse || !database || !username || !resolvedPasswordRef) return null; return { driver: 'snowflake', authMethod: 'password', account, warehouse, database, - schema_name: schemaName, username, password: resolvedPasswordRef, ...(role ? { role } : {}), @@ -1312,6 +1305,21 @@ async function writeScopeConfig(input: { }); } +async function promptCommaSeparatedScope(input: { + prompts: KtxSetupDatabasesPromptAdapter; + connectionId: string; + spec: ScopeDiscoverySpec; +}): Promise { + const example = + input.spec.nounPlural === 'datasets' ? 'sales, marketing' : 'SALES, MARKETING'; + const value = await promptText( + input.prompts, + `Enter ${input.spec.nounPlural} for ${input.connectionId} as a comma-separated list (e.g. ${example}).`, + ); + if (value === undefined) return undefined; + return unique(value.split(',').map((part) => part.trim())); +} + async function maybeConfigureDatabaseScope(input: { projectDir: string; connectionId: string; @@ -1382,11 +1390,15 @@ async function maybeConfigureDatabaseScope(input: { `Connecting to ${input.connectionId}…`, ]); - const schemasFilter = await (async (): Promise => { - if (cliSchemas.length > 0) return cliSchemas; - if (!spec) return []; + let effectiveCliSchemas = cliSchemas; + let schemasFilter: string[]; + if (effectiveCliSchemas.length > 0) { + schemasFilter = effectiveCliSchemas; + } else if (!spec) { + schemasFilter = []; + } else { try { - return unique( + schemasFilter = unique( await (input.deps.listSchemas ?? defaultListSchemas)(input.projectDir, input.connectionId), ); } catch (error) { @@ -1394,9 +1406,21 @@ async function maybeConfigureDatabaseScope(input: { input.io.stderr.write( `Could not discover ${spec.promptLabel.toLowerCase()} for ${input.connectionId}; ${detail}\n`, ); - return []; + const prompts = input.deps.prompts ?? createPromptAdapter(); + const typed = await promptCommaSeparatedScope({ prompts, connectionId: input.connectionId, spec }); + if (typed === undefined) return 'back'; + effectiveCliSchemas = typed; + schemasFilter = typed; + if (typed.length > 0) { + await writeScopeConfig({ + projectDir: input.projectDir, + connectionId: input.connectionId, + values: typed, + spec, + }); + } } - })(); + } let discovered: KtxTableListEntry[]; try { @@ -1426,7 +1450,7 @@ async function maybeConfigureDatabaseScope(input: { const schemasInDiscovery = unique(discovered.map((t) => t.schema)); const defaultSchemas = (() => { - if (cliSchemas.length > 0) return cliSchemas; + if (effectiveCliSchemas.length > 0) return effectiveCliSchemas; if (!spec) return schemasInDiscovery; return spec.defaultSelection(schemasInDiscovery); })();