Skip to content

Commit 7df315e

Browse files
committed
Enforce maximum payload size
1 parent 577c1d7 commit 7df315e

File tree

5 files changed

+218
-15
lines changed

5 files changed

+218
-15
lines changed

__tests__/max-payload-size.test.ts

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
import {
2+
afterEach,
3+
assert,
4+
beforeEach,
5+
describe,
6+
expect,
7+
test,
8+
vitest,
9+
} from 'vitest';
10+
import { createMockTransportNetwork } from '../testUtil/fixtures/mockTransport';
11+
import {
12+
Err,
13+
Ok,
14+
Procedure,
15+
ServiceSchema,
16+
createClient,
17+
createServer,
18+
} from '../router';
19+
import { MAX_PAYLOAD_SIZE_EXCEEDED } from '../router/errors';
20+
import { Type } from '@sinclair/typebox';
21+
import { readNextResult } from '../testUtil';
22+
import { MaxPayloadSizeExceeded } from '../transport/sessionStateMachine/common';
23+
24+
describe('client exceeded max payload size', () => {
25+
let mockTransportNetwork: ReturnType<typeof createMockTransportNetwork>;
26+
27+
beforeEach(async () => {
28+
mockTransportNetwork = createMockTransportNetwork({
29+
client: { maxPayloadSizeBytes: 1024 },
30+
});
31+
});
32+
33+
afterEach(async () => {
34+
await mockTransportNetwork.cleanup();
35+
});
36+
37+
test('rpc init exceeds max payload size', async () => {
38+
const mockHandler = vitest.fn();
39+
const services = {
40+
service: ServiceSchema.define({
41+
echo: Procedure.rpc({
42+
requestInit: Type.String(),
43+
responseData: Type.String(),
44+
handler: mockHandler,
45+
}),
46+
}),
47+
};
48+
createServer(mockTransportNetwork.getServerTransport(), services);
49+
const client = createClient<typeof services>(
50+
mockTransportNetwork.getClientTransport('client'),
51+
'SERVER',
52+
);
53+
54+
const result = await client.service.echo.rpc('0'.repeat(1025));
55+
expect(result).toStrictEqual({
56+
ok: false,
57+
payload: {
58+
code: MAX_PAYLOAD_SIZE_EXCEEDED,
59+
message: 'payload exceeded maximum payload size size=1241 max=1024',
60+
},
61+
});
62+
expect(mockHandler).not.toHaveBeenCalled();
63+
});
64+
65+
test('stream message exceeds max payload size', async () => {
66+
let handlerCanceled: Promise<null> | undefined;
67+
const services = {
68+
service: ServiceSchema.define({
69+
echo: Procedure.stream({
70+
requestInit: Type.String(),
71+
requestData: Type.String(),
72+
responseData: Type.String(),
73+
responseError: Type.Object({
74+
code: Type.Literal('ERROR'),
75+
message: Type.String(),
76+
}),
77+
handler: async ({ ctx, reqInit, reqReadable, resWritable }) => {
78+
handlerCanceled = new Promise((resolve) => {
79+
ctx.signal.onabort = () => resolve(null);
80+
});
81+
82+
resWritable.write(Ok(reqInit));
83+
for await (const msg of reqReadable) {
84+
if (msg.ok) {
85+
resWritable.write(Ok(msg.payload));
86+
} else {
87+
resWritable.write(
88+
Err({
89+
code: 'ERROR',
90+
message: 'error reading from client',
91+
}),
92+
);
93+
break;
94+
}
95+
}
96+
},
97+
}),
98+
}),
99+
};
100+
createServer(mockTransportNetwork.getServerTransport(), services);
101+
const transport = mockTransportNetwork.getClientTransport('client');
102+
const client = createClient<typeof services>(transport, 'SERVER');
103+
104+
const stream = client.service.echo.stream('start');
105+
let result = await readNextResult(stream.resReadable);
106+
expect(result).toStrictEqual({ ok: true, payload: 'start' });
107+
108+
let error;
109+
try {
110+
stream.reqWritable.write('0'.repeat(1025));
111+
} catch (e) {
112+
error = e;
113+
}
114+
expect(error).toBeInstanceOf(MaxPayloadSizeExceeded);
115+
116+
result = await readNextResult(stream.resReadable);
117+
expect(result).toStrictEqual({
118+
ok: false,
119+
payload: {
120+
code: MAX_PAYLOAD_SIZE_EXCEEDED,
121+
message: 'payload exceeded maximum payload size size=1148 max=1024',
122+
},
123+
});
124+
assert(handlerCanceled);
125+
await handlerCanceled;
126+
});
127+
});

router/client.ts

Lines changed: 65 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,9 @@ import {
3636
CANCEL_CODE,
3737
ReaderErrorSchema,
3838
UNEXPECTED_DISCONNECT_CODE,
39+
MAX_PAYLOAD_SIZE_EXCEEDED,
3940
} from './errors';
41+
import { MaxPayloadSizeExceeded } from '../transport/sessionStateMachine/common';
4042

4143
const ReaderErrResultSchema = ErrResultSchema(ReaderErrorSchema);
4244

@@ -297,11 +299,42 @@ function handleProc(
297299
let cleanClose = true;
298300
const reqWritable = new WritableImpl<Static<PayloadType>>({
299301
writeCb: (rawIn) => {
300-
sessionScopedSend({
301-
streamId,
302-
payload: rawIn,
303-
controlFlags: 0,
304-
});
302+
try {
303+
sessionScopedSend({
304+
streamId,
305+
payload: rawIn,
306+
controlFlags: 0,
307+
});
308+
} catch (e) {
309+
if (!(e instanceof MaxPayloadSizeExceeded)) {
310+
throw e;
311+
}
312+
313+
cleanClose = false;
314+
if (!resReadable.isClosed()) {
315+
resReadable._pushValue(
316+
Err({
317+
code: MAX_PAYLOAD_SIZE_EXCEEDED,
318+
message: e.message,
319+
}),
320+
);
321+
closeReadable();
322+
}
323+
324+
reqWritable.close();
325+
// TODO: Is this the right error to send to the server?
326+
sessionScopedSend(
327+
cancelMessage(
328+
streamId,
329+
Err({
330+
code: CANCEL_CODE,
331+
message: 'cancelled by client',
332+
}),
333+
),
334+
);
335+
336+
throw e;
337+
}
305338
},
306339
// close callback
307340
closeCb: () => {
@@ -480,16 +513,33 @@ function handleProc(
480513
transport.addEventListener('message', onMessage);
481514
transport.addEventListener('sessionStatus', onSessionStatus);
482515

483-
sessionScopedSend({
484-
streamId,
485-
serviceName,
486-
procedureName,
487-
tracing: getPropagationContext(ctx),
488-
payload: init,
489-
controlFlags: procClosesWithInit
490-
? ControlFlags.StreamOpenBit | ControlFlags.StreamClosedBit
491-
: ControlFlags.StreamOpenBit,
492-
});
516+
try {
517+
sessionScopedSend({
518+
streamId,
519+
serviceName,
520+
procedureName,
521+
tracing: getPropagationContext(ctx),
522+
payload: init,
523+
controlFlags: procClosesWithInit
524+
? ControlFlags.StreamOpenBit | ControlFlags.StreamClosedBit
525+
: ControlFlags.StreamOpenBit,
526+
});
527+
} catch (e) {
528+
if (!(e instanceof MaxPayloadSizeExceeded)) {
529+
throw e;
530+
}
531+
532+
cleanClose = false;
533+
resReadable._pushValue(
534+
Err({
535+
code: MAX_PAYLOAD_SIZE_EXCEEDED,
536+
message: e.message,
537+
}),
538+
);
539+
closeReadable();
540+
541+
reqWritable.close();
542+
}
493543

494544
if (procClosesWithInit) {
495545
reqWritable.close();

router/errors.ts

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,10 @@ export const INVALID_REQUEST_CODE = 'INVALID_REQUEST';
2727
* {@link CANCEL_CODE} is the code used when either server or client cancels the stream.
2828
*/
2929
export const CANCEL_CODE = 'CANCEL';
30+
/**
31+
* {@link MAX_PAYLOAD_SIZE_EXCEEDED} is the code used when a request's payload exceeds the maximum allowed size.
32+
*/
33+
export const MAX_PAYLOAD_SIZE_EXCEEDED = 'MAX_PAYLOAD_SIZE_EXCEEDED';
3034

3135
type TLiteralString = TLiteral<string>;
3236

@@ -72,6 +76,10 @@ export const ReaderErrorSchema = Type.Union([
7276
code: Type.Literal(CANCEL_CODE),
7377
message: Type.String(),
7478
}),
79+
Type.Object({
80+
code: Type.Literal(MAX_PAYLOAD_SIZE_EXCEEDED),
81+
message: Type.String(),
82+
}),
7583
]) satisfies ProcedureErrorSchemaType;
7684

7785
/**

transport/options.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ export const defaultTransportOptions: TransportOptions = {
1313
connectionTimeoutMs: 2_000,
1414
handshakeTimeoutMs: 1_000,
1515
enableTransparentSessionReconnects: true,
16+
maxPayloadSizeBytes: 4 * 1024 * 1024,
1617
codec: NaiveJsonCodec,
1718
};
1819

transport/sessionStateMachine/common.ts

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,10 @@ export interface SessionOptions {
141141
* The codec to use for encoding/decoding messages over the wire
142142
*/
143143
codec: Codec;
144+
/**
145+
* The maximum payload size that is allowed to be sent or received.
146+
*/
147+
maxPayloadSizeBytes: number;
144148
}
145149

146150
// all session states have a from and options
@@ -209,6 +213,12 @@ export interface IdentifiedSessionProps extends CommonSessionProps {
209213
protocolVersion: ProtocolVersion;
210214
}
211215

216+
export class MaxPayloadSizeExceeded extends Error {
217+
constructor(size: number, max: number) {
218+
super(`payload exceeded maximum payload size size=${size} max=${max}`);
219+
}
220+
}
221+
212222
export abstract class IdentifiedSession extends CommonSession {
213223
readonly id: SessionId;
214224
readonly telemetry: TelemetryInfo;
@@ -276,6 +286,13 @@ export abstract class IdentifiedSession extends CommonSession {
276286
data: this.options.codec.toBuffer(msg),
277287
};
278288

289+
if (encodedMsg.data.byteLength > this.options.maxPayloadSizeBytes) {
290+
throw new MaxPayloadSizeExceeded(
291+
encodedMsg.data.byteLength,
292+
this.options.maxPayloadSizeBytes,
293+
);
294+
}
295+
279296
this.seq++;
280297

281298
return encodedMsg;

0 commit comments

Comments
 (0)