diff --git a/ts/packages/base/src/__tests__/nats-backend.test.ts b/ts/packages/base/src/__tests__/nats-backend.test.ts index aedb6756..a009c62a 100644 --- a/ts/packages/base/src/__tests__/nats-backend.test.ts +++ b/ts/packages/base/src/__tests__/nats-backend.test.ts @@ -1,6 +1,6 @@ import { beforeEach, describe, expect, it, vi } from "vitest"; import { Effect } from "effect"; -import { makeNatsBackend } from "../backend/nats.js"; +import { makeNatsBackend, makeNatsBackendScoped } from "../backend/nats.js"; const natsMock = vi.hoisted(() => { const S = require("effect/Schema"); @@ -252,4 +252,17 @@ describe("NATS backend", () => { operation: "negative-acknowledge:tg.test.topic", }); }); + + it("drains the connection when a scoped NATS backend closes", async () => { + await Effect.runPromise( + Effect.scoped( + Effect.gen(function* () { + const backend = yield* makeNatsBackendScoped("nats://test"); + yield* backend.createProducer({ topic: "tg.test.topic" }); + }), + ), + ); + + expect(natsMock.drain).toHaveBeenCalledTimes(1); + }); }); diff --git a/ts/packages/base/src/backend/index.ts b/ts/packages/base/src/backend/index.ts index 0e664f77..b2ff7f56 100644 --- a/ts/packages/base/src/backend/index.ts +++ b/ts/packages/base/src/backend/index.ts @@ -9,7 +9,7 @@ export type { InitialPosition, } from "./types.js"; -export { makeNatsBackend } from "./nats.js"; +export { makeNatsBackend, makeNatsBackendScoped } from "./nats.js"; export { PubSub, NatsPubSubLive, diff --git a/ts/packages/base/src/backend/nats.ts b/ts/packages/base/src/backend/nats.ts index 29d4cae7..2c41ae61 100644 --- a/ts/packages/base/src/backend/nats.ts +++ b/ts/packages/base/src/backend/nats.ts @@ -384,3 +384,17 @@ export function makeNatsBackend(url = "nats://localhost:4222"): PubSubBackend { }), }; } + +export const makeNatsBackendScoped = (url = "nats://localhost:4222") => + Effect.acquireRelease( + Effect.sync(() => makeNatsBackend(url)), + (backend) => + backend.close.pipe( + Effect.catch((error) => + Effect.logError("[NatsBackend] Failed to close scoped backend", { + error: error.message, + operation: error.operation, + }) + ), + ), + ); diff --git a/ts/packages/base/src/backend/pubsub.ts b/ts/packages/base/src/backend/pubsub.ts index 40f59171..8d3d0a45 100644 --- a/ts/packages/base/src/backend/pubsub.ts +++ b/ts/packages/base/src/backend/pubsub.ts @@ -14,7 +14,7 @@ import type { CreateProducerOptions, PubSubBackend, } from "./types.js"; -import { makeNatsBackend } from "./nats.js"; +import { makeNatsBackendScoped } from "./nats.js"; import type { PubSubError } from "../errors.js"; export interface PubSubService { @@ -49,9 +49,9 @@ export function makePubSubService(backend: PubSubBackend): PubSubService { export function pubSubLayer(backend: PubSubBackend): Layer.Layer { return Layer.effect(PubSub)( - Effect.gen(function* () { - const service = makePubSubService(backend); - yield* Effect.addFinalizer(() => + Effect.acquireRelease( + Effect.sync(() => makePubSubService(backend)), + (service) => service.close.pipe( Effect.catch((error) => Effect.logError("[PubSub] Failed to close backend", { @@ -59,32 +59,28 @@ export function pubSubLayer(backend: PubSubBackend): Layer.Layer { operation: error.operation, }), ), - ), - ); - return PubSub.of(service); - }), + ) + ).pipe( + Effect.map(PubSub.of), + ), ); } export function makeNatsPubSubLayer(url = "nats://localhost:4222"): Layer.Layer { - return pubSubLayer(makeNatsBackend(url)); + return Layer.effect(PubSub)( + makeNatsBackendScoped(url).pipe( + Effect.map(makePubSubService), + Effect.map(PubSub.of), + ), + ); } export const NatsPubSubLive = Layer.effect(PubSub)( Effect.gen(function* () { const natsUrl = O.getOrUndefined(yield* Config.string("NATS_URL").pipe(Config.option)); const pulsarHost = O.getOrUndefined(yield* Config.string("PULSAR_HOST").pipe(Config.option)); - const service = makePubSubService(makeNatsBackend(natsUrl ?? pulsarHost ?? "nats://localhost:4222")); - yield* Effect.addFinalizer(() => - service.close.pipe( - Effect.catch((error) => - Effect.logError("[PubSub] Failed to close NATS backend", { - error: error.message, - operation: error.operation, - }), - ), - ), - ); + const backend = yield* makeNatsBackendScoped(natsUrl ?? pulsarHost ?? "nats://localhost:4222"); + const service = makePubSubService(backend); return PubSub.of(service); }), ); diff --git a/ts/packages/cli/src/__tests__/library.test.ts b/ts/packages/cli/src/__tests__/library.test.ts index b65f4565..043e9312 100644 --- a/ts/packages/cli/src/__tests__/library.test.ts +++ b/ts/packages/cli/src/__tests__/library.test.ts @@ -2,6 +2,25 @@ import { describe, expect, it } from "vitest"; import { guessMimeType } from "../commands/library.js"; describe("library CLI helpers", () => { + it("keeps library load -t assigned to title while token remains long-only", () => { + const result = Bun.spawnSync({ + cmd: ["bun", "src/index.ts", "library", "load", "--help"], + cwd: new URL("../../", import.meta.url).pathname, + stdout: "pipe", + stderr: "pipe", + }); + + const stdout = result.stdout.toString(); + const stderr = result.stderr.toString(); + + expect(result.exitCode).toBe(0); + expect(stderr).toBe(""); + expect(stdout).toContain("--title"); + expect(stdout).toContain("-t"); + expect(stdout).toContain("--token"); + expect(stdout).not.toContain("-t, --token"); + }); + it("detects known MIME types through the Match-backed extension mapper", () => { expect(guessMimeType("paper.pdf")).toBe("application/pdf"); expect(guessMimeType("notes.TXT")).toBe("text/plain"); diff --git a/ts/packages/cli/src/commands/util.ts b/ts/packages/cli/src/commands/util.ts index 18d28904..97c1b921 100644 --- a/ts/packages/cli/src/commands/util.ts +++ b/ts/packages/cli/src/commands/util.ts @@ -30,7 +30,6 @@ export const rootCommand = Command.make("tg").pipe( Flag.withDefault("cli"), ), token: Flag.string("token").pipe( - Flag.withAlias("t"), Flag.withDescription("Authentication token"), Flag.optional, ), diff --git a/ts/packages/flow/src/__tests__/gateway-routes.test.ts b/ts/packages/flow/src/__tests__/gateway-routes.test.ts new file mode 100644 index 00000000..e2180785 --- /dev/null +++ b/ts/packages/flow/src/__tests__/gateway-routes.test.ts @@ -0,0 +1,214 @@ +import { describe, expect, it } from "@effect/vitest"; +import { Effect, Exit, Scope } from "effect"; +import { HttpRouter, HttpServerResponse } from "effect/unstable/http"; +import type { DispatcherManager } from "../gateway/dispatch/manager.js"; +import type { GatewayRpcServer } from "../gateway/rpc-server.js"; +import { makeGatewayRoutes, type GatewayConfig } from "../gateway/server.js"; + +interface DispatchCall { + readonly scope: "global" | "flow"; + readonly flow?: string; + readonly kind: string; + readonly request: Record; +} + +interface PublishCall { + readonly topic: string; + readonly message: unknown; + readonly id?: string; +} + +interface RouteCalls { + readonly dispatches: Array; + readonly publishes: Array; + rpcCalls: number; +} + +const makeFakeDispatcher = ( + calls: RouteCalls, + response: unknown = { ok: true }, +): DispatcherManager => + ({ + start: Effect.void, + stop: Effect.void, + dispatchGlobalService: (kind, request) => + Effect.sync(() => { + calls.dispatches.push({ scope: "global", kind, request }); + return response; + }), + dispatchGlobalServiceStreaming: () => Effect.void, + dispatchFlowService: (flow, kind, request) => + Effect.sync(() => { + calls.dispatches.push({ scope: "flow", flow, kind, request }); + return response; + }), + dispatchFlowServiceStreaming: () => Effect.void, + publishToTopic: (topic, message, id) => + Effect.sync(() => { + calls.publishes.push(id === undefined ? { topic, message } : { topic, message, id }); + }), + }) as DispatcherManager; + +const makeRequest = ( + path: string, + body: unknown, + headers: Record = {}, +) => + new Request(`http://localhost${path}`, { + method: "POST", + headers: { + "content-type": "application/json", + ...headers, + }, + body: typeof body === "string" ? body : JSON.stringify(body), + }); + +const makeRouteHarness = ( + config: GatewayConfig = { port: 0, metricsPort: 0, secret: "secret" }, + dispatchResponse?: unknown, +) => + Effect.gen(function* () { + const rpcScope = yield* Scope.make(); + yield* Effect.addFinalizer(() => Scope.close(rpcScope, Exit.void)); + + const calls: RouteCalls = { + dispatches: [], + publishes: [], + rpcCalls: 0, + }; + const dispatcher = makeFakeDispatcher(calls, dispatchResponse); + const rpcServer: GatewayRpcServer = { + httpEffect: Effect.sync(() => { + calls.rpcCalls += 1; + return HttpServerResponse.text("rpc ok"); + }), + }; + const { handler, dispose } = HttpRouter.toWebHandler( + makeGatewayRoutes(config, dispatcher, rpcServer, rpcScope), + { disableLogger: true }, + ); + yield* Effect.addFinalizer(() => Effect.promise(() => dispose())); + + return { calls, handler }; + }); + +describe("gateway HTTP routes", () => { + it.effect( + "rejects protected dispatch routes without bearer auth before dispatching", + Effect.fnUntraced(function* () { + const { calls, handler } = yield* makeRouteHarness(); + + const response = yield* Effect.promise(() => + handler(makeRequest("/api/v1/config", { operation: "get" })) + ); + + expect(response.status).toBe(401); + expect(calls.dispatches).toEqual([]); + }), + ); + + it.effect( + "returns bad request for malformed and non-object JSON bodies", + Effect.fnUntraced(function* () { + const { calls, handler } = yield* makeRouteHarness(); + const headers = { authorization: "Bearer secret" }; + + const malformed = yield* Effect.promise(() => + handler(makeRequest("/api/v1/config", "{", headers)) + ); + const malformedBody = yield* Effect.promise(() => malformed.json()) as { error?: { message?: string } }; + + const arrayBody = yield* Effect.promise(() => + handler(makeRequest("/api/v1/config", [], headers)) + ); + const arrayBodyJson = yield* Effect.promise(() => arrayBody.json()) as { error?: { message?: string } }; + + expect(malformed.status).toBe(400); + expect(malformedBody.error?.message).toBe("request body must be valid JSON"); + expect(arrayBody.status).toBe(400); + expect(arrayBodyJson.error?.message).toBe("request body must be a JSON object"); + expect(calls.dispatches).toEqual([]); + }), + ); + + it.effect( + "dispatches global and flow POST routes with parsed request objects", + Effect.fnUntraced(function* () { + const { calls, handler } = yield* makeRouteHarness(undefined, { answer: 42 }); + const headers = { authorization: "Bearer secret" }; + + const globalResponse = yield* Effect.promise(() => + handler(makeRequest("/api/v1/config", { operation: "get", key: "demo" }, headers)) + ); + const flowResponse = yield* Effect.promise(() => + handler(makeRequest("/api/v1/flow/demo/service/text-completion", { prompt: "hi" }, headers)) + ); + + expect(globalResponse.status).toBe(200); + expect(flowResponse.status).toBe(200); + expect(calls.dispatches).toEqual([ + { + scope: "global", + kind: "config", + request: { operation: "get", key: "demo" }, + }, + { + scope: "flow", + flow: "demo", + kind: "text-completion", + request: { prompt: "hi" }, + }, + ]); + }), + ); + + it.effect( + "publishes flow load requests with generated metadata", + Effect.fnUntraced(function* () { + const { calls, handler } = yield* makeRouteHarness(); + + const response = yield* Effect.promise(() => + handler( + makeRequest( + "/api/v1/flow/demo/load", + { documentId: "doc-1", user: "alice", collection: "docs" }, + { authorization: "Bearer secret" }, + ), + ) + ); + const body = yield* Effect.promise(() => response.json()) as { status?: string; documentId?: string; flow?: string }; + + expect(response.status).toBe(200); + expect(body).toEqual({ status: "processing", documentId: "doc-1", flow: "demo" }); + expect(calls.publishes).toHaveLength(1); + expect(calls.publishes[0]?.topic).toBe("tg.flow.document"); + expect(calls.publishes[0]?.message).toMatchObject({ + documentId: "doc-1", + metadata: { + root: "doc-1", + user: "alice", + collection: "docs", + }, + }); + }), + ); + + it.effect( + "checks the RPC route token before delegating to the native websocket effect", + Effect.fnUntraced(function* () { + const { calls, handler } = yield* makeRouteHarness(); + + const rejected = yield* Effect.promise(() => + handler(new Request("http://localhost/api/v1/rpc?token=wrong")) + ); + const accepted = yield* Effect.promise(() => + handler(new Request("http://localhost/api/v1/rpc?token=secret")) + ); + + expect(rejected.status).toBe(401); + expect(accepted.status).toBe(200); + expect(yield* Effect.promise(() => accepted.text())).toBe("rpc ok"); + expect(calls.rpcCalls).toBe(1); + }), + ); +}); diff --git a/ts/packages/flow/src/__tests__/gateway-rpc-protocol.test.ts b/ts/packages/flow/src/__tests__/gateway-rpc-protocol.test.ts deleted file mode 100644 index f7525b70..00000000 --- a/ts/packages/flow/src/__tests__/gateway-rpc-protocol.test.ts +++ /dev/null @@ -1,167 +0,0 @@ -import { Effect, Queue } from "effect"; -import * as O from "effect/Option"; -import * as RpcMessage from "effect/unstable/rpc/RpcMessage"; -import * as RpcSerialization from "effect/unstable/rpc/RpcSerialization"; -import * as Socket from "effect/unstable/socket/Socket"; -import { describe, expect, it } from "vitest"; -import { makeSocketRpcProtocol } from "../gateway/rpc-protocol.js"; - -interface ReceivedMessage { - readonly clientId: number; - readonly message: RpcMessage.FromClientEncoded; -} - -interface ProtocolRunResult { - readonly messages: ReadonlyArray; - readonly writes: ReadonlyArray; - readonly clientIds: ReadonlyArray; -} - -const jsonFrame = (value: unknown): string => `${JSON.stringify(value)}\n`; - -const optionToArray = (value: O.Option): Array => - O.match(value, { - onNone: () => [], - onSome: (item) => [item], - }); - -const runProtocolFrames = ( - frames: ReadonlyArray, - headers?: ReadonlyArray<[string, string]>, - sendResponse?: RpcMessage.FromServerEncoded, -): Promise => - Effect.runPromise( - Effect.scoped( - Effect.gen(function* () { - const received = yield* Queue.unbounded(); - const writes: Array = []; - const { onSocket, protocol } = yield* makeSocketRpcProtocol; - - yield* protocol.run((clientId, message) => - Queue.offer(received, { clientId, message }).pipe(Effect.asVoid) - ).pipe(Effect.forkScoped); - yield* Effect.yieldNow; - - const socket = Socket.make({ - writer: Effect.succeed((chunk) => - Effect.sync(() => { - writes.push(chunk); - }) - ), - runRaw: (handler) => - Effect.forEach(frames, (frame) => - Effect.suspend(() => { - const result = handler(frame); - return result === undefined ? Effect.void : result; - }), { discard: true }), - }); - - yield* onSocket(socket, headers); - yield* Effect.yieldNow; - const clientIds = yield* protocol.clientIds; - if (sendResponse !== undefined) { - yield* protocol.send(0, sendResponse); - } - - const first = yield* Queue.poll(received); - const second = yield* Queue.poll(received); - const third = yield* Queue.poll(received); - - return { - messages: [ - ...optionToArray(first), - ...optionToArray(second), - ...optionToArray(third), - ], - writes, - clientIds: Array.from(clientIds), - }; - }).pipe( - Effect.provideService(RpcSerialization.RpcSerialization, RpcSerialization.ndjson), - ), - ), - ); - -describe("gateway RPC socket protocol", () => { - it("validates client request frames and prepends websocket headers", async () => { - const result = await runProtocolFrames([ - jsonFrame({ - _tag: "Request", - id: "1", - tag: "Dispatch", - payload: { - scope: "global", - service: "config", - request: {}, - }, - headers: [["rpc", "client"]], - }), - ], [["socket", "header"]]); - - expect(result.writes).toEqual([]); - expect(result.messages).toHaveLength(1); - - const received = result.messages[0]; - expect(received).toBeDefined(); - if (received === undefined) return; - - expect(received.clientId).toBe(0); - expect(received.message._tag).toBe("Request"); - if (received.message._tag !== "Request") return; - - expect(received.message.id).toBe("1"); - expect(received.message.tag).toBe("Dispatch"); - expect(received.message.headers).toEqual([ - ["socket", "header"], - ["rpc", "client"], - ]); - }); - - it("validates client control frames without mutating them", async () => { - const result = await runProtocolFrames([ - jsonFrame({ - _tag: "Ping", - }), - jsonFrame({ - _tag: "Ack", - requestId: "1", - }), - ]); - - expect(result.writes).toEqual([]); - expect(result.messages.map(({ message }) => message._tag)).toEqual(["Ping", "Ack"]); - expect(result.messages[1]?.message).toEqual({ - _tag: "Ack", - requestId: "1", - }); - }); - - it("rejects server response envelopes received from the socket", async () => { - const result = await runProtocolFrames([ - jsonFrame({ - _tag: "Exit", - requestId: "1", - exit: { - _tag: "Success", - value: { ok: true }, - }, - }), - ]); - - expect(result.messages).toEqual([]); - expect(result.writes).toHaveLength(1); - expect(String(result.writes[0])).toContain("\"_tag\":\"Defect\""); - }); - - it("sends server responses through the registered client", async () => { - const result = await runProtocolFrames( - [], - undefined, - RpcMessage.ResponseDefectEncoded("server-boom"), - ); - - expect(result.clientIds).toEqual([0]); - expect(result.writes).toHaveLength(1); - expect(String(result.writes[0])).toContain("server-boom"); - }); -}); diff --git a/ts/packages/flow/src/gateway/rpc-protocol.ts b/ts/packages/flow/src/gateway/rpc-protocol.ts deleted file mode 100644 index f2a70c45..00000000 --- a/ts/packages/flow/src/gateway/rpc-protocol.ts +++ /dev/null @@ -1,147 +0,0 @@ -import { Effect, Queue, Scope } from "effect"; -import * as MutableHashMap from "effect/MutableHashMap"; -import * as MutableHashSet from "effect/MutableHashSet"; -import * as O from "effect/Option"; -import * as S from "effect/Schema"; -import * as RpcMessage from "effect/unstable/rpc/RpcMessage"; -import * as RpcSerialization from "effect/unstable/rpc/RpcSerialization"; -import * as RpcServer from "effect/unstable/rpc/RpcServer"; -import * as Socket from "effect/unstable/socket/Socket"; - -export class RpcProtocolDecodeError extends S.TaggedErrorClass()( - "RpcProtocolDecodeError", - { - message: S.String, - cause: S.Unknown, - }, -) {} - -const HeaderSchema = S.mutable(S.Tuple([S.String, S.String])); -const FromClientEncodedSchema = S.Union([ - S.TaggedStruct("Request", { - id: S.String, - tag: S.String, - payload: S.Unknown, - headers: S.Array(HeaderSchema), - traceId: S.optionalKey(S.String), - spanId: S.optionalKey(S.String), - sampled: S.optionalKey(S.Boolean), - }), - S.TaggedStruct("Ack", { - requestId: S.String, - }), - S.TaggedStruct("Interrupt", { - requestId: S.String, - }), - S.TaggedStruct("Ping", {}), - S.TaggedStruct("Eof", {}), -]).pipe(S.toTaggedUnion("_tag")); -const FromClientEncodedMessagesSchema = S.Array(FromClientEncodedSchema); -const decodeFromClientMessages = S.decodeUnknownEffect(FromClientEncodedMessagesSchema); - -export const makeSocketRpcProtocol = Effect.gen(function* () { - const serialization = yield* RpcSerialization.RpcSerialization; - const disconnects = yield* Queue.make(); - - let nextClientId = 0; - const clients = MutableHashMap.empty Effect.Effect; - }>(); - const clientIds = MutableHashSet.empty(); - - let writeRequest!: ( - clientId: number, - message: RpcMessage.FromClientEncoded, - ) => Effect.Effect; - - const onSocket = function* ( - socket: Socket.Socket, - headers?: ReadonlyArray<[string, string]>, - ) { - const scope = yield* Effect.scope; - const parser = serialization.makeUnsafe(); - const clientId = nextClientId++; - - yield* Scope.addFinalizerExit(scope, () => { - MutableHashMap.remove(clients, clientId); - MutableHashSet.remove(clientIds, clientId); - return Queue.offer(disconnects, clientId); - }); - - const writeRaw = yield* socket.writer; - const writeDefect = (cause: unknown) => - Effect.sync(() => parser.encode(RpcMessage.ResponseDefectEncoded(cause))).pipe( - Effect.flatMap((encoded) => encoded === undefined ? Effect.void : writeRaw(encoded)), - ); - const write = (response: RpcMessage.FromServerEncoded) => - Effect.sync(() => parser.encode(response)).pipe( - Effect.flatMap((encoded) => - encoded === undefined ? Effect.void : Effect.orDie(writeRaw(encoded)), - ), - Effect.catchDefect((cause: unknown) => - writeDefect(cause).pipe(Effect.orDie), - ), - ); - - MutableHashMap.set(clients, clientId, { write }); - MutableHashSet.add(clientIds, clientId); - - yield* socket.runRaw((data) => - Effect.try({ - try: () => parser.decode(data), - catch: (cause) => - RpcProtocolDecodeError.make({ - message: "Failed to decode RPC socket frame", - cause, - }), - }).pipe( - Effect.flatMap((raw) => - decodeFromClientMessages(raw).pipe( - Effect.mapError((cause) => - RpcProtocolDecodeError.make({ - message: "RPC socket frame did not contain valid client messages", - cause, - }) - ), - ) - ), - Effect.flatMap((decoded) => - Effect.forEach(decoded, (message) => { - if (message._tag === "Request" && headers !== undefined) { - return writeRequest(clientId, { - ...message, - headers: headers.concat(message.headers), - }); - } - return writeRequest(clientId, message); - }, { discard: true }), - ), - Effect.catch((cause) => writeDefect(cause)), - Effect.catchDefect((cause: unknown) => writeDefect(cause)), - ) - ).pipe( - Effect.catchReason("SocketError", "SocketCloseError", () => Effect.void), - Effect.orDie, - ); - }; - - const protocol = yield* RpcServer.Protocol.make((writeRequest_) => { - writeRequest = writeRequest_; - return Effect.succeed({ - disconnects, - send: (clientId, response) => - O.match(MutableHashMap.get(clients, clientId), { - onNone: () => Effect.void, - onSome: (client) => Effect.orDie(client.write(response)), - }), - end: () => Effect.void, - clientIds: Effect.sync(() => new Set(clientIds)), - initialMessage: Effect.succeedNone, - supportsAck: true, - supportsTransferables: false, - supportsSpanPropagation: true, - }); - }); - - return { onSocket, protocol } as const; -}); diff --git a/ts/packages/flow/src/gateway/rpc-server.ts b/ts/packages/flow/src/gateway/rpc-server.ts index 38dfeb09..154a5fd7 100644 --- a/ts/packages/flow/src/gateway/rpc-server.ts +++ b/ts/packages/flow/src/gateway/rpc-server.ts @@ -1,23 +1,23 @@ import { Cause, Effect, Layer, Queue, Scope } from "effect"; +import { HttpServerRequest, HttpServerResponse } from "effect/unstable/http"; import * as RpcSerialization from "effect/unstable/rpc/RpcSerialization"; import * as RpcServer from "effect/unstable/rpc/RpcServer"; -import type * as Socket from "effect/unstable/socket/Socket"; import { errorMessage } from "@trustgraph/base"; import type { DispatcherManager, DispatcherStreamError } from "./dispatch/manager.js"; import { DispatchError, DispatchPayload, DispatchStreamChunk, TrustGraphRpcs } from "./rpc-contract.js"; -import { makeSocketRpcProtocol } from "./rpc-protocol.js"; export interface GatewayRpcServer { - readonly onSocket: ( - socket: Socket.Socket, - headers?: ReadonlyArray<[string, string]>, - ) => Effect.Effect; + readonly httpEffect: Effect.Effect< + HttpServerResponse.HttpServerResponse, + never, + Scope.Scope | HttpServerRequest.HttpServerRequest + >; } export const makeGatewayRpcServer = Effect.fn("makeGatewayRpcServer")(function* ( dispatcher: DispatcherManager, ) { - const { onSocket, protocol } = yield* makeSocketRpcProtocol; + const { httpEffect, protocol } = yield* RpcServer.makeProtocolWithHttpEffectWebsocket; const serverLayer = RpcServer.layer(TrustGraphRpcs, { disableFatalDefects: true, @@ -30,9 +30,7 @@ export const makeGatewayRpcServer = Effect.fn("makeGatewayRpcServer")(function* yield* Layer.launch(serverLayer).pipe(Effect.forkScoped); return { - onSocket: Effect.fn("GatewayRpc.onSocket")(function* (socket, headers) { - yield* onSocket(socket, headers); - }), + httpEffect, } satisfies GatewayRpcServer; }); diff --git a/ts/packages/flow/src/gateway/server.ts b/ts/packages/flow/src/gateway/server.ts index 7714270d..a382b2c3 100644 --- a/ts/packages/flow/src/gateway/server.ts +++ b/ts/packages/flow/src/gateway/server.ts @@ -54,10 +54,25 @@ const dispatchResult = (result: unknown) => { const readJsonRecord = Effect.gen(function* () { const request = yield* HttpServerRequest.HttpServerRequest; - const body = yield* request.json; - return isRecord(body) ? body : {}; + const body = yield* request.json.pipe( + Effect.mapError(() => "request body must be valid JSON"), + ); + if (!isRecord(body)) { + return yield* Effect.fail("request body must be a JSON object"); + } + return body; }); +const withJsonRecord = ( + use: (body: Record) => Effect.Effect, +): Effect.Effect => + readJsonRecord.pipe( + Effect.matchEffect({ + onFailure: (message) => Effect.succeed(badRequest(message)), + onSuccess: use, + }), + ); + const bearerAuthResponse = (config: GatewayConfig) => Effect.gen(function* () { if (config.secret === undefined || config.secret.length === 0) return null; @@ -98,26 +113,25 @@ const workbenchDispatch = ( ) => withBearerAuth( config, - Effect.gen(function* () { - const body = yield* readJsonRecord.pipe( - Effect.catch(() => Effect.succeed>({})), - ); - const service = typeof body.service === "string" ? body.service : undefined; - const payload = isRecord(body.request) ? body.request : undefined; - if (service === undefined || service.length === 0 || payload === undefined) { - return badRequest("service and request are required"); - } + withJsonRecord((body) => + Effect.gen(function* () { + const service = typeof body.service === "string" ? body.service : undefined; + const payload = isRecord(body.request) ? body.request : undefined; + if (service === undefined || service.length === 0 || payload === undefined) { + return badRequest("service and request are required"); + } - const dispatch = body.scope === "flow" - ? dispatcher.dispatchFlowService( - typeof body.flow === "string" ? body.flow : "default", - service, - payload, - ) - : dispatcher.dispatchGlobalService(service, payload); + const dispatch = body.scope === "flow" + ? dispatcher.dispatchFlowService( + typeof body.flow === "string" ? body.flow : "default", + service, + payload, + ) + : dispatcher.dispatchGlobalService(service, payload); - return yield* withDispatchError(dispatch, "workbench-dispatch"); - }), + return yield* withDispatchError(dispatch, "workbench-dispatch"); + }) + ), ); const globalDispatch = ( @@ -128,12 +142,11 @@ const globalDispatch = ( config, Effect.gen(function* () { const params = yield* HttpRouter.params; - const body = yield* readJsonRecord.pipe( - Effect.catch(() => Effect.succeed>({})), - ); - return yield* withDispatchError( - dispatcher.dispatchGlobalService(params.kind ?? "", body), - "global-dispatch", + return yield* withJsonRecord((body) => + withDispatchError( + dispatcher.dispatchGlobalService(params.kind ?? "", body), + "global-dispatch", + ) ); }), ); @@ -146,12 +159,11 @@ const flowDispatch = ( config, Effect.gen(function* () { const params = yield* HttpRouter.params; - const body = yield* readJsonRecord.pipe( - Effect.catch(() => Effect.succeed>({})), - ); - return yield* withDispatchError( - dispatcher.dispatchFlowService(params.flow ?? "default", params.kind ?? "", body), - "flow-dispatch", + return yield* withJsonRecord((body) => + withDispatchError( + dispatcher.dispatchFlowService(params.flow ?? "default", params.kind ?? "", body), + "flow-dispatch", + ) ); }), ); @@ -164,33 +176,34 @@ const flowLoad = ( config, Effect.gen(function* () { const params = yield* HttpRouter.params; - const body = yield* readJsonRecord.pipe( - Effect.catch(() => Effect.succeed>({})), + return yield* withJsonRecord((body) => + Effect.gen(function* () { + const documentId = typeof body.documentId === "string" ? body.documentId : undefined; + if (documentId === undefined || documentId.length === 0) { + return badRequest("documentId is required"); + } + + const user = typeof body.user === "string" ? body.user : "default"; + const collection = typeof body.collection === "string" ? body.collection : "default"; + const timestamp = yield* Clock.currentTimeMillis; + const suffix = yield* Random.nextIntBetween(0, 36 ** 6, { halfOpen: true }); + const metadata = { + id: `load-${timestamp}-${suffix.toString(36).padStart(6, "0")}`, + root: documentId, + user, + collection, + }; + + yield* dispatcher.publishToTopic("tg.flow.document", { metadata, documentId }).pipe( + Effect.mapError((cause) => messagingLifecycleError("gateway", "publish-load", cause)), + ); + + return json({ status: "processing", documentId, flow: params.flow ?? "default" }); + }).pipe( + Effect.catch((error) => Effect.succeed(dispatchError(error))), + ) ); - const documentId = typeof body.documentId === "string" ? body.documentId : undefined; - if (documentId === undefined || documentId.length === 0) { - return badRequest("documentId is required"); - } - - const user = typeof body.user === "string" ? body.user : "default"; - const collection = typeof body.collection === "string" ? body.collection : "default"; - const timestamp = yield* Clock.currentTimeMillis; - const suffix = yield* Random.nextIntBetween(0, 36 ** 6, { halfOpen: true }); - const metadata = { - id: `load-${timestamp}-${suffix.toString(36).padStart(6, "0")}`, - root: documentId, - user, - collection, - }; - - yield* dispatcher.publishToTopic("tg.flow.document", { metadata, documentId }).pipe( - Effect.mapError((cause) => messagingLifecycleError("gateway", "publish-load", cause)), - ); - - return json({ status: "processing", documentId, flow: params.flow ?? "default" }); - }).pipe( - Effect.catch((error) => Effect.succeed(dispatchError(error))), - ), + }), ); const rpcRoute = ( @@ -206,14 +219,10 @@ const rpcRoute = ( return json({ error: "Unauthorized" }, 401); } - const socket = yield* request.upgrade; - yield* rpcServer.onSocket(socket, headersFrom(request.headers)).pipe( + return yield* rpcServer.httpEffect.pipe( Scope.provide(rpcScope), ); - return HttpServerResponse.empty(); - }).pipe( - Effect.catch((error) => Effect.succeed(dispatchError(error))), - ); + }); const metricsRoute = formatPrometheusMetrics.pipe( @@ -224,7 +233,7 @@ const metricsRoute = ), ); -const gatewayRoutes = ( +export const makeGatewayRoutes = ( config: GatewayConfig, dispatcher: DispatcherManager, rpcServer: GatewayRpcServer, @@ -265,7 +274,7 @@ export function createGateway(config: GatewayConfig) { ); const serverLayer = HttpRouter.serve( - gatewayRoutes(config, dispatcher, rpcServer, rpcScope), + makeGatewayRoutes(config, dispatcher, rpcServer, rpcScope), ).pipe( Layer.provideMerge(NodeHttpServer.layer(createServer, { port: config.port, @@ -279,15 +288,6 @@ export function createGateway(config: GatewayConfig) { ); } -function headersFrom(headers: Record): ReadonlyArray<[string, string]> { - return Object.entries(headers).flatMap(([key, value]) => { - if (typeof value === "string") return [[key, value] satisfies [string, string]]; - if (typeof value === "number") return [[key, String(value)] satisfies [string, string]]; - if (Array.isArray(value)) return value.map((item) => [key, item] satisfies [string, string]); - return []; - }); -} - export function runMain(): void { NodeRuntime.runMain(program); } diff --git a/ts/packages/mcp/src/__tests__/server-effect.test.ts b/ts/packages/mcp/src/__tests__/server-effect.test.ts index eff27f07..bff415c4 100644 --- a/ts/packages/mcp/src/__tests__/server-effect.test.ts +++ b/ts/packages/mcp/src/__tests__/server-effect.test.ts @@ -1,8 +1,8 @@ import { describe, expect, it } from "@effect/vitest"; import type { BaseApi } from "@trustgraph/client"; -import { Effect, Layer, Stream } from "effect"; +import { Effect, Layer } from "effect"; import * as S from "effect/Schema"; -import { LanguageModel, McpServer } from "effect/unstable/ai"; +import { McpServer } from "effect/unstable/ai"; import * as McpSchema from "effect/unstable/ai/McpSchema"; import { FetchHttpClient, HttpRouter } from "effect/unstable/http"; import { RpcSerialization } from "effect/unstable/rpc"; @@ -52,7 +52,7 @@ interface FakeSocketCalls { } interface NativeTestClientOptions { - readonly languageText?: string | undefined; + readonly textCompletion?: (() => Promise) | undefined; readonly graphRag?: (() => Promise) | undefined; } @@ -60,6 +60,7 @@ const decodeJsonText = S.decodeUnknownSync(S.UnknownFromJsonString); const makeFakeSocket = ( options: { + readonly textCompletion?: (() => Promise) | undefined; readonly graphRag?: (() => Promise) | undefined; } = {}, ) => { @@ -73,7 +74,9 @@ const makeFakeSocket = ( flow: (flowId: string) => { calls.flowIds.push(flowId); return { - textCompletion: () => Promise.resolve("legacy text completion should not be used"), + textCompletion: () => options.textCompletion === undefined + ? Promise.resolve("gateway text completion") + : options.textCompletion(), graphRag: (query: string, ragOptions: unknown, collection?: string) => { calls.graphRag.push({ query, options: ragOptions, collection }); return options.graphRag === undefined @@ -121,15 +124,6 @@ const makeFakeSocket = ( return { socket, calls }; }; -const makeLanguageModelLayer = (text: string) => - Layer.effect( - LanguageModel.LanguageModel, - LanguageModel.make({ - generateText: () => Effect.succeed([{ type: "text", text }]), - streamText: () => Stream.empty, - }), - ); - const testConfig = TrustGraphMcpConfig.of({ gatewayUrl: "ws://localhost:8088/api/v1/rpc", user: "mcp-test", @@ -138,8 +132,6 @@ const testConfig = TrustGraphMcpConfig.of({ name: "trustgraph", version: "0.1.0-test", mcpPath: "/mcp", - openAiModel: "test-model", - openAiApiKey: undefined, port: 3000, }); @@ -151,10 +143,12 @@ const makeNativeTestClient = ( const makeNativeTestClientEffect = Effect.fn("makeNativeTestClient")(function*( options: NativeTestClientOptions, ) { - const { socket, calls } = makeFakeSocket({ graphRag: options.graphRag }); + const { socket, calls } = makeFakeSocket({ + textCompletion: options.textCompletion, + graphRag: options.graphRag, + }); const serverLayer = McpServer.toolkit(TrustGraphMcpToolkit).pipe( Layer.provide(TrustGraphMcpToolkitLive), - Layer.provide(makeLanguageModelLayer(options.languageText ?? "direct ai answer")), Layer.provide(Layer.succeed(TrustGraphSocket, TrustGraphSocket.of(socket))), Layer.provide(Layer.succeed(TrustGraphMcpConfig, testConfig)), Layer.provide(McpServer.layerHttp({ @@ -225,7 +219,6 @@ describe("Effect MCP server", () => { gatewayUrl: "ws://localhost:8088/api/v1/rpc", user: "mcp-test", flowId: "default", - openAiApiKey: "test-key", }), ).toBeDefined(); @@ -252,11 +245,11 @@ describe("Effect MCP server", () => { ); it.effect( - "calls text_completion through the direct Effect language model", + "calls text_completion through the configured TrustGraph flow", Effect.fnUntraced(function*() { yield* Effect.scoped(Effect.gen(function*() { const { client, calls } = yield* makeNativeTestClient({ - languageText: "direct model response", + textCompletion: () => Promise.resolve("gateway model response"), }); const result = yield* client["tools/call"]({ @@ -268,9 +261,9 @@ describe("Effect MCP server", () => { }); expect(result.isError).toBe(false); - expect(result.structuredContent).toEqual({ text: "direct model response" }); - expect(decodeJsonText(textContent(result))).toEqual({ text: "direct model response" }); - expect(calls.flowIds).toEqual([]); + expect(result.structuredContent).toEqual({ text: "gateway model response" }); + expect(decodeJsonText(textContent(result))).toEqual({ text: "gateway model response" }); + expect(calls.flowIds).toEqual(["default"]); })); }), ); @@ -319,12 +312,32 @@ describe("Effect MCP server", () => { }, }); - expect(result.structuredContent).toEqual({ - _tag: "GraphRagError", - message: "gateway unavailable", + expect(result.isError).toBe(true); + expect(result.structuredContent).toBeUndefined(); + expect(textContent(result)).toContain("gateway unavailable"); + })); + }), + ); + + it.effect( + "returns MCP errors for text_completion failures", + Effect.fnUntraced(function*() { + yield* Effect.scoped(Effect.gen(function*() { + const { client } = yield* makeNativeTestClient({ + textCompletion: () => Promise.reject(new Error("text service down")), }); - expect(result.structuredContent).not.toHaveProperty("cause"); - expect(decodeJsonText(textContent(result))).toEqual(result.structuredContent); + + const result = yield* client["tools/call"]({ + name: "text_completion", + arguments: { + system: "You are concise.", + prompt: "Say hello.", + }, + }); + + expect(result.isError).toBe(true); + expect(result.structuredContent).toBeUndefined(); + expect(textContent(result)).toContain("text service down"); })); }), ); diff --git a/ts/packages/mcp/src/server-effect.ts b/ts/packages/mcp/src/server-effect.ts index 51e4d25c..91ba9904 100644 --- a/ts/packages/mcp/src/server-effect.ts +++ b/ts/packages/mcp/src/server-effect.ts @@ -1,12 +1,11 @@ -import {OpenAiClient, OpenAiLanguageModel} from "@effect/ai-openai"; import {BunHttpServer, BunRuntime} from "@effect/platform-bun"; import {NodeRuntime, NodeStdio} from "@effect/platform-node"; import {createTrustGraphSocket, type BaseApi, type Term as ClientTerm} from "@trustgraph/client"; -import {Config, Context, Effect, Layer, Redacted} from "effect"; +import {Config, Context, Effect, Layer} from "effect"; import * as O from "effect/Option"; import * as Predicate from "effect/Predicate"; -import {LanguageModel, McpServer, Prompt, Tool, Toolkit} from "effect/unstable/ai"; -import {FetchHttpClient, HttpRouter} from "effect/unstable/http"; +import {McpServer, Tool, Toolkit} from "effect/unstable/ai"; +import {HttpRouter} from "effect/unstable/http"; import {HttpApi, HttpApiBuilder, HttpApiEndpoint, HttpApiGroup, HttpApiSchema, OpenApi} from "effect/unstable/httpapi"; import * as S from "effect/Schema"; @@ -160,7 +159,7 @@ export const TextCompletionTool = annotateTool( parameters: TextCompletionParameters, success: TextCompletionSuccess, failure: TextCompletionError, - failureMode: "return", + failureMode: "error", }), { title: "Text Completion", @@ -210,7 +209,7 @@ export const GraphRagTool = annotateTool( parameters: GraphRagParameters, success: GraphRagSuccess, failure: GraphRagError, - failureMode: "return", + failureMode: "error", }), { title: "Graph RAG", @@ -258,7 +257,7 @@ export const DocumentRagTool = annotateTool( parameters: DocumentRagParameters, success: DocumentRagSuccess, failure: DocumentRagError, - failureMode: "return", + failureMode: "error", }), { title: "Document RAG", @@ -298,7 +297,7 @@ export const AgentTool = annotateTool( parameters: AgentParameters, success: AgentSuccess, failure: AgentError, - failureMode: "return", + failureMode: "error", description: "Ask the TrustGraph agent a question" }), { @@ -341,7 +340,7 @@ export const EmbeddingsTool = annotateTool( parameters: EmbeddingsParameters, success: EmbeddingsSuccess, failure: EmbeddingsError, - failureMode: "return", + failureMode: "error", description: "Generate text embeddings" }), { @@ -397,7 +396,7 @@ export const TriplesQueryTool = annotateTool( parameters: TriplesQueryParameters, success: TriplesQuerySuccess, failure: TriplesQueryError, - failureMode: "return", + failureMode: "error", description: "Query the knowledge graph for triples matching a pattern" }), { @@ -456,7 +455,7 @@ export const GraphEmbeddingsQueryTool = annotateTool( parameters: GraphEmbeddingsQueryParameters, success: GraphEmbeddingsQuerySuccess, failure: GraphEmbeddingsQueryError, - failureMode: "return", + failureMode: "error", description: "Find entities similar to a text query using vector embeddings" }), { @@ -493,7 +492,7 @@ export const GetConfigAllTool = annotateTool( parameters: GetConfigAllParameters, success: GetConfigAllSuccess, failure: GetConfigAllError, - failureMode: "return", + failureMode: "error", description: "Get all configuration values" }), { @@ -544,7 +543,7 @@ export const GetConfigTool = annotateTool( parameters: GetConfigParameters, success: GetConfigSuccess, failure: GetConfigError, - failureMode: "return", + failureMode: "error", description: "Get specific configuration values" }), { @@ -597,7 +596,7 @@ export const PutConfigTool = annotateTool( parameters: PutConfigParameters, success: PutConfigSuccess, failure: PutConfigError, - failureMode: "return", + failureMode: "error", description: "Set configuration values" }), { @@ -641,7 +640,7 @@ export const DeleteConfigTool = annotateTool( parameters: DeleteConfigParameters, success: DeleteConfigSuccess, failure: DeleteConfigError, - failureMode: "return", + failureMode: "error", description: "Delete a configuration entry" }), { @@ -682,7 +681,7 @@ export const GetFlowTool = annotateTool( parameters: GetFlowParameters, success: GetFlowSuccess, failure: GetFlowError, - failureMode: "return", + failureMode: "error", description: "Get a specific flow definition" }), { @@ -721,7 +720,7 @@ export const GetFlowsTool = annotateTool( parameters: GetFlowsParameters, success: GetFlowsSuccess, failure: GetFlowsError, - failureMode: "return", + failureMode: "error", description: "List all available flows" }), { @@ -771,7 +770,7 @@ export const StartFlowTool = annotateTool( parameters: StartFlowParameters, success: StartFlowSuccess, failure: StartFlowError, - failureMode: "return", + failureMode: "error", description: "Start a flow instance" }), { @@ -812,7 +811,7 @@ export const StopFlowTool = annotateTool( parameters: StopFlowParameters, success: StopFlowSuccess, failure: StopFlowError, - failureMode: "return", + failureMode: "error", description: "Stop a running flow" }), { @@ -851,7 +850,7 @@ export const GetDocumentsTool = annotateTool( parameters: GetDocumentsParameters, success: GetDocumentsSuccess, failure: GetDocumentsError, - failureMode: "return", + failureMode: "error", description: "List all documents in the library" }), { @@ -907,7 +906,7 @@ export const LoadDocumentTool = annotateTool( parameters: LoadDocumentParameters, success: LoadDocumentSuccess, failure: LoadDocumentError, - failureMode: "return", + failureMode: "error", description: "Upload a document to the library" }), { @@ -951,7 +950,7 @@ export const RemoveDocumentTool = annotateTool( parameters: RemoveDocumentParameters, success: RemoveDocumentSuccess, failure: RemoveDocumentError, - failureMode: "return", + failureMode: "error", description: "Remove a document from the library" }), { @@ -990,7 +989,7 @@ export const GetPromptsTool = annotateTool( parameters: GetPromptsParameters, success: GetPromptsSuccess, failure: GetPromptsError, - failureMode: "return", + failureMode: "error", description: "List available prompt templates" }), { @@ -1031,7 +1030,7 @@ export const GetPromptTool = annotateTool( parameters: GetPromptParameters, success: GetPromptSuccess, failure: GetPromptError, - failureMode: "return", + failureMode: "error", description: "Get a specific prompt template" }), { @@ -1070,7 +1069,7 @@ export const GetKnowledgeCoresTool = annotateTool( parameters: GetKnowledgeCoresParameters, success: GetKnowledgeCoresSuccess, failure: GetKnowledgeCoresError, - failureMode: "return", + failureMode: "error", description: "List available knowledge graph cores" }), { @@ -1114,7 +1113,7 @@ export const DeleteKgCoreTool = annotateTool( parameters: DeleteKgCoreParameters, success: DeleteKgCoreSuccess, failure: DeleteKgCoreError, - failureMode: "return", + failureMode: "error", description: "Delete a knowledge graph core" }), { @@ -1161,7 +1160,7 @@ export const LoadKgCoreTool = annotateTool( parameters: LoadKgCoreParameters, success: LoadKgCoreSuccess, failure: LoadKgCoreError, - failureMode: "return", + failureMode: "error", description: "Load a knowledge graph core" }), { @@ -1207,8 +1206,6 @@ export interface TrustGraphMcpOptions { readonly name?: string | undefined readonly version?: string | undefined readonly mcpPath?: HttpRouter.PathInput | undefined - readonly openAiModel?: string | undefined - readonly openAiApiKey?: string | undefined readonly port?: number | undefined } @@ -1220,8 +1217,6 @@ export interface TrustGraphMcpConfigShape { readonly name: string readonly version: string readonly mcpPath: HttpRouter.PathInput - readonly openAiModel: string - readonly openAiApiKey: Redacted.Redacted | undefined readonly port: number } @@ -1244,9 +1239,6 @@ export const loadTrustGraphMcpConfig = Effect.fn("loadTrustGraphMcpConfig")(func const gatewaySecret = O.getOrUndefined(yield* Config.string("GATEWAY_SECRET").pipe(Config.option)) const token = readNonEmpty(gatewaySecret) const flowId = O.getOrUndefined(yield* Config.string("FLOW_ID").pipe(Config.option)) - const openAiModel = O.getOrUndefined(yield* Config.string("OPENAI_MODEL").pipe(Config.option)) - const openAiApiKey = O.getOrUndefined(yield* Config.redacted("OPENAI_API_KEY").pipe(Config.option)) - const openAiToken = O.getOrUndefined(yield* Config.redacted("OPENAI_TOKEN").pipe(Config.option)) const port = O.getOrUndefined(yield* Config.string("PORT").pipe(Config.option)) return { @@ -1257,10 +1249,6 @@ export const loadTrustGraphMcpConfig = Effect.fn("loadTrustGraphMcpConfig")(func name: options.name ?? "trustgraph", version: options.version ?? "0.1.0", mcpPath: options.mcpPath ?? "/mcp", - openAiModel: options.openAiModel ?? openAiModel ?? "gpt-4.1", - openAiApiKey: options.openAiApiKey === undefined - ? openAiApiKey ?? openAiToken - : Redacted.make(options.openAiApiKey), port: options.port ?? parsePort(readNonEmpty(port)), } }) @@ -1328,46 +1316,19 @@ const decodeJsonArrayOrFail = ( const asIriTerm = (value: string | undefined): ClientTerm | undefined => value !== undefined && value.length > 0 ? {t: "i", i: value} : undefined -const openAiApiKeyOptions = (apiKey: Redacted.Redacted | undefined) => - apiKey === undefined - ? {} - : {apiKey} - -const makeOpenAiProviderLayerFromConfig = ( - config: TrustGraphMcpConfigShape, -) => - OpenAiLanguageModel.layer({ - model: config.openAiModel, - config: { - strictJsonSchema: true, - }, - }).pipe( - Layer.provide(OpenAiClient.layer(openAiApiKeyOptions(config.openAiApiKey))), - Layer.provide(FetchHttpClient.layer), - ) - -export const makeOpenAiProviderLayer = ( - options: TrustGraphMcpOptions = {}, -) => - Layer.unwrap( - loadTrustGraphMcpConfig(options).pipe( - Effect.map(makeOpenAiProviderLayerFromConfig), - ), - ) - export const TrustGraphMcpToolkitLive = TrustGraphMcpToolkit.toLayer( Effect.gen(function*() { const config = yield* TrustGraphMcpConfig const socket = yield* TrustGraphSocket - const model = yield* LanguageModel.LanguageModel return TrustGraphMcpToolkit.of({ - text_completion: Effect.fn("TrustGraphMcp.text_completion")(function*({system, prompt}) { - const response = yield* model.generateText({ - prompt: Prompt.make(prompt).pipe(Prompt.setSystem(system)), - }) - return TextCompletionSuccess.make({text: response.text}) - }), + text_completion: ({system, prompt}) => + Effect.tryPromise({ + try: () => socket.flow(config.flowId).textCompletion(system, prompt), + catch: (cause) => TextCompletionError.make({message: toErrorMessage(cause)}), + }).pipe( + Effect.map((text) => TextCompletionSuccess.make({text})), + ), graph_rag: ({query, entity_limit, triple_limit, collection}) => Effect.tryPromise({ @@ -1722,7 +1683,7 @@ export const TrustGraphMcpHttpApiRoutes = HttpApiBuilder.layer( const makeTrustGraphMcpHttpLayerFromConfig = ( config: TrustGraphMcpConfigShape, ) => { - const tools = makeTrustGraphMcpToolkitLayerFromConfig(config) + const tools = makeTrustGraphMcpToolkitLayer() return Layer.mergeAll( TrustGraphMcpHttpApiRoutes, @@ -1738,18 +1699,15 @@ const makeTrustGraphMcpHttpLayerFromConfig = ( ) } -const makeTrustGraphMcpToolkitLayerFromConfig = ( - config: TrustGraphMcpConfigShape, -) => +const makeTrustGraphMcpToolkitLayer = () => McpServer.toolkit(TrustGraphMcpToolkit).pipe( Layer.provide(TrustGraphMcpToolkitLive), - Layer.provide(makeOpenAiProviderLayerFromConfig(config)), ) const makeTrustGraphMcpStdioLayerFromConfig = ( config: TrustGraphMcpConfigShape, ) => - makeTrustGraphMcpToolkitLayerFromConfig(config).pipe( + makeTrustGraphMcpToolkitLayer().pipe( Layer.provide(McpServer.layerStdio({ name: config.name, version: config.version,