Validate RPC protocol frames with Schema

This commit is contained in:
elpresidank 2026-06-04 07:04:15 -05:00
parent 67b5e0dd5b
commit 5a945af345
3 changed files with 241 additions and 19 deletions

View file

@ -0,0 +1,148 @@
import { Effect, Queue } from "effect";
import * as O from "effect/Option";
import * as RpcMessage from "effect/unstable/rpc/RpcMessage";
import * as RpcSerialization from "effect/unstable/rpc/RpcSerialization";
import * as Socket from "effect/unstable/socket/Socket";
import { describe, expect, it } from "vitest";
import { makeSocketRpcProtocol } from "../gateway/rpc-protocol.js";
interface ReceivedMessage {
readonly clientId: number;
readonly message: RpcMessage.FromClientEncoded;
}
interface ProtocolRunResult {
readonly messages: ReadonlyArray<ReceivedMessage>;
readonly writes: ReadonlyArray<string | Uint8Array | CloseEvent>;
}
const jsonFrame = (value: unknown): string => `${JSON.stringify(value)}\n`;
const optionToArray = <A>(value: O.Option<A>): Array<A> =>
O.match(value, {
onNone: () => [],
onSome: (item) => [item],
});
const runProtocolFrames = (
frames: ReadonlyArray<string | Uint8Array>,
headers?: ReadonlyArray<[string, string]>,
): Promise<ProtocolRunResult> =>
Effect.runPromise(
Effect.scoped(
Effect.gen(function* () {
const received = yield* Queue.unbounded<ReceivedMessage>();
const writes: Array<string | Uint8Array | CloseEvent> = [];
const { onSocket, protocol } = yield* makeSocketRpcProtocol;
yield* protocol.run((clientId, message) =>
Queue.offer(received, { clientId, message }).pipe(Effect.asVoid)
).pipe(Effect.forkScoped);
yield* Effect.yieldNow;
const socket = Socket.make({
writer: Effect.succeed((chunk) =>
Effect.sync(() => {
writes.push(chunk);
})
),
runRaw: (handler) =>
Effect.forEach(frames, (frame) =>
Effect.suspend(() => {
const result = handler(frame);
return result === undefined ? Effect.void : result;
}), { discard: true }),
});
yield* onSocket(socket, headers);
yield* Effect.yieldNow;
const first = yield* Queue.poll(received);
const second = yield* Queue.poll(received);
const third = yield* Queue.poll(received);
return {
messages: [
...optionToArray(first),
...optionToArray(second),
...optionToArray(third),
],
writes,
};
}).pipe(
Effect.provideService(RpcSerialization.RpcSerialization, RpcSerialization.ndjson),
),
),
);
describe("gateway RPC socket protocol", () => {
it("validates client request frames and prepends websocket headers", async () => {
const result = await runProtocolFrames([
jsonFrame({
_tag: "Request",
id: "1",
tag: "Dispatch",
payload: {
scope: "global",
service: "config",
request: {},
},
headers: [["rpc", "client"]],
}),
], [["socket", "header"]]);
expect(result.writes).toEqual([]);
expect(result.messages).toHaveLength(1);
const received = result.messages[0];
expect(received).toBeDefined();
if (received === undefined) return;
expect(received.clientId).toBe(0);
expect(received.message._tag).toBe("Request");
if (received.message._tag !== "Request") return;
expect(received.message.id).toBe("1");
expect(received.message.tag).toBe("Dispatch");
expect(received.message.headers).toEqual([
["socket", "header"],
["rpc", "client"],
]);
});
it("validates client control frames without mutating them", async () => {
const result = await runProtocolFrames([
jsonFrame({
_tag: "Ping",
}),
jsonFrame({
_tag: "Ack",
requestId: "1",
}),
]);
expect(result.writes).toEqual([]);
expect(result.messages.map(({ message }) => message._tag)).toEqual(["Ping", "Ack"]);
expect(result.messages[1]?.message).toEqual({
_tag: "Ack",
requestId: "1",
});
});
it("rejects server response envelopes received from the socket", async () => {
const result = await runProtocolFrames([
jsonFrame({
_tag: "Exit",
requestId: "1",
exit: {
_tag: "Success",
value: { ok: true },
},
}),
]);
expect(result.messages).toEqual([]);
expect(result.writes).toHaveLength(1);
expect(String(result.writes[0])).toContain("\"_tag\":\"Defect\"");
});
});

