trustgraph/ts/packages/flow/src/retrieval/graph-rag.ts
2026-06-11 08:34:59 -05:00

477 lines
14 KiB
TypeScript

/**
* Graph RAG retrieval pipeline.
*
* Python reference: trustgraph-flow/trustgraph/retrieval/graph_rag/graph_rag.py
*/
import type {
EmbeddingsRequest,
EmbeddingsResponse,
EffectRequestOptions,
EffectRequestResponse,
GraphEmbeddingsRequest,
GraphEmbeddingsResponse,
PromptRequest,
PromptResponse,
Term,
TextCompletionRequest,
TextCompletionResponse,
TriplesQueryRequest,
TriplesQueryResponse,
} from "@trustgraph/base";
import { Triple, errorMessage } from "@trustgraph/base";
import { Array as A, Context, Effect, Layer, Match, Order } from "effect";
import * as O from "effect/Option";
import * as S from "effect/Schema";
export class GraphRagConfig extends S.Class<GraphRagConfig>("GraphRagConfig")({
entityLimit: S.optionalKey(S.Finite),
tripleLimit: S.optionalKey(S.Finite),
maxSubgraphSize: S.optionalKey(S.Finite),
maxPathLength: S.optionalKey(S.Finite),
edgeScoreLimit: S.optionalKey(S.Finite),
edgeLimit: S.optionalKey(S.Finite),
}, { description: "Graph RAG retrieval tuning limits." }) {}
export interface GraphRagClients {
llm: EffectRequestResponse<TextCompletionRequest, TextCompletionResponse>;
embeddings: EffectRequestResponse<EmbeddingsRequest, EmbeddingsResponse>;
graphEmbeddings: EffectRequestResponse<GraphEmbeddingsRequest, GraphEmbeddingsResponse>;
triples: EffectRequestResponse<TriplesQueryRequest, TriplesQueryResponse>;
prompt: EffectRequestResponse<PromptRequest, PromptResponse>;
}
export type ChunkCallback = (
text: string,
endOfStream: boolean,
) => Effect.Effect<void, GraphRagEngineError>;
export interface GraphRagQueryOptions {
readonly collection?: string;
readonly streaming?: boolean;
readonly chunkCallback?: ChunkCallback;
}
export class GraphRagResult extends S.Class<GraphRagResult>("GraphRagResult")({
answer: S.String,
subgraph: S.Array(Triple),
}, { description: "Graph RAG answer with the supporting subgraph." }) {}
interface NormalizedGraphRagConfig {
entityLimit: number;
tripleLimit: number;
maxSubgraphSize: number;
maxPathLength: number;
edgeScoreLimit: number;
edgeLimit: number;
}
export class GraphRagEngineError extends S.TaggedErrorClass<GraphRagEngineError>()(
"GraphRagEngineError",
{
message: S.String,
operation: S.String,
cause: S.Defect({ includeStack: true }),
},
) {}
export interface GraphRagEngineShape {
readonly query: (
clients: GraphRagClients,
queryText: string,
options?: GraphRagQueryOptions,
config?: GraphRagConfig,
) => Effect.Effect<GraphRagResult, GraphRagEngineError>;
}
export class GraphRagEngine extends Context.Service<GraphRagEngine, GraphRagEngineShape>()(
"@trustgraph/flow/retrieval/graph-rag/GraphRagEngine",
) {}
const graphRagError = (operation: string, cause: unknown) =>
GraphRagEngineError.make({
operation,
cause,
message: errorMessage(cause),
});
const requestClient = <TReq, TRes>(
requestor: EffectRequestResponse<TReq, TRes>,
operation: string,
request: TReq,
options?: EffectRequestOptions<TRes, GraphRagEngineError>,
): Effect.Effect<TRes, GraphRagEngineError> =>
requestor.request(request, options).pipe(
Effect.mapError((cause) => graphRagError(operation, cause)),
);
export function normalizeGraphRagConfig(config: GraphRagConfig = {}): NormalizedGraphRagConfig {
return {
entityLimit: config.entityLimit ?? 50,
tripleLimit: config.tripleLimit ?? 30,
maxSubgraphSize: config.maxSubgraphSize ?? 1000,
maxPathLength: config.maxPathLength ?? 2,
edgeScoreLimit: config.edgeScoreLimit ?? 50,
edgeLimit: config.edgeLimit ?? 25,
};
}
export function makeGraphRagEngine(): GraphRagEngineShape {
return {
query: Effect.fn("GraphRagEngine.query")((
clients: GraphRagClients,
queryText: string,
options?: GraphRagQueryOptions,
config?: GraphRagConfig,
) => queryGraphRag(clients, queryText, options, config),
),
};
}
export const GraphRagLive: Layer.Layer<GraphRagEngine> = Layer.succeed(
GraphRagEngine,
GraphRagEngine.of(makeGraphRagEngine()),
);
export interface GraphRag {
readonly query: (
queryText: string,
options?: GraphRagQueryOptions,
) => Effect.Effect<GraphRagResult, GraphRagEngineError>;
}
export function makeGraphRag(
clients: GraphRagClients,
config: GraphRagConfig = {},
): GraphRag {
const engine = makeGraphRagEngine();
return {
query: (queryText, options) => engine.query(clients, queryText, options, config),
};
}
const queryGraphRag = Effect.fn("GraphRagEngine.queryGraphRag")(function* (
clients: GraphRagClients,
queryText: string,
options?: GraphRagQueryOptions,
rawConfig?: GraphRagConfig,
) {
const config = normalizeGraphRagConfig(rawConfig);
yield* Effect.log(`[GraphRag] Query: "${queryText.slice(0, 80)}..."`);
const concepts = yield* extractConcepts(clients, queryText);
yield* Effect.log(`[GraphRag] Step 1: extracted ${concepts.length} concepts: ${concepts.slice(0, 5).join(", ")}`);
const vectors = yield* getVectors(clients, concepts);
yield* Effect.log(`[GraphRag] Step 2: got ${vectors.length} vectors (dim=${vectors[0]?.length ?? 0})`);
const entities = yield* getEntities(clients, config, vectors, options?.collection);
yield* Effect.log(`[GraphRag] Step 3: found ${entities.length} matching entities`);
const subgraph = yield* followEdges(clients, config, entities, options?.collection);
yield* Effect.log(`[GraphRag] Step 4: traversed graph, ${subgraph.length} triples in subgraph`);
const scoredEdges = yield* scoreEdges(clients, config, queryText, subgraph);
yield* Effect.log(`[GraphRag] Step 5: scored down to ${scoredEdges.length} edges`);
yield* Effect.log(`[GraphRag] Step 6: synthesizing answer from ${scoredEdges.length} edges...`);
const answer = yield* synthesize(
clients,
queryText,
scoredEdges,
options?.chunkCallback,
);
yield* Effect.log(`[GraphRag] Step 6: done (${answer.length} chars)`);
return { answer, subgraph: scoredEdges };
});
const extractConcepts = Effect.fn("GraphRagEngine.extractConcepts")(function* (
clients: GraphRagClients,
query: string,
) {
const promptResp = yield* requestClient(
clients.prompt,
"extract-concepts-prompt",
{
name: "extract-concepts",
variables: { query },
},
);
const llmResp = yield* requestClient(
clients.llm,
"extract-concepts-llm",
{
system: promptResp.system,
prompt: promptResp.prompt,
},
);
return llmResp.response
.split("\n")
.map((concept) => concept.trim())
.filter((concept) => concept.length > 0);
});
const getVectors = Effect.fn("GraphRagEngine.getVectors")(function* (
clients: GraphRagClients,
concepts: string[],
) {
const resp = yield* requestClient(clients.embeddings, "get-vectors", { text: concepts });
return resp.vectors;
});
const getEntities = Effect.fn("GraphRagEngine.getEntities")(function* (
clients: GraphRagClients,
config: NormalizedGraphRagConfig,
vectors: number[][],
collection?: string,
) {
const resp = yield* requestClient(
clients.graphEmbeddings,
"get-entities",
{
vectors,
user: "default",
collection: collection ?? "default",
limit: config.entityLimit,
},
);
return resp.entities;
});
const followEdges = Effect.fn("GraphRagEngine.followEdges")(function* (
clients: GraphRagClients,
config: NormalizedGraphRagConfig,
entities: Term[],
collection?: string,
) {
const visited = new Set<string>();
const subgraph: Triple[] = [];
let currentLevel = new Set<string>(
entities.map((entity) => termToString(entity)),
);
for (let depth = 0; depth < config.maxPathLength; depth++) {
if (currentLevel.size === 0 || subgraph.length >= config.maxSubgraphSize) {
break;
}
const unvisited = [...currentLevel].filter((entity) => !visited.has(entity));
if (unvisited.length === 0) break;
const queries = unvisited.map((entityStr) => {
const term = stringToTerm(entityStr);
const request: TriplesQueryRequest = {
s: term,
limit: config.tripleLimit,
...(collection !== undefined ? { collection } : {}),
};
return requestClient(clients.triples, "follow-edges-query", request);
});
const results = yield* Effect.all(queries);
const nextLevel = new Set<string>();
for (const result of results) {
for (const triple of result.triples) {
subgraph.push(triple);
if (depth < config.maxPathLength - 1) {
const objStr = termToString(triple.o);
if (!visited.has(objStr)) {
nextLevel.add(objStr);
}
}
if (subgraph.length >= config.maxSubgraphSize) {
return subgraph;
}
}
}
for (const entity of currentLevel) {
visited.add(entity);
}
currentLevel = nextLevel;
}
return subgraph.slice(0, config.maxSubgraphSize);
});
const scoreEdges = Effect.fn("GraphRagEngine.scoreEdges")(function* (
clients: GraphRagClients,
config: NormalizedGraphRagConfig,
query: string,
triples: Triple[],
) {
if (triples.length === 0) return [];
if (triples.length <= 500) {
yield* Effect.log(`[GraphRag] Skipping edge scoring - ${triples.length} triples fits in context directly`);
return triples;
}
const edgeDescriptions = triples.map((triple, index) => ({
id: String(index),
s: termToString(triple.s),
p: termToString(triple.p),
o: termToString(triple.o),
}));
const toScore = edgeDescriptions.slice(0, config.edgeScoreLimit);
const knowledgeJson = yield* S.encodeUnknownEffect(S.UnknownFromJsonString)(toScore).pipe(
Effect.mapError((cause) => graphRagError("edge-score-encode", cause)),
);
const promptResp = yield* requestClient(
clients.prompt,
"edge-score-prompt",
{
name: "kg-edge-scoring",
variables: {
query,
knowledge: knowledgeJson,
},
},
);
const llmResp = yield* requestClient(
clients.llm,
"edge-score-llm",
{
system: promptResp.system,
prompt: promptResp.prompt,
},
);
yield* Effect.log(`[GraphRag] Edge scoring LLM response (first 500 chars): ${llmResp.response.slice(0, 500)}`);
const scored = A.sort(
parseScoredEdges(llmResp.response),
Order.make<typeof ScoredEdge.Type>((left, right) => Order.Number(right.score, left.score)),
);
const topN = scored.slice(0, config.edgeLimit);
const result: Triple[] = [];
for (const entry of topN) {
const idx = Number.parseInt(entry.id, 10);
if (!Number.isNaN(idx) && idx >= 0 && idx < triples.length) {
result.push(triples[idx]);
}
}
yield* Effect.log(`[GraphRag] Edge scoring: LLM returned ${scored.length} scores, keeping top ${topN.length}, mapped ${result.length} triples`);
if (result.length === 0) {
return triples.slice(0, config.edgeLimit);
}
return result;
});
const synthesize = Effect.fn("GraphRagEngine.synthesize")(function* (
clients: GraphRagClients,
query: string,
edges: Triple[],
chunkCallback?: ChunkCallback,
) {
const context = edges
.map((triple) => `${termToString(triple.s)} -> ${termToString(triple.p)} -> ${termToString(triple.o)}`)
.join("\n");
const promptResp = yield* requestClient(
clients.prompt,
"synthesize-prompt",
{
name: "graph-rag-synthesize",
variables: { query, context },
},
);
if (chunkCallback !== undefined) {
let fullText = "";
yield* requestClient(
clients.llm,
"synthesize-stream",
{
system: promptResp.system,
prompt: promptResp.prompt,
streaming: true,
},
{
recipient: (resp) => {
if (resp.response.length === 0) {
return Effect.succeed(resp.endOfStream === true);
}
fullText += resp.response;
return chunkCallback(resp.response, resp.endOfStream === true).pipe(
Effect.as(resp.endOfStream === true),
);
},
},
);
return fullText;
}
const resp = yield* requestClient(
clients.llm,
"synthesize-llm",
{
system: promptResp.system,
prompt: promptResp.prompt,
},
);
return resp.response;
});
const ScoredEdge = S.Struct({
id: S.String,
score: S.Finite,
});
const ScoredEdgesFromJson = S.Array(ScoredEdge).pipe(S.fromJsonString);
const ScoredEdgeFromJson = ScoredEdge.pipe(S.fromJsonString);
const decodeScoredEdges = S.decodeUnknownOption(ScoredEdgesFromJson);
const decodeScoredEdge = S.decodeUnknownOption(ScoredEdgeFromJson);
function parseScoredEdges(responseText: string): Array<typeof ScoredEdge.Type> {
const parsedArray = decodeScoredEdges(responseText);
if (O.isSome(parsedArray)) {
return Array.from(parsedArray.value);
}
const scored: Array<typeof ScoredEdge.Type> = [];
for (const line of responseText.split("\n")) {
const trimmed = line.trim();
if (trimmed.length === 0) continue;
const parsedLine = decodeScoredEdge(trimmed);
if (O.isSome(parsedLine)) {
scored.push(parsedLine.value);
}
}
return scored;
}
export function termToString(term: Term): string {
return Match.type<Term>().pipe(
Match.discriminatorsExhaustive("type")({
IRI: (iri) => iri.iri,
LITERAL: (literal) => literal.value,
BLANK: (blank) => `_:${blank.id}`,
TRIPLE: (triple) =>
`(${termToString(triple.triple.s)} ${termToString(triple.triple.p)} ${termToString(triple.triple.o)})`,
}),
)(term);
}
export function stringToTerm(value: string): Term {
if (value.startsWith("http://") || value.startsWith("https://")) {
return { type: "IRI", iri: value };
}
if (value.startsWith("_:")) {
return { type: "BLANK", id: value.slice(2) };
}
return { type: "LITERAL", value };
}