mirror of
https://github.com/Kaelio/ktx.git
synced 2026-06-10 08:05:14 +02:00
feat(context): emit mcp query progress stages
This commit is contained in:
parent
3ef46185d4
commit
fabc1e9e35
4 changed files with 102 additions and 6 deletions
|
|
@ -244,6 +244,50 @@ describe('createLocalProjectMcpContextPorts', () => {
|
|||
expect(connector.cleanup).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('emits sql_execution progress stages from local MCP ports', async () => {
|
||||
const project = await initKtxProject({ projectDir: tempDir });
|
||||
project.config.connections.warehouse = {
|
||||
driver: 'postgres',
|
||||
url: 'env:DATABASE_URL',
|
||||
};
|
||||
const connector = testConnector(testSnapshot(), {
|
||||
headers: ['id'],
|
||||
headerTypes: ['integer'],
|
||||
rows: [[1]],
|
||||
totalRows: 1,
|
||||
rowCount: 1,
|
||||
});
|
||||
const createConnector = vi.fn(async () => connector);
|
||||
const sqlAnalysis = {
|
||||
analyzeForFingerprint: vi.fn(),
|
||||
analyzeBatch: vi.fn(),
|
||||
validateReadOnly: vi.fn(async () => ({ ok: true, error: null })),
|
||||
};
|
||||
const progress: Array<{ progress: number; message: string }> = [];
|
||||
const ports = createLocalProjectMcpContextPorts(project, {
|
||||
sqlAnalysis,
|
||||
localScan: {
|
||||
createConnector,
|
||||
},
|
||||
});
|
||||
|
||||
const result = await ports.sqlExecution?.execute(
|
||||
{ connectionId: 'warehouse', sql: 'select id from public.orders', maxRows: 5 },
|
||||
{
|
||||
onProgress: (event) => {
|
||||
progress.push({ progress: event.progress, message: event.message });
|
||||
},
|
||||
},
|
||||
);
|
||||
|
||||
expect(result?.rowCount).toBe(1);
|
||||
expect(progress).toEqual([
|
||||
{ progress: 0, message: 'Validating SQL' },
|
||||
{ progress: 0.3, message: 'Executing' },
|
||||
{ progress: 1, message: 'Fetched 1 rows' },
|
||||
]);
|
||||
});
|
||||
|
||||
it('rejects MCP SQL before connector execution when parser validation fails', async () => {
|
||||
const project = await initKtxProject({ projectDir: tempDir });
|
||||
project.config.connections.warehouse = {
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ import { createKtxDiscoverDataService } from '../search/index.js';
|
|||
import type { SqlAnalysisDialect, SqlAnalysisPort } from '../sql-analysis/index.js';
|
||||
import { compileLocalSlQuery, createKtxDictionarySearchService } from '../sl/index.js';
|
||||
import { readLocalKnowledgePage, searchLocalKnowledgePages } from '../wiki/local-knowledge.js';
|
||||
import type { KtxMcpContextPorts, KtxSqlExecutionResponse } from './types.js';
|
||||
import type { KtxMcpContextPorts, KtxMcpProgressCallback, KtxSqlExecutionResponse } from './types.js';
|
||||
|
||||
interface CreateLocalProjectMcpContextPortsOptions {
|
||||
semanticLayerCompute?: KtxSemanticLayerComputePort;
|
||||
|
|
@ -83,7 +83,9 @@ async function executeValidatedReadOnlySql(
|
|||
project: KtxLocalProject,
|
||||
options: CreateLocalProjectMcpContextPortsOptions,
|
||||
input: { connectionId: string; sql: string; maxRows: number },
|
||||
onProgress?: KtxMcpProgressCallback,
|
||||
): Promise<KtxSqlExecutionResponse> {
|
||||
await onProgress?.({ progress: 0, message: 'Validating SQL' });
|
||||
const connectionId = assertSafeConnectionId(input.connectionId);
|
||||
const connection = project.config.connections[connectionId];
|
||||
if (!connection) {
|
||||
|
|
@ -107,6 +109,7 @@ async function executeValidatedReadOnlySql(
|
|||
if (!connector.capabilities.readOnlySql || !connector.executeReadOnly) {
|
||||
throw new Error(`Connection "${connectionId}" does not support read-only SQL execution.`);
|
||||
}
|
||||
await onProgress?.({ progress: 0.3, message: 'Executing' });
|
||||
const result = await connector.executeReadOnly(
|
||||
{
|
||||
connectionId,
|
||||
|
|
@ -115,12 +118,14 @@ async function executeValidatedReadOnlySql(
|
|||
},
|
||||
{ runId: 'mcp-sql-execution' },
|
||||
);
|
||||
return {
|
||||
const response = {
|
||||
headers: result.headers,
|
||||
...(result.headerTypes ? { headerTypes: result.headerTypes } : {}),
|
||||
rows: result.rows,
|
||||
rowCount: result.rowCount ?? result.rows.length,
|
||||
};
|
||||
await onProgress?.({ progress: 1, message: `Fetched ${response.rowCount} rows` });
|
||||
return response;
|
||||
} finally {
|
||||
await cleanupConnector(connector);
|
||||
}
|
||||
|
|
@ -194,7 +199,7 @@ export function createLocalProjectMcpContextPorts(
|
|||
return null;
|
||||
}
|
||||
},
|
||||
async query(input) {
|
||||
async query(input, executionOptions) {
|
||||
if (!options.semanticLayerCompute) {
|
||||
throw new Error('sl_query requires a semantic-layer query adapter.');
|
||||
}
|
||||
|
|
@ -205,6 +210,7 @@ export function createLocalProjectMcpContextPorts(
|
|||
execute: Boolean(options.queryExecutor),
|
||||
maxRows: input.query.limit,
|
||||
queryExecutor: options.queryExecutor,
|
||||
onProgress: executionOptions?.onProgress,
|
||||
});
|
||||
},
|
||||
},
|
||||
|
|
@ -227,8 +233,8 @@ export function createLocalProjectMcpContextPorts(
|
|||
|
||||
if (options.sqlAnalysis && options.localScan?.createConnector) {
|
||||
ports.sqlExecution = {
|
||||
async execute(input) {
|
||||
return executeValidatedReadOnlySql(project, options, input);
|
||||
async execute(input, executionOptions) {
|
||||
return executeValidatedReadOnlySql(project, options, input, executionOptions?.onProgress);
|
||||
},
|
||||
};
|
||||
}
|
||||
|
|
|
|||
|
|
@ -276,6 +276,43 @@ grain: []
|
|||
});
|
||||
});
|
||||
|
||||
it('emits progress while compiling and executing a local semantic-layer query', async () => {
|
||||
const progress: Array<{ progress: number; message: string }> = [];
|
||||
const queryExecutor = {
|
||||
execute: vi.fn(async () => ({
|
||||
headers: ['status', 'order_count'],
|
||||
rows: [['paid', 2]],
|
||||
totalRows: 1,
|
||||
command: 'SELECT',
|
||||
rowCount: 1,
|
||||
})),
|
||||
};
|
||||
|
||||
const result = await compileLocalSlQuery(project, {
|
||||
connectionId: 'warehouse',
|
||||
query: {
|
||||
measures: ['orders.order_count'],
|
||||
dimensions: ['orders.status'],
|
||||
limit: 25,
|
||||
},
|
||||
compute,
|
||||
execute: true,
|
||||
maxRows: 10,
|
||||
queryExecutor,
|
||||
onProgress: (event) => {
|
||||
progress.push({ progress: event.progress, message: event.message });
|
||||
},
|
||||
});
|
||||
|
||||
expect(result.totalRows).toBe(1);
|
||||
expect(progress).toEqual([
|
||||
{ progress: 0, message: 'Compiling query' },
|
||||
{ progress: 0.3, message: 'Generating SQL' },
|
||||
{ progress: 0.6, message: 'Executing' },
|
||||
{ progress: 1, message: 'Fetched 1 rows' },
|
||||
]);
|
||||
});
|
||||
|
||||
it('requires a query executor for executed mode', async () => {
|
||||
await expect(
|
||||
compileLocalSlQuery(project, {
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
import type { KtxSqlQueryExecutorPort } from '../connections/index.js';
|
||||
import type { KtxSemanticLayerComputePort } from '../daemon/index.js';
|
||||
import type { KtxMcpProgressCallback } from '../mcp/types.js';
|
||||
import type { KtxLocalProject } from '../project/index.js';
|
||||
import { loadLocalSlSourceRecords } from './local-sl.js';
|
||||
import { toResolvedWire } from './semantic-layer.service.js';
|
||||
|
|
@ -15,6 +16,7 @@ export interface CompileLocalSlQueryOptions {
|
|||
execute?: boolean;
|
||||
maxRows?: number;
|
||||
queryExecutor?: KtxSqlQueryExecutorPort;
|
||||
onProgress?: KtxMcpProgressCallback;
|
||||
}
|
||||
|
||||
export interface CompileLocalSlQueryResult extends SemanticLayerQueryExecutionResult {
|
||||
|
|
@ -92,15 +94,20 @@ export async function compileLocalSlQuery(
|
|||
project: KtxLocalProject,
|
||||
options: CompileLocalSlQueryOptions,
|
||||
): Promise<CompileLocalSlQueryResult> {
|
||||
await options.onProgress?.({ progress: 0, message: 'Compiling query' });
|
||||
const connectionId = resolveLocalConnectionId(project, options.connectionId);
|
||||
const dialect = dialectForDriver(project.config.connections[connectionId]?.driver);
|
||||
const sources = await loadComputableSources(project, connectionId);
|
||||
|
||||
await options.onProgress?.({ progress: 0.3, message: 'Generating SQL' });
|
||||
const response = await options.compute.query({
|
||||
sources: await loadComputableSources(project, connectionId),
|
||||
sources,
|
||||
dialect,
|
||||
query: options.query,
|
||||
});
|
||||
|
||||
if (!options.execute) {
|
||||
await options.onProgress?.({ progress: 1, message: 'Fetched 0 rows' });
|
||||
return {
|
||||
connectionId,
|
||||
dialect: response.dialect,
|
||||
|
|
@ -123,6 +130,7 @@ export async function compileLocalSlQuery(
|
|||
}
|
||||
|
||||
const maxRows = options.maxRows ?? options.query.limit;
|
||||
await options.onProgress?.({ progress: 0.6, message: 'Executing' });
|
||||
const execution = await options.queryExecutor.execute({
|
||||
connectionId,
|
||||
projectDir: project.projectDir,
|
||||
|
|
@ -130,6 +138,7 @@ export async function compileLocalSlQuery(
|
|||
sql: response.sql,
|
||||
maxRows,
|
||||
});
|
||||
await options.onProgress?.({ progress: 1, message: `Fetched ${execution.totalRows} rows` });
|
||||
|
||||
return {
|
||||
connectionId,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue