diff --git a/packages/cli/src/knowledge.test.ts b/packages/cli/src/knowledge.test.ts index d3db8465..0e9ed1d5 100644 --- a/packages/cli/src/knowledge.test.ts +++ b/packages/cli/src/knowledge.test.ts @@ -2,6 +2,7 @@ import { mkdtemp, rm } from 'node:fs/promises'; import { tmpdir } from 'node:os'; import { join } from 'node:path'; import { initKtxProject } from '@ktx/context/project'; +import type { KtxEmbeddingPort } from '@ktx/context'; import { afterEach, beforeEach, describe, expect, it } from 'vitest'; import { runKtxKnowledge } from './knowledge.js'; @@ -26,6 +27,19 @@ function makeIo() { }; } +class FakeEmbeddingPort implements KtxEmbeddingPort { + readonly maxBatchSize = 16; + + async computeEmbedding(text: string): Promise { + const lower = text.toLowerCase(); + return lower.includes('revenue') || lower.includes('arr') ? [1, 0] : [0, 1]; + } + + async computeEmbeddingsBulk(texts: string[]): Promise { + return Promise.all(texts.map((text) => this.computeEmbedding(text))); + } +} + describe('runKtxKnowledge', () => { let tempDir: string; @@ -92,4 +106,39 @@ describe('runKtxKnowledge', () => { expect(searchIo.stderr()).toContain('No local wiki pages found'); expect(searchIo.stderr()).toContain('ktx wiki write'); }); + + it('uses configured embeddings for semantic wiki search', async () => { + const projectDir = join(tempDir, 'semantic-project'); + await initKtxProject({ projectDir, projectName: 'warehouse' }); + + await expect( + runKtxKnowledge( + { + command: 'write', + projectDir, + key: 'historic-sql/active-contract-arr-open-tickets', + scope: 'GLOBAL', + userId: 'local', + summary: 'Active Contract ARR Ranked by Open Support Ticket Count', + content: 'Accounts ranked by annual recurring contract value and support ticket load.', + tags: ['historic-sql'], + refs: [], + slRefs: [], + }, + makeIo().io, + ), + ).resolves.toBe(0); + + const searchIo = makeIo(); + await expect( + runKtxKnowledge( + { command: 'search', projectDir, query: 'revenue', userId: 'local' }, + searchIo.io, + { embeddingService: new FakeEmbeddingPort() }, + ), + ).resolves.toBe(0); + + expect(searchIo.stdout()).toContain('historic-sql/active-contract-arr-open-tickets'); + expect(searchIo.stderr()).toBe(''); + }); }); diff --git a/packages/cli/src/knowledge.ts b/packages/cli/src/knowledge.ts index 89afda8e..40cc5372 100644 --- a/packages/cli/src/knowledge.ts +++ b/packages/cli/src/knowledge.ts @@ -1,3 +1,8 @@ +import { + createLocalKtxEmbeddingProviderFromConfig, + KtxIngestEmbeddingPortAdapter, + type KtxEmbeddingPort, +} from '@ktx/context'; import { loadKtxProject } from '@ktx/context/project'; import { type LocalKnowledgeScope, @@ -29,7 +34,29 @@ interface KtxKnowledgeIo { stderr: { write(chunk: string): void }; } -export async function runKtxKnowledge(args: KtxKnowledgeArgs, io: KtxKnowledgeIo = process): Promise { +interface KtxKnowledgeDeps { + embeddingService?: KtxEmbeddingPort | null; + createEmbeddingProvider?: typeof createLocalKtxEmbeddingProviderFromConfig; +} + +function wikiSearchEmbeddingService( + project: Awaited>, + deps: KtxKnowledgeDeps, +): KtxEmbeddingPort | null { + if ('embeddingService' in deps) { + return deps.embeddingService ?? null; + } + const provider = (deps.createEmbeddingProvider ?? createLocalKtxEmbeddingProviderFromConfig)( + project.config.ingest.embeddings, + ); + return provider ? new KtxIngestEmbeddingPortAdapter(provider) : null; +} + +export async function runKtxKnowledge( + args: KtxKnowledgeArgs, + io: KtxKnowledgeIo = process, + deps: KtxKnowledgeDeps = {}, +): Promise { try { const project = await loadKtxProject({ projectDir: args.projectDir }); if (args.command === 'list') { @@ -51,7 +78,11 @@ export async function runKtxKnowledge(args: KtxKnowledgeArgs, io: KtxKnowledgeIo return 0; } if (args.command === 'search') { - const results = await searchLocalKnowledgePages(project, { query: args.query, userId: args.userId }); + const results = await searchLocalKnowledgePages(project, { + query: args.query, + userId: args.userId, + embeddingService: wikiSearchEmbeddingService(project, deps), + }); if (results.length === 0) { const pages = await listLocalKnowledgePages(project, { userId: args.userId }); if (pages.length === 0) {