23
23
*/
24
24
@ CheckReturnValue
25
25
final class ArrayDecoders {
26
+ static final int DEFAULT_RECURSION_LIMIT = 100 ;
26
27
27
- private ArrayDecoders () {
28
- }
28
+ @ SuppressWarnings ("NonFinalStaticField" )
29
+ private static volatile int recursionLimit = DEFAULT_RECURSION_LIMIT ;
30
+
31
+ private ArrayDecoders () {}
29
32
30
33
/**
31
34
* A helper used to return multiple values in a Java function. Java doesn't natively support
@@ -38,6 +41,7 @@ static final class Registers {
38
41
public long long1 ;
39
42
public Object object1 ;
40
43
public final ExtensionRegistryLite extensionRegistry ;
44
+ public int recursionDepth ;
41
45
42
46
Registers () {
43
47
this .extensionRegistry = ExtensionRegistryLite .getEmptyRegistry ();
@@ -245,7 +249,10 @@ static int mergeMessageField(
245
249
if (length < 0 || length > limit - position ) {
246
250
throw InvalidProtocolBufferException .truncatedMessage ();
247
251
}
252
+ registers .recursionDepth ++;
253
+ checkRecursionLimit (registers .recursionDepth );
248
254
schema .mergeFrom (msg , data , position , position + length , registers );
255
+ registers .recursionDepth --;
249
256
registers .object1 = msg ;
250
257
return position + length ;
251
258
}
@@ -263,8 +270,11 @@ static int mergeGroupField(
263
270
// A group field must has a MessageSchema (the only other subclass of Schema is MessageSetSchema
264
271
// and it can't be used in group fields).
265
272
final MessageSchema messageSchema = (MessageSchema ) schema ;
273
+ registers .recursionDepth ++;
274
+ checkRecursionLimit (registers .recursionDepth );
266
275
final int endPosition =
267
276
messageSchema .parseMessage (msg , data , position , limit , endGroup , registers );
277
+ registers .recursionDepth --;
268
278
registers .object1 = msg ;
269
279
return endPosition ;
270
280
}
@@ -1025,6 +1035,8 @@ static int decodeUnknownField(
1025
1035
final UnknownFieldSetLite child = UnknownFieldSetLite .newInstance ();
1026
1036
final int endGroup = (tag & ~0x7 ) | WireFormat .WIRETYPE_END_GROUP ;
1027
1037
int lastTag = 0 ;
1038
+ registers .recursionDepth ++;
1039
+ checkRecursionLimit (registers .recursionDepth );
1028
1040
while (position < limit ) {
1029
1041
position = decodeVarint32 (data , position , registers );
1030
1042
lastTag = registers .int1 ;
@@ -1033,6 +1045,7 @@ static int decodeUnknownField(
1033
1045
}
1034
1046
position = decodeUnknownField (lastTag , data , position , limit , child , registers );
1035
1047
}
1048
+ registers .recursionDepth --;
1036
1049
if (position > limit || lastTag != endGroup ) {
1037
1050
throw InvalidProtocolBufferException .parseFailure ();
1038
1051
}
@@ -1079,4 +1092,18 @@ static int skipField(int tag, byte[] data, int position, int limit, Registers re
1079
1092
throw InvalidProtocolBufferException .invalidTag ();
1080
1093
}
1081
1094
}
1095
+
1096
+ /**
1097
+ * Set the maximum recursion limit that ArrayDecoders will allow. An exception will be thrown if
1098
+ * the depth of the message exceeds this limit.
1099
+ */
1100
+ public static void setRecursionLimit (int limit ) {
1101
+ recursionLimit = limit ;
1102
+ }
1103
+
1104
+ private static void checkRecursionLimit (int depth ) throws InvalidProtocolBufferException {
1105
+ if (depth >= recursionLimit ) {
1106
+ throw InvalidProtocolBufferException .recursionLimitExceeded ();
1107
+ }
1108
+ }
1082
1109
}
0 commit comments