diff --git a/packages/context/src/ingest/tools/warehouse-verification/sql-execution.tool.test.ts b/packages/context/src/ingest/tools/warehouse-verification/sql-execution.tool.test.ts new file mode 100644 index 00000000..1cc63cac --- /dev/null +++ b/packages/context/src/ingest/tools/warehouse-verification/sql-execution.tool.test.ts @@ -0,0 +1,54 @@ +import { describe, expect, it, vi } from 'vitest'; +import type { SlConnectionCatalogPort } from '../../../sl/index.js'; +import type { ToolContext } from '../../../tools/index.js'; +import { SqlExecutionTool } from './sql-execution.tool.js'; + +describe('SqlExecutionTool', () => { + const connections = { + executeQuery: vi.fn(), + } as unknown as SlConnectionCatalogPort & { executeQuery: ReturnType }; + const tool = new SqlExecutionTool(connections); + const context: ToolContext = { + sourceId: 'ingest', + messageId: 'm1', + userId: 'system', + session: { allowedConnectionNames: new Set(['warehouse']) } as any, + }; + + it('wraps read-only SQL with a capped row limit', async () => { + connections.executeQuery.mockResolvedValue({ headers: ['status'], rows: [['paid']], totalRows: 1 }); + + const result = await tool.call( + { connectionName: 'warehouse', sql: 'select status from public.orders', rowLimit: 5 }, + context, + ); + + expect(connections.executeQuery).toHaveBeenCalledWith( + 'warehouse', + 'select * from (select status from public.orders) as ktx_query_result limit 5', + ); + expect(result.markdown).toContain('| status |'); + expect(result.structured.wrappedSql).toContain('limit 5'); + }); + + it.each(['insert into x values (1)', 'drop table x', 'vacuum'])('rejects mutating SQL: %s', async (sql) => { + connections.executeQuery.mockClear(); + + const result = await tool.call({ connectionName: 'warehouse', sql }, context); + + expect(result.markdown).toContain('Only read-only SELECT/WITH queries can be executed locally.'); + expect(connections.executeQuery).not.toHaveBeenCalled(); + }); + + it('surfaces connector errors verbatim', async () => { + connections.executeQuery.mockRejectedValue(new Error('relation "orbit_analytics.customer" does not exist')); + + const result = await tool.call( + { connectionName: 'warehouse', sql: 'select 1 from orbit_analytics.customer', rowLimit: 1 }, + context, + ); + + expect(result.markdown).toContain('relation "orbit_analytics.customer" does not exist'); + expect(result.structured.error).toContain('relation "orbit_analytics.customer" does not exist'); + }); +}); diff --git a/packages/context/src/ingest/tools/warehouse-verification/sql-execution.tool.ts b/packages/context/src/ingest/tools/warehouse-verification/sql-execution.tool.ts new file mode 100644 index 00000000..459e480f --- /dev/null +++ b/packages/context/src/ingest/tools/warehouse-verification/sql-execution.tool.ts @@ -0,0 +1,102 @@ +import { z } from 'zod'; +import { assertReadOnlySql, limitSqlForExecution } from '../../../connections/index.js'; +import type { SlConnectionCatalogPort } from '../../../sl/index.js'; +import { BaseTool, type ToolContext, type ToolOutput } from '../../../tools/index.js'; + +const sqlExecutionInputSchema = z.object({ + connectionName: z.string().regex(/^[a-zA-Z0-9][a-zA-Z0-9_-]*$/), + sql: z.string().min(1), + rowLimit: z.number().int().positive().max(1000).optional().default(100), +}); + +type SqlExecutionInput = z.infer; + +export interface SqlExecutionStructured { + headers: string[]; + rows: unknown[][]; + rowCount: number; + truncated: boolean; + sql: string; + wrappedSql: string; + error?: string; +} + +function markdownTable(headers: string[], rows: unknown[][], totalRows: number): string { + if (headers.length === 0) { + return rows.length === 0 ? 'Query returned no rows.' : JSON.stringify(rows.slice(0, 20)); + } + const visible = rows.slice(0, 20); + const lines = [ + `| ${headers.join(' | ')} |`, + `| ${headers.map(() => '---').join(' | ')} |`, + ...visible.map((row) => `| ${row.map((value) => String(value ?? '')).join(' | ')} |`), + ]; + if (totalRows > visible.length) { + lines.push(`... +${totalRows - visible.length} more rows`); + } + return lines.join('\n'); +} + +export class SqlExecutionTool extends BaseTool { + readonly name = 'sql_execution'; + + constructor(private readonly connections: SlConnectionCatalogPort) { + super(); + } + + get description(): string { + return 'Run a single read-only SELECT or WITH probe against an allowed warehouse connection and return a capped markdown table or the warehouse error.'; + } + + get inputSchema() { + return sqlExecutionInputSchema; + } + + async call(input: SqlExecutionInput, context: ToolContext): Promise> { + const allowed = context.session?.allowedConnectionNames; + if (allowed && !allowed.has(input.connectionName)) { + return { + markdown: `Connection "${input.connectionName}" is not available to this ingest stage.`, + structured: { + headers: [], + rows: [], + rowCount: 0, + truncated: false, + sql: input.sql, + wrappedSql: '', + error: 'connection_not_allowed', + }, + }; + } + + let sql: string; + let wrappedSql: string; + try { + sql = assertReadOnlySql(input.sql); + wrappedSql = limitSqlForExecution(sql, input.rowLimit); + } catch (error) { + const message = error instanceof Error ? error.message : String(error); + return { + markdown: message, + structured: { headers: [], rows: [], rowCount: 0, truncated: false, sql: input.sql, wrappedSql: '', error: message }, + }; + } + + try { + const result = await this.connections.executeQuery(input.connectionName, wrappedSql); + const headers = result.headers ?? []; + const rows = result.rows ?? []; + const rowCount = result.totalRows ?? rows.length; + return { + markdown: markdownTable(headers, rows, rowCount), + structured: { headers, rows, rowCount, truncated: rowCount > rows.length, sql, wrappedSql }, + }; + } catch (error) { + const message = error instanceof Error ? error.message : String(error); + return { + markdown: `SQL execution failed: ${message}`, + structured: { headers: [], rows: [], rowCount: 0, truncated: false, sql, wrappedSql, error: message }, + }; + } + } +}