Skip to content

Commit d6c82fc

Browse files
Add recursion check when parsing unknown fields in Java. (#18388)
* Internal change PiperOrigin-RevId: 663919912 * Internal change PiperOrigin-RevId: 653615736 * Add recursion check when parsing unknown fields in Java. PiperOrigin-RevId: 675657198 --------- Co-authored-by: Protobuf Team Bot <protobuf-github-bot@google.com>
1 parent 6fa3f2d commit d6c82fc

File tree

12 files changed

+493
-93
lines changed

12 files changed

+493
-93
lines changed

java/core/BUILD.bazel

+2
Original file line numberDiff line numberDiff line change
@@ -608,6 +608,7 @@ junit_tests(
608608
"src/test/java/com/google/protobuf/DescriptorsTest.java",
609609
"src/test/java/com/google/protobuf/DebugFormatTest.java",
610610
"src/test/java/com/google/protobuf/CodedOutputStreamTest.java",
611+
"src/test/java/com/google/protobuf/CodedInputStreamTest.java",
611612
"src/test/java/com/google/protobuf/ProtobufToStringOutputTest.java",
612613
# Excluded in core_tests
613614
"src/test/java/com/google/protobuf/DecodeUtf8Test.java",
@@ -656,6 +657,7 @@ junit_tests(
656657
"src/test/java/com/google/protobuf/DescriptorsTest.java",
657658
"src/test/java/com/google/protobuf/DebugFormatTest.java",
658659
"src/test/java/com/google/protobuf/CodedOutputStreamTest.java",
660+
"src/test/java/com/google/protobuf/CodedInputStreamTest.java",
659661
"src/test/java/com/google/protobuf/ProtobufToStringOutputTest.java",
660662
# Excluded in core_tests
661663
"src/test/java/com/google/protobuf/DecodeUtf8Test.java",

java/core/src/main/java/com/google/protobuf/ArrayDecoders.java

+29-2
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,12 @@
2323
*/
2424
@CheckReturnValue
2525
final class ArrayDecoders {
26+
static final int DEFAULT_RECURSION_LIMIT = 100;
2627

27-
private ArrayDecoders() {
28-
}
28+
@SuppressWarnings("NonFinalStaticField")
29+
private static volatile int recursionLimit = DEFAULT_RECURSION_LIMIT;
30+
31+
private ArrayDecoders() {}
2932

3033
/**
3134
* A helper used to return multiple values in a Java function. Java doesn't natively support
@@ -38,6 +41,7 @@ static final class Registers {
3841
public long long1;
3942
public Object object1;
4043
public final ExtensionRegistryLite extensionRegistry;
44+
public int recursionDepth;
4145

4246
Registers() {
4347
this.extensionRegistry = ExtensionRegistryLite.getEmptyRegistry();
@@ -245,7 +249,10 @@ static int mergeMessageField(
245249
if (length < 0 || length > limit - position) {
246250
throw InvalidProtocolBufferException.truncatedMessage();
247251
}
252+
registers.recursionDepth++;
253+
checkRecursionLimit(registers.recursionDepth);
248254
schema.mergeFrom(msg, data, position, position + length, registers);
255+
registers.recursionDepth--;
249256
registers.object1 = msg;
250257
return position + length;
251258
}
@@ -263,8 +270,11 @@ static int mergeGroupField(
263270
// A group field must has a MessageSchema (the only other subclass of Schema is MessageSetSchema
264271
// and it can't be used in group fields).
265272
final MessageSchema messageSchema = (MessageSchema) schema;
273+
registers.recursionDepth++;
274+
checkRecursionLimit(registers.recursionDepth);
266275
final int endPosition =
267276
messageSchema.parseMessage(msg, data, position, limit, endGroup, registers);
277+
registers.recursionDepth--;
268278
registers.object1 = msg;
269279
return endPosition;
270280
}
@@ -1025,6 +1035,8 @@ static int decodeUnknownField(
10251035
final UnknownFieldSetLite child = UnknownFieldSetLite.newInstance();
10261036
final int endGroup = (tag & ~0x7) | WireFormat.WIRETYPE_END_GROUP;
10271037
int lastTag = 0;
1038+
registers.recursionDepth++;
1039+
checkRecursionLimit(registers.recursionDepth);
10281040
while (position < limit) {
10291041
position = decodeVarint32(data, position, registers);
10301042
lastTag = registers.int1;
@@ -1033,6 +1045,7 @@ static int decodeUnknownField(
10331045
}
10341046
position = decodeUnknownField(lastTag, data, position, limit, child, registers);
10351047
}
1048+
registers.recursionDepth--;
10361049
if (position > limit || lastTag != endGroup) {
10371050
throw InvalidProtocolBufferException.parseFailure();
10381051
}
@@ -1079,4 +1092,18 @@ static int skipField(int tag, byte[] data, int position, int limit, Registers re
10791092
throw InvalidProtocolBufferException.invalidTag();
10801093
}
10811094
}
1095+
1096+
/**
1097+
* Set the maximum recursion limit that ArrayDecoders will allow. An exception will be thrown if
1098+
* the depth of the message exceeds this limit.
1099+
*/
1100+
public static void setRecursionLimit(int limit) {
1101+
recursionLimit = limit;
1102+
}
1103+
1104+
private static void checkRecursionLimit(int depth) throws InvalidProtocolBufferException {
1105+
if (depth >= recursionLimit) {
1106+
throw InvalidProtocolBufferException.recursionLimitExceeded();
1107+
}
1108+
}
10821109
}

java/core/src/main/java/com/google/protobuf/CodedInputStream.java

+30-82
Original file line numberDiff line numberDiff line change
@@ -224,13 +224,41 @@ public abstract boolean skipField(final int tag, final CodedOutputStream output)
224224
* Reads and discards an entire message. This will read either until EOF or until an endgroup tag,
225225
* whichever comes first.
226226
*/
227-
public abstract void skipMessage() throws IOException;
227+
public void skipMessage() throws IOException {
228+
while (true) {
229+
final int tag = readTag();
230+
if (tag == 0) {
231+
return;
232+
}
233+
checkRecursionLimit();
234+
++recursionDepth;
235+
boolean fieldSkipped = skipField(tag);
236+
--recursionDepth;
237+
if (!fieldSkipped) {
238+
return;
239+
}
240+
}
241+
}
228242

229243
/**
230244
* Reads an entire message and writes it to output in wire format. This will read either until EOF
231245
* or until an endgroup tag, whichever comes first.
232246
*/
233-
public abstract void skipMessage(CodedOutputStream output) throws IOException;
247+
public void skipMessage(CodedOutputStream output) throws IOException {
248+
while (true) {
249+
final int tag = readTag();
250+
if (tag == 0) {
251+
return;
252+
}
253+
checkRecursionLimit();
254+
++recursionDepth;
255+
boolean fieldSkipped = skipField(tag, output);
256+
--recursionDepth;
257+
if (!fieldSkipped) {
258+
return;
259+
}
260+
}
261+
}
234262

235263
// -----------------------------------------------------------------
236264

@@ -700,26 +728,6 @@ public boolean skipField(final int tag, final CodedOutputStream output) throws I
700728
}
701729
}
702730

703-
@Override
704-
public void skipMessage() throws IOException {
705-
while (true) {
706-
final int tag = readTag();
707-
if (tag == 0 || !skipField(tag)) {
708-
return;
709-
}
710-
}
711-
}
712-
713-
@Override
714-
public void skipMessage(CodedOutputStream output) throws IOException {
715-
while (true) {
716-
final int tag = readTag();
717-
if (tag == 0 || !skipField(tag, output)) {
718-
return;
719-
}
720-
}
721-
}
722-
723731
// -----------------------------------------------------------------
724732

725733
@Override
@@ -1412,26 +1420,6 @@ public boolean skipField(final int tag, final CodedOutputStream output) throws I
14121420
}
14131421
}
14141422