View file

@ -1,9 +1,41 @@
import { Effect, Queue, Scope } from "effect";
import * as S from "effect/Schema";
import * as RpcMessage from "effect/unstable/rpc/RpcMessage";
import * as RpcSerialization from "effect/unstable/rpc/RpcSerialization";
import * as RpcServer from "effect/unstable/rpc/RpcServer";
import * as Socket from "effect/unstable/socket/Socket";
export class RpcProtocolDecodeError extends S.TaggedErrorClass<RpcProtocolDecodeError>()(
"RpcProtocolDecodeError",
{
message: S.String,
cause: S.Unknown,
},
) {}
const HeaderSchema = S.mutable(S.Tuple([S.String, S.String]));
const FromClientEncodedSchema = S.Union([
S.TaggedStruct("Request", {
id: S.String,
tag: S.String,
payload: S.Unknown,
headers: S.Array(HeaderSchema),
traceId: S.optionalKey(S.String),
spanId: S.optionalKey(S.String),
sampled: S.optionalKey(S.Boolean),
}),
S.TaggedStruct("Ack", {
requestId: S.String,
}),
S.TaggedStruct("Interrupt", {
requestId: S.String,
}),
S.TaggedStruct("Ping", {}),
S.TaggedStruct("Eof", {}),
]).pipe(S.toTaggedUnion("_tag"));
const FromClientEncodedMessagesSchema = S.Array(FromClientEncodedSchema);
const decodeFromClientMessages = S.decodeUnknownEffect(FromClientEncodedMessagesSchema);
export const makeSocketRpcProtocol = Effect.gen(function* () {
const serialization = yield* RpcSerialization.RpcSerialization;
const disconnects = yield* Queue.make<number>();
@ -34,18 +66,17 @@ export const makeSocketRpcProtocol = Effect.gen(function* () {
});
const writeRaw = yield* socket.writer;
const encodeDefect = (cause: unknown) =>
Effect.sync(() => parser.encode(RpcMessage.ResponseDefectEncoded(cause))!);
const writeDefect = (cause: unknown) =>
Effect.sync(() => parser.encode(RpcMessage.ResponseDefectEncoded(cause))).pipe(
Effect.flatMap((encoded) => encoded === undefined ? Effect.void : writeRaw(encoded)),
);
const write = (response: RpcMessage.FromServerEncoded) =>
Effect.sync(() => parser.encode(response)).pipe(
Effect.flatMap((encoded) =>
encoded === undefined ? Effect.void : Effect.orDie(writeRaw(encoded)),
),
Effect.catchDefect((cause: unknown) =>
encodeDefect(cause).pipe(
Effect.flatMap((encoded) => Effect.orDie(writeRaw(encoded))),
Effect.orDie,
),
writeDefect(cause).pipe(Effect.orDie),
),
);
@ -53,7 +84,24 @@ export const makeSocketRpcProtocol = Effect.gen(function* () {
clientIds.add(clientId);
yield* socket.runRaw((data) =>
Effect.sync(() => parser.decode(data) as ReadonlyArray<RpcMessage.FromClientEncoded>).pipe(
Effect.try({
try: () => parser.decode(data),
catch: (cause) =>
RpcProtocolDecodeError.make({
message: "Failed to decode RPC socket frame",
cause,
}),
}).pipe(
Effect.flatMap((raw) =>
decodeFromClientMessages(raw).pipe(
Effect.mapError((cause) =>
RpcProtocolDecodeError.make({
message: "RPC socket frame did not contain valid client messages",
cause,
})
),
)
),
Effect.flatMap((decoded) =>
Effect.forEach(decoded, (message) => {
if (message._tag === "Request" && headers !== undefined) {
@ -65,11 +113,8 @@ export const makeSocketRpcProtocol = Effect.gen(function* () {
return writeRequest(clientId, message);
}, { discard: true }),
),
Effect.catchDefect((cause: unknown) =>
encodeDefect(cause).pipe(
Effect.flatMap((encoded) => writeRaw(encoded)),
),
),
Effect.catch((cause) => writeDefect(cause)),
Effect.catchDefect((cause: unknown) => writeDefect(cause)),
)
).pipe(
Effect.catchReason("SocketError", "SocketCloseError", () => Effect.void),