import { randomUUID } from 'node:crypto'; import type { ToolAnnotations } from '@modelcontextprotocol/sdk/types.js'; import { z } from 'zod'; import type { KtxCliIo } from '../../cli-runtime.js'; import type { MemoryAgentInput } from '../../context/memory/types.js'; import { emitTelemetryEvent, mcpTelemetrySampleRate, shouldEmitMcpTelemetry } from '../../telemetry/index.js'; import { scrubErrorClass } from '../../telemetry/scrubber.js'; import type { KtxMcpContextPorts, KtxMcpProgressCallback, KtxMcpServerLike, KtxMcpToolHandlerContext, KtxMcpToolResult, KtxMcpUserContext, NonArrayObject, } from './types.js'; export interface RegisterKtxContextToolsDeps { server: KtxMcpServerLike; ports: KtxMcpContextPorts; userContext: KtxMcpUserContext; projectDir?: string; io?: KtxCliIo; } const connectionIdSchema = z.string().min(1); const unknownRecordSchema = z.record(z.string(), z.unknown()); const tableRefSchema = z.object({ catalog: z.string().nullable(), db: z.string().nullable(), name: z.string(), }); const toolAnnotations = { connection_list: { title: 'Connection List', readOnlyHint: true, idempotentHint: true, openWorldHint: false }, discover_data: { title: 'Discover Data', readOnlyHint: true, openWorldHint: false }, wiki_search: { title: 'Wiki Search', readOnlyHint: true, openWorldHint: false }, wiki_read: { title: 'Wiki Read', readOnlyHint: true, idempotentHint: true, openWorldHint: false }, entity_details: { title: 'Entity Details', readOnlyHint: true, idempotentHint: true, openWorldHint: false }, dictionary_search: { title: 'Dictionary Search', readOnlyHint: true, openWorldHint: false }, sl_read_source: { title: 'Semantic Layer Read Source', readOnlyHint: true, idempotentHint: true, openWorldHint: false }, sl_query: { title: 'Semantic Layer Query', readOnlyHint: true, openWorldHint: false }, sql_execution: { title: 'SQL Execution', readOnlyHint: true, openWorldHint: false }, memory_ingest: { title: 'Memory Ingest', destructiveHint: true, openWorldHint: false }, memory_ingest_status: { title: 'Memory Ingest Status', readOnlyHint: true, openWorldHint: false }, } satisfies Record; const toolDescriptions = { connection_list: 'List configured read-only data connections available to this KTX project. Use this before connection-scoped tools when the project may have multiple warehouses.', discover_data: 'Search across KTX wiki pages, semantic-layer sources, measures, dimensions, raw tables, and columns. Example: discover_data({ query: "monthly orders by customer", connectionId: "warehouse", kinds: ["sl_source", "table"] }).', wiki_search: 'Search KTX wiki pages for reusable business context. Example: wiki_search({ query: "revenue recognition", limit: 5 }).', wiki_read: 'Read a KTX wiki page by key returned from wiki_search. Example: wiki_read({ key: "global/revenue" }).', entity_details: 'Read table and column metadata from the latest live-database scan snapshot. Example: entity_details({ connectionId: "warehouse", entities: [{ table: { catalog: null, db: "public", name: "orders" }, columns: ["id"] }] }).', dictionary_search: 'Search profile-sampled warehouse values to locate likely source columns for business values. Example: dictionary_search({ values: ["Acme Corp"], connectionId: "warehouse" }).', sl_read_source: 'Read a semantic-layer YAML source by connection id and source name. Example: sl_read_source({ connectionId: "warehouse", sourceName: "orders" }).', sl_query: 'Execute a semantic-layer query and return rows, headers, generated SQL, and plan details. Example: sl_query({ connectionId: "warehouse", measures: ["orders.order_count"], dimensions: [{ field: "orders.created_at", granularity: "month" }] }).', sql_execution: 'Execute one parser-validated read-only SQL query against a configured KTX connection. Example: sql_execution({ connectionId: "warehouse", sql: "select count(*) from public.orders", maxRows: 100 }).', memory_ingest: 'Ingest free-form markdown knowledge into durable KTX memory. Use this for business rules, metric definitions, schema gotchas, recurring findings, or explicit user requests to remember something. Example: memory_ingest({ connectionId: "warehouse", content: "ARR is reported in cents in this warehouse." }).', memory_ingest_status: 'Read the current or final status for a memory ingest run. Example: memory_ingest_status({ runId: "memory-run-1" }).', } satisfies Record; const connectionListSchema = z.object({}); const knowledgeSearchSchema = z.object({ query: z.string().min(1).describe('Natural-language wiki search query, e.g. "revenue recognition policy".'), limit: z.number().int().min(1).max(50).default(10).describe('Maximum wiki pages to return. Defaults to 10.'), }); const knowledgeReadSchema = z.object({ key: z.string().min(1).describe('Wiki page key returned by wiki_search, e.g. "global/revenue".'), }); const slReadSourceSchema = z.object({ connectionId: connectionIdSchema.describe('Connection id that owns the semantic-layer source.'), sourceName: z.string().min(1).describe('Semantic-layer source name without ".yaml", e.g. "orders".'), }); const slQueryMeasureSchema = z.union([ z.string().describe('Semantic-layer measure key, e.g. "orders.order_count".'), z.object({ expr: z.string().min(1).describe('Ad hoc aggregate expression, e.g. "sum(orders.amount)".'), name: z.string().min(1).describe('Alias for the ad hoc measure, e.g. "gross_revenue".'), }), ]); const slQueryDimensionSchema = z.object({ field: z.string().min(1).describe('Dimension to group by, e.g. "orders.created_at" or "orders.status".'), granularity: z .string() .min(1) .optional() .describe('Time grain for time dimensions: day, week, month, quarter, or year.'), }); const slQueryOrderBySchema = z.object({ field: z .string() .min(1) .describe( 'Field/measure/dimension id to order by, e.g. "orders.created_at", a dimension key like "mart_nrr_quarterly.quarter_label", or a measure alias.', ), direction: z .enum(['asc', 'desc']) .default('asc') .describe('Sort direction: "asc" or "desc". Defaults to "asc".'), }); const slQuerySchema = z.object({ connectionId: connectionIdSchema .optional() .describe('Connection id to query. Omit only when the project has exactly one configured connection.'), measures: z.array(slQueryMeasureSchema).min(1).describe('Measures to select. Use semantic-layer keys when available.'), dimensions: z .array(slQueryDimensionSchema) .default([]) .describe('Dimensions to group by. Use {field, granularity?} entries.'), filters: z .array(z.string().describe('Semantic-layer filter expression, e.g. "orders.status = paid".')) .default([]) .describe('Semantic-layer filter expressions to apply.'), segments: z .array(z.string().describe('Semantic-layer segment key to apply.')) .default([]) .describe('Semantic-layer segment keys to apply.'), order_by: z .array(slQueryOrderBySchema) .default([]) .describe('Sort clauses. Use {field, direction?} entries.'), limit: z.number().int().min(0).default(1000).describe('Maximum rows to return. Defaults to 1000.'), include_empty: z.boolean().default(true).describe('Whether to include empty dimension groups. Defaults to true.'), }); const entityDetailsTableRefSchema = z.object({ catalog: z.string().nullable().describe('Catalog/project/database. Use null when not applicable.'), db: z.string().nullable().describe('Schema/database/dataset. Use null when not applicable.'), name: z.string().min(1).describe('Table name.'), }); const entityDetailsSchema = z.object({ connectionId: connectionIdSchema.describe('Connection id whose latest scan snapshot should be read.'), entities: z .array( z.object({ table: z .union([z.string().min(1), entityDetailsTableRefSchema]) .describe('Table display string or canonical object ref.'), columns: z .array(z.string().min(1).describe('Column name to inspect.')) .optional() .describe('Optional column filter.'), }), ) .min(1) .max(20) .describe('Tables or columns to inspect. Maximum 20 entities.'), }); const dictionarySearchSchema = z.object({ values: z .array(z.string().min(1).describe('Business value to locate, e.g. "Acme Corp" or "enterprise".')) .min(1) .max(20) .describe('Values to search for in sampled warehouse dictionaries.'), connectionId: connectionIdSchema .optional() .describe('Optional connection id. Pass it when user intent pins a specific warehouse.'), }); const discoverDataKindSchema = z.enum(['wiki', 'sl_source', 'sl_measure', 'sl_dimension', 'table', 'column']); const discoverDataSchema = z.object({ query: z.string().min(1).describe('Natural-language discovery query, e.g. "monthly orders by customer".'), connectionId: connectionIdSchema .optional() .describe('Optional connection id. Pass it when user intent pins a specific warehouse.'), kinds: z.array(discoverDataKindSchema.describe('Reference kind to include.')).optional().describe('Optional kind filter.'), limit: z.number().int().min(1).max(50).default(15).optional().describe('Maximum refs to return. Defaults to 15.'), }); const sqlExecutionSchema = z.object({ connectionId: connectionIdSchema.describe('Connection id to execute against. Required for raw SQL.'), sql: z.string().min(1).describe('Parser-validated read-only SQL, e.g. "select count(*) from public.orders".'), maxRows: z.number().int().min(1).max(10_000).default(1000).optional().describe('Maximum rows to return. Defaults to 1000.'), }); const memoryIngestSchema = z.object({ content: z .string() .min(1) .describe( 'Free-form markdown to ingest. Include the knowledge itself plus any context (source, the user question, why this came up) that the memory agent should consider when triaging into wiki/SL.', ), connectionId: connectionIdSchema .optional() .describe( 'Scope this memory to a specific connection. Required when the knowledge is warehouse-specific, including measure definitions, schema gotchas, or anything tied to a particular warehouse. Omit only for global wiki knowledge.', ), }); const memoryIngestStatusSchema = z.object({ runId: z.string().min(1).describe('The memory ingest run id returned by memory_ingest.'), }); const connectionListOutputSchema = z.object({ connections: z.array( z.object({ id: z.string(), name: z.string(), connectionType: z.string(), }), ), }); const wikiSearchOutputSchema = z.object({ results: z.array( z.object({ key: z.string(), path: z.string(), scope: z.enum(['GLOBAL', 'USER']), summary: z.string(), score: z.number(), matchReasons: z.array(z.string()).optional(), lanes: z .array( z.object({ lane: z.string(), status: z.string(), requestedCandidatePoolLimit: z.number(), effectiveCandidatePoolLimit: z.number(), returnedCandidateCount: z.number(), weight: z.number(), reason: z.string().optional(), }), ) .optional(), }), ), totalFound: z.number(), }); const wikiReadOutputSchema = z.object({ key: z.string(), summary: z.string(), content: z.string(), scope: z.enum(['GLOBAL', 'USER']), tags: z.array(z.string()).optional(), refs: z.array(z.string()).optional(), slRefs: z.array(z.string()).optional(), }); const slReadSourceOutputSchema = z.object({ sourceName: z.string(), yaml: z.string(), }); const slQueryOutputSchema = z.object({ connectionId: z.string().optional(), dialect: z.string().optional(), sql: z.string(), headers: z.array(z.string()), rows: z.array(z.array(z.unknown())), totalRows: z.number(), plan: unknownRecordSchema.optional(), }); const entityDetailsSnapshotOutputSchema = z.object({ syncId: z.string(), extractedAt: z.string(), scanRunId: z.string().nullable(), }); const entityDetailsColumnOutputSchema = z.object({ name: z.string(), nativeType: z.string(), normalizedType: z.string(), dimensionType: z.enum(['time', 'string', 'number', 'boolean']), nullable: z.boolean(), primaryKey: z.boolean(), comment: z.string().nullable(), }); const entityDetailsForeignKeyOutputSchema = z.object({ fromColumn: z.string(), toCatalog: z.string().nullable(), toDb: z.string().nullable(), toTable: z.string(), toColumn: z.string(), constraintName: z.string().nullable(), }); const entityDetailsOutputSchema = z.object({ results: z.array( z.union([ z.object({ ok: z.literal(true), connectionId: z.string(), tableRef: tableRefSchema, display: z.string(), kind: z.enum(['table', 'view', 'external', 'event_stream']), comment: z.string().nullable(), estimatedRows: z.number().nullable(), columns: z.array(entityDetailsColumnOutputSchema), foreignKeys: z.array(entityDetailsForeignKeyOutputSchema), snapshot: entityDetailsSnapshotOutputSchema, }), z.object({ ok: z.literal(false), connectionId: z.string(), table: z.union([z.string(), tableRefSchema]), snapshot: entityDetailsSnapshotOutputSchema.optional(), error: z.object({ code: z.enum(['scan_missing', 'table_not_found', 'ambiguous_table', 'column_not_found']), message: z.string(), candidates: z .union([z.array(z.object({ tableRef: tableRefSchema, display: z.string() })), z.array(z.string())]) .optional(), }), }), ]), ), }); const dictionarySearchOutputSchema = z.object({ searched: z.array( z.object({ connectionId: z.string(), coverage: z.object({ sampledRows: z.number().nullable(), valuesPerColumn: z.number().nullable(), profiledColumns: z.number(), syncId: z.string().nullable(), profiledAt: z.string().nullable(), }), status: z.enum(['ready', 'no_profile_artifact', 'no_candidate_columns']), }), ), results: z.array( z.object({ value: z.string(), matches: z.array( z.object({ connectionId: z.string(), sourceName: z.string(), columnName: z.string(), matchedValue: z.string(), cardinality: z.number().nullable(), }), ), misses: z.array( z.object({ connectionId: z.string(), reason: z.enum(['no_profile_artifact', 'no_candidate_columns', 'value_not_in_sample']), }), ), }), ), }); const discoverDataOutputSchema = z.object({ refs: z.array( z.object({ kind: discoverDataKindSchema, id: z.string(), score: z.number(), summary: z.string().nullable(), snippet: z.string().nullable(), matchedOn: z.enum(['name', 'display', 'description', 'comment', 'expr', 'sample_value', 'body']), connectionId: z.string().optional(), tableRef: tableRefSchema.optional(), columnName: z.string().optional(), }), ), }); const sqlExecutionOutputSchema = z.object({ headers: z.array(z.string()), headerTypes: z.array(z.string()).optional(), rows: z.array(z.array(z.unknown())), rowCount: z.number(), }); const memoryIngestOutputSchema = z.object({ runId: z.string(), }); const memoryIngestStatusOutputSchema = z.object({ runId: z.string(), status: z.enum(['running', 'done', 'error']), stage: z.string(), done: z.boolean(), captured: z.object({ wiki: z.array(z.string()), sl: z.array(z.string()), xrefs: z.array(z.string()), }), error: z.string().nullable(), commitHash: z.string().nullable(), skillsLoaded: z.array(z.string()), signalDetected: z.boolean(), }); /** @internal */ export function jsonToolResult(structuredContent: T): KtxMcpToolResult { return { content: [{ type: 'text', text: JSON.stringify(structuredContent, null, 2) }], structuredContent, }; } function jsonErrorToolResult(text: string): KtxMcpToolResult> { return { content: [{ type: 'text', text }], isError: true, }; } function formatToolError(error: unknown): string { if (error instanceof z.ZodError) { return error.issues .map((issue) => `${issue.path.length > 0 ? issue.path.join('.') : ''}: ${issue.message}`) .join('\n'); } return error instanceof Error ? error.message : String(error); } function mcpProgressCallback(context?: KtxMcpToolHandlerContext): KtxMcpProgressCallback | undefined { const progressToken = context?._meta?.progressToken; if (progressToken === undefined || !context?.sendNotification) { return undefined; } return async (event) => { await context.sendNotification?.({ method: 'notifications/progress', params: { progressToken, progress: event.progress, ...(event.total !== undefined ? { total: event.total } : {}), message: event.message, }, }); }; } function registerParsedTool( server: KtxMcpServerLike, name: string, config: { title: string; description: string; inputSchema: unknown; outputSchema: unknown; annotations: ToolAnnotations; }, schema: TSchema, handler: (input: z.infer, context?: KtxMcpToolHandlerContext) => Promise, ): void { server.registerTool(name, config, async (input, context) => { try { return await handler(schema.parse(input), context); } catch (error) { return jsonErrorToolResult(formatToolError(error)); } }); } function instrumentMcpServer( server: KtxMcpServerLike, telemetry: { projectDir?: string; io?: KtxCliIo }, ): KtxMcpServerLike { return { registerTool(name, config, handler) { server.registerTool(name, config, async (input, context) => { const startedAt = performance.now(); try { const result = await handler(input, context); if (telemetry.io && telemetry.projectDir && shouldEmitMcpTelemetry()) { const isError = typeof result === 'object' && result !== null && 'isError' in result && result.isError === true; await emitTelemetryEvent({ name: 'mcp_request_completed', projectDir: telemetry.projectDir, io: telemetry.io, fields: { toolName: name, outcome: isError ? 'error' : 'ok', durationMs: Math.max(0, performance.now() - startedAt), sampleRate: mcpTelemetrySampleRate(), }, }); } return result; } catch (error) { if (telemetry.io && telemetry.projectDir && shouldEmitMcpTelemetry()) { const errorClass = scrubErrorClass(error); await emitTelemetryEvent({ name: 'mcp_request_completed', projectDir: telemetry.projectDir, io: telemetry.io, fields: { toolName: name, outcome: 'error', ...(errorClass ? { errorClass } : {}), durationMs: Math.max(0, performance.now() - startedAt), sampleRate: mcpTelemetrySampleRate(), }, }); } throw error; } }); }, }; } export function registerKtxContextTools(deps: RegisterKtxContextToolsDeps): void { const { ports, userContext } = deps; const server = instrumentMcpServer(deps.server, { projectDir: deps.projectDir, io: deps.io }); if (ports.connections) { const connections = ports.connections; registerParsedTool( server, 'connection_list', { title: toolAnnotations.connection_list.title!, description: toolDescriptions.connection_list, inputSchema: connectionListSchema.shape, outputSchema: connectionListOutputSchema, annotations: toolAnnotations.connection_list, }, connectionListSchema, async () => jsonToolResult({ connections: await connections.list() }), ); } if (ports.knowledge) { const knowledge = ports.knowledge; registerParsedTool( server, 'wiki_search', { title: toolAnnotations.wiki_search.title!, description: toolDescriptions.wiki_search, inputSchema: knowledgeSearchSchema.shape, outputSchema: wikiSearchOutputSchema, annotations: toolAnnotations.wiki_search, }, knowledgeSearchSchema, async (input) => jsonToolResult( await knowledge.search({ userId: userContext.userId, query: input.query, limit: input.limit, }), ), ); registerParsedTool( server, 'wiki_read', { title: toolAnnotations.wiki_read.title!, description: toolDescriptions.wiki_read, inputSchema: knowledgeReadSchema.shape, outputSchema: wikiReadOutputSchema, annotations: toolAnnotations.wiki_read, }, knowledgeReadSchema, async (input) => { const page = await knowledge.read({ userId: userContext.userId, key: input.key }); return page ? jsonToolResult(page) : jsonErrorToolResult(`Wiki page "${input.key}" was not found.`); }, ); } if (ports.semanticLayer) { const semanticLayer = ports.semanticLayer; registerParsedTool( server, 'sl_read_source', { title: toolAnnotations.sl_read_source.title!, description: toolDescriptions.sl_read_source, inputSchema: slReadSourceSchema.shape, outputSchema: slReadSourceOutputSchema, annotations: toolAnnotations.sl_read_source, }, slReadSourceSchema, async (input) => { const source = await semanticLayer.readSource(input); return source ? jsonToolResult(source) : jsonErrorToolResult(`Semantic-layer source "${input.sourceName}" was not found.`); }, ); registerParsedTool( server, 'sl_query', { title: toolAnnotations.sl_query.title!, description: toolDescriptions.sl_query, inputSchema: slQuerySchema.shape, outputSchema: slQueryOutputSchema, annotations: toolAnnotations.sl_query, }, slQuerySchema, async (input, context) => { const onProgress = mcpProgressCallback(context); return jsonToolResult( await semanticLayer.query( { connectionId: input.connectionId, query: { measures: input.measures, dimensions: input.dimensions, filters: input.filters, segments: input.segments, order_by: input.order_by, limit: input.limit, include_empty: input.include_empty, }, }, onProgress ? { onProgress } : undefined, ), ); }, ); } if (ports.entityDetails) { const entityDetails = ports.entityDetails; registerParsedTool( server, 'entity_details', { title: toolAnnotations.entity_details.title!, description: toolDescriptions.entity_details, inputSchema: entityDetailsSchema.shape, outputSchema: entityDetailsOutputSchema, annotations: toolAnnotations.entity_details, }, entityDetailsSchema, async (input) => jsonToolResult(await entityDetails.read(input)), ); } if (ports.dictionarySearch) { const dictionarySearch = ports.dictionarySearch; registerParsedTool( server, 'dictionary_search', { title: toolAnnotations.dictionary_search.title!, description: toolDescriptions.dictionary_search, inputSchema: dictionarySearchSchema.shape, outputSchema: dictionarySearchOutputSchema, annotations: toolAnnotations.dictionary_search, }, dictionarySearchSchema, async (input) => jsonToolResult(await dictionarySearch.search(input)), ); } if (ports.discover) { const discover = ports.discover; registerParsedTool( server, 'discover_data', { title: toolAnnotations.discover_data.title!, description: toolDescriptions.discover_data, inputSchema: discoverDataSchema.shape, outputSchema: discoverDataOutputSchema, annotations: toolAnnotations.discover_data, }, discoverDataSchema, async (input) => jsonToolResult({ refs: await discover.search(input) }), ); } if (ports.sqlExecution) { const sqlExecution = ports.sqlExecution; registerParsedTool( server, 'sql_execution', { title: toolAnnotations.sql_execution.title!, description: toolDescriptions.sql_execution, inputSchema: sqlExecutionSchema.shape, outputSchema: sqlExecutionOutputSchema, annotations: toolAnnotations.sql_execution, }, sqlExecutionSchema, async (input, context) => { const onProgress = mcpProgressCallback(context); return jsonToolResult( await sqlExecution.execute( { connectionId: input.connectionId, sql: input.sql, maxRows: input.maxRows ?? 1000, }, onProgress ? { onProgress } : undefined, ), ); }, ); } if (ports.memoryIngest) { const memoryIngest = ports.memoryIngest; registerParsedTool( server, 'memory_ingest', { title: toolAnnotations.memory_ingest.title!, description: toolDescriptions.memory_ingest, inputSchema: memoryIngestSchema.shape, outputSchema: memoryIngestOutputSchema, annotations: toolAnnotations.memory_ingest, }, memoryIngestSchema, async (input) => { const ingestInput: MemoryAgentInput = { userId: userContext.userId, chatId: `mcp-${randomUUID()}`, userMessage: 'Ingest external knowledge into KTX memory.', assistantMessage: input.content, connectionId: input.connectionId, sourceType: 'external_ingest', }; return jsonToolResult(await memoryIngest.ingest(ingestInput)); }, ); registerParsedTool( server, 'memory_ingest_status', { title: toolAnnotations.memory_ingest_status.title!, description: toolDescriptions.memory_ingest_status, inputSchema: memoryIngestStatusSchema.shape, outputSchema: memoryIngestStatusOutputSchema, annotations: toolAnnotations.memory_ingest_status, }, memoryIngestStatusSchema, async (input) => { const status = await memoryIngest.status(input.runId); return status ? jsonToolResult(status) : jsonErrorToolResult(`Memory ingest run "${input.runId}" was not found.`); }, ); } }