/** * Graph RAG retrieval pipeline. * * This is the core RAG pipeline that: * 1. Extracts concepts from the query * 2. Embeds concepts to find matching entities * 3. Traverses the knowledge graph from those entities * 4. Scores and filters edges * 5. Synthesizes an answer with the selected context * * Python reference: trustgraph-flow/trustgraph/retrieval/graph_rag/graph_rag.py */ import type { EmbeddingsRequest, EmbeddingsResponse, GraphEmbeddingsRequest, GraphEmbeddingsResponse, FlowRequestor, PromptRequest, PromptResponse, Term, TextCompletionRequest, TextCompletionResponse, Triple, TriplesQueryRequest, TriplesQueryResponse, } from "@trustgraph/base"; export interface GraphRagConfig { entityLimit?: number; tripleLimit?: number; maxSubgraphSize?: number; maxPathLength?: number; edgeScoreLimit?: number; edgeLimit?: number; } export interface GraphRagClients { llm: FlowRequestor; embeddings: FlowRequestor; graphEmbeddings: FlowRequestor; triples: FlowRequestor; prompt: FlowRequestor; } export type ChunkCallback = (text: string, endOfStream: boolean) => Promise; export interface GraphRagResult { answer: string; subgraph: Triple[]; } export class GraphRag { private readonly clients: GraphRagClients; private config: Required; constructor( clients: GraphRagClients, config: GraphRagConfig = {}, ) { this.clients = clients; this.config = { entityLimit: config.entityLimit ?? 50, tripleLimit: config.tripleLimit ?? 30, maxSubgraphSize: config.maxSubgraphSize ?? 1000, maxPathLength: config.maxPathLength ?? 2, edgeScoreLimit: config.edgeScoreLimit ?? 50, edgeLimit: config.edgeLimit ?? 25, }; } async query( queryText: string, options?: { collection?: string; streaming?: boolean; chunkCallback?: ChunkCallback; }, ): Promise { console.log(`[GraphRag] Query: "${queryText.slice(0, 80)}..."`); // Step 1: Extract concepts from the query via prompt + LLM const concepts = await this.extractConcepts(queryText); console.log(`[GraphRag] Step 1: extracted ${concepts.length} concepts: ${concepts.slice(0, 5).join(", ")}`); // Step 2: Embed concepts concurrently const vectors = await this.getVectors(concepts); console.log(`[GraphRag] Step 2: got ${vectors.length} vectors (dim=${vectors[0]?.length ?? 0})`); // Step 3: Find matching entities via graph embeddings const entities = await this.getEntities(vectors, options?.collection); console.log(`[GraphRag] Step 3: found ${entities.length} matching entities`); // Step 4: Traverse the knowledge graph from entities const subgraph = await this.followEdges(entities, options?.collection); console.log(`[GraphRag] Step 4: traversed graph, ${subgraph.length} triples in subgraph`); // Step 5: Score and filter edges via LLM const scoredEdges = await this.scoreEdges(queryText, subgraph); console.log(`[GraphRag] Step 5: scored down to ${scoredEdges.length} edges`); // Step 6: Synthesize answer console.log(`[GraphRag] Step 6: synthesizing answer from ${scoredEdges.length} edges...`); const answer = await this.synthesize( queryText, scoredEdges, options?.chunkCallback, ); console.log(`[GraphRag] Step 6: done (${answer.length} chars)`); return { answer, subgraph: scoredEdges }; } private async extractConcepts(query: string): Promise { const promptResp = await this.clients.prompt.request({ name: "extract-concepts", variables: { query }, }); const llmResp = await this.clients.llm.request({ system: (promptResp as PromptResponse).system, prompt: (promptResp as PromptResponse).prompt, }); // Parse concepts from LLM response (newline-separated) return (llmResp as TextCompletionResponse).response .split("\n") .map((c) => c.trim()) .filter((c) => c.length > 0); } private async getVectors(concepts: string[]): Promise { const resp = await this.clients.embeddings.request({ text: concepts }); return (resp as EmbeddingsResponse).vectors; } private async getEntities(vectors: number[][], collection?: string): Promise { const resp = await this.clients.graphEmbeddings.request({ vectors, user: "default", collection: collection ?? "default", limit: this.config.entityLimit, }); return (resp as GraphEmbeddingsResponse).entities; } private async followEdges(entities: Term[], collection?: string): Promise { // BFS multi-hop traversal up to maxPathLength const visited = new Set(); const subgraph: Triple[] = []; // Current frontier: the set of entities to expand at this depth level let currentLevel = new Set( entities.map((e) => termToString(e)), ); for (let depth = 0; depth < this.config.maxPathLength; depth++) { if (currentLevel.size === 0 || subgraph.length >= this.config.maxSubgraphSize) { break; } // Filter out already-visited entities const unvisited = [...currentLevel].filter((e) => !visited.has(e)); if (unvisited.length === 0) break; // Batch triple queries for all unvisited entities at this depth // Query each entity as subject to get outgoing edges const queries = unvisited.map((entityStr) => { const term = stringToTerm(entityStr); const request: TriplesQueryRequest = { s: term, limit: this.config.tripleLimit, ...(collection !== undefined ? { collection } : {}), }; return this.clients.triples.request(request); }); const results = await Promise.all(queries); const nextLevel = new Set(); for (const result of results) { const triples = (result as TriplesQueryResponse).triples; for (const triple of triples) { subgraph.push(triple); // Collect objects as next-level entities for further expansion // (only if we have more depth levels remaining) if (depth < this.config.maxPathLength - 1) { const objStr = termToString(triple.o); if (!visited.has(objStr)) { nextLevel.add(objStr); } } if (subgraph.length >= this.config.maxSubgraphSize) { return subgraph; } } } // Mark current level as visited and move to next for (const e of currentLevel) { visited.add(e); } currentLevel = nextLevel; } return subgraph.slice(0, this.config.maxSubgraphSize); } private async scoreEdges(query: string, triples: Triple[]): Promise { if (triples.length === 0) return []; // If the subgraph is small enough, skip LLM scoring entirely // 500 triples is well within LLM context limits and avoids lossy scoring if (triples.length <= 500) { console.log(`[GraphRag] Skipping edge scoring — ${triples.length} triples fits in context directly`); return triples; } // Build a numbered list of edges for the LLM to score const edgeDescriptions = triples.map((t, i) => ({ id: String(i), s: termToString(t.s), p: termToString(t.p), o: termToString(t.o), })); // Limit how many edges we send for scoring to avoid overflowing context const toScore = edgeDescriptions.slice(0, this.config.edgeScoreLimit); const knowledgeJson = JSON.stringify(toScore, null, 2); // Ask the LLM to score each edge for relevance to the query const promptResp = await this.clients.prompt.request({ name: "kg-edge-scoring", variables: { query, knowledge: knowledgeJson, }, }); const llmResp = await this.clients.llm.request({ system: (promptResp as PromptResponse).system, prompt: (promptResp as PromptResponse).prompt, }); const responseText = (llmResp as TextCompletionResponse).response; console.log(`[GraphRag] Edge scoring LLM response (first 500 chars): ${responseText.slice(0, 500)}`); // Parse scores from LLM response // Expected format: JSON array of { id: string, score: number } // or newline-separated JSON objects const scored: Array<{ id: string; score: number }> = []; try { // Try parsing as a JSON array first const parsed = JSON.parse(responseText) as Array<{ id: string; score: number }>; if (Array.isArray(parsed)) { for (const item of parsed) { if ( typeof item === "object" && item !== null && typeof item.id === "string" && typeof item.score === "number" ) { scored.push({ id: item.id, score: item.score }); } } } } catch { // Fall back to parsing line-by-line JSON objects for (const line of responseText.split("\n")) { const trimmed = line.trim(); if (trimmed.length === 0) continue; try { const obj = JSON.parse(trimmed) as { id?: string; score?: number }; if ( typeof obj === "object" && obj !== null && typeof obj.id === "string" && typeof obj.score === "number" ) { scored.push({ id: obj.id, score: obj.score }); } } catch { // Skip unparseable lines } } } // Sort by score descending and keep top N scored.sort((a, b) => b.score - a.score); const topN = scored.slice(0, this.config.edgeLimit); // Map back to triples const result: Triple[] = []; for (const entry of topN) { const idx = parseInt(entry.id, 10); if (!isNaN(idx) && idx >= 0 && idx < triples.length) { result.push(triples[idx]); } } console.log(`[GraphRag] Edge scoring: LLM returned ${scored.length} scores, keeping top ${topN.length}, mapped ${result.length} triples`); // If scoring failed entirely, fall back to returning the first edgeLimit triples if (result.length === 0) { return triples.slice(0, this.config.edgeLimit); } return result; } private async synthesize( query: string, edges: Triple[], chunkCallback?: ChunkCallback, ): Promise { // Format edges as context const context = edges .map((t) => `${termToString(t.s)} -> ${termToString(t.p)} -> ${termToString(t.o)}`) .join("\n"); const promptResp = await this.clients.prompt.request({ name: "graph-rag-synthesize", variables: { query, context }, }); if (chunkCallback !== undefined) { // Streaming response let fullText = ""; await this.clients.llm.request( { system: (promptResp as PromptResponse).system, prompt: (promptResp as PromptResponse).prompt, streaming: true, }, { recipient: async (resp) => { const r = resp as TextCompletionResponse; if (r.response.length > 0) { fullText += r.response; await chunkCallback(r.response, r.endOfStream === true); } return r.endOfStream === true; }, }, ); return fullText; } const resp = await this.clients.llm.request({ system: (promptResp as PromptResponse).system, prompt: (promptResp as PromptResponse).prompt, }); return (resp as TextCompletionResponse).response; } } function termToString(term: Term): string { switch (term.type) { case "IRI": return term.iri; case "LITERAL": return term.value; case "BLANK": return `_:${term.id}`; case "TRIPLE": return `(${termToString(term.triple.s)} ${termToString(term.triple.p)} ${termToString(term.triple.o)})`; } } function stringToTerm(value: string): Term { if (value.startsWith("http://") || value.startsWith("https://")) { return { type: "IRI", iri: value }; } if (value.startsWith("_:")) { return { type: "BLANK", id: value.slice(2) }; } return { type: "LITERAL", value }; }