trustgraph/ts/packages/flow/src/retrieval/graph-rag.ts
elpresidank ffd97375a8 saving
2026-05-12 08:06:58 -05:00

384 lines
12 KiB
TypeScript

/**
* Graph RAG retrieval pipeline.
*
* This is the core RAG pipeline that:
* 1. Extracts concepts from the query
* 2. Embeds concepts to find matching entities
* 3. Traverses the knowledge graph from those entities
* 4. Scores and filters edges
* 5. Synthesizes an answer with the selected context
*
* Python reference: trustgraph-flow/trustgraph/retrieval/graph_rag/graph_rag.py
*/
import type {
EmbeddingsRequest,
EmbeddingsResponse,
GraphEmbeddingsRequest,
GraphEmbeddingsResponse,
FlowRequestor,
PromptRequest,
PromptResponse,
Term,
TextCompletionRequest,
TextCompletionResponse,
Triple,
TriplesQueryRequest,
TriplesQueryResponse,
} from "@trustgraph/base";
export interface GraphRagConfig {
entityLimit?: number;
tripleLimit?: number;
maxSubgraphSize?: number;
maxPathLength?: number;
edgeScoreLimit?: number;
edgeLimit?: number;
}
export interface GraphRagClients {
llm: FlowRequestor<TextCompletionRequest, TextCompletionResponse>;
embeddings: FlowRequestor<EmbeddingsRequest, EmbeddingsResponse>;
graphEmbeddings: FlowRequestor<GraphEmbeddingsRequest, GraphEmbeddingsResponse>;
triples: FlowRequestor<TriplesQueryRequest, TriplesQueryResponse>;
prompt: FlowRequestor<PromptRequest, PromptResponse>;
}
export type ChunkCallback = (text: string, endOfStream: boolean) => Promise<void>;
export interface GraphRagResult {
answer: string;
subgraph: Triple[];
}
export class GraphRag {
private readonly clients: GraphRagClients;
private config: Required<GraphRagConfig>;
constructor(
clients: GraphRagClients,
config: GraphRagConfig = {},
) {
this.clients = clients;
this.config = {
entityLimit: config.entityLimit ?? 50,
tripleLimit: config.tripleLimit ?? 30,
maxSubgraphSize: config.maxSubgraphSize ?? 1000,
maxPathLength: config.maxPathLength ?? 2,
edgeScoreLimit: config.edgeScoreLimit ?? 50,
edgeLimit: config.edgeLimit ?? 25,
};
}
async query(
queryText: string,
options?: {
collection?: string;
streaming?: boolean;
chunkCallback?: ChunkCallback;
},
): Promise<GraphRagResult> {
console.log(`[GraphRag] Query: "${queryText.slice(0, 80)}..."`);
// Step 1: Extract concepts from the query via prompt + LLM
const concepts = await this.extractConcepts(queryText);
console.log(`[GraphRag] Step 1: extracted ${concepts.length} concepts: ${concepts.slice(0, 5).join(", ")}`);
// Step 2: Embed concepts concurrently
const vectors = await this.getVectors(concepts);
console.log(`[GraphRag] Step 2: got ${vectors.length} vectors (dim=${vectors[0]?.length ?? 0})`);
// Step 3: Find matching entities via graph embeddings
const entities = await this.getEntities(vectors, options?.collection);
console.log(`[GraphRag] Step 3: found ${entities.length} matching entities`);
// Step 4: Traverse the knowledge graph from entities
const subgraph = await this.followEdges(entities, options?.collection);
console.log(`[GraphRag] Step 4: traversed graph, ${subgraph.length} triples in subgraph`);
// Step 5: Score and filter edges via LLM
const scoredEdges = await this.scoreEdges(queryText, subgraph);
console.log(`[GraphRag] Step 5: scored down to ${scoredEdges.length} edges`);
// Step 6: Synthesize answer
console.log(`[GraphRag] Step 6: synthesizing answer from ${scoredEdges.length} edges...`);
const answer = await this.synthesize(
queryText,
scoredEdges,
options?.chunkCallback,
);
console.log(`[GraphRag] Step 6: done (${answer.length} chars)`);
return { answer, subgraph: scoredEdges };
}
private async extractConcepts(query: string): Promise<string[]> {
const promptResp = await this.clients.prompt.request({
name: "extract-concepts",
variables: { query },
});
const llmResp = await this.clients.llm.request({
system: (promptResp as PromptResponse).system,
prompt: (promptResp as PromptResponse).prompt,
});
// Parse concepts from LLM response (newline-separated)
return (llmResp as TextCompletionResponse).response
.split("\n")
.map((c) => c.trim())
.filter((c) => c.length > 0);
}
private async getVectors(concepts: string[]): Promise<number[][]> {
const resp = await this.clients.embeddings.request({ text: concepts });
return (resp as EmbeddingsResponse).vectors;
}
private async getEntities(vectors: number[][], collection?: string): Promise<Term[]> {
const resp = await this.clients.graphEmbeddings.request({
vectors,
user: "default",
collection: collection ?? "default",
limit: this.config.entityLimit,
});
return (resp as GraphEmbeddingsResponse).entities;
}
private async followEdges(entities: Term[], collection?: string): Promise<Triple[]> {
// BFS multi-hop traversal up to maxPathLength
const visited = new Set<string>();
const subgraph: Triple[] = [];
// Current frontier: the set of entities to expand at this depth level
let currentLevel = new Set<string>(
entities.map((e) => termToString(e)),
);
for (let depth = 0; depth < this.config.maxPathLength; depth++) {
if (currentLevel.size === 0 || subgraph.length >= this.config.maxSubgraphSize) {
break;
}
// Filter out already-visited entities
const unvisited = [...currentLevel].filter((e) => !visited.has(e));
if (unvisited.length === 0) break;
// Batch triple queries for all unvisited entities at this depth
// Query each entity as subject to get outgoing edges
const queries = unvisited.map((entityStr) => {
const term = stringToTerm(entityStr);
const request: TriplesQueryRequest = {
s: term,
limit: this.config.tripleLimit,
...(collection !== undefined ? { collection } : {}),
};
return this.clients.triples.request(request);
});
const results = await Promise.all(queries);
const nextLevel = new Set<string>();
for (const result of results) {
const triples = (result as TriplesQueryResponse).triples;
for (const triple of triples) {
subgraph.push(triple);
// Collect objects as next-level entities for further expansion
// (only if we have more depth levels remaining)
if (depth < this.config.maxPathLength - 1) {
const objStr = termToString(triple.o);
if (!visited.has(objStr)) {
nextLevel.add(objStr);
}
}
if (subgraph.length >= this.config.maxSubgraphSize) {
return subgraph;
}
}
}
// Mark current level as visited and move to next
for (const e of currentLevel) {
visited.add(e);
}
currentLevel = nextLevel;
}
return subgraph.slice(0, this.config.maxSubgraphSize);
}
private async scoreEdges(query: string, triples: Triple[]): Promise<Triple[]> {
if (triples.length === 0) return [];
// If the subgraph is small enough, skip LLM scoring entirely
// 500 triples is well within LLM context limits and avoids lossy scoring
if (triples.length <= 500) {
console.log(`[GraphRag] Skipping edge scoring — ${triples.length} triples fits in context directly`);
return triples;
}
// Build a numbered list of edges for the LLM to score
const edgeDescriptions = triples.map((t, i) => ({
id: String(i),
s: termToString(t.s),
p: termToString(t.p),
o: termToString(t.o),
}));
// Limit how many edges we send for scoring to avoid overflowing context
const toScore = edgeDescriptions.slice(0, this.config.edgeScoreLimit);
const knowledgeJson = JSON.stringify(toScore, null, 2);
// Ask the LLM to score each edge for relevance to the query
const promptResp = await this.clients.prompt.request({
name: "kg-edge-scoring",
variables: {
query,
knowledge: knowledgeJson,
},
});
const llmResp = await this.clients.llm.request({
system: (promptResp as PromptResponse).system,
prompt: (promptResp as PromptResponse).prompt,
});
const responseText = (llmResp as TextCompletionResponse).response;
console.log(`[GraphRag] Edge scoring LLM response (first 500 chars): ${responseText.slice(0, 500)}`);
// Parse scores from LLM response
// Expected format: JSON array of { id: string, score: number }
// or newline-separated JSON objects
const scored: Array<{ id: string; score: number }> = [];
try {
// Try parsing as a JSON array first
const parsed = JSON.parse(responseText) as Array<{ id: string; score: number }>;
if (Array.isArray(parsed)) {
for (const item of parsed) {
if (
typeof item === "object" &&
item !== null &&
typeof item.id === "string" &&
typeof item.score === "number"
) {
scored.push({ id: item.id, score: item.score });
}
}
}
} catch {
// Fall back to parsing line-by-line JSON objects
for (const line of responseText.split("\n")) {
const trimmed = line.trim();
if (trimmed.length === 0) continue;
try {
const obj = JSON.parse(trimmed) as { id?: string; score?: number };
if (
typeof obj === "object" &&
obj !== null &&
typeof obj.id === "string" &&
typeof obj.score === "number"
) {
scored.push({ id: obj.id, score: obj.score });
}
} catch {
// Skip unparseable lines
}
}
}
// Sort by score descending and keep top N
scored.sort((a, b) => b.score - a.score);
const topN = scored.slice(0, this.config.edgeLimit);
// Map back to triples
const result: Triple[] = [];
for (const entry of topN) {
const idx = parseInt(entry.id, 10);
if (!isNaN(idx) && idx >= 0 && idx < triples.length) {
result.push(triples[idx]);
}
}
console.log(`[GraphRag] Edge scoring: LLM returned ${scored.length} scores, keeping top ${topN.length}, mapped ${result.length} triples`);
// If scoring failed entirely, fall back to returning the first edgeLimit triples
if (result.length === 0) {
return triples.slice(0, this.config.edgeLimit);
}
return result;
}
private async synthesize(
query: string,
edges: Triple[],
chunkCallback?: ChunkCallback,
): Promise<string> {
// Format edges as context
const context = edges
.map((t) => `${termToString(t.s)} -> ${termToString(t.p)} -> ${termToString(t.o)}`)
.join("\n");
const promptResp = await this.clients.prompt.request({
name: "graph-rag-synthesize",
variables: { query, context },
});
if (chunkCallback !== undefined) {
// Streaming response
let fullText = "";
await this.clients.llm.request(
{
system: (promptResp as PromptResponse).system,
prompt: (promptResp as PromptResponse).prompt,
streaming: true,
},
{
recipient: async (resp) => {
const r = resp as TextCompletionResponse;
if (r.response.length > 0) {
fullText += r.response;
await chunkCallback(r.response, r.endOfStream === true);
}
return r.endOfStream === true;
},
},
);
return fullText;
}
const resp = await this.clients.llm.request({
system: (promptResp as PromptResponse).system,
prompt: (promptResp as PromptResponse).prompt,
});
return (resp as TextCompletionResponse).response;
}
}
function termToString(term: Term): string {
switch (term.type) {
case "IRI":
return term.iri;
case "LITERAL":
return term.value;
case "BLANK":
return `_:${term.id}`;
case "TRIPLE":
return `(${termToString(term.triple.s)} ${termToString(term.triple.p)} ${termToString(term.triple.o)})`;
}
}
function stringToTerm(value: string): Term {
if (value.startsWith("http://") || value.startsWith("https://")) {
return { type: "IRI", iri: value };
}
if (value.startsWith("_:")) {
return { type: "BLANK", id: value.slice(2) };
}
return { type: "LITERAL", value };
}