From dfdb13171d1bb5a709b8b72480df8d0c19d6e745 Mon Sep 17 00:00:00 2001 From: Luca Martial Date: Wed, 13 May 2026 15:28:52 -0700 Subject: [PATCH] feat(cli): add database tree picker for schema and table scope selection Replace inline multiselect prompts in setup-databases with a new database-tree-picker that uses the generic tree picker TUI. This gives database scope selection the same grouped tree UI as the Notion page picker, combining schema and table selection into a single step. Co-Authored-By: Claude Opus 4.6 (1M context) --- packages/cli/src/database-tree-picker.test.ts | 188 ++++++++++ packages/cli/src/database-tree-picker.ts | 210 +++++++++++ packages/cli/src/setup-databases.test.ts | 160 ++++++--- packages/cli/src/setup-databases.ts | 332 +++++++----------- 4 files changed, 647 insertions(+), 243 deletions(-) create mode 100644 packages/cli/src/database-tree-picker.test.ts create mode 100644 packages/cli/src/database-tree-picker.ts diff --git a/packages/cli/src/database-tree-picker.test.ts b/packages/cli/src/database-tree-picker.test.ts new file mode 100644 index 00000000..5559ee42 --- /dev/null +++ b/packages/cli/src/database-tree-picker.test.ts @@ -0,0 +1,188 @@ +import { describe, expect, it, vi } from 'vitest'; +import { + pickDatabaseScope, + type DatabaseTreePickerRenderer, + type PickDatabaseScopeArgs, +} from './database-tree-picker.js'; +import type { TreePickerChrome, TreePickerResult } from './tree-picker-tui.js'; +import type { PickerState } from './tree-picker-state.js'; + +function makeIo() { + let stdout = ''; + let stderr = ''; + return { + io: { + stdout: { isTTY: true, write: (chunk: string) => { stdout += chunk; } }, + stderr: { write: (chunk: string) => { stderr += chunk; } }, + }, + stdout: () => stdout, + stderr: () => stderr, + }; +} + +function captureRenderer(): { + renderer: DatabaseTreePickerRenderer; + capture: { chrome?: TreePickerChrome; state?: PickerState }; + setResult: (result: TreePickerResult) => void; +} { + const capture: { chrome?: TreePickerChrome; state?: PickerState } = {}; + let nextResult: TreePickerResult = { kind: 'quit' }; + const renderer: DatabaseTreePickerRenderer = vi.fn(async (chrome, state) => { + capture.chrome = chrome; + capture.state = state; + return nextResult; + }); + return { + renderer, + capture, + setResult: (result) => { + nextResult = result; + }, + }; +} + +const discovered = [ + { schema: 'analytics', name: 'customers', kind: 'table' as const }, + { schema: 'analytics', name: 'orders', kind: 'table' as const }, + { schema: 'public', name: 'events', kind: 'view' as const }, + { schema: 'public', name: 'sessions', kind: 'table' as const }, +]; + +function baseArgs(overrides: Partial = {}): PickDatabaseScopeArgs { + return { + connectionId: 'warehouse', + schemaNoun: 'schema', + schemaNounPlural: 'schemas', + discovered, + existing: { enabledTables: [] }, + defaultSchemas: ['analytics'], + supportsSchemaScope: true, + ...overrides, + }; +} + +describe('pickDatabaseScope', () => { + it('builds a 2-level tree (schemas as parents, tables as children) and uses save-empty action', async () => { + const { renderer, capture, setResult } = captureRenderer(); + setResult({ kind: 'quit' }); + + await pickDatabaseScope(baseArgs(), makeIo().io, renderer); + + expect(capture.state?.skipEmptyAction).toBe('save-empty'); + const schemaIds = capture.state?.tree.filter((n) => n.parentId === null).map((n) => n.id); + const tableIds = capture.state?.tree.filter((n) => n.parentId !== null).map((n) => n.id); + expect((schemaIds ?? []).sort()).toEqual(['analytics', 'public']); + expect((tableIds ?? []).sort()).toEqual([ + 'analytics.customers', + 'analytics.orders', + 'public.events', + 'public.sessions', + ]); + expect(capture.state?.byId.get('public.events')?.title).toBe('events (view)'); + }); + + it('pre-checks default schemas at the parent level when no existing selection', async () => { + const { renderer, capture, setResult } = captureRenderer(); + setResult({ kind: 'quit' }); + + await pickDatabaseScope(baseArgs({ defaultSchemas: ['analytics'] }), makeIo().io, renderer); + + expect([...(capture.state?.checked ?? [])]).toEqual(['analytics']); + }); + + it('collapses an existing full-schema selection back into the parent check', async () => { + const { renderer, capture, setResult } = captureRenderer(); + setResult({ kind: 'quit' }); + + await pickDatabaseScope( + baseArgs({ existing: { enabledTables: ['analytics.customers', 'analytics.orders'] } }), + makeIo().io, + renderer, + ); + + expect([...(capture.state?.checked ?? [])]).toEqual(['analytics']); + }); + + it('keeps a partial existing selection at the leaf level', async () => { + const { renderer, capture, setResult } = captureRenderer(); + setResult({ kind: 'quit' }); + + await pickDatabaseScope( + baseArgs({ existing: { enabledTables: ['analytics.customers'] } }), + makeIo().io, + renderer, + ); + + expect([...(capture.state?.checked ?? [])]).toEqual(['analytics.customers']); + }); + + it('expands a selected schema parent into all its tables and derives activeSchemas', async () => { + const { renderer, setResult } = captureRenderer(); + setResult({ kind: 'save', selectedIds: ['analytics'] }); + + const result = await pickDatabaseScope(baseArgs(), makeIo().io, renderer); + + expect(result).toEqual({ + kind: 'selected', + activeSchemas: ['analytics'], + enabledTables: ['analytics.customers', 'analytics.orders'], + }); + }); + + it('combines parent and individual leaf selections without duplicate tables', async () => { + const { renderer, setResult } = captureRenderer(); + setResult({ kind: 'save', selectedIds: ['analytics', 'public.events'] }); + + const result = await pickDatabaseScope(baseArgs(), makeIo().io, renderer); + + expect(result).toEqual({ + kind: 'selected', + activeSchemas: ['analytics', 'public'], + enabledTables: ['analytics.customers', 'analytics.orders', 'public.events'], + }); + }); + + it('treats empty save as enable-all', async () => { + const { renderer, setResult } = captureRenderer(); + setResult({ kind: 'save', selectedIds: [] }); + + const result = await pickDatabaseScope(baseArgs(), makeIo().io, renderer); + + expect(result).toEqual({ + kind: 'selected', + activeSchemas: ['analytics', 'public'], + enabledTables: [ + 'analytics.customers', + 'analytics.orders', + 'public.events', + 'public.sessions', + ], + }); + }); + + it('omits activeSchemas when the driver does not support a schema scope', async () => { + const { renderer, setResult } = captureRenderer(); + setResult({ kind: 'save', selectedIds: ['analytics'] }); + + const result = await pickDatabaseScope( + baseArgs({ supportsSchemaScope: false }), + makeIo().io, + renderer, + ); + + expect(result).toEqual({ + kind: 'selected', + activeSchemas: [], + enabledTables: ['analytics.customers', 'analytics.orders'], + }); + }); + + it('returns back when the picker quits', async () => { + const { renderer, setResult } = captureRenderer(); + setResult({ kind: 'quit' }); + + const result = await pickDatabaseScope(baseArgs(), makeIo().io, renderer); + + expect(result).toEqual({ kind: 'back' }); + }); +}); diff --git a/packages/cli/src/database-tree-picker.ts b/packages/cli/src/database-tree-picker.ts new file mode 100644 index 00000000..d494003d --- /dev/null +++ b/packages/cli/src/database-tree-picker.ts @@ -0,0 +1,210 @@ +import type { KtxTableListEntry } from '@ktx/context/scan'; +import type { KtxCliIo } from './cli-runtime.js'; +import { profileMark } from './startup-profile.js'; +import { + buildInitialState, + buildPickerTree, + type PickerState, + type TreePickerNode, + type TreePickerNodeInput, +} from './tree-picker-state.js'; +import { + renderTreePickerTui, + type TreePickerChrome, + type TreePickerResult, + type TreePickerTuiIo, +} from './tree-picker-tui.js'; + +profileMark('module:database-tree-picker'); + +const DATABASE_SCRIPTED_MODE_HINT = + 'Database picker requires a TTY. Use --no-input and the relevant flags for scripted mode.'; + +export type DatabaseTreePickerRenderer = ( + chrome: TreePickerChrome, + initialState: PickerState, + io: TreePickerTuiIo, +) => Promise; + +function defaultRenderer( + chrome: TreePickerChrome, + initialState: PickerState, + io: TreePickerTuiIo, +): Promise { + return renderTreePickerTui({ chrome, initialState }, io, { scriptedModeHint: DATABASE_SCRIPTED_MODE_HINT }); +} + +export type DatabaseScopePickResult = + | { kind: 'selected'; activeSchemas: string[]; enabledTables: string[] } + | { kind: 'back' }; + +export interface PickDatabaseScopeArgs { + connectionId: string; + schemaNoun: string; + schemaNounPlural: string; + discovered: readonly KtxTableListEntry[]; + existing: { enabledTables: readonly string[] }; + defaultSchemas: readonly string[]; + supportsSchemaScope: boolean; +} + +function qualifiedTableId(entry: KtxTableListEntry): string { + return `${entry.schema}.${entry.name}`; +} + +function tableTitle(entry: KtxTableListEntry): string { + return entry.kind === 'view' ? `${entry.name} (view)` : entry.name; +} + +function buildTreeInputs(discovered: readonly KtxTableListEntry[]): { + inputs: TreePickerNodeInput[]; + schemaIds: string[]; + allTables: string[]; +} { + const schemaSeen = new Set(); + const schemaIds: string[] = []; + for (const entry of discovered) { + if (!schemaSeen.has(entry.schema)) { + schemaSeen.add(entry.schema); + schemaIds.push(entry.schema); + } + } + const inputs: TreePickerNodeInput[] = []; + for (const schema of schemaIds) { + inputs.push({ id: schema, title: schema, archived: false, parentId: null }); + } + for (const entry of discovered) { + inputs.push({ + id: qualifiedTableId(entry), + title: tableTitle(entry), + archived: false, + parentId: entry.schema, + }); + } + return { inputs, schemaIds, allTables: discovered.map(qualifiedTableId) }; +} + +function initialSelectionForExisting( + existing: readonly string[], + byId: Map, +): string[] { + const tableIds = new Set( + [...byId.values()].filter((node) => node.parentId !== null).map((node) => node.id), + ); + const existingTables = new Set(existing.filter((id) => tableIds.has(id))); + const schemaChildren = new Map(); + for (const node of byId.values()) { + if (node.parentId === null && node.childIds.length > 0) { + schemaChildren.set(node.id, [...node.childIds]); + } + } + const result: string[] = []; + for (const [schema, children] of schemaChildren) { + const allChecked = children.length > 0 && children.every((childId) => existingTables.has(childId)); + if (allChecked) { + result.push(schema); + for (const childId of children) { + existingTables.delete(childId); + } + } + } + for (const id of existingTables) { + result.push(id); + } + return result; +} + +function initialSelectionFromDefaults( + defaultSchemas: readonly string[], + schemaIds: readonly string[], +): string[] { + const valid = new Set(schemaIds); + const filtered = defaultSchemas.filter((s) => valid.has(s)); + return filtered.length > 0 ? filtered : [...schemaIds]; +} + +function expandSelectedToTables( + selectedIds: readonly string[], + byId: Map, +): string[] { + const expanded: string[] = []; + const seen = new Set(); + for (const id of selectedIds) { + const node = byId.get(id); + if (!node) continue; + if (node.childIds.length === 0) { + if (node.parentId !== null && !seen.has(id)) { + seen.add(id); + expanded.push(id); + } + continue; + } + for (const childId of node.childIds) { + if (!seen.has(childId)) { + seen.add(childId); + expanded.push(childId); + } + } + } + return expanded; +} + +function schemasFromEnabledTables(enabledTables: readonly string[]): string[] { + const seen = new Set(); + const result: string[] = []; + for (const qualified of enabledTables) { + const schema = qualified.split('.')[0] ?? ''; + if (schema.length === 0 || seen.has(schema)) continue; + seen.add(schema); + result.push(schema); + } + return result; +} + +export async function pickDatabaseScope( + args: PickDatabaseScopeArgs, + io: KtxCliIo, + render: DatabaseTreePickerRenderer = defaultRenderer, +): Promise { + const { inputs, schemaIds, allTables } = buildTreeInputs(args.discovered); + const tree = buildPickerTree(inputs); + const byId = new Map(tree.map((node) => [node.id, node])); + const tableCount = allTables.length; + const schemaCount = schemaIds.length; + + const initialSelection = + args.existing.enabledTables.length > 0 + ? initialSelectionForExisting(args.existing.enabledTables, byId) + : initialSelectionFromDefaults(args.defaultSchemas, schemaIds); + + const initialState = buildInitialState({ + tree, + existingSelectedIds: initialSelection, + skipEmptyAction: 'save-empty', + }); + + const schemaWordPlural = schemaCount === 1 ? args.schemaNoun : args.schemaNounPlural; + const subtitleLines = [ + `Connection: ${args.connectionId}`, + `Found ${tableCount} ${tableCount === 1 ? 'table' : 'tables'} across ${schemaCount} ${schemaWordPlural}.`, + `Toggle a ${args.schemaNoun} to enable all of its tables, or expand to pick individual tables.`, + ]; + + const chrome: TreePickerChrome = { + title: `Choose tables to enable for ${args.connectionId}`, + subtitleLines, + skipEmptyMessage: + 'Nothing selected. Enable all tables? Press Enter to enable all or Escape to go back.', + }; + + const result = await render(chrome, initialState, io as TreePickerTuiIo); + if (result.kind === 'quit') { + return { kind: 'back' }; + } + + const enabledTables = + result.selectedIds.length === 0 ? allTables : expandSelectedToTables(result.selectedIds, byId); + const activeSchemas = args.supportsSchemaScope ? schemasFromEnabledTables(enabledTables) : []; + + return { kind: 'selected', activeSchemas, enabledTables }; +} diff --git a/packages/cli/src/setup-databases.test.ts b/packages/cli/src/setup-databases.test.ts index fa4ca3f2..d3a55fba 100644 --- a/packages/cli/src/setup-databases.test.ts +++ b/packages/cli/src/setup-databases.test.ts @@ -5,10 +5,15 @@ import { initKtxProject, parseKtxProjectConfig, readKtxSetupState, writeKtxSetup import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'; import { type KtxSetupDatabaseDriver, + type KtxSetupDatabasesDeps, type KtxSetupDatabasesPromptAdapter, runKtxSetupDatabasesStep, } from './setup-databases.js'; import type { KtxCliIo } from './cli-runtime.js'; +import type { + DatabaseScopePickResult, + PickDatabaseScopeArgs, +} from './database-tree-picker.js'; function makeIo() { let stdout = ''; @@ -32,6 +37,43 @@ function makeIo() { }; } +type ScopePick = + | 'back' + | 'enable-all' + | { schemas: string[]; tables: string[] }; + +interface PickerStubs { + pickDatabaseScope: KtxSetupDatabasesDeps['pickDatabaseScope']; + scopeCalls: PickDatabaseScopeArgs[]; +} + +function makePickerStubs(options: { scopes?: ScopePick[] } = {}): PickerStubs { + const queue: ScopePick[] = [...(options.scopes ?? [])]; + const scopeCalls: PickDatabaseScopeArgs[] = []; + return { + scopeCalls, + pickDatabaseScope: vi.fn(async (args: PickDatabaseScopeArgs): Promise => { + scopeCalls.push(args); + const next = queue.shift(); + if (next === undefined || next === 'enable-all') { + const enabledTables = args.discovered.map((t) => `${t.schema}.${t.name}`); + const activeSchemas = args.supportsSchemaScope + ? Array.from(new Set(args.discovered.map((t) => t.schema))) + : []; + return { kind: 'selected', activeSchemas, enabledTables }; + } + if (next === 'back') { + return { kind: 'back' }; + } + return { + kind: 'selected', + activeSchemas: args.supportsSchemaScope ? next.schemas : [], + enabledTables: next.tables, + }; + }), + }; +} + function makePromptAdapter(options: { multiselectValues?: string[][]; selectValues?: string[]; @@ -819,7 +861,6 @@ describe('setup databases step', () => { await writeKtxSetupState(tempDir, { completed_steps: ['databases'] }); const prompts = makePromptAdapter({ textValues: ['env:DATABASE_URL'], - multiselectValues: [['analytics']], }); let primaryMenuCount = 0; vi.mocked(prompts.select).mockImplementation(async (options) => { @@ -835,11 +876,21 @@ describe('setup databases step', () => { const scanConnection = vi.fn(async () => 0); const listSchemas = vi.fn(async () => ['analytics', 'public']); const listTables = vi.fn(async () => [{ schema: 'analytics', name: 'customers', kind: 'table' as const }]); + const pickers = makePickerStubs({ + scopes: [{ schemas: ['analytics'], tables: ['analytics.customers'] }], + }); const result = await runKtxSetupDatabasesStep( { projectDir: tempDir, inputMode: 'auto', skipDatabases: false, databaseSchemas: [] }, makeIo().io, - { prompts, testConnection, scanConnection, listSchemas, listTables }, + { + prompts, + testConnection, + scanConnection, + listSchemas, + listTables, + pickDatabaseScope: pickers.pickDatabaseScope, + }, ); expect(result).toEqual({ status: 'ready', projectDir: tempDir, connectionIds: ['warehouse'] }); @@ -848,7 +899,7 @@ describe('setup databases step', () => { placeholder: 'env:DATABASE_URL', initialValue: 'env:DATABASE_URL', }); - expect(listTables).toHaveBeenCalledWith(tempDir, 'warehouse'); + expect(listTables).toHaveBeenCalledWith(tempDir, 'warehouse', ['analytics', 'public']); expect(testConnection).toHaveBeenCalledWith(tempDir, 'warehouse', expect.anything()); expect(scanConnection).toHaveBeenCalledWith(tempDir, 'warehouse', expect.anything()); const config = parseKtxProjectConfig(await readFile(join(tempDir, 'ktx.yaml'), 'utf-8')); @@ -882,7 +933,6 @@ describe('setup databases step', () => { await writeKtxSetupState(tempDir, { completed_steps: ['databases'] }); const prompts = makePromptAdapter({ textValues: ['env:DATABASE_URL'], - multiselectValues: [['public'], ['public.customers', 'public.orders']], }); let primaryMenuCount = 0; vi.mocked(prompts.select).mockImplementation(async (options) => { @@ -892,7 +942,6 @@ describe('setup databases step', () => { } if (options.message === 'Primary source to edit') return 'warehouse'; if (options.message === 'How do you want to connect to PostgreSQL?') return 'url'; - if (options.message.startsWith('Tables found in selected schemas')) return 'customize'; return 'back'; }); const listSchemas = vi.fn(async () => ['orbit_analytics', 'orbit_raw', 'public']); @@ -901,6 +950,9 @@ describe('setup databases step', () => { { schema: 'public', name: 'orders', kind: 'table' as const }, { schema: 'public', name: 'products', kind: 'table' as const }, ]); + const pickers = makePickerStubs({ + scopes: [{ schemas: ['public'], tables: ['public.customers', 'public.orders'] }], + }); const result = await runKtxSetupDatabasesStep( { projectDir: tempDir, inputMode: 'auto', skipDatabases: false, databaseSchemas: [] }, @@ -911,29 +963,17 @@ describe('setup databases step', () => { scanConnection: vi.fn(async () => 0), listSchemas, listTables, + pickDatabaseScope: pickers.pickDatabaseScope, }, ); expect(result).toEqual({ status: 'ready', projectDir: tempDir, connectionIds: ['warehouse'] }); - expect(prompts.multiselect).toHaveBeenNthCalledWith(1, { - message: expect.stringContaining('PostgreSQL schemas to scan'), - options: [ - { value: 'orbit_analytics', label: 'orbit_analytics' }, - { value: 'orbit_raw', label: 'orbit_raw' }, - { value: 'public', label: 'public' }, - ], - initialValues: ['public'], - required: true, - }); - expect(prompts.multiselect).toHaveBeenNthCalledWith(2, { - message: expect.stringContaining('Tables to enable for warehouse'), - options: [ - { value: 'public.customers', label: 'public.customers' }, - { value: 'public.orders', label: 'public.orders' }, - { value: 'public.products', label: 'public.products' }, - ], - initialValues: ['public.customers', 'public.orders'], - required: true, + expect(pickers.scopeCalls).toHaveLength(1); + expect(pickers.scopeCalls[0]).toMatchObject({ + connectionId: 'warehouse', + schemaNoun: 'schema', + supportsSchemaScope: true, + existing: { enabledTables: ['public.customers', 'public.orders'] }, }); const config = parseKtxProjectConfig(await readFile(join(tempDir, 'ktx.yaml'), 'utf-8')); expect(config.connections.warehouse).toMatchObject({ @@ -965,7 +1005,6 @@ describe('setup databases step', () => { await writeKtxSetupState(tempDir, { completed_steps: ['databases'] }); const prompts = makePromptAdapter({ textValues: ['env:DATABASE_URL'], - multiselectValues: [['back']], }); let primaryMenuCount = 0; vi.mocked(prompts.select).mockImplementation(async (options) => { @@ -980,19 +1019,29 @@ describe('setup databases step', () => { const testConnection = vi.fn(async () => 0); const scanConnection = vi.fn(async () => 0); const listSchemas = vi.fn(async () => ['analytics', 'public']); - const listTables = vi.fn(async () => [{ schema: 'analytics', name: 'customers', kind: 'table' as const }]); + const listTables = vi.fn(async () => [ + { schema: 'analytics', name: 'customers', kind: 'table' as const }, + { schema: 'public', name: 'orders', kind: 'table' as const }, + ]); + const pickers = makePickerStubs({ scopes: ['back'] }); const result = await runKtxSetupDatabasesStep( { projectDir: tempDir, inputMode: 'auto', skipDatabases: false, databaseSchemas: [] }, makeIo().io, - { prompts, testConnection, scanConnection, listSchemas, listTables }, + { + prompts, + testConnection, + scanConnection, + listSchemas, + listTables, + pickDatabaseScope: pickers.pickDatabaseScope, + }, ); expect(result).toEqual({ status: 'ready', projectDir: tempDir, connectionIds: ['warehouse'] }); expect(primaryMenuCount).toBe(2); expect(testConnection).toHaveBeenCalledWith(tempDir, 'warehouse', expect.anything()); expect(scanConnection).not.toHaveBeenCalled(); - expect(listTables).not.toHaveBeenCalled(); const config = parseKtxProjectConfig(await readFile(join(tempDir, 'ktx.yaml'), 'utf-8')); expect(config.connections.warehouse).toMatchObject({ url: 'env:DATABASE_URL', @@ -1031,7 +1080,6 @@ describe('setup databases step', () => { } if (options.message === 'Primary source to edit') return 'warehouse'; if (options.message === 'How do you want to connect to PostgreSQL?') return 'url'; - if (options.message.startsWith('Tables found in selected schemas')) return 'back'; return 'back'; }); const testConnection = vi.fn(async () => 0); @@ -1041,16 +1089,24 @@ describe('setup databases step', () => { { schema: 'public', name: 'customers', kind: 'table' as const }, { schema: 'public', name: 'orders', kind: 'table' as const }, ]); + const pickers = makePickerStubs({ scopes: ['back'] }); const result = await runKtxSetupDatabasesStep( { projectDir: tempDir, inputMode: 'auto', skipDatabases: false, databaseSchemas: [] }, makeIo().io, - { prompts, testConnection, scanConnection, listSchemas, listTables }, + { + prompts, + testConnection, + scanConnection, + listSchemas, + listTables, + pickDatabaseScope: pickers.pickDatabaseScope, + }, ); expect(result).toEqual({ status: 'ready', projectDir: tempDir, connectionIds: ['warehouse'] }); expect(primaryMenuCount).toBe(2); - expect(listTables).toHaveBeenCalledWith(tempDir, 'warehouse'); + expect(listTables).toHaveBeenCalledWith(tempDir, 'warehouse', ['public']); expect(scanConnection).not.toHaveBeenCalled(); const config = parseKtxProjectConfig(await readFile(join(tempDir, 'ktx.yaml'), 'utf-8')); expect(config.connections.warehouse).toMatchObject({ @@ -1083,19 +1139,18 @@ describe('setup databases step', () => { await writeKtxSetupState(tempDir, { completed_steps: ['databases'] }); const prompts = makePromptAdapter({ textValues: ['env:DATABASE_URL'], - multiselectValues: [['public']], }); vi.mocked(prompts.select).mockImplementation(async (options) => { if (options.message === 'Primary sources already configured: warehouse\nWhat would you like to do?') return 'edit'; if (options.message === 'Primary source to edit') return 'warehouse'; if (options.message === 'How do you want to connect to PostgreSQL?') return 'url'; - if (options.message.startsWith('Tables found in selected schemas')) return 'all'; return 'back'; }); const listTables = vi.fn(async () => [ { schema: 'public', name: 'customers', kind: 'table' as const }, { schema: 'public', name: 'orders', kind: 'table' as const }, ]); + const pickers = makePickerStubs({ scopes: ['enable-all'] }); const result = await runKtxSetupDatabasesStep( { projectDir: tempDir, inputMode: 'auto', skipDatabases: false, databaseSchemas: [] }, @@ -1105,6 +1160,7 @@ describe('setup databases step', () => { testConnection: vi.fn(async () => 0), scanConnection: vi.fn(async () => 1), listTables, + pickDatabaseScope: pickers.pickDatabaseScope, }, ); @@ -1390,7 +1446,6 @@ describe('setup databases step', () => { const prompts = makePromptAdapter({ selectValues: ['url'], textValues: ['', 'env:DATABASE_URL'], - multiselectValues: [['orbit_analytics', 'orbit_raw']], }); const testConnection = vi.fn(async () => 0); const scanConnection = vi.fn(async asyncScanProjectDir => { @@ -1401,6 +1456,19 @@ describe('setup databases step', () => { return 0; }); const listSchemas = vi.fn(async () => ['orbit_analytics', 'orbit_raw', 'public']); + const listTables = vi.fn(async () => [ + { schema: 'orbit_analytics', name: 'events', kind: 'table' as const }, + { schema: 'orbit_raw', name: 'inputs', kind: 'table' as const }, + { schema: 'public', name: 'misc', kind: 'table' as const }, + ]); + const pickers = makePickerStubs({ + scopes: [ + { + schemas: ['orbit_analytics', 'orbit_raw'], + tables: ['orbit_analytics.events', 'orbit_raw.inputs'], + }, + ], + }); const result = await runKtxSetupDatabasesStep( { @@ -1411,20 +1479,24 @@ describe('setup databases step', () => { skipDatabases: false, }, io.io, - { prompts, testConnection, scanConnection, listSchemas }, + { + prompts, + testConnection, + scanConnection, + listSchemas, + listTables, + pickDatabaseScope: pickers.pickDatabaseScope, + }, ); expect(result.status).toBe('ready'); expect(listSchemas).toHaveBeenCalledWith(tempDir, 'postgres-warehouse'); - expect(prompts.multiselect).toHaveBeenCalledWith({ - message: expect.stringContaining('PostgreSQL schemas to scan'), - options: [ - { value: 'orbit_analytics', label: 'orbit_analytics' }, - { value: 'orbit_raw', label: 'orbit_raw' }, - { value: 'public', label: 'public' }, - ], - initialValues: ['orbit_analytics', 'orbit_raw'], - required: true, + expect(pickers.scopeCalls).toHaveLength(1); + expect(pickers.scopeCalls[0]).toMatchObject({ + connectionId: 'postgres-warehouse', + schemaNoun: 'schema', + schemaNounPlural: 'schemas', + defaultSchemas: ['orbit_analytics', 'orbit_raw'], }); const config = parseKtxProjectConfig(await readFile(join(tempDir, 'ktx.yaml'), 'utf-8')); expect(config.connections['postgres-warehouse']).toMatchObject({ diff --git a/packages/cli/src/setup-databases.ts b/packages/cli/src/setup-databases.ts index 9db80689..c21ab6d1 100644 --- a/packages/cli/src/setup-databases.ts +++ b/packages/cli/src/setup-databases.ts @@ -14,6 +14,11 @@ import { import type { KtxTableListEntry } from '@ktx/context/scan'; import type { KtxCliIo } from './cli-runtime.js'; import { runKtxConnection } from './connection.js'; +import { + pickDatabaseScope as defaultPickDatabaseScope, + type DatabaseScopePickResult, + type PickDatabaseScopeArgs, +} from './database-tree-picker.js'; import { withMultiselectNavigation, withTextInputNavigation } from './prompt-navigation.js'; import { runKtxScan } from './scan.js'; import { writeProjectLocalSecretReference } from './setup-secrets.js'; @@ -90,7 +95,8 @@ export interface KtxSetupDatabasesDeps { scanConnection?: (projectDir: string, connectionId: string, io: KtxCliIo) => Promise; rebuildNativeSqlite?: (io: KtxCliIo) => Promise; listSchemas?: (projectDir: string, connectionId: string) => Promise; - listTables?: (projectDir: string, connectionId: string) => Promise; + listTables?: (projectDir: string, connectionId: string, schemas?: string[]) => Promise; + pickDatabaseScope?: (args: PickDatabaseScopeArgs, io: KtxCliIo) => Promise; historicSqlProbe?: KtxSetupHistoricSqlProbe; } @@ -363,11 +369,15 @@ function configuredSchemas(connection: KtxProjectConnectionConfig | undefined, d return values.length > 0 ? values : undefined; } -async function defaultListTables(projectDir: string, connectionId: string): Promise { +async function defaultListTables( + projectDir: string, + connectionId: string, + schemasOverride?: string[], +): Promise { const project = await loadKtxProject({ projectDir }); const connection = project.config.connections[connectionId]; const driver = normalizeDriver(connection?.driver); - const schemas = driver ? configuredSchemas(connection, driver) : undefined; + const schemas = schemasOverride ?? (driver ? configuredSchemas(connection, driver) : undefined); if (driver === 'postgres') { const { KtxPostgresScanConnector, isKtxPostgresConnectionConfig } = await import('@ktx/connector-postgres'); @@ -1271,145 +1281,98 @@ async function writeScopeConfig(input: { }); } -async function clearScopeConfig(projectDir: string, connectionId: string): Promise { - const project = await loadKtxProject({ projectDir }); - const connection = project.config.connections[connectionId]; - if (!connection) return; - const driver = normalizeDriver(connection.driver); - if (!driver) return; - const spec = SCOPE_DISCOVERY_SPECS[driver]; - if (!spec) return; - const cleaned = Object.fromEntries( - Object.entries(connection).filter( - ([key]) => key !== spec.configArrayField && key !== spec.configSingleField && key !== 'enabled_tables', - ), - ) as KtxProjectConnectionConfig; - await writeConnectionConfig({ projectDir, connectionId, connection: cleaned }); -} - -async function maybeConfigureSchemaScope(input: { +async function maybeConfigureDatabaseScope(input: { projectDir: string; connectionId: string; args: KtxSetupDatabasesArgs; - prompts: KtxSetupDatabasesPromptAdapter; deps: KtxSetupDatabasesDeps; io: KtxCliIo; forcePrompt?: boolean; -}): Promise { - const project = await loadKtxProject({ projectDir: input.projectDir }); - const connection = project.config.connections[input.connectionId]; - const driver = normalizeDriver(connection?.driver); - if (!driver) return 'ready'; - - const spec = SCOPE_DISCOVERY_SPECS[driver]; - if (!spec) return 'ready'; - - const arrayVal = connection?.[spec.configArrayField]; - if (Array.isArray(arrayVal) && arrayVal.length > 0 && input.forcePrompt !== true) { - return 'ready'; - } - - if (input.args.databaseSchemas.length > 0) { - await writeScopeConfig({ - projectDir: input.projectDir, - connectionId: input.connectionId, - values: input.args.databaseSchemas, - spec, - }); - return 'ready'; - } - - writeSetupSection(input.io, `Discovering ${spec.promptLabel.toLowerCase()}`, [ - `Connecting to ${input.connectionId}…`, - ]); - - let discovered: string[]; - try { - discovered = unique( - await (input.deps.listSchemas ?? defaultListSchemas)(input.projectDir, input.connectionId), - ); - } catch (error) { - const detail = error instanceof Error ? error.message : String(error); - input.io.stderr.write( - input.forcePrompt === true - ? `Could not discover ${spec.promptLabel.toLowerCase()} for ${input.connectionId}; edit was not saved. ` + - `Pass --database-schema to set it explicitly. ${detail}\n` - : `Could not discover ${spec.promptLabel.toLowerCase()} for ${input.connectionId}; continuing with existing ${spec.noun} scope. ` + - `Pass --database-schema to set it explicitly. ${detail}\n`, - ); - return input.forcePrompt === true ? 'failed' : 'ready'; - } - if (discovered.length === 0) { - return 'ready'; - } - - let selected: string[]; - if (input.args.inputMode === 'disabled' || discovered.length === 1) { - const preconfigured = configuredScopeValues(connection, spec).filter((v) => discovered.includes(v)); - selected = preconfigured.length > 0 ? preconfigured : discovered; - } else { - const preconfigured = configuredScopeValues(connection, spec).filter((v) => discovered.includes(v)); - const initialValues = preconfigured.length > 0 ? preconfigured : spec.defaultSelection(discovered); - const choices = await input.prompts.multiselect({ - message: withMultiselectNavigation( - `${spec.promptLabel} to scan\n` + - `KTX found multiple ${spec.nounPlural}. Select every ${spec.noun} agents should use.`, - ), - options: discovered.map((v) => ({ value: v, label: v })), - initialValues, - required: true, - }); - if (choices.includes('back')) { - return 'back'; - } - selected = choices.length > 0 ? choices : initialValues; - } - - await writeScopeConfig({ - projectDir: input.projectDir, - connectionId: input.connectionId, - values: selected, - spec, - }); - const capitalNounPlural = spec.nounPlural[0]!.toUpperCase() + spec.nounPlural.slice(1); - writeSetupSection(input.io, `${capitalNounPlural} saved for ${input.connectionId}`, [ - `✓ ${selected.join(', ')}`, - ]); - return 'ready'; -} - -async function maybeConfigureTableScope(input: { - projectDir: string; - connectionId: string; - args: KtxSetupDatabasesArgs; - prompts: KtxSetupDatabasesPromptAdapter; - io: KtxCliIo; - deps: KtxSetupDatabasesDeps; - forcePrompt?: boolean; }): Promise { const project = await loadKtxProject({ projectDir: input.projectDir }); const connection = project.config.connections[input.connectionId]; const driver = normalizeDriver(connection?.driver); if (!driver || driver === 'sqlite') return 'ready'; + const spec = SCOPE_DISCOVERY_SPECS[driver]; const existingTables = connection?.enabled_tables; - if (Array.isArray(existingTables) && existingTables.length > 0 && input.forcePrompt !== true) { + const hasExistingTables = Array.isArray(existingTables) && existingTables.length > 0; + const existingScope = spec ? configuredScopeValues(connection, spec) : []; + const hasExistingScope = !spec || existingScope.length > 0; + + if (hasExistingTables && hasExistingScope && input.forcePrompt !== true) { return 'ready'; } + const cliSchemas = input.args.databaseSchemas; + if (input.args.inputMode === 'disabled') { + if (spec) { + let scopeToWrite: string[] = cliSchemas; + if (scopeToWrite.length === 0) { + try { + scopeToWrite = unique( + await (input.deps.listSchemas ?? defaultListSchemas)(input.projectDir, input.connectionId), + ); + } catch (error) { + const detail = error instanceof Error ? error.message : String(error); + input.io.stderr.write( + `Could not discover ${spec.promptLabel.toLowerCase()} for ${input.connectionId}; ${detail}\n`, + ); + return 'ready'; + } + } + if (scopeToWrite.length > 0) { + await writeScopeConfig({ + projectDir: input.projectDir, + connectionId: input.connectionId, + values: scopeToWrite, + spec, + }); + const capitalNounPlural = spec.nounPlural[0]!.toUpperCase() + spec.nounPlural.slice(1); + writeSetupSection(input.io, `${capitalNounPlural} saved for ${input.connectionId}`, [ + `✓ ${scopeToWrite.join(', ')}`, + ]); + } + } return 'ready'; } + if (spec && cliSchemas.length > 0) { + await writeScopeConfig({ + projectDir: input.projectDir, + connectionId: input.connectionId, + values: cliSchemas, + spec, + }); + } + writeSetupSection(input.io, 'Discovering tables', [ `Connecting to ${input.connectionId}…`, ]); + const schemasFilter = await (async (): Promise => { + if (cliSchemas.length > 0) return cliSchemas; + if (!spec) return []; + try { + return unique( + await (input.deps.listSchemas ?? defaultListSchemas)(input.projectDir, input.connectionId), + ); + } catch (error) { + const detail = error instanceof Error ? error.message : String(error); + input.io.stderr.write( + `Could not discover ${spec.promptLabel.toLowerCase()} for ${input.connectionId}; ${detail}\n`, + ); + return []; + } + })(); + let discovered: KtxTableListEntry[]; try { discovered = await (input.deps.listTables ?? defaultListTables)( input.projectDir, input.connectionId, + schemasFilter.length > 0 ? schemasFilter : undefined, ); } catch (error) { const detail = error instanceof Error ? error.message : String(error); @@ -1429,84 +1392,72 @@ async function maybeConfigureTableScope(input: { } const allQualified = discovered.map((t) => `${t.schema}.${t.name}`); + const schemasInDiscovery = unique(discovered.map((t) => t.schema)); + + const defaultSchemas = (() => { + if (cliSchemas.length > 0) return cliSchemas; + if (!spec) return schemasInDiscovery; + return spec.defaultSelection(schemasInDiscovery); + })(); + + const existingEnabled = + hasExistingTables && input.forcePrompt === true + ? (existingTables ?? []).filter( + (table): table is string => typeof table === 'string' && allQualified.includes(table), + ) + : []; + + let activeSchemas: string[]; + let enabledTables: string[]; if (discovered.length === 1) { - await writeConnectionConfig({ - projectDir: input.projectDir, - connectionId: input.connectionId, - connection: { ...connection!, enabled_tables: allQualified }, - }); - writeSetupSection(input.io, `Tables enabled for ${input.connectionId}`, [ - `✓ ${allQualified[0]}`, - ]); - return 'ready'; - } - - const bySchema = new Map(); - for (const entry of discovered) { - const existing = bySchema.get(entry.schema) ?? []; - existing.push(entry); - bySchema.set(entry.schema, existing); - } - const schemaList = [...bySchema.keys()].sort(); - const schemaSummary = schemaList.map((s) => `${s} (${bySchema.get(s)!.length})`).join(', '); - - let selected: string[] | null = null; - - while (selected === null) { - const action = await input.prompts.select({ - message: `Tables found in selected schemas\n` + - `${discovered.length} tables across ${schemaList.length} ${schemaList.length === 1 ? 'schema' : 'schemas'}: ${schemaSummary}`, - options: [ - { value: 'all', label: 'Enable all tables' }, - { value: 'customize', label: 'Customize which tables to enable' }, - { value: 'back', label: 'Back' }, - ], - }); - - if (action === 'back') { + enabledTables = allQualified; + activeSchemas = spec ? schemasInDiscovery : []; + } else { + const pickResult = await (input.deps.pickDatabaseScope ?? defaultPickDatabaseScope)( + { + connectionId: input.connectionId, + schemaNoun: spec?.noun ?? 'schema', + schemaNounPlural: spec?.nounPlural ?? 'schemas', + discovered, + existing: { enabledTables: existingEnabled }, + defaultSchemas, + supportsSchemaScope: spec !== undefined, + }, + input.io, + ); + if (pickResult.kind === 'back') { return 'back'; } - - if (action === 'all') { - selected = allQualified; - } else { - const choices = await input.prompts.multiselect({ - message: withMultiselectNavigation( - `Tables to enable for ${input.connectionId}\n` + - `Deselect any tables agents should not use.`, - ), - options: discovered.map((t) => { - const qualified = `${t.schema}.${t.name}`; - const suffix = t.kind === 'view' ? ' (view)' : ''; - return { value: qualified, label: `${qualified}${suffix}` }; - }), - initialValues: - Array.isArray(existingTables) && input.forcePrompt === true - ? existingTables.filter((table): table is string => typeof table === 'string' && allQualified.includes(table)) - : allQualified, - required: true, - }); - - if (choices.includes('back')) { - continue; - } - if (choices.length === 0) { - input.io.stdout.write('│ KTX needs at least one table enabled. Select a table or press Escape to go back.\n'); - continue; - } - selected = choices; - } + enabledTables = pickResult.enabledTables; + activeSchemas = pickResult.activeSchemas; } + if (spec) { + await writeScopeConfig({ + projectDir: input.projectDir, + connectionId: input.connectionId, + values: activeSchemas, + spec, + }); + } + const refreshedProject = await loadKtxProject({ projectDir: input.projectDir }); + const currentConnection = refreshedProject.config.connections[input.connectionId]; + if (!currentConnection) return 'ready'; await writeConnectionConfig({ projectDir: input.projectDir, connectionId: input.connectionId, - connection: { ...connection!, enabled_tables: selected }, + connection: { ...currentConnection, enabled_tables: enabledTables }, }); + if (spec && activeSchemas.length > 0) { + const capitalNounPlural = spec.nounPlural[0]!.toUpperCase() + spec.nounPlural.slice(1); + writeSetupSection(input.io, `${capitalNounPlural} saved for ${input.connectionId}`, [ + `✓ ${activeSchemas.join(', ')}`, + ]); + } writeSetupSection(input.io, `Tables enabled for ${input.connectionId}`, [ - `✓ ${selected.length}/${discovered.length} tables enabled`, + `✓ ${enabledTables.length}/${discovered.length} tables enabled`, ]); return 'ready'; } @@ -1638,26 +1589,9 @@ async function validateAndScanConnection(input: { const testLines = ['✓ Connection test passed', `Driver: ${driverDisplay}`]; writeSetupSection(input.io, `Testing ${input.connectionId}`, testLines); - while (true) { - const schemaStatus = await maybeConfigureSchemaScope({ ...input, forcePrompt: input.forceScopeAndTables }); - if (schemaStatus !== 'ready') { - return schemaStatus; - } - - const tableStatus = await maybeConfigureTableScope({ ...input, forcePrompt: input.forceScopeAndTables }); - if (tableStatus === 'ready') { - break; - } - - if (input.forceScopeAndTables) { - return tableStatus; - } - - if (tableStatus === 'failed') { - return 'failed'; - } - - await clearScopeConfig(input.projectDir, input.connectionId); + const scopeStatus = await maybeConfigureDatabaseScope({ ...input, forcePrompt: input.forceScopeAndTables }); + if (scopeStatus !== 'ready') { + return scopeStatus; } await maybeRunHistoricSqlSetupProbe({