1415-
@Override
1416-
public void skipMessage() throws IOException {
1417-
while (true) {
1418-
final int tag = readTag();
1419-
if (tag == 0 || !skipField(tag)) {
1420-
return;
1421-
}
1422-
}
1423-
}
1424-
1425-
@Override
1426-
public void skipMessage(CodedOutputStream output) throws IOException {
1427-
while (true) {
1428-
final int tag = readTag();
1429-
if (tag == 0 || !skipField(tag, output)) {
1430-
return;
1431-
}
1432-
}
1433-
}
1434-
14351423
// -----------------------------------------------------------------
14361424

14371425
@Override
@@ -2178,26 +2166,6 @@ public boolean skipField(final int tag, final CodedOutputStream output) throws I
21782166
}
21792167
}
21802168

2181-
@Override
2182-
public void skipMessage() throws IOException {
2183-
while (true) {
2184-
final int tag = readTag();
2185-
if (tag == 0 || !skipField(tag)) {
2186-
return;
2187-
}
2188-
}
2189-
}
2190-
2191-
@Override
2192-
public void skipMessage(CodedOutputStream output) throws IOException {
2193-
while (true) {
2194-
final int tag = readTag();
2195-
if (tag == 0 || !skipField(tag, output)) {
2196-
return;
2197-
}
2198-
}
2199-
}
2200-
22012169
/** Collects the bytes skipped and returns the data in a ByteBuffer. */
22022170
private class SkippedDataSink implements RefillCallback {
22032171
private int lastPos = pos;
@@ -3322,26 +3290,6 @@ public boolean skipField(final int tag, final CodedOutputStream output) throws I
33223290
}
33233291
}
33243292

3325-
@Override
3326-
public void skipMessage() throws IOException {
3327-
while (true) {
3328-
final int tag = readTag();
3329-
if (tag == 0 || !skipField(tag)) {
3330-
return;
3331-
}
3332-
}
3333-
}
3334-
3335-
@Override
3336-
public void skipMessage(CodedOutputStream output) throws IOException {
3337-
while (true) {
3338-
final int tag = readTag();
3339-
if (tag == 0 || !skipField(tag, output)) {
3340-
return;
3341-
}
3342-
}
3343-
}
3344-
33453293
// -----------------------------------------------------------------
33463294

33473295
@Override

java/core/src/main/java/com/google/protobuf/InvalidProtocolBufferException.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ public InvalidWireTypeException(String description) {
132132
static InvalidProtocolBufferException recursionLimitExceeded() {
133133
return new InvalidProtocolBufferException(
134134
"Protocol message had too many levels of nesting. May be malicious. "
135-
+ "Use CodedInputStream.setRecursionLimit() to increase the depth limit.");
135+
+ "Use setRecursionLimit() to increase the recursion depth limit.");
136136
}
137137

138138
static InvalidProtocolBufferException sizeLimitExceeded() {

java/core/src/main/java/com/google/protobuf/MessageSchema.java

+6-3
Original file line numberDiff line numberDiff line change
@@ -3006,7 +3006,8 @@ private <UT, UB, ET extends FieldDescriptorLite<ET>> void mergeFromHelper(
30063006
unknownFields = unknownFieldSchema.getBuilderFromMessage(message);
30073007
}
30083008
// Unknown field.
3009-
if (unknownFieldSchema.mergeOneFieldFrom(unknownFields, reader)) {
3009+
if (unknownFieldSchema.mergeOneFieldFrom(
3010+
unknownFields, reader, /* currentDepth= */ 0)) {
30103011
continue;
30113012
}
30123013
}
@@ -3381,7 +3382,8 @@ private <UT, UB, ET extends FieldDescriptorLite<ET>> void mergeFromHelper(
33813382
if (unknownFields == null) {
33823383
unknownFields = unknownFieldSchema.getBuilderFromMessage(message);
33833384
}
3384-
if (!unknownFieldSchema.mergeOneFieldFrom(unknownFields, reader)) {
3385+
if (!unknownFieldSchema.mergeOneFieldFrom(
3386+
unknownFields, reader, /* currentDepth= */ 0)) {
33853387
return;
33863388
}
33873389
break;
@@ -3397,7 +3399,8 @@ private <UT, UB, ET extends FieldDescriptorLite<ET>> void mergeFromHelper(
33973399
if (unknownFields == null) {
33983400
unknownFields = unknownFieldSchema.getBuilderFromMessage(message);
33993401
}
3400-
if (!unknownFieldSchema.mergeOneFieldFrom(unknownFields, reader)) {
3402+
if (!unknownFieldSchema.mergeOneFieldFrom(
3403+
unknownFields, reader, /* currentDepth= */ 0)) {
34013404
return;
34023405
}
34033406
}

java/core/src/main/java/com/google/protobuf/MessageSetSchema.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -278,7 +278,7 @@ boolean parseMessageSetItemOrUnknownField(
278278
reader, extension, extensionRegistry, extensions);
279279
return true;
280280
} else {
281-
return unknownFieldSchema.mergeOneFieldFrom(unknownFields, reader);
281+
return unknownFieldSchema.mergeOneFieldFrom(unknownFields, reader, /* currentDepth= */ 0);
282282
}
283283
} else {
284284
return reader.skipField();

java/core/src/main/java/com/google/protobuf/UnknownFieldSchema.java

+24-4
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,11 @@
1313
@CheckReturnValue
1414
abstract class UnknownFieldSchema<T, B> {
1515

16+
static final int DEFAULT_RECURSION_LIMIT = 100;
17+
18+
@SuppressWarnings("NonFinalStaticField")
19+
private static volatile int recursionLimit = DEFAULT_RECURSION_LIMIT;
20+
1621
/** Whether unknown fields should be dropped. */
1722
abstract boolean shouldDiscardUnknownFields(Reader reader);
1823

@@ -56,7 +61,8 @@ abstract class UnknownFieldSchema<T, B> {
5661
abstract void makeImmutable(Object message);
5762

5863
/** Merges one field into the unknown fields. */
59-
final boolean mergeOneFieldFrom(B unknownFields, Reader reader) throws IOException {
64+
final boolean mergeOneFieldFrom(B unknownFields, Reader reader, int currentDepth)
65+
throws IOException {
6066
int tag = reader.getTag();
6167
int fieldNumber = WireFormat.getTagFieldNumber(tag);
6268
switch (WireFormat.getTagWireType(tag)) {
@@ -75,7 +81,12 @@ final boolean mergeOneFieldFrom(B unknownFields, Reader reader) throws IOExcepti
7581
case WireFormat.WIRETYPE_START_GROUP:
7682
final B subFields = newBuilder();
7783
int endGroupTag = WireFormat.makeTag(fieldNumber, WireFormat.WIRETYPE_END_GROUP);
78-
mergeFrom(subFields, reader);
84+
currentDepth++;
85+
if (currentDepth >= recursionLimit) {
86+
throw InvalidProtocolBufferException.recursionLimitExceeded();
87+
}
88+
mergeFrom(subFields, reader, currentDepth);
89+
currentDepth--;
7990
if (endGroupTag != reader.getTag()) {
8091
throw InvalidProtocolBufferException.invalidEndTag();
8192
}
@@ -88,10 +99,11 @@ final boolean mergeOneFieldFrom(B unknownFields, Reader reader) throws IOExcepti
8899
}
89100
}
90101

91-
final void mergeFrom(B unknownFields, Reader reader) throws IOException {
102+
private final void mergeFrom(B unknownFields, Reader reader, int currentDepth)
103+
throws IOException {
92104
while (true) {
93105
if (reader.getFieldNumber() == Reader.READ_DONE
94-
|| !mergeOneFieldFrom(unknownFields, reader)) {
106+
|| !mergeOneFieldFrom(unknownFields, reader, currentDepth)) {
95107
break;
96108
}
97109
}
@@ -108,4 +120,12 @@ final void mergeFrom(B unknownFields, Reader reader) throws IOException {
108120
abstract int getSerializedSizeAsMessageSet(T message);
109121

110122
abstract int getSerializedSize(T unknowns);
123+
124+
/**
125+
* Set the maximum recursion limit that ArrayDecoders will allow. An exception will be thrown if
126+
* the depth of the message exceeds this limit.
127+
*/
128+
public void setRecursionLimit(int limit) {
129+
recursionLimit = limit;
130+
}
111131
}

0 commit comments

Comments
 (0)