From 65ba2f14a388a1e99e0ebeecfda0352630089e43 Mon Sep 17 00:00:00 2001 From: Andrey Avtomonov <7889985+andreybavt@users.noreply.github.com> Date: Tue, 12 May 2026 01:06:23 +0200 Subject: [PATCH] fix: allow dbt ingest to discover warehouse schemas --- .../src/sl/semantic-layer.service.test.ts | 25 ++++++ .../context/src/sl/semantic-layer.service.ts | 8 +- .../src/sl/tools/sl-discover.tool.test.ts | 80 +++++++++++++++++++ .../context/src/sl/tools/sl-discover.tool.ts | 20 +++-- 4 files changed, 122 insertions(+), 11 deletions(-) create mode 100644 packages/context/src/sl/tools/sl-discover.tool.test.ts diff --git a/packages/context/src/sl/semantic-layer.service.test.ts b/packages/context/src/sl/semantic-layer.service.test.ts index 3adde085..c89cbc28 100644 --- a/packages/context/src/sl/semantic-layer.service.test.ts +++ b/packages/context/src/sl/semantic-layer.service.test.ts @@ -38,6 +38,31 @@ const baseTable: SemanticLayerSource = { measures: [], }; +describe('listConnectionIdsWithNames', () => { + it('discovers local KTX connection ids from semantic-layer directories', async () => { + const configService = { + listFiles: vi.fn().mockResolvedValue({ + files: [ + 'semantic-layer/warehouse/_schema/public.yaml', + 'semantic-layer/dbt-main/orders.yaml', + 'semantic-layer/.gitkeep', + ], + }), + }; + const catalog = connectionCatalog(); + catalog.listEnabledConnections.mockImplementation(async (ids: string[]) => + ids.map((id) => ({ id, name: id, connectionType: id === 'warehouse' ? 'postgres' : 'dbt' })), + ); + const service = new SemanticLayerService(configService as never, catalog, pythonPort); + + await expect(service.listConnectionIdsWithNames()).resolves.toEqual([ + { id: 'dbt-main', name: 'dbt-main', connectionType: 'dbt' }, + { id: 'warehouse', name: 'warehouse', connectionType: 'postgres' }, + ]); + expect(catalog.listEnabledConnections).toHaveBeenCalledWith(['dbt-main', 'warehouse']); + }); +}); + describe('composeOverlay', () => { it('carries top-level segments from overlay into the composed source', () => { const overlay = { diff --git a/packages/context/src/sl/semantic-layer.service.ts b/packages/context/src/sl/semantic-layer.service.ts index ffae0b12..938763fe 100644 --- a/packages/context/src/sl/semantic-layer.service.ts +++ b/packages/context/src/sl/semantic-layer.service.ts @@ -12,6 +12,7 @@ interface WriteSourceOptions { } const SL_DIR_PREFIX = 'semantic-layer'; +const CONNECTION_ID_PATTERN = /^[a-zA-Z0-9][a-zA-Z0-9_-]*$/; function formatPortError(error: unknown, fallback: string): string { if (typeof error === 'string') { @@ -61,11 +62,12 @@ export class SemanticLayerService { async listConnectionIds(): Promise { try { const result = await this.configService.listFiles(SL_DIR_PREFIX); - // Directories under semantic-layer/ are connectionIds (UUIDs) - const uuidPattern = /^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$/i; + // Directories under semantic-layer/ are connectionIds. Local KTX projects use + // readable ids like "warehouse" and "dbt-main", not only UUIDs. return result.files .map((f) => f.replace(`${SL_DIR_PREFIX}/`, '').split('/')[0]) - .filter((name, i, arr) => uuidPattern.test(name) && arr.indexOf(name) === i); + .filter((name, i, arr) => CONNECTION_ID_PATTERN.test(name) && arr.indexOf(name) === i) + .sort(); } catch { return []; } diff --git a/packages/context/src/sl/tools/sl-discover.tool.test.ts b/packages/context/src/sl/tools/sl-discover.tool.test.ts new file mode 100644 index 00000000..3277d45d --- /dev/null +++ b/packages/context/src/sl/tools/sl-discover.tool.test.ts @@ -0,0 +1,80 @@ +import { describe, expect, it, vi } from 'vitest'; +import type { ToolContext, ToolSession } from '../../tools/index.js'; +import { createTouchedSlSources } from '../../tools/index.js'; +import type { SemanticLayerSource } from '../types.js'; +import { SlDiscoverTool } from './sl-discover.tool.js'; + +function makeTool() { + const semanticLayerService = { + listConnectionIdsWithNames: vi.fn(async () => [] as Array<{ id: string; name: string; connectionType: string }>), + loadAllSources: vi.fn(async () => [] as SemanticLayerSource[]), + }; + const slSearchService = { + search: vi.fn(async () => []), + }; + const tool = new SlDiscoverTool( + { + semanticLayerService: semanticLayerService as never, + slSearchService: slSearchService as never, + authorResolver: { resolve: vi.fn() }, + }, + { maxSources: 25, minRrfScore: 0, maxDetailedSources: 5 }, + ); + return { tool, semanticLayerService, slSearchService }; +} + +function makeContext(overrides: Partial = {}): ToolContext { + return { + sourceId: 'src', + messageId: 'msg', + userId: 'user', + ...overrides, + }; +} + +function makeSession(semanticLayerService: Record): ToolSession { + return { + connectionId: 'dbt-main', + isWorktreeScoped: true, + preHead: 'base', + touchedSlSources: createTouchedSlSources(), + actions: [], + semanticLayerService: semanticLayerService as never, + wikiService: {} as never, + configService: {} as never, + gitService: {} as never, + }; +} + +describe('SlDiscoverTool - session-scoped reads', () => { + it('discovers sources through context.session.semanticLayerService when a session is present', async () => { + const { tool, semanticLayerService } = makeTool(); + const sessionSemanticLayerService = { + listConnectionIdsWithNames: vi.fn().mockResolvedValue([ + { id: 'warehouse', name: 'warehouse', connectionType: 'postgres' }, + ]), + loadAllSources: vi.fn().mockResolvedValue([ + { + name: 'orders', + table: 'public.orders', + grain: ['order_id'], + columns: [{ name: 'order_id', type: 'string' }], + measures: [], + joins: [], + }, + ]), + }; + + const result = await tool.call({}, makeContext({ session: makeSession(sessionSemanticLayerService) })); + + expect(result.structured.totalSources).toBe(1); + expect(result.structured.sources[0]).toMatchObject({ + connectionId: 'warehouse', + name: 'orders', + columnCount: 1, + }); + expect(sessionSemanticLayerService.listConnectionIdsWithNames).toHaveBeenCalled(); + expect(sessionSemanticLayerService.loadAllSources).toHaveBeenCalledWith('warehouse'); + expect(semanticLayerService.listConnectionIdsWithNames).not.toHaveBeenCalled(); + }); +}); diff --git a/packages/context/src/sl/tools/sl-discover.tool.ts b/packages/context/src/sl/tools/sl-discover.tool.ts index ed7c1854..97426b40 100644 --- a/packages/context/src/sl/tools/sl-discover.tool.ts +++ b/packages/context/src/sl/tools/sl-discover.tool.ts @@ -1,5 +1,6 @@ import { z } from 'zod'; import { DEFAULT_PRIORITY, resolveDescription } from '../descriptions.js'; +import type { SemanticLayerService } from '../semantic-layer.service.js'; import type { SemanticLayerSource } from '../types.js'; import type { ToolContext, ToolOutput } from '../../tools/index.js'; import { BaseSemanticLayerTool, type BaseSemanticLayerToolDeps } from './base-semantic-layer.tool.js'; @@ -66,13 +67,14 @@ Use this to understand what data is available before writing a semantic_query. return slDiscoverInputSchema; } - async call(input: SlDiscoverInput, _context: ToolContext): Promise> { + async call(input: SlDiscoverInput, context: ToolContext): Promise> { const { query, sourceName } = input; + const semanticLayerService = context.session?.semanticLayerService ?? this.semanticLayerService; // Resolve connectionId: use provided value, or auto-detect let connectionId = input.connectionId; if (!connectionId) { - const connections = await this.semanticLayerService.listConnectionIdsWithNames(); + const connections = await semanticLayerService.listConnectionIdsWithNames(); if (connections.length === 0) { return { markdown: 'No semantic layer sources found. Run a schema scan first.', @@ -92,14 +94,14 @@ Use this to understand what data is available before writing a semantic_query. structured: { sources: [], totalSources: 0 }, }; } - return this.discoverAcrossConnections(connections, query); + return this.discoverAcrossConnections(semanticLayerService, connections, query); } } // If inspecting a specific source — show the SL interface (columns, measures, joins) // without the raw SQL. Use `sl_read_source` to see the full YAML including SQL. if (sourceName) { - const sources = await this.semanticLayerService.loadAllSources(connectionId); + const sources = await semanticLayerService.loadAllSources(connectionId); const source = sources.find((s) => s.name === sourceName); if (!source) { return { @@ -136,19 +138,20 @@ Use this to understand what data is available before writing a semantic_query. } // Single connection: list all sources - const connections = await this.semanticLayerService.listConnectionIdsWithNames(); + const connections = await semanticLayerService.listConnectionIdsWithNames(); const connInfo = connections.find((c) => c.id === connectionId); - return this.discoverForConnection(connectionId, connInfo?.name ?? connectionId, query); + return this.discoverForConnection(semanticLayerService, connectionId, connInfo?.name ?? connectionId, query); } private async discoverAcrossConnections( + semanticLayerService: SemanticLayerService, connections: Array<{ id: string; name: string; connectionType: string }>, query?: string, ): Promise> { // Load sources from all connections in parallel const results = await Promise.all( connections.map(async (conn) => { - const sources = await this.semanticLayerService.loadAllSources(conn.id); + const sources = await semanticLayerService.loadAllSources(conn.id); let filtered = sources; if (query) { filtered = await this.filterByQuery(conn.id, sources, query); @@ -205,11 +208,12 @@ Use this to understand what data is available before writing a semantic_query. } private async discoverForConnection( + semanticLayerService: SemanticLayerService, connectionId: string, connectionName: string, query?: string, ): Promise> { - const sources = await this.semanticLayerService.loadAllSources(connectionId); + const sources = await semanticLayerService.loadAllSources(connectionId); if (sources.length === 0) { return {