/** * Graph RAG retrieval pipeline. * * Python reference: trustgraph-flow/trustgraph/retrieval/graph_rag/graph_rag.py */ import type { EmbeddingsRequest, EmbeddingsResponse, EffectRequestOptions, EffectRequestResponse, GraphEmbeddingsRequest, GraphEmbeddingsResponse, PromptRequest, PromptResponse, Term, TextCompletionRequest, TextCompletionResponse, Triple, TriplesQueryRequest, TriplesQueryResponse, } from "@trustgraph/base"; import { errorMessage } from "@trustgraph/base"; import { Context, Effect, Layer } from "effect"; import * as O from "effect/Option"; import * as S from "effect/Schema"; export interface GraphRagConfig { entityLimit?: number; tripleLimit?: number; maxSubgraphSize?: number; maxPathLength?: number; edgeScoreLimit?: number; edgeLimit?: number; } export interface GraphRagClients { llm: EffectRequestResponse; embeddings: EffectRequestResponse; graphEmbeddings: EffectRequestResponse; triples: EffectRequestResponse; prompt: EffectRequestResponse; } export type ChunkCallback = (text: string, endOfStream: boolean) => Promise; export interface GraphRagQueryOptions { readonly collection?: string; readonly streaming?: boolean; readonly chunkCallback?: ChunkCallback; } export interface GraphRagResult { answer: string; subgraph: Triple[]; } interface NormalizedGraphRagConfig { entityLimit: number; tripleLimit: number; maxSubgraphSize: number; maxPathLength: number; edgeScoreLimit: number; edgeLimit: number; } export class GraphRagEngineError extends S.TaggedErrorClass()( "GraphRagEngineError", { message: S.String, operation: S.String, cause: S.DefectWithStack, }, ) {} export interface GraphRagEngineShape { readonly query: ( clients: GraphRagClients, queryText: string, options?: GraphRagQueryOptions, config?: GraphRagConfig, ) => Effect.Effect; } export class GraphRagEngine extends Context.Service()( "@trustgraph/flow/retrieval/graph-rag/GraphRagEngine", ) {} const graphRagError = (operation: string, cause: unknown) => GraphRagEngineError.make({ operation, cause, message: errorMessage(cause), }); const requestClient = ( requestor: EffectRequestResponse, operation: string, request: TReq, options?: EffectRequestOptions, ): Effect.Effect => requestor.request(request, options).pipe( Effect.mapError((cause) => graphRagError(operation, cause)), ); export function normalizeGraphRagConfig(config: GraphRagConfig = {}): NormalizedGraphRagConfig { return { entityLimit: config.entityLimit ?? 50, tripleLimit: config.tripleLimit ?? 30, maxSubgraphSize: config.maxSubgraphSize ?? 1000, maxPathLength: config.maxPathLength ?? 2, edgeScoreLimit: config.edgeScoreLimit ?? 50, edgeLimit: config.edgeLimit ?? 25, }; } export function makeGraphRagEngine(): GraphRagEngineShape { return { query: Effect.fn("GraphRagEngine.query")(( clients: GraphRagClients, queryText: string, options?: GraphRagQueryOptions, config?: GraphRagConfig, ) => queryGraphRag(clients, queryText, options, config), ), }; } export const GraphRagLive: Layer.Layer = Layer.succeed( GraphRagEngine, GraphRagEngine.of(makeGraphRagEngine()), ); export interface GraphRag { readonly query: ( queryText: string, options?: GraphRagQueryOptions, ) => Promise; } export function makeGraphRag( clients: GraphRagClients, config: GraphRagConfig = {}, ): GraphRag { const engine = makeGraphRagEngine(); return { query: (queryText, options) => Effect.runPromise(engine.query(clients, queryText, options, config)), }; } function queryGraphRag( clients: GraphRagClients, queryText: string, options?: GraphRagQueryOptions, rawConfig?: GraphRagConfig, ): Effect.Effect { return Effect.gen(function* () { const config = normalizeGraphRagConfig(rawConfig); yield* Effect.log(`[GraphRag] Query: "${queryText.slice(0, 80)}..."`); const concepts = yield* extractConcepts(clients, queryText); yield* Effect.log(`[GraphRag] Step 1: extracted ${concepts.length} concepts: ${concepts.slice(0, 5).join(", ")}`); const vectors = yield* getVectors(clients, concepts); yield* Effect.log(`[GraphRag] Step 2: got ${vectors.length} vectors (dim=${vectors[0]?.length ?? 0})`); const entities = yield* getEntities(clients, config, vectors, options?.collection); yield* Effect.log(`[GraphRag] Step 3: found ${entities.length} matching entities`); const subgraph = yield* followEdges(clients, config, entities, options?.collection); yield* Effect.log(`[GraphRag] Step 4: traversed graph, ${subgraph.length} triples in subgraph`); const scoredEdges = yield* scoreEdges(clients, config, queryText, subgraph); yield* Effect.log(`[GraphRag] Step 5: scored down to ${scoredEdges.length} edges`); yield* Effect.log(`[GraphRag] Step 6: synthesizing answer from ${scoredEdges.length} edges...`); const answer = yield* synthesize( clients, queryText, scoredEdges, options?.chunkCallback, ); yield* Effect.log(`[GraphRag] Step 6: done (${answer.length} chars)`); return { answer, subgraph: scoredEdges }; }); } function extractConcepts(clients: GraphRagClients, query: string): Effect.Effect { return Effect.gen(function* () { const promptResp = yield* requestClient( clients.prompt, "extract-concepts-prompt", { name: "extract-concepts", variables: { query }, }, ); const llmResp = yield* requestClient( clients.llm, "extract-concepts-llm", { system: promptResp.system, prompt: promptResp.prompt, }, ); return llmResp.response .split("\n") .map((concept) => concept.trim()) .filter((concept) => concept.length > 0); }); } function getVectors(clients: GraphRagClients, concepts: string[]): Effect.Effect { return Effect.gen(function* () { const resp = yield* requestClient(clients.embeddings, "get-vectors", { text: concepts }); return resp.vectors; }); } function getEntities( clients: GraphRagClients, config: NormalizedGraphRagConfig, vectors: number[][], collection?: string, ): Effect.Effect { return Effect.gen(function* () { const resp = yield* requestClient( clients.graphEmbeddings, "get-entities", { vectors, user: "default", collection: collection ?? "default", limit: config.entityLimit, }, ); return resp.entities; }); } function followEdges( clients: GraphRagClients, config: NormalizedGraphRagConfig, entities: Term[], collection?: string, ): Effect.Effect { return Effect.gen(function* () { const visited = new Set(); const subgraph: Triple[] = []; let currentLevel = new Set( entities.map((entity) => termToString(entity)), ); for (let depth = 0; depth < config.maxPathLength; depth++) { if (currentLevel.size === 0 || subgraph.length >= config.maxSubgraphSize) { break; } const unvisited = [...currentLevel].filter((entity) => !visited.has(entity)); if (unvisited.length === 0) break; const queries = unvisited.map((entityStr) => { const term = stringToTerm(entityStr); const request: TriplesQueryRequest = { s: term, limit: config.tripleLimit, ...(collection !== undefined ? { collection } : {}), }; return requestClient(clients.triples, "follow-edges-query", request); }); const results = yield* Effect.all(queries); const nextLevel = new Set(); for (const result of results) { for (const triple of result.triples) { subgraph.push(triple); if (depth < config.maxPathLength - 1) { const objStr = termToString(triple.o); if (!visited.has(objStr)) { nextLevel.add(objStr); } } if (subgraph.length >= config.maxSubgraphSize) { return subgraph; } } } for (const entity of currentLevel) { visited.add(entity); } currentLevel = nextLevel; } return subgraph.slice(0, config.maxSubgraphSize); }); } function scoreEdges( clients: GraphRagClients, config: NormalizedGraphRagConfig, query: string, triples: Triple[], ): Effect.Effect { return Effect.gen(function* () { if (triples.length === 0) return []; if (triples.length <= 500) { yield* Effect.log(`[GraphRag] Skipping edge scoring - ${triples.length} triples fits in context directly`); return triples; } const edgeDescriptions = triples.map((triple, index) => ({ id: String(index), s: termToString(triple.s), p: termToString(triple.p), o: termToString(triple.o), })); const toScore = edgeDescriptions.slice(0, config.edgeScoreLimit); const knowledgeJson = yield* S.encodeUnknownEffect(S.UnknownFromJsonString)(toScore).pipe( Effect.mapError((cause) => graphRagError("edge-score-encode", cause)), ); const promptResp = yield* requestClient( clients.prompt, "edge-score-prompt", { name: "kg-edge-scoring", variables: { query, knowledge: knowledgeJson, }, }, ); const llmResp = yield* requestClient( clients.llm, "edge-score-llm", { system: promptResp.system, prompt: promptResp.prompt, }, ); yield* Effect.log(`[GraphRag] Edge scoring LLM response (first 500 chars): ${llmResp.response.slice(0, 500)}`); const scored = parseScoredEdges(llmResp.response); scored.sort((a, b) => b.score - a.score); const topN = scored.slice(0, config.edgeLimit); const result: Triple[] = []; for (const entry of topN) { const idx = Number.parseInt(entry.id, 10); if (!Number.isNaN(idx) && idx >= 0 && idx < triples.length) { result.push(triples[idx]); } } yield* Effect.log(`[GraphRag] Edge scoring: LLM returned ${scored.length} scores, keeping top ${topN.length}, mapped ${result.length} triples`); if (result.length === 0) { return triples.slice(0, config.edgeLimit); } return result; }); } function synthesize( clients: GraphRagClients, query: string, edges: Triple[], chunkCallback?: ChunkCallback, ): Effect.Effect { return Effect.gen(function* () { const context = edges .map((triple) => `${termToString(triple.s)} -> ${termToString(triple.p)} -> ${termToString(triple.o)}`) .join("\n"); const promptResp = yield* requestClient( clients.prompt, "synthesize-prompt", { name: "graph-rag-synthesize", variables: { query, context }, }, ); if (chunkCallback !== undefined) { let fullText = ""; yield* requestClient( clients.llm, "synthesize-stream", { system: promptResp.system, prompt: promptResp.prompt, streaming: true, }, { recipient: (resp) => { if (resp.response.length === 0) { return Effect.succeed(resp.endOfStream === true); } fullText += resp.response; return Effect.tryPromise({ try: () => chunkCallback(resp.response, resp.endOfStream === true).then(() => resp.endOfStream === true), catch: (cause) => graphRagError("synthesize-stream-callback", cause), }); }, }, ); return fullText; } const resp = yield* requestClient( clients.llm, "synthesize-llm", { system: promptResp.system, prompt: promptResp.prompt, }, ); return resp.response; }); } const ScoredEdge = S.Struct({ id: S.String, score: S.Number, }); const ScoredEdgesFromJson = S.Array(ScoredEdge).pipe(S.fromJsonString); const ScoredEdgeFromJson = ScoredEdge.pipe(S.fromJsonString); const decodeScoredEdges = S.decodeUnknownOption(ScoredEdgesFromJson); const decodeScoredEdge = S.decodeUnknownOption(ScoredEdgeFromJson); function parseScoredEdges(responseText: string): Array { const parsedArray = decodeScoredEdges(responseText); if (O.isSome(parsedArray)) { return Array.from(parsedArray.value); } const scored: Array = []; for (const line of responseText.split("\n")) { const trimmed = line.trim(); if (trimmed.length === 0) continue; const parsedLine = decodeScoredEdge(trimmed); if (O.isSome(parsedLine)) { scored.push(parsedLine.value); } } return scored; } export function termToString(term: Term): string { switch (term.type) { case "IRI": return term.iri; case "LITERAL": return term.value; case "BLANK": return `_:${term.id}`; case "TRIPLE": return `(${termToString(term.triple.s)} ${termToString(term.triple.p)} ${termToString(term.triple.o)})`; } } export function stringToTerm(value: string): Term { if (value.startsWith("http://") || value.startsWith("https://")) { return { type: "IRI", iri: value }; } if (value.startsWith("_:")) { return { type: "BLANK", id: value.slice(2) }; } return { type: "LITERAL", value }; }