mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-07-01 09:29:38 +02:00
saving
This commit is contained in:
parent
9e9307a2aa
commit
e26caa0b12
123 changed files with 3478 additions and 10078 deletions
110
ts/packages/flow/src/gateway/dispatch/manager.ts
Normal file
110
ts/packages/flow/src/gateway/dispatch/manager.ts
Normal file
|
|
@ -0,0 +1,110 @@
|
|||
/**
|
||||
* Dispatcher manager — routes requests to backend services via pub/sub.
|
||||
*
|
||||
* Python reference: trustgraph-flow/trustgraph/gateway/dispatch/manager.py
|
||||
*/
|
||||
|
||||
import { NatsBackend, RequestResponse, type PubSubBackend } from "@trustgraph/base";
|
||||
import type { GatewayConfig } from "../server.js";
|
||||
|
||||
export type Responder = (response: unknown, complete: boolean) => Promise<void>;
|
||||
|
||||
export class DispatcherManager {
|
||||
private pubsub: PubSubBackend;
|
||||
private requestors = new Map<string, RequestResponse<unknown, unknown>>();
|
||||
|
||||
constructor(private readonly config: GatewayConfig) {
|
||||
this.pubsub = new NatsBackend(config.natsUrl ?? "nats://localhost:4222");
|
||||
}
|
||||
|
||||
async start(): Promise<void> {
|
||||
// Pre-create requestors for known global services
|
||||
// Flow-specific requestors are created on demand
|
||||
}
|
||||
|
||||
async stop(): Promise<void> {
|
||||
for (const rr of this.requestors.values()) {
|
||||
await rr.stop();
|
||||
}
|
||||
await this.pubsub.close();
|
||||
}
|
||||
|
||||
private async getRequestor(
|
||||
requestTopic: string,
|
||||
responseTopic: string,
|
||||
key: string,
|
||||
): Promise<RequestResponse<unknown, unknown>> {
|
||||
let rr = this.requestors.get(key);
|
||||
if (!rr) {
|
||||
rr = new RequestResponse({
|
||||
pubsub: this.pubsub,
|
||||
requestTopic,
|
||||
responseTopic,
|
||||
subscription: `gateway-${key}`,
|
||||
});
|
||||
await rr.start();
|
||||
this.requestors.set(key, rr);
|
||||
}
|
||||
return rr;
|
||||
}
|
||||
|
||||
async dispatchGlobalService(
|
||||
kind: string,
|
||||
request: Record<string, unknown>,
|
||||
): Promise<unknown> {
|
||||
const requestTopic = `tg.flow.${kind}-request`;
|
||||
const responseTopic = `tg.flow.${kind}-response`;
|
||||
const rr = await this.getRequestor(requestTopic, responseTopic, `global:${kind}`);
|
||||
return rr.request(request);
|
||||
}
|
||||
|
||||
async dispatchFlowService(
|
||||
flow: string,
|
||||
kind: string,
|
||||
request: Record<string, unknown>,
|
||||
): Promise<unknown> {
|
||||
const requestTopic = `tg.flow.${kind}-request`;
|
||||
const responseTopic = `tg.flow.${kind}-response`;
|
||||
const rr = await this.getRequestor(requestTopic, responseTopic, `flow:${flow}:${kind}`);
|
||||
return rr.request(request);
|
||||
}
|
||||
|
||||
async dispatchGlobalServiceStreaming(
|
||||
kind: string,
|
||||
request: Record<string, unknown>,
|
||||
responder: Responder,
|
||||
): Promise<void> {
|
||||
const requestTopic = `tg.flow.${kind}-request`;
|
||||
const responseTopic = `tg.flow.${kind}-response`;
|
||||
const rr = await this.getRequestor(requestTopic, responseTopic, `global:${kind}`);
|
||||
|
||||
await rr.request(request, {
|
||||
recipient: async (response) => {
|
||||
const res = response as Record<string, unknown>;
|
||||
const complete = !!res.complete || !!res.endOfStream || !!res.endOfSession;
|
||||
await responder(res, complete);
|
||||
return complete;
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
async dispatchFlowServiceStreaming(
|
||||
flow: string,
|
||||
kind: string,
|
||||
request: Record<string, unknown>,
|
||||
responder: Responder,
|
||||
): Promise<void> {
|
||||
const requestTopic = `tg.flow.${kind}-request`;
|
||||
const responseTopic = `tg.flow.${kind}-response`;
|
||||
const rr = await this.getRequestor(requestTopic, responseTopic, `flow:${flow}:${kind}`);
|
||||
|
||||
await rr.request(request, {
|
||||
recipient: async (response) => {
|
||||
const res = response as Record<string, unknown>;
|
||||
const complete = !!res.complete || !!res.endOfStream || !!res.endOfSession;
|
||||
await responder(res, complete);
|
||||
return complete;
|
||||
},
|
||||
});
|
||||
}
|
||||
}
|
||||
86
ts/packages/flow/src/gateway/dispatch/mux.ts
Normal file
86
ts/packages/flow/src/gateway/dispatch/mux.ts
Normal file
|
|
@ -0,0 +1,86 @@
|
|||
/**
|
||||
* WebSocket multiplexer — handles concurrent requests over a single connection.
|
||||
*
|
||||
* Python reference: trustgraph-flow/trustgraph/gateway/dispatch/mux.py
|
||||
*/
|
||||
|
||||
import { AsyncQueue } from "@trustgraph/base";
|
||||
|
||||
const MAX_OUTSTANDING = 15;
|
||||
const MAX_QUEUE_SIZE = 10;
|
||||
|
||||
export interface MuxRequest {
|
||||
id: string;
|
||||
service: string;
|
||||
flow?: string;
|
||||
request: Record<string, unknown>;
|
||||
}
|
||||
|
||||
export type MuxHandler = (
|
||||
request: MuxRequest,
|
||||
respond: (response: unknown, complete: boolean) => Promise<void>,
|
||||
) => Promise<void>;
|
||||
|
||||
export class Mux {
|
||||
private queue = new AsyncQueue<MuxRequest>();
|
||||
private outstanding = 0;
|
||||
private running = true;
|
||||
|
||||
constructor(private readonly handler: MuxHandler) {}
|
||||
|
||||
receive(request: MuxRequest): void {
|
||||
if (this.queue.length >= MAX_QUEUE_SIZE) {
|
||||
console.warn("[Mux] Queue full, dropping request:", request.id);
|
||||
return;
|
||||
}
|
||||
this.queue.push(request);
|
||||
}
|
||||
|
||||
async run(send: (data: string) => void): Promise<void> {
|
||||
while (this.running) {
|
||||
if (this.outstanding >= MAX_OUTSTANDING) {
|
||||
await sleep(50);
|
||||
continue;
|
||||
}
|
||||
|
||||
try {
|
||||
const request = await this.queue.pop(1000);
|
||||
this.outstanding++;
|
||||
|
||||
// Fire and forget — error handling inside
|
||||
this.processRequest(request, send).finally(() => {
|
||||
this.outstanding--;
|
||||
});
|
||||
} catch {
|
||||
// Timeout on queue pop — just loop
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
stop(): void {
|
||||
this.running = false;
|
||||
}
|
||||
|
||||
private async processRequest(
|
||||
request: MuxRequest,
|
||||
send: (data: string) => void,
|
||||
): Promise<void> {
|
||||
try {
|
||||
await this.handler(request, async (response, complete) => {
|
||||
send(JSON.stringify({ id: request.id, response, complete }));
|
||||
});
|
||||
} catch (err) {
|
||||
send(
|
||||
JSON.stringify({
|
||||
id: request.id,
|
||||
error: { type: "internal", message: String(err) },
|
||||
complete: true,
|
||||
}),
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
function sleep(ms: number): Promise<void> {
|
||||
return new Promise((resolve) => setTimeout(resolve, ms));
|
||||
}
|
||||
3
ts/packages/flow/src/gateway/index.ts
Normal file
3
ts/packages/flow/src/gateway/index.ts
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
export { createGateway, run, type GatewayConfig } from "./server.js";
|
||||
export { DispatcherManager } from "./dispatch/manager.js";
|
||||
export { Mux, type MuxRequest, type MuxHandler } from "./dispatch/mux.js";
|
||||
136
ts/packages/flow/src/gateway/server.ts
Normal file
136
ts/packages/flow/src/gateway/server.ts
Normal file
|
|
@ -0,0 +1,136 @@
|
|||
/**
|
||||
* API Gateway — HTTP + WebSocket server.
|
||||
*
|
||||
* Replaces the Python aiohttp gateway with Fastify.
|
||||
*
|
||||
* Python reference: trustgraph-flow/trustgraph/gateway/service.py
|
||||
*/
|
||||
|
||||
import Fastify from "fastify";
|
||||
import websocketPlugin from "@fastify/websocket";
|
||||
import { DispatcherManager } from "./dispatch/manager.js";
|
||||
|
||||
export interface GatewayConfig {
|
||||
port: number;
|
||||
metricsPort: number;
|
||||
secret?: string;
|
||||
natsUrl?: string;
|
||||
}
|
||||
|
||||
export async function createGateway(config: GatewayConfig) {
|
||||
const app = Fastify({ logger: true });
|
||||
await app.register(websocketPlugin);
|
||||
|
||||
const dispatcher = new DispatcherManager(config);
|
||||
await dispatcher.start();
|
||||
|
||||
// Authentication middleware
|
||||
app.addHook("onRequest", async (request, reply) => {
|
||||
if (request.url === "/api/v1/metrics") return;
|
||||
if (request.url === "/api/v1/socket") return; // Socket auth via query param
|
||||
|
||||
if (config.secret) {
|
||||
const auth = request.headers.authorization;
|
||||
if (!auth || auth !== `Bearer ${config.secret}`) {
|
||||
reply.code(401).send({ error: "Unauthorized" });
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// REST endpoint: POST /api/v1/:kind
|
||||
app.post<{ Params: { kind: string } }>("/api/v1/:kind", async (request, reply) => {
|
||||
const { kind } = request.params;
|
||||
const body = request.body as Record<string, unknown>;
|
||||
|
||||
try {
|
||||
const result = await dispatcher.dispatchGlobalService(kind, body);
|
||||
return result;
|
||||
} catch (err) {
|
||||
reply.code(500).send({ error: { type: "internal", message: String(err) } });
|
||||
}
|
||||
});
|
||||
|
||||
// REST endpoint: POST /api/v1/flow/:flow/service/:kind
|
||||
app.post<{ Params: { flow: string; kind: string } }>(
|
||||
"/api/v1/flow/:flow/service/:kind",
|
||||
async (request, reply) => {
|
||||
const { flow, kind } = request.params;
|
||||
const body = request.body as Record<string, unknown>;
|
||||
|
||||
try {
|
||||
const result = await dispatcher.dispatchFlowService(flow, kind, body);
|
||||
return result;
|
||||
} catch (err) {
|
||||
reply.code(500).send({ error: { type: "internal", message: String(err) } });
|
||||
}
|
||||
},
|
||||
);
|
||||
|
||||
// WebSocket endpoint: /api/v1/socket
|
||||
app.get("/api/v1/socket", { websocket: true }, (socket, request) => {
|
||||
// Auth via query param
|
||||
const url = new URL(request.url, `http://${request.headers.host}`);
|
||||
const token = url.searchParams.get("token");
|
||||
if (config.secret && token !== config.secret) {
|
||||
socket.close(4001, "Unauthorized");
|
||||
return;
|
||||
}
|
||||
|
||||
socket.on("message", async (data) => {
|
||||
try {
|
||||
const msg = JSON.parse(data.toString());
|
||||
const { id, service, flow, request: req } = msg;
|
||||
|
||||
const responder = async (response: unknown, complete: boolean) => {
|
||||
socket.send(JSON.stringify({ id, response, complete }));
|
||||
};
|
||||
|
||||
if (flow) {
|
||||
await dispatcher.dispatchFlowServiceStreaming(flow, service, req, responder);
|
||||
} else {
|
||||
await dispatcher.dispatchGlobalServiceStreaming(service, req, responder);
|
||||
}
|
||||
} catch (err) {
|
||||
const msg = JSON.parse(data.toString());
|
||||
socket.send(
|
||||
JSON.stringify({
|
||||
id: msg.id,
|
||||
error: { type: "internal", message: String(err) },
|
||||
complete: true,
|
||||
}),
|
||||
);
|
||||
}
|
||||
});
|
||||
|
||||
socket.on("close", () => {
|
||||
// Cleanup
|
||||
});
|
||||
});
|
||||
|
||||
// Metrics endpoint
|
||||
app.get("/api/v1/metrics", async () => {
|
||||
const { registry } = await import("@trustgraph/base");
|
||||
return registry.metrics();
|
||||
});
|
||||
|
||||
return {
|
||||
start: () => app.listen({ port: config.port, host: "0.0.0.0" }),
|
||||
stop: async () => {
|
||||
await app.close();
|
||||
await dispatcher.stop();
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
export async function run(): Promise<void> {
|
||||
const config: GatewayConfig = {
|
||||
port: parseInt(process.env.GATEWAY_PORT ?? "8088", 10),
|
||||
metricsPort: parseInt(process.env.METRICS_PORT ?? "8000", 10),
|
||||
secret: process.env.GATEWAY_SECRET,
|
||||
natsUrl: process.env.NATS_URL,
|
||||
};
|
||||
|
||||
const gateway = await createGateway(config);
|
||||
await gateway.start();
|
||||
console.log(`[Gateway] Listening on port ${config.port}`);
|
||||
}
|
||||
7
ts/packages/flow/src/index.ts
Normal file
7
ts/packages/flow/src/index.ts
Normal file
|
|
@ -0,0 +1,7 @@
|
|||
// @trustgraph/flow — processing services
|
||||
|
||||
export { createGateway, type GatewayConfig } from "./gateway/index.js";
|
||||
export { OpenAIProcessor } from "./model/text-completion/openai.js";
|
||||
export { ClaudeProcessor } from "./model/text-completion/claude.js";
|
||||
export { GraphRag, type GraphRagConfig, type GraphRagClients } from "./retrieval/graph-rag.js";
|
||||
export { DocumentRag, type DocumentRagClients } from "./retrieval/document-rag.js";
|
||||
129
ts/packages/flow/src/model/text-completion/claude.ts
Normal file
129
ts/packages/flow/src/model/text-completion/claude.ts
Normal file
|
|
@ -0,0 +1,129 @@
|
|||
/**
|
||||
* Anthropic Claude text completion service.
|
||||
*
|
||||
* Python reference: trustgraph-flow/trustgraph/model/text_completion/claude/llm.py
|
||||
*/
|
||||
|
||||
import Anthropic from "@anthropic-ai/sdk";
|
||||
import { LlmService, type ProcessorConfig, type LlmResult, type LlmChunk, TooManyRequestsError } from "@trustgraph/base";
|
||||
|
||||
export class ClaudeProcessor extends LlmService {
|
||||
private client: Anthropic;
|
||||
private defaultModel: string;
|
||||
private defaultTemperature: number;
|
||||
private maxOutput: number;
|
||||
|
||||
constructor(config: ProcessorConfig & {
|
||||
model?: string;
|
||||
apiKey?: string;
|
||||
temperature?: number;
|
||||
maxOutput?: number;
|
||||
}) {
|
||||
super(config);
|
||||
|
||||
this.defaultModel = config.model ?? "claude-sonnet-4-20250514";
|
||||
this.defaultTemperature = config.temperature ?? 0.0;
|
||||
this.maxOutput = config.maxOutput ?? 8192;
|
||||
|
||||
const apiKey = config.apiKey ?? process.env.CLAUDE_KEY;
|
||||
if (!apiKey) throw new Error("Claude API key not specified");
|
||||
|
||||
this.client = new Anthropic({ apiKey });
|
||||
|
||||
console.log("[Claude] LLM service initialized");
|
||||
}
|
||||
|
||||
async generateContent(
|
||||
system: string,
|
||||
prompt: string,
|
||||
model?: string,
|
||||
temperature?: number,
|
||||
): Promise<LlmResult> {
|
||||
const modelName = model ?? this.defaultModel;
|
||||
const temp = temperature ?? this.defaultTemperature;
|
||||
|
||||
try {
|
||||
const response = await this.client.messages.create({
|
||||
model: modelName,
|
||||
max_tokens: this.maxOutput,
|
||||
temperature: temp,
|
||||
system,
|
||||
messages: [
|
||||
{ role: "user", content: prompt },
|
||||
],
|
||||
});
|
||||
|
||||
const text = response.content[0].type === "text"
|
||||
? response.content[0].text
|
||||
: "";
|
||||
|
||||
return {
|
||||
text,
|
||||
inToken: response.usage.input_tokens,
|
||||
outToken: response.usage.output_tokens,
|
||||
model: modelName,
|
||||
};
|
||||
} catch (err) {
|
||||
if (err instanceof Anthropic.RateLimitError) {
|
||||
throw new TooManyRequestsError();
|
||||
}
|
||||
throw err;
|
||||
}
|
||||
}
|
||||
|
||||
override supportsStreaming(): boolean {
|
||||
return true;
|
||||
}
|
||||
|
||||
async *generateContentStream(
|
||||
system: string,
|
||||
prompt: string,
|
||||
model?: string,
|
||||
temperature?: number,
|
||||
): AsyncGenerator<LlmChunk> {
|
||||
const modelName = model ?? this.defaultModel;
|
||||
const temp = temperature ?? this.defaultTemperature;
|
||||
|
||||
try {
|
||||
const stream = this.client.messages.stream({
|
||||
model: modelName,
|
||||
max_tokens: this.maxOutput,
|
||||
temperature: temp,
|
||||
system,
|
||||
messages: [
|
||||
{ role: "user", content: prompt },
|
||||
],
|
||||
});
|
||||
|
||||
for await (const event of stream) {
|
||||
if (event.type === "content_block_delta" && event.delta.type === "text_delta") {
|
||||
yield {
|
||||
text: event.delta.text,
|
||||
inToken: null,
|
||||
outToken: null,
|
||||
model: modelName,
|
||||
isFinal: false,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
const finalMessage = await stream.finalMessage();
|
||||
yield {
|
||||
text: "",
|
||||
inToken: finalMessage.usage.input_tokens,
|
||||
outToken: finalMessage.usage.output_tokens,
|
||||
model: modelName,
|
||||
isFinal: true,
|
||||
};
|
||||
} catch (err) {
|
||||
if (err instanceof Anthropic.RateLimitError) {
|
||||
throw new TooManyRequestsError();
|
||||
}
|
||||
throw err;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
export async function run(): Promise<void> {
|
||||
await ClaudeProcessor.launch("text-completion");
|
||||
}
|
||||
138
ts/packages/flow/src/model/text-completion/openai.ts
Normal file
138
ts/packages/flow/src/model/text-completion/openai.ts
Normal file
|
|
@ -0,0 +1,138 @@
|
|||
/**
|
||||
* OpenAI text completion service.
|
||||
*
|
||||
* Python reference: trustgraph-flow/trustgraph/model/text_completion/openai/llm.py
|
||||
*/
|
||||
|
||||
import OpenAI from "openai";
|
||||
import { LlmService, type ProcessorConfig, type LlmResult, type LlmChunk, TooManyRequestsError } from "@trustgraph/base";
|
||||
|
||||
export class OpenAIProcessor extends LlmService {
|
||||
private client: OpenAI;
|
||||
private defaultModel: string;
|
||||
private defaultTemperature: number;
|
||||
private maxOutput: number;
|
||||
|
||||
constructor(config: ProcessorConfig & {
|
||||
model?: string;
|
||||
apiKey?: string;
|
||||
baseUrl?: string;
|
||||
temperature?: number;
|
||||
maxOutput?: number;
|
||||
}) {
|
||||
super(config);
|
||||
|
||||
this.defaultModel = config.model ?? "gpt-4o";
|
||||
this.defaultTemperature = config.temperature ?? 0.0;
|
||||
this.maxOutput = config.maxOutput ?? 4096;
|
||||
|
||||
const apiKey = config.apiKey ?? process.env.OPENAI_TOKEN;
|
||||
if (!apiKey) throw new Error("OpenAI API key not specified");
|
||||
|
||||
this.client = new OpenAI({
|
||||
apiKey,
|
||||
baseURL: config.baseUrl ?? process.env.OPENAI_BASE_URL,
|
||||
});
|
||||
|
||||
console.log("[OpenAI] LLM service initialized");
|
||||
}
|
||||
|
||||
async generateContent(
|
||||
system: string,
|
||||
prompt: string,
|
||||
model?: string,
|
||||
temperature?: number,
|
||||
): Promise<LlmResult> {
|
||||
const modelName = model ?? this.defaultModel;
|
||||
const temp = temperature ?? this.defaultTemperature;
|
||||
|
||||
try {
|
||||
const resp = await this.client.chat.completions.create({
|
||||
model: modelName,
|
||||
messages: [
|
||||
{ role: "system", content: system },
|
||||
{ role: "user", content: prompt },
|
||||
],
|
||||
temperature: temp,
|
||||
max_completion_tokens: this.maxOutput,
|
||||
});
|
||||
|
||||
return {
|
||||
text: resp.choices[0].message.content ?? "",
|
||||
inToken: resp.usage?.prompt_tokens ?? 0,
|
||||
outToken: resp.usage?.completion_tokens ?? 0,
|
||||
model: modelName,
|
||||
};
|
||||
} catch (err) {
|
||||
if (err instanceof OpenAI.RateLimitError) {
|
||||
throw new TooManyRequestsError();
|
||||
}
|
||||
throw err;
|
||||
}
|
||||
}
|
||||
|
||||
override supportsStreaming(): boolean {
|
||||
return true;
|
||||
}
|
||||
|
||||
async *generateContentStream(
|
||||
system: string,
|
||||
prompt: string,
|
||||
model?: string,
|
||||
temperature?: number,
|
||||
): AsyncGenerator<LlmChunk> {
|
||||
const modelName = model ?? this.defaultModel;
|
||||
const temp = temperature ?? this.defaultTemperature;
|
||||
|
||||
try {
|
||||
const stream = await this.client.chat.completions.create({
|
||||
model: modelName,
|
||||
messages: [
|
||||
{ role: "system", content: system },
|
||||
{ role: "user", content: prompt },
|
||||
],
|
||||
temperature: temp,
|
||||
max_completion_tokens: this.maxOutput,
|
||||
stream: true,
|
||||
stream_options: { include_usage: true },
|
||||
});
|
||||
|
||||
let totalInputTokens = 0;
|
||||
let totalOutputTokens = 0;
|
||||
|
||||
for await (const chunk of stream) {
|
||||
if (chunk.choices?.[0]?.delta?.content) {
|
||||
yield {
|
||||
text: chunk.choices[0].delta.content,
|
||||
inToken: null,
|
||||
outToken: null,
|
||||
model: modelName,
|
||||
isFinal: false,
|
||||
};
|
||||
}
|
||||
|
||||
if (chunk.usage) {
|
||||
totalInputTokens = chunk.usage.prompt_tokens;
|
||||
totalOutputTokens = chunk.usage.completion_tokens;
|
||||
}
|
||||
}
|
||||
|
||||
yield {
|
||||
text: "",
|
||||
inToken: totalInputTokens,
|
||||
outToken: totalOutputTokens,
|
||||
model: modelName,
|
||||
isFinal: true,
|
||||
};
|
||||
} catch (err) {
|
||||
if (err instanceof OpenAI.RateLimitError) {
|
||||
throw new TooManyRequestsError();
|
||||
}
|
||||
throw err;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
export async function run(): Promise<void> {
|
||||
await OpenAIProcessor.launch("text-completion");
|
||||
}
|
||||
66
ts/packages/flow/src/retrieval/document-rag.ts
Normal file
66
ts/packages/flow/src/retrieval/document-rag.ts
Normal file
|
|
@ -0,0 +1,66 @@
|
|||
/**
|
||||
* Document RAG retrieval pipeline.
|
||||
*
|
||||
* Simpler than Graph RAG — embeds the query, finds similar document chunks,
|
||||
* and synthesizes an answer from the chunk content.
|
||||
*
|
||||
* Python reference: trustgraph-flow/trustgraph/retrieval/document_rag/
|
||||
*/
|
||||
|
||||
import type {
|
||||
RequestResponse,
|
||||
TextCompletionRequest,
|
||||
TextCompletionResponse,
|
||||
EmbeddingsRequest,
|
||||
EmbeddingsResponse,
|
||||
PromptRequest,
|
||||
PromptResponse,
|
||||
} from "@trustgraph/base";
|
||||
|
||||
export interface DocumentRagClients {
|
||||
llm: RequestResponse<TextCompletionRequest, TextCompletionResponse>;
|
||||
embeddings: RequestResponse<EmbeddingsRequest, EmbeddingsResponse>;
|
||||
docEmbeddings: RequestResponse<unknown, unknown>; // Doc embedding query
|
||||
prompt: RequestResponse<PromptRequest, PromptResponse>;
|
||||
}
|
||||
|
||||
export type ChunkCallback = (text: string, endOfStream: boolean) => Promise<void>;
|
||||
|
||||
export class DocumentRag {
|
||||
constructor(private readonly clients: DocumentRagClients) {}
|
||||
|
||||
async query(
|
||||
queryText: string,
|
||||
options?: {
|
||||
collection?: string;
|
||||
streaming?: boolean;
|
||||
chunkCallback?: ChunkCallback;
|
||||
},
|
||||
): Promise<string> {
|
||||
// Step 1: Embed the query
|
||||
const embResp = await this.clients.embeddings.request({ text: [queryText] });
|
||||
const vectors = (embResp as EmbeddingsResponse).vectors;
|
||||
|
||||
// Step 2: Find similar document chunks
|
||||
const docResp = await this.clients.docEmbeddings.request({ vectors, limit: 10 });
|
||||
const chunks = docResp as { chunks: Array<{ content: string; document: string }> };
|
||||
|
||||
// Step 3: Build context from chunks
|
||||
const context = (chunks.chunks ?? [])
|
||||
.map((c) => c.content)
|
||||
.join("\n\n---\n\n");
|
||||
|
||||
// Step 4: Synthesize answer
|
||||
const promptResp = await this.clients.prompt.request({
|
||||
name: "document-rag-synthesize",
|
||||
variables: { query: queryText, context },
|
||||
});
|
||||
|
||||
const resp = await this.clients.llm.request({
|
||||
system: (promptResp as PromptResponse).system,
|
||||
prompt: (promptResp as PromptResponse).prompt,
|
||||
});
|
||||
|
||||
return (resp as TextCompletionResponse).response;
|
||||
}
|
||||
}
|
||||
207
ts/packages/flow/src/retrieval/graph-rag.ts
Normal file
207
ts/packages/flow/src/retrieval/graph-rag.ts
Normal file
|
|
@ -0,0 +1,207 @@
|
|||
/**
|
||||
* Graph RAG retrieval pipeline.
|
||||
*
|
||||
* This is the core RAG pipeline that:
|
||||
* 1. Extracts concepts from the query
|
||||
* 2. Embeds concepts to find matching entities
|
||||
* 3. Traverses the knowledge graph from those entities
|
||||
* 4. Scores and filters edges
|
||||
* 5. Synthesizes an answer with the selected context
|
||||
*
|
||||
* Python reference: trustgraph-flow/trustgraph/retrieval/graph_rag/graph_rag.py
|
||||
*/
|
||||
|
||||
import type {
|
||||
RequestResponse,
|
||||
TextCompletionRequest,
|
||||
TextCompletionResponse,
|
||||
EmbeddingsRequest,
|
||||
EmbeddingsResponse,
|
||||
GraphEmbeddingsRequest,
|
||||
GraphEmbeddingsResponse,
|
||||
TriplesQueryRequest,
|
||||
TriplesQueryResponse,
|
||||
PromptRequest,
|
||||
PromptResponse,
|
||||
Term,
|
||||
Triple,
|
||||
} from "@trustgraph/base";
|
||||
|
||||
export interface GraphRagConfig {
|
||||
entityLimit?: number;
|
||||
tripleLimit?: number;
|
||||
maxSubgraphSize?: number;
|
||||
maxPathLength?: number;
|
||||
edgeScoreLimit?: number;
|
||||
edgeLimit?: number;
|
||||
}
|
||||
|
||||
export interface GraphRagClients {
|
||||
llm: RequestResponse<TextCompletionRequest, TextCompletionResponse>;
|
||||
embeddings: RequestResponse<EmbeddingsRequest, EmbeddingsResponse>;
|
||||
graphEmbeddings: RequestResponse<GraphEmbeddingsRequest, GraphEmbeddingsResponse>;
|
||||
triples: RequestResponse<TriplesQueryRequest, TriplesQueryResponse>;
|
||||
prompt: RequestResponse<PromptRequest, PromptResponse>;
|
||||
}
|
||||
|
||||
export type ChunkCallback = (text: string, endOfStream: boolean) => Promise<void>;
|
||||
|
||||
export class GraphRag {
|
||||
private config: Required<GraphRagConfig>;
|
||||
|
||||
constructor(
|
||||
private readonly clients: GraphRagClients,
|
||||
config: GraphRagConfig = {},
|
||||
) {
|
||||
this.config = {
|
||||
entityLimit: config.entityLimit ?? 50,
|
||||
tripleLimit: config.tripleLimit ?? 30,
|
||||
maxSubgraphSize: config.maxSubgraphSize ?? 1000,
|
||||
maxPathLength: config.maxPathLength ?? 2,
|
||||
edgeScoreLimit: config.edgeScoreLimit ?? 30,
|
||||
edgeLimit: config.edgeLimit ?? 25,
|
||||
};
|
||||
}
|
||||
|
||||
async query(
|
||||
queryText: string,
|
||||
options?: {
|
||||
collection?: string;
|
||||
streaming?: boolean;
|
||||
chunkCallback?: ChunkCallback;
|
||||
},
|
||||
): Promise<string> {
|
||||
// Step 1: Extract concepts from the query via prompt + LLM
|
||||
const concepts = await this.extractConcepts(queryText);
|
||||
|
||||
// Step 2: Embed concepts concurrently
|
||||
const vectors = await this.getVectors(concepts);
|
||||
|
||||
// Step 3: Find matching entities via graph embeddings
|
||||
const entities = await this.getEntities(vectors);
|
||||
|
||||
// Step 4: Traverse the knowledge graph from entities
|
||||
const subgraph = await this.followEdges(entities);
|
||||
|
||||
// Step 5: Score and filter edges via LLM
|
||||
const scoredEdges = await this.scoreEdges(queryText, subgraph);
|
||||
|
||||
// Step 6: Synthesize answer
|
||||
const answer = await this.synthesize(queryText, scoredEdges, options?.chunkCallback);
|
||||
|
||||
return answer;
|
||||
}
|
||||
|
||||
private async extractConcepts(query: string): Promise<string[]> {
|
||||
const promptResp = await this.clients.prompt.request({
|
||||
name: "extract-concepts",
|
||||
variables: { query },
|
||||
});
|
||||
|
||||
const llmResp = await this.clients.llm.request({
|
||||
system: (promptResp as PromptResponse).system,
|
||||
prompt: (promptResp as PromptResponse).prompt,
|
||||
});
|
||||
|
||||
// Parse concepts from LLM response (newline-separated)
|
||||
return (llmResp as TextCompletionResponse).response
|
||||
.split("\n")
|
||||
.map((c) => c.trim())
|
||||
.filter(Boolean);
|
||||
}
|
||||
|
||||
private async getVectors(concepts: string[]): Promise<number[][]> {
|
||||
const resp = await this.clients.embeddings.request({ text: concepts });
|
||||
return (resp as EmbeddingsResponse).vectors;
|
||||
}
|
||||
|
||||
private async getEntities(vectors: number[][]): Promise<Term[]> {
|
||||
const resp = await this.clients.graphEmbeddings.request({
|
||||
vectors,
|
||||
limit: this.config.entityLimit,
|
||||
});
|
||||
return (resp as GraphEmbeddingsResponse).entities;
|
||||
}
|
||||
|
||||
private async followEdges(entities: Term[]): Promise<Triple[]> {
|
||||
// Batch triple queries for all entities
|
||||
const allTriples: Triple[] = [];
|
||||
|
||||
const queries = entities.map((entity) =>
|
||||
this.clients.triples.request({ s: entity, limit: this.config.tripleLimit }),
|
||||
);
|
||||
|
||||
const results = await Promise.all(queries);
|
||||
for (const result of results) {
|
||||
allTriples.push(...(result as TriplesQueryResponse).triples);
|
||||
}
|
||||
|
||||
// TODO: Multi-hop traversal up to maxPathLength
|
||||
return allTriples.slice(0, this.config.maxSubgraphSize);
|
||||
}
|
||||
|
||||
private async scoreEdges(query: string, triples: Triple[]): Promise<Triple[]> {
|
||||
// TODO: LLM-based edge scoring and filtering
|
||||
// For now, return top N edges
|
||||
return triples.slice(0, this.config.edgeLimit);
|
||||
}
|
||||
|
||||
private async synthesize(
|
||||
query: string,
|
||||
edges: Triple[],
|
||||
chunkCallback?: ChunkCallback,
|
||||
): Promise<string> {
|
||||
// Format edges as context
|
||||
const context = edges
|
||||
.map((t) => `${termToString(t.s)} -> ${termToString(t.p)} -> ${termToString(t.o)}`)
|
||||
.join("\n");
|
||||
|
||||
const promptResp = await this.clients.prompt.request({
|
||||
name: "graph-rag-synthesize",
|
||||
variables: { query, context },
|
||||
});
|
||||
|
||||
if (chunkCallback) {
|
||||
// Streaming response
|
||||
let fullText = "";
|
||||
await this.clients.llm.request(
|
||||
{
|
||||
system: (promptResp as PromptResponse).system,
|
||||
prompt: (promptResp as PromptResponse).prompt,
|
||||
streaming: true,
|
||||
},
|
||||
{
|
||||
recipient: async (resp) => {
|
||||
const r = resp as TextCompletionResponse;
|
||||
if (r.response) {
|
||||
fullText += r.response;
|
||||
await chunkCallback(r.response, !!r.endOfStream);
|
||||
}
|
||||
return !!r.endOfStream;
|
||||
},
|
||||
},
|
||||
);
|
||||
return fullText;
|
||||
}
|
||||
|
||||
const resp = await this.clients.llm.request({
|
||||
system: (promptResp as PromptResponse).system,
|
||||
prompt: (promptResp as PromptResponse).prompt,
|
||||
});
|
||||
|
||||
return (resp as TextCompletionResponse).response;
|
||||
}
|
||||
}
|
||||
|
||||
function termToString(term: Term): string {
|
||||
switch (term.type) {
|
||||
case "IRI":
|
||||
return term.iri;
|
||||
case "LITERAL":
|
||||
return term.value;
|
||||
case "BLANK":
|
||||
return `_:${term.id}`;
|
||||
case "TRIPLE":
|
||||
return `(${termToString(term.triple.s)} ${termToString(term.triple.p)} ${termToString(term.triple.o)})`;
|
||||
}
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue