mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-07-01 09:29:38 +02:00
477 lines
14 KiB
TypeScript
477 lines
14 KiB
TypeScript
/**
|
|
* Graph RAG retrieval pipeline.
|
|
*
|
|
* Python reference: trustgraph-flow/trustgraph/retrieval/graph_rag/graph_rag.py
|
|
*/
|
|
|
|
import type {
|
|
EmbeddingsRequest,
|
|
EmbeddingsResponse,
|
|
EffectRequestOptions,
|
|
EffectRequestResponse,
|
|
GraphEmbeddingsRequest,
|
|
GraphEmbeddingsResponse,
|
|
PromptRequest,
|
|
PromptResponse,
|
|
Term,
|
|
TextCompletionRequest,
|
|
TextCompletionResponse,
|
|
TriplesQueryRequest,
|
|
TriplesQueryResponse,
|
|
} from "@trustgraph/base";
|
|
import { Triple, errorMessage } from "@trustgraph/base";
|
|
import { Array as A, Context, Effect, Layer, Match, Order } from "effect";
|
|
import * as O from "effect/Option";
|
|
import * as S from "effect/Schema";
|
|
|
|
export class GraphRagConfig extends S.Class<GraphRagConfig>("GraphRagConfig")({
|
|
entityLimit: S.optionalKey(S.Finite),
|
|
tripleLimit: S.optionalKey(S.Finite),
|
|
maxSubgraphSize: S.optionalKey(S.Finite),
|
|
maxPathLength: S.optionalKey(S.Finite),
|
|
edgeScoreLimit: S.optionalKey(S.Finite),
|
|
edgeLimit: S.optionalKey(S.Finite),
|
|
}, { description: "Graph RAG retrieval tuning limits." }) {}
|
|
|
|
export interface GraphRagClients {
|
|
llm: EffectRequestResponse<TextCompletionRequest, TextCompletionResponse>;
|
|
embeddings: EffectRequestResponse<EmbeddingsRequest, EmbeddingsResponse>;
|
|
graphEmbeddings: EffectRequestResponse<GraphEmbeddingsRequest, GraphEmbeddingsResponse>;
|
|
triples: EffectRequestResponse<TriplesQueryRequest, TriplesQueryResponse>;
|
|
prompt: EffectRequestResponse<PromptRequest, PromptResponse>;
|
|
}
|
|
|
|
export type ChunkCallback = (
|
|
text: string,
|
|
endOfStream: boolean,
|
|
) => Effect.Effect<void, GraphRagEngineError>;
|
|
|
|
export interface GraphRagQueryOptions {
|
|
readonly collection?: string;
|
|
readonly streaming?: boolean;
|
|
readonly chunkCallback?: ChunkCallback;
|
|
}
|
|
|
|
export class GraphRagResult extends S.Class<GraphRagResult>("GraphRagResult")({
|
|
answer: S.String,
|
|
subgraph: S.Array(Triple),
|
|
}, { description: "Graph RAG answer with the supporting subgraph." }) {}
|
|
|
|
interface NormalizedGraphRagConfig {
|
|
entityLimit: number;
|
|
tripleLimit: number;
|
|
maxSubgraphSize: number;
|
|
maxPathLength: number;
|
|
edgeScoreLimit: number;
|
|
edgeLimit: number;
|
|
}
|
|
|
|
export class GraphRagEngineError extends S.TaggedErrorClass<GraphRagEngineError>()(
|
|
"GraphRagEngineError",
|
|
{
|
|
message: S.String,
|
|
operation: S.String,
|
|
cause: S.Defect({ includeStack: true }),
|
|
},
|
|
) {}
|
|
|
|
export interface GraphRagEngineShape {
|
|
readonly query: (
|
|
clients: GraphRagClients,
|
|
queryText: string,
|
|
options?: GraphRagQueryOptions,
|
|
config?: GraphRagConfig,
|
|
) => Effect.Effect<GraphRagResult, GraphRagEngineError>;
|
|
}
|
|
|
|
export class GraphRagEngine extends Context.Service<GraphRagEngine, GraphRagEngineShape>()(
|
|
"@trustgraph/flow/retrieval/graph-rag/GraphRagEngine",
|
|
) {}
|
|
|
|
const graphRagError = (operation: string, cause: unknown) =>
|
|
GraphRagEngineError.make({
|
|
operation,
|
|
cause,
|
|
message: errorMessage(cause),
|
|
});
|
|
|
|
const requestClient = <TReq, TRes>(
|
|
requestor: EffectRequestResponse<TReq, TRes>,
|
|
operation: string,
|
|
request: TReq,
|
|
options?: EffectRequestOptions<TRes, GraphRagEngineError>,
|
|
): Effect.Effect<TRes, GraphRagEngineError> =>
|
|
requestor.request(request, options).pipe(
|
|
Effect.mapError((cause) => graphRagError(operation, cause)),
|
|
);
|
|
|
|
export function normalizeGraphRagConfig(config: GraphRagConfig = {}): NormalizedGraphRagConfig {
|
|
return {
|
|
entityLimit: config.entityLimit ?? 50,
|
|
tripleLimit: config.tripleLimit ?? 30,
|
|
maxSubgraphSize: config.maxSubgraphSize ?? 1000,
|
|
maxPathLength: config.maxPathLength ?? 2,
|
|
edgeScoreLimit: config.edgeScoreLimit ?? 50,
|
|
edgeLimit: config.edgeLimit ?? 25,
|
|
};
|
|
}
|
|
|
|
export function makeGraphRagEngine(): GraphRagEngineShape {
|
|
return {
|
|
query: Effect.fn("GraphRagEngine.query")((
|
|
clients: GraphRagClients,
|
|
queryText: string,
|
|
options?: GraphRagQueryOptions,
|
|
config?: GraphRagConfig,
|
|
) => queryGraphRag(clients, queryText, options, config),
|
|
),
|
|
};
|
|
}
|
|
|
|
export const GraphRagLive: Layer.Layer<GraphRagEngine> = Layer.succeed(
|
|
GraphRagEngine,
|
|
GraphRagEngine.of(makeGraphRagEngine()),
|
|
);
|
|
|
|
export interface GraphRag {
|
|
readonly query: (
|
|
queryText: string,
|
|
options?: GraphRagQueryOptions,
|
|
) => Effect.Effect<GraphRagResult, GraphRagEngineError>;
|
|
}
|
|
|
|
export function makeGraphRag(
|
|
clients: GraphRagClients,
|
|
config: GraphRagConfig = {},
|
|
): GraphRag {
|
|
const engine = makeGraphRagEngine();
|
|
return {
|
|
query: (queryText, options) => engine.query(clients, queryText, options, config),
|
|
};
|
|
}
|
|
|
|
const queryGraphRag = Effect.fn("GraphRagEngine.queryGraphRag")(function* (
|
|
clients: GraphRagClients,
|
|
queryText: string,
|
|
options?: GraphRagQueryOptions,
|
|
rawConfig?: GraphRagConfig,
|
|
) {
|
|
const config = normalizeGraphRagConfig(rawConfig);
|
|
yield* Effect.log(`[GraphRag] Query: "${queryText.slice(0, 80)}..."`);
|
|
|
|
const concepts = yield* extractConcepts(clients, queryText);
|
|
yield* Effect.log(`[GraphRag] Step 1: extracted ${concepts.length} concepts: ${concepts.slice(0, 5).join(", ")}`);
|
|
|
|
const vectors = yield* getVectors(clients, concepts);
|
|
yield* Effect.log(`[GraphRag] Step 2: got ${vectors.length} vectors (dim=${vectors[0]?.length ?? 0})`);
|
|
|
|
const entities = yield* getEntities(clients, config, vectors, options?.collection);
|
|
yield* Effect.log(`[GraphRag] Step 3: found ${entities.length} matching entities`);
|
|
|
|
const subgraph = yield* followEdges(clients, config, entities, options?.collection);
|
|
yield* Effect.log(`[GraphRag] Step 4: traversed graph, ${subgraph.length} triples in subgraph`);
|
|
|
|
const scoredEdges = yield* scoreEdges(clients, config, queryText, subgraph);
|
|
yield* Effect.log(`[GraphRag] Step 5: scored down to ${scoredEdges.length} edges`);
|
|
|
|
yield* Effect.log(`[GraphRag] Step 6: synthesizing answer from ${scoredEdges.length} edges...`);
|
|
const answer = yield* synthesize(
|
|
clients,
|
|
queryText,
|
|
scoredEdges,
|
|
options?.chunkCallback,
|
|
);
|
|
yield* Effect.log(`[GraphRag] Step 6: done (${answer.length} chars)`);
|
|
|
|
return { answer, subgraph: scoredEdges };
|
|
});
|
|
|
|
const extractConcepts = Effect.fn("GraphRagEngine.extractConcepts")(function* (
|
|
clients: GraphRagClients,
|
|
query: string,
|
|
) {
|
|
const promptResp = yield* requestClient(
|
|
clients.prompt,
|
|
"extract-concepts-prompt",
|
|
{
|
|
name: "extract-concepts",
|
|
variables: { query },
|
|
},
|
|
);
|
|
|
|
const llmResp = yield* requestClient(
|
|
clients.llm,
|
|
"extract-concepts-llm",
|
|
{
|
|
system: promptResp.system,
|
|
prompt: promptResp.prompt,
|
|
},
|
|
);
|
|
|
|
return llmResp.response
|
|
.split("\n")
|
|
.map((concept) => concept.trim())
|
|
.filter((concept) => concept.length > 0);
|
|
});
|
|
|
|
const getVectors = Effect.fn("GraphRagEngine.getVectors")(function* (
|
|
clients: GraphRagClients,
|
|
concepts: string[],
|
|
) {
|
|
const resp = yield* requestClient(clients.embeddings, "get-vectors", { text: concepts });
|
|
return resp.vectors;
|
|
});
|
|
|
|
const getEntities = Effect.fn("GraphRagEngine.getEntities")(function* (
|
|
clients: GraphRagClients,
|
|
config: NormalizedGraphRagConfig,
|
|
vectors: number[][],
|
|
collection?: string,
|
|
) {
|
|
const resp = yield* requestClient(
|
|
clients.graphEmbeddings,
|
|
"get-entities",
|
|
{
|
|
vectors,
|
|
user: "default",
|
|
collection: collection ?? "default",
|
|
limit: config.entityLimit,
|
|
},
|
|
);
|
|
return resp.entities;
|
|
});
|
|
|
|
const followEdges = Effect.fn("GraphRagEngine.followEdges")(function* (
|
|
clients: GraphRagClients,
|
|
config: NormalizedGraphRagConfig,
|
|
entities: Term[],
|
|
collection?: string,
|
|
) {
|
|
const visited = new Set<string>();
|
|
const subgraph: Triple[] = [];
|
|
let currentLevel = new Set<string>(
|
|
entities.map((entity) => termToString(entity)),
|
|
);
|
|
|
|
for (let depth = 0; depth < config.maxPathLength; depth++) {
|
|
if (currentLevel.size === 0 || subgraph.length >= config.maxSubgraphSize) {
|
|
break;
|
|
}
|
|
|
|
const unvisited = [...currentLevel].filter((entity) => !visited.has(entity));
|
|
if (unvisited.length === 0) break;
|
|
|
|
const queries = unvisited.map((entityStr) => {
|
|
const term = stringToTerm(entityStr);
|
|
const request: TriplesQueryRequest = {
|
|
s: term,
|
|
limit: config.tripleLimit,
|
|
...(collection !== undefined ? { collection } : {}),
|
|
};
|
|
return requestClient(clients.triples, "follow-edges-query", request);
|
|
});
|
|
|
|
const results = yield* Effect.all(queries);
|
|
const nextLevel = new Set<string>();
|
|
|
|
for (const result of results) {
|
|
for (const triple of result.triples) {
|
|
subgraph.push(triple);
|
|
|
|
if (depth < config.maxPathLength - 1) {
|
|
const objStr = termToString(triple.o);
|
|
if (!visited.has(objStr)) {
|
|
nextLevel.add(objStr);
|
|
}
|
|
}
|
|
|
|
if (subgraph.length >= config.maxSubgraphSize) {
|
|
return subgraph;
|
|
}
|
|
}
|
|
}
|
|
|
|
for (const entity of currentLevel) {
|
|
visited.add(entity);
|
|
}
|
|
currentLevel = nextLevel;
|
|
}
|
|
|
|
return subgraph.slice(0, config.maxSubgraphSize);
|
|
});
|
|
|
|
const scoreEdges = Effect.fn("GraphRagEngine.scoreEdges")(function* (
|
|
clients: GraphRagClients,
|
|
config: NormalizedGraphRagConfig,
|
|
query: string,
|
|
triples: Triple[],
|
|
) {
|
|
if (triples.length === 0) return [];
|
|
|
|
if (triples.length <= 500) {
|
|
yield* Effect.log(`[GraphRag] Skipping edge scoring - ${triples.length} triples fits in context directly`);
|
|
return triples;
|
|
}
|
|
|
|
const edgeDescriptions = triples.map((triple, index) => ({
|
|
id: String(index),
|
|
s: termToString(triple.s),
|
|
p: termToString(triple.p),
|
|
o: termToString(triple.o),
|
|
}));
|
|
|
|
const toScore = edgeDescriptions.slice(0, config.edgeScoreLimit);
|
|
const knowledgeJson = yield* S.encodeUnknownEffect(S.UnknownFromJsonString)(toScore).pipe(
|
|
Effect.mapError((cause) => graphRagError("edge-score-encode", cause)),
|
|
);
|
|
|
|
const promptResp = yield* requestClient(
|
|
clients.prompt,
|
|
"edge-score-prompt",
|
|
{
|
|
name: "kg-edge-scoring",
|
|
variables: {
|
|
query,
|
|
knowledge: knowledgeJson,
|
|
},
|
|
},
|
|
);
|
|
|
|
const llmResp = yield* requestClient(
|
|
clients.llm,
|
|
"edge-score-llm",
|
|
{
|
|
system: promptResp.system,
|
|
prompt: promptResp.prompt,
|
|
},
|
|
);
|
|
|
|
yield* Effect.log(`[GraphRag] Edge scoring LLM response (first 500 chars): ${llmResp.response.slice(0, 500)}`);
|
|
|
|
const scored = A.sort(
|
|
parseScoredEdges(llmResp.response),
|
|
Order.make<typeof ScoredEdge.Type>((left, right) => Order.Number(right.score, left.score)),
|
|
);
|
|
const topN = scored.slice(0, config.edgeLimit);
|
|
|
|
const result: Triple[] = [];
|
|
for (const entry of topN) {
|
|
const idx = Number.parseInt(entry.id, 10);
|
|
if (!Number.isNaN(idx) && idx >= 0 && idx < triples.length) {
|
|
result.push(triples[idx]);
|
|
}
|
|
}
|
|
|
|
yield* Effect.log(`[GraphRag] Edge scoring: LLM returned ${scored.length} scores, keeping top ${topN.length}, mapped ${result.length} triples`);
|
|
|
|
if (result.length === 0) {
|
|
return triples.slice(0, config.edgeLimit);
|
|
}
|
|
|
|
return result;
|
|
});
|
|
|
|
const synthesize = Effect.fn("GraphRagEngine.synthesize")(function* (
|
|
clients: GraphRagClients,
|
|
query: string,
|
|
edges: Triple[],
|
|
chunkCallback?: ChunkCallback,
|
|
) {
|
|
const context = edges
|
|
.map((triple) => `${termToString(triple.s)} -> ${termToString(triple.p)} -> ${termToString(triple.o)}`)
|
|
.join("\n");
|
|
|
|
const promptResp = yield* requestClient(
|
|
clients.prompt,
|
|
"synthesize-prompt",
|
|
{
|
|
name: "graph-rag-synthesize",
|
|
variables: { query, context },
|
|
},
|
|
);
|
|
|
|
if (chunkCallback !== undefined) {
|
|
let fullText = "";
|
|
yield* requestClient(
|
|
clients.llm,
|
|
"synthesize-stream",
|
|
{
|
|
system: promptResp.system,
|
|
prompt: promptResp.prompt,
|
|
streaming: true,
|
|
},
|
|
{
|
|
recipient: (resp) => {
|
|
if (resp.response.length === 0) {
|
|
return Effect.succeed(resp.endOfStream === true);
|
|
}
|
|
fullText += resp.response;
|
|
return chunkCallback(resp.response, resp.endOfStream === true).pipe(
|
|
Effect.as(resp.endOfStream === true),
|
|
);
|
|
},
|
|
},
|
|
);
|
|
return fullText;
|
|
}
|
|
|
|
const resp = yield* requestClient(
|
|
clients.llm,
|
|
"synthesize-llm",
|
|
{
|
|
system: promptResp.system,
|
|
prompt: promptResp.prompt,
|
|
},
|
|
);
|
|
|
|
return resp.response;
|
|
});
|
|
|
|
const ScoredEdge = S.Struct({
|
|
id: S.String,
|
|
score: S.Finite,
|
|
});
|
|
const ScoredEdgesFromJson = S.Array(ScoredEdge).pipe(S.fromJsonString);
|
|
const ScoredEdgeFromJson = ScoredEdge.pipe(S.fromJsonString);
|
|
const decodeScoredEdges = S.decodeUnknownOption(ScoredEdgesFromJson);
|
|
const decodeScoredEdge = S.decodeUnknownOption(ScoredEdgeFromJson);
|
|
|
|
function parseScoredEdges(responseText: string): Array<typeof ScoredEdge.Type> {
|
|
const parsedArray = decodeScoredEdges(responseText);
|
|
if (O.isSome(parsedArray)) {
|
|
return Array.from(parsedArray.value);
|
|
}
|
|
|
|
const scored: Array<typeof ScoredEdge.Type> = [];
|
|
for (const line of responseText.split("\n")) {
|
|
const trimmed = line.trim();
|
|
if (trimmed.length === 0) continue;
|
|
const parsedLine = decodeScoredEdge(trimmed);
|
|
if (O.isSome(parsedLine)) {
|
|
scored.push(parsedLine.value);
|
|
}
|
|
}
|
|
return scored;
|
|
}
|
|
|
|
export function termToString(term: Term): string {
|
|
return Match.type<Term>().pipe(
|
|
Match.discriminatorsExhaustive("type")({
|
|
IRI: (iri) => iri.iri,
|
|
LITERAL: (literal) => literal.value,
|
|
BLANK: (blank) => `_:${blank.id}`,
|
|
TRIPLE: (triple) =>
|
|
`(${termToString(triple.triple.s)} ${termToString(triple.triple.p)} ${termToString(triple.triple.o)})`,
|
|
}),
|
|
)(term);
|
|
}
|
|
|
|
export function stringToTerm(value: string): Term {
|
|
if (value.startsWith("http://") || value.startsWith("https://")) {
|
|
return { type: "IRI", iri: value };
|
|
}
|
|
if (value.startsWith("_:")) {
|
|
return { type: "BLANK", id: value.slice(2) };
|
|
}
|
|
return { type: "LITERAL", value };
|
|
}
|