Skip to content

Commit 60a55bf

Browse files
maskri17copybara-github
authored andcommitted
Checker and parser changes to support comprehensionsV2
PiperOrigin-RevId: 788649636
1 parent e1b6c11 commit 60a55bf

File tree

13 files changed

+302
-1
lines changed

13 files changed

+302
-1
lines changed

checker/src/main/java/dev/cel/checker/ExprChecker.java

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import com.google.auto.value.AutoValue;
2323
import com.google.common.base.Joiner;
2424
import com.google.common.base.Optional;
25+
import com.google.common.base.Strings;
2526
import com.google.common.collect.ImmutableList;
2627
import com.google.common.collect.Maps;
2728
import com.google.errorprone.annotations.CheckReturnValue;
@@ -510,13 +511,21 @@ private CelExpr visit(CelExpr expr, CelExpr.CelComprehension compre) {
510511
CelType accuType = env.getType(visitedInit);
511512
CelType rangeType = inferenceContext.specialize(env.getType(visitedRange));
512513
CelType varType;
514+
CelType varType2 = null;
513515
switch (rangeType.kind()) {
514516
case LIST:
515517
varType = ((ListType) rangeType).elemType();
518+
if (!Strings.isNullOrEmpty(compre.iterVar2())) {
519+
varType2 = varType;
520+
varType = SimpleType.INT;
521+
}
516522
break;
517523
case MAP:
518524
// Ranges over the keys.
519525
varType = ((MapType) rangeType).keyType();
526+
if (!Strings.isNullOrEmpty(compre.iterVar2())) {
527+
varType2 = ((MapType) rangeType).valueType();
528+
}
520529
break;
521530
case DYN:
522531
case ERROR:
@@ -547,6 +556,9 @@ private CelExpr visit(CelExpr expr, CelExpr.CelComprehension compre) {
547556
// Declare iteration variable on inner scope.
548557
env.enterScope();
549558
env.add(CelIdentDecl.newIdentDeclaration(compre.iterVar(), varType));
559+
if (!Strings.isNullOrEmpty(compre.iterVar2())) {
560+
env.add(CelIdentDecl.newIdentDeclaration(compre.iterVar2(), varType2));
561+
}
550562
CelExpr condition = visit(compre.loopCondition());
551563
assertType(condition, SimpleType.BOOL);
552564
CelExpr visitedStep = visit(compre.loopStep());

checker/src/test/java/dev/cel/checker/ExprCheckerTest.java

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -792,6 +792,26 @@ public void quantifiers() throws Exception {
792792
runTest();
793793
}
794794

795+
@Test
796+
public void twoVarComprehensions() throws Exception {
797+
CelType messageType = StructTypeReference.create("cel.expr.conformance.proto3.TestAllTypes");
798+
declareVariable("x", messageType);
799+
source =
800+
"x.map_string_string.all(i, v, i < v) "
801+
+ "&& x.repeated_int64.all(i, v, i < v) "
802+
+ "&& [1, 2, 3, 4].all(i, v, i < 5 && v > 0)"
803+
+ "&& {'a': 1, 'b': 2}.all(k, v, k.startsWith('a') && v == 1)";
804+
runTest();
805+
}
806+
807+
@Test
808+
public void twoVarComprehensionsErrors() throws Exception {
809+
CelType messageType = StructTypeReference.create("cel.expr.conformance.proto3.TestAllTypes");
810+
declareVariable("x", messageType);
811+
source = "x.map_string_string.all(i + 1, v, i < v) && x.repeated_int64.all(i, v + 1, i < v)";
812+
runTest();
813+
}
814+
795815
@Test
796816
public void quantifiersErrors() throws Exception {
797817
CelType messageType = StructTypeReference.create("cel.expr.conformance.proto3.TestAllTypes");
Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
Source: x.map_string_string.all(i, v, i < v) && x.repeated_int64.all(i, v, i < v) && [1, 2, 3, 4].all(i, v, i < 5 && v > 0)&& {'a': 1, 'b': 2}.all(k, v, k.startsWith('a') && v == 1)
2+
declare x {
3+
value cel.expr.conformance.proto3.TestAllTypes
4+
}
5+
=====>
6+
_&&_(
7+
_&&_(
8+
__comprehension__(
9+
// Variable
10+
i,
11+
v,
12+
// Target
13+
x~cel.expr.conformance.proto3.TestAllTypes^x.map_string_string~map(string, string),
14+
// Accumulator
15+
@result,
16+
// Init
17+
true~bool,
18+
// LoopCondition
19+
@not_strictly_false(
20+
@result~bool^@result
21+
)~bool^not_strictly_false,
22+
// LoopStep
23+
_&&_(
24+
@result~bool^@result,
25+
_<_(
26+
i~string^i,
27+
v~string^v
28+
)~bool^less_string
29+
)~bool^logical_and,
30+
// Result
31+
@result~bool^@result)~bool,
32+
__comprehension__(
33+
// Variable
34+
i,
35+
v,
36+
// Target
37+
x~cel.expr.conformance.proto3.TestAllTypes^x.repeated_int64~list(int),
38+
// Accumulator
39+
@result,
40+
// Init
41+
true~bool,
42+
// LoopCondition
43+
@not_strictly_false(
44+
@result~bool^@result
45+
)~bool^not_strictly_false,
46+
// LoopStep
47+
_&&_(
48+
@result~bool^@result,
49+
_<_(
50+
i~int^i,
51+
v~int^v
52+
)~bool^less_int64
53+
)~bool^logical_and,
54+
// Result
55+
@result~bool^@result)~bool
56+
)~bool^logical_and,
57+
_&&_(
58+
__comprehension__(
59+
// Variable
60+
i,
61+
v,
62+
// Target
63+
[
64+
1~int,
65+
2~int,
66+
3~int,
67+
4~int
68+
]~list(int),
69+
// Accumulator
70+
@result,
71+
// Init
72+
true~bool,
73+
// LoopCondition
74+
@not_strictly_false(
75+
@result~bool^@result
76+
)~bool^not_strictly_false,
77+
// LoopStep
78+
_&&_(
79+
@result~bool^@result,
80+
_&&_(
81+
_<_(
82+
i~int^i,
83+
5~int
84+
)~bool^less_int64,
85+
_>_(
86+
v~int^v,
87+
0~int
88+
)~bool^greater_int64
89+
)~bool^logical_and
90+
)~bool^logical_and,
91+
// Result
92+
@result~bool^@result)~bool,
93+
__comprehension__(
94+
// Variable
95+
k,
96+
v,
97+
// Target
98+
{
99+
"a"~string:1~int,
100+
"b"~string:2~int
101+
}~map(string, int),
102+
// Accumulator
103+
@result,
104+
// Init
105+
true~bool,
106+
// LoopCondition
107+
@not_strictly_false(
108+
@result~bool^@result
109+
)~bool^not_strictly_false,
110+
// LoopStep
111+
_&&_(
112+
@result~bool^@result,
113+
_&&_(
114+
k~string^k.startsWith(
115+
"a"~string
116+
)~bool^starts_with_string,
117+
_==_(
118+
v~int^v,
119+
1~int
120+
)~bool^equals
121+
)~bool^logical_and
122+
)~bool^logical_and,
123+
// Result
124+
@result~bool^@result)~bool
125+
)~bool^logical_and
126+
)~bool^logical_and
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
Source: x.map_string_string.all(i + 1, v, i < v) && x.repeated_int64.all(i, v + 1, i < v)
2+
declare x {
3+
value cel.expr.conformance.proto3.TestAllTypes
4+
}
5+
=====>
6+
ERROR: test_location:1:27: The argument must be a simple name
7+
| x.map_string_string.all(i + 1, v, i < v) && x.repeated_int64.all(i, v + 1, i < v)
8+
| ..........................^
9+
ERROR: test_location:1:71: The argument must be a simple name
10+
| x.map_string_string.all(i + 1, v, i < v) && x.repeated_int64.all(i, v + 1, i < v)
11+
| ......................................................................^

common/src/main/java/dev/cel/common/ast/CelExprFactory.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ public final CelExpr.CelStruct.Entry newMessageField(String field, CelExpr value
138138
.build();
139139
}
140140

141-
/** Fold creates a fold comprehension instruction. */
141+
/** Fold creates a fold for one variable comprehension instruction. */
142142
public final CelExpr fold(
143143
String iterVar,
144144
CelExpr iterRange,

extensions/BUILD.bazel

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,3 +51,8 @@ java_library(
5151
name = "sets_function",
5252
exports = ["//extensions/src/main/java/dev/cel/extensions:sets_function"],
5353
)
54+
55+
java_library(
56+
name = "comprehensions",
57+
exports = ["//extensions/src/main/java/dev/cel/extensions:comprehensions"],
58+
)

extensions/src/main/java/dev/cel/extensions/BUILD.bazel

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ java_library(
3030
],
3131
deps = [
3232
":bindings",
33+
":comprehensions",
3334
":encoders",
3435
":lists",
3536
":math",
@@ -293,3 +294,17 @@ java_library(
293294
"@maven//:com_google_re2j_re2j",
294295
],
295296
)
297+
298+
java_library(
299+
name = "comprehensions",
300+
srcs = ["CelComprehensionsExtensions.java"],
301+
deps = [
302+
"//common:compiler_common",
303+
"//common/ast",
304+
"//compiler:compiler_builder",
305+
"//parser:macro",
306+
"//parser:operator",
307+
"//parser:parser_builder",
308+
"@maven//:com_google_guava_guava",
309+
],
310+
)
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
// Copyright 2025 Google LLC
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// https://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package dev.cel.extensions;
16+
17+
import static com.google.common.base.Preconditions.checkArgument;
18+
import static com.google.common.base.Preconditions.checkNotNull;
19+
20+
import com.google.common.collect.ImmutableList;
21+
import com.google.common.collect.ImmutableSet;
22+
import dev.cel.common.CelIssue;
23+
import dev.cel.common.ast.CelExpr;
24+
import dev.cel.compiler.CelCompilerLibrary;
25+
import dev.cel.parser.CelMacro;
26+
import dev.cel.parser.CelMacroExprFactory;
27+
import dev.cel.parser.CelParserBuilder;
28+
import dev.cel.parser.Operator;
29+
import java.util.Optional;
30+
31+
/** Internal implementation of CEL comprehensions extensions. */
32+
public final class CelComprehensionsExtensions implements CelCompilerLibrary {
33+
// TODO: Implement CelExtensionLibrary.FeatureSet interface.
34+
public ImmutableSet<CelMacro> macros() {
35+
return ImmutableSet.of(
36+
CelMacro.newReceiverMacro(
37+
Operator.ALL.getFunction(), 3, CelComprehensionsExtensions::expandAllMacro));
38+
}
39+
40+
@Override
41+
public void setParserOptions(CelParserBuilder parserBuilder) {
42+
parserBuilder.addMacros(macros());
43+
}
44+
45+
private static Optional<CelExpr> expandAllMacro(
46+
CelMacroExprFactory exprFactory, CelExpr target, ImmutableList<CelExpr> arguments) {
47+
checkNotNull(exprFactory);
48+
checkNotNull(target);
49+
checkArgument(arguments.size() == 3);
50+
CelExpr arg0 = checkNotNull(arguments.get(0));
51+
if (arg0.exprKind().getKind() != CelExpr.ExprKind.Kind.IDENT) {
52+
return Optional.of(reportArgumentError(exprFactory, arg0));
53+
}
54+
CelExpr arg1 = checkNotNull(arguments.get(1));
55+
if (arg1.exprKind().getKind() != CelExpr.ExprKind.Kind.IDENT) {
56+
return Optional.of(reportArgumentError(exprFactory, arg1));
57+
}
58+
CelExpr arg2 = checkNotNull(arguments.get(2));
59+
CelExpr accuInit = exprFactory.newBoolLiteral(true);
60+
CelExpr condition =
61+
exprFactory.newGlobalCall(
62+
Operator.NOT_STRICTLY_FALSE.getFunction(),
63+
exprFactory.newIdentifier(exprFactory.getAccumulatorVarName()));
64+
CelExpr step =
65+
exprFactory.newGlobalCall(
66+
Operator.LOGICAL_AND.getFunction(),
67+
exprFactory.newIdentifier(exprFactory.getAccumulatorVarName()),
68+
arg2);
69+
CelExpr result = exprFactory.newIdentifier(exprFactory.getAccumulatorVarName());
70+
return Optional.of(
71+
exprFactory.fold(
72+
arg0.ident().name(),
73+
arg1.ident().name(),
74+
target,
75+
exprFactory.getAccumulatorVarName(),
76+
accuInit,
77+
condition,
78+
step,
79+
result));
80+
}
81+
82+
private static CelExpr reportArgumentError(CelMacroExprFactory exprFactory, CelExpr argument) {
83+
return exprFactory.reportError(
84+
CelIssue.formatError(
85+
exprFactory.getSourceLocation(argument), "The argument must be a simple name"));
86+
}
87+
}

extensions/src/main/java/dev/cel/extensions/CelExtensions.java

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ public final class CelExtensions {
3434
private static final CelProtoExtensions PROTO_EXTENSIONS = new CelProtoExtensions();
3535
private static final CelBindingsExtensions BINDINGS_EXTENSIONS = new CelBindingsExtensions();
3636
private static final CelRegexExtensions REGEX_EXTENSIONS = new CelRegexExtensions();
37+
private static final CelComprehensionsExtensions COMPREHENSIONS_EXTENSIONS =
38+
new CelComprehensionsExtensions();
3739

3840
/**
3941
* Implementation of optional values.
@@ -305,6 +307,20 @@ public static CelRegexExtensions regex() {
305307
return REGEX_EXTENSIONS;
306308
}
307309

310+
/**
311+
* Extended functions for Two Variable Comprehensions Expressions.
312+
*
313+
* <p>Refer to README.md for available functions.
314+
*
315+
* <p>This will include all functions denoted in {@link CelComprehensionsExtensions.Function},
316+
* including any future additions.
317+
*/
318+
// TODO: Remove visibility restrictions and make this public once the feature is
319+
// ready.
320+
private static CelComprehensionsExtensions comprehensions() {
321+
return COMPREHENSIONS_EXTENSIONS;
322+
}
323+
308324
/**
309325
* Retrieves all function names used by every extension libraries.
310326
*

parser/src/main/java/dev/cel/parser/CelMacroExprFactory.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@ public final CelExpr copy(CelExpr expr) {
133133
builder.setComprehension(
134134
CelExpr.CelComprehension.newBuilder()
135135
.setIterVar(expr.comprehension().iterVar())
136+
.setIterVar2(expr.comprehension().iterVar2())
136137
.setIterRange(copy(expr.comprehension().iterRange()))
137138
.setAccuVar(expr.comprehension().accuVar())
138139
.setAccuInit(copy(expr.comprehension().accuInit()))

0 commit comments

Comments
 (0)