mirror of
https://github.com/Kaelio/ktx.git
synced 2026-06-10 08:05:14 +02:00
feat(context): add ingest SQL verification tool
This commit is contained in:
parent
52ac0e7664
commit
a4b0e37254
2 changed files with 156 additions and 0 deletions
|
|
@ -0,0 +1,54 @@
|
|||
import { describe, expect, it, vi } from 'vitest';
|
||||
import type { SlConnectionCatalogPort } from '../../../sl/index.js';
|
||||
import type { ToolContext } from '../../../tools/index.js';
|
||||
import { SqlExecutionTool } from './sql-execution.tool.js';
|
||||
|
||||
describe('SqlExecutionTool', () => {
|
||||
const connections = {
|
||||
executeQuery: vi.fn(),
|
||||
} as unknown as SlConnectionCatalogPort & { executeQuery: ReturnType<typeof vi.fn> };
|
||||
const tool = new SqlExecutionTool(connections);
|
||||
const context: ToolContext = {
|
||||
sourceId: 'ingest',
|
||||
messageId: 'm1',
|
||||
userId: 'system',
|
||||
session: { allowedConnectionNames: new Set(['warehouse']) } as any,
|
||||
};
|
||||
|
||||
it('wraps read-only SQL with a capped row limit', async () => {
|
||||
connections.executeQuery.mockResolvedValue({ headers: ['status'], rows: [['paid']], totalRows: 1 });
|
||||
|
||||
const result = await tool.call(
|
||||
{ connectionName: 'warehouse', sql: 'select status from public.orders', rowLimit: 5 },
|
||||
context,
|
||||
);
|
||||
|
||||
expect(connections.executeQuery).toHaveBeenCalledWith(
|
||||
'warehouse',
|
||||
'select * from (select status from public.orders) as ktx_query_result limit 5',
|
||||
);
|
||||
expect(result.markdown).toContain('| status |');
|
||||
expect(result.structured.wrappedSql).toContain('limit 5');
|
||||
});
|
||||
|
||||
it.each(['insert into x values (1)', 'drop table x', 'vacuum'])('rejects mutating SQL: %s', async (sql) => {
|
||||
connections.executeQuery.mockClear();
|
||||
|
||||
const result = await tool.call({ connectionName: 'warehouse', sql }, context);
|
||||
|
||||
expect(result.markdown).toContain('Only read-only SELECT/WITH queries can be executed locally.');
|
||||
expect(connections.executeQuery).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('surfaces connector errors verbatim', async () => {
|
||||
connections.executeQuery.mockRejectedValue(new Error('relation "orbit_analytics.customer" does not exist'));
|
||||
|
||||
const result = await tool.call(
|
||||
{ connectionName: 'warehouse', sql: 'select 1 from orbit_analytics.customer', rowLimit: 1 },
|
||||
context,
|
||||
);
|
||||
|
||||
expect(result.markdown).toContain('relation "orbit_analytics.customer" does not exist');
|
||||
expect(result.structured.error).toContain('relation "orbit_analytics.customer" does not exist');
|
||||
});
|
||||
});
|
||||
|
|
@ -0,0 +1,102 @@
|
|||
import { z } from 'zod';
|
||||
import { assertReadOnlySql, limitSqlForExecution } from '../../../connections/index.js';
|
||||
import type { SlConnectionCatalogPort } from '../../../sl/index.js';
|
||||
import { BaseTool, type ToolContext, type ToolOutput } from '../../../tools/index.js';
|
||||
|
||||
const sqlExecutionInputSchema = z.object({
|
||||
connectionName: z.string().regex(/^[a-zA-Z0-9][a-zA-Z0-9_-]*$/),
|
||||
sql: z.string().min(1),
|
||||
rowLimit: z.number().int().positive().max(1000).optional().default(100),
|
||||
});
|
||||
|
||||
type SqlExecutionInput = z.infer<typeof sqlExecutionInputSchema>;
|
||||
|
||||
export interface SqlExecutionStructured {
|
||||
headers: string[];
|
||||
rows: unknown[][];
|
||||
rowCount: number;
|
||||
truncated: boolean;
|
||||
sql: string;
|
||||
wrappedSql: string;
|
||||
error?: string;
|
||||
}
|
||||
|
||||
function markdownTable(headers: string[], rows: unknown[][], totalRows: number): string {
|
||||
if (headers.length === 0) {
|
||||
return rows.length === 0 ? 'Query returned no rows.' : JSON.stringify(rows.slice(0, 20));
|
||||
}
|
||||
const visible = rows.slice(0, 20);
|
||||
const lines = [
|
||||
`| ${headers.join(' | ')} |`,
|
||||
`| ${headers.map(() => '---').join(' | ')} |`,
|
||||
...visible.map((row) => `| ${row.map((value) => String(value ?? '')).join(' | ')} |`),
|
||||
];
|
||||
if (totalRows > visible.length) {
|
||||
lines.push(`... +${totalRows - visible.length} more rows`);
|
||||
}
|
||||
return lines.join('\n');
|
||||
}
|
||||
|
||||
export class SqlExecutionTool extends BaseTool<typeof sqlExecutionInputSchema> {
|
||||
readonly name = 'sql_execution';
|
||||
|
||||
constructor(private readonly connections: SlConnectionCatalogPort) {
|
||||
super();
|
||||
}
|
||||
|
||||
get description(): string {
|
||||
return 'Run a single read-only SELECT or WITH probe against an allowed warehouse connection and return a capped markdown table or the warehouse error.';
|
||||
}
|
||||
|
||||
get inputSchema() {
|
||||
return sqlExecutionInputSchema;
|
||||
}
|
||||
|
||||
async call(input: SqlExecutionInput, context: ToolContext): Promise<ToolOutput<SqlExecutionStructured>> {
|
||||
const allowed = context.session?.allowedConnectionNames;
|
||||
if (allowed && !allowed.has(input.connectionName)) {
|
||||
return {
|
||||
markdown: `Connection "${input.connectionName}" is not available to this ingest stage.`,
|
||||
structured: {
|
||||
headers: [],
|
||||
rows: [],
|
||||
rowCount: 0,
|
||||
truncated: false,
|
||||
sql: input.sql,
|
||||
wrappedSql: '',
|
||||
error: 'connection_not_allowed',
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
let sql: string;
|
||||
let wrappedSql: string;
|
||||
try {
|
||||
sql = assertReadOnlySql(input.sql);
|
||||
wrappedSql = limitSqlForExecution(sql, input.rowLimit);
|
||||
} catch (error) {
|
||||
const message = error instanceof Error ? error.message : String(error);
|
||||
return {
|
||||
markdown: message,
|
||||
structured: { headers: [], rows: [], rowCount: 0, truncated: false, sql: input.sql, wrappedSql: '', error: message },
|
||||
};
|
||||
}
|
||||
|
||||
try {
|
||||
const result = await this.connections.executeQuery(input.connectionName, wrappedSql);
|
||||
const headers = result.headers ?? [];
|
||||
const rows = result.rows ?? [];
|
||||
const rowCount = result.totalRows ?? rows.length;
|
||||
return {
|
||||
markdown: markdownTable(headers, rows, rowCount),
|
||||
structured: { headers, rows, rowCount, truncated: rowCount > rows.length, sql, wrappedSql },
|
||||
};
|
||||
} catch (error) {
|
||||
const message = error instanceof Error ? error.message : String(error);
|
||||
return {
|
||||
markdown: `SQL execution failed: ${message}`,
|
||||
structured: { headers: [], rows: [], rowCount: 0, truncated: false, sql, wrappedSql, error: message },
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue