Skip to content

Commit 35ee644

Browse files
authored
ESQL: Check a few toStrings for aggs (#132700)
Adds `toString` checking for aggregators to the generic aggs test cases so we can make sure they spit out sensible looking results. We have this for scalar functions but it isn't plugged in for aggs and I noticed it while working on #132603 where I stuck `asdf` for the toString thinking I'd fix it when the test failed. It didn't. There's to many changes to grab this in one go so I've made a hook that tests can opt into. We'll drop the hook once everything has opted into it.
1 parent 063fa89 commit 35ee644

File tree

5 files changed

+109
-28
lines changed

5 files changed

+109
-28
lines changed

x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/AbstractAggregationTestCase.java

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,13 +42,15 @@
4242
import java.util.stream.IntStream;
4343

4444
import static org.elasticsearch.xpack.esql.EsqlTestUtils.unboundLogicalOptimizerContext;
45+
import static org.hamcrest.Matchers.endsWith;
4546
import static org.hamcrest.Matchers.equalTo;
4647
import static org.hamcrest.Matchers.instanceOf;
4748
import static org.hamcrest.Matchers.is;
4849
import static org.hamcrest.Matchers.lessThan;
4950
import static org.hamcrest.Matchers.not;
5051
import static org.hamcrest.Matchers.nullValue;
5152
import static org.hamcrest.Matchers.oneOf;
53+
import static org.hamcrest.Matchers.startsWith;
5254

5355
/**
5456
* Base class for aggregation tests.
@@ -140,12 +142,30 @@ public void testAggregate() {
140142
resolveExpression(expression, this::aggregateSingleMode, this::evaluate);
141143
}
142144

145+
public void testAggregateToString() {
146+
Expression expression = randomBoolean() ? buildDeepCopyOfFieldExpression(testCase) : buildFieldExpression(testCase);
147+
resolveExpression(expression, e -> {
148+
try (var aggregator = aggregator(e, initialInputChannels(), AggregatorMode.SINGLE)) {
149+
assertAggregatorToString(aggregator);
150+
}
151+
}, this::evaluate);
152+
}
153+
143154
public void testGroupingAggregate() {
144155
Expression expression = randomBoolean() ? buildDeepCopyOfFieldExpression(testCase) : buildFieldExpression(testCase);
145156

146157
resolveExpression(expression, this::aggregateGroupingSingleMode, this::evaluate);
147158
}
148159

160+
public void testGroupingAggregateToString() {
161+
Expression expression = randomBoolean() ? buildDeepCopyOfFieldExpression(testCase) : buildFieldExpression(testCase);
162+
resolveExpression(expression, e -> {
163+
try (var aggregator = groupingAggregator(e, initialInputChannels(), AggregatorMode.SINGLE)) {
164+
assertAggregatorToString(aggregator);
165+
}
166+
}, this::evaluate);
167+
}
168+
149169
public void testAggregateIntermediate() {
150170
Expression expression = randomBoolean() ? buildDeepCopyOfFieldExpression(testCase) : buildFieldExpression(testCase);
151171

@@ -164,6 +184,7 @@ public void testFold() {
164184
private void aggregateSingleMode(Expression expression) {
165185
Object result;
166186
try (var aggregator = aggregator(expression, initialInputChannels(), AggregatorMode.SINGLE)) {
187+
assertAggregatorToString(aggregator);
167188
for (Page inputPage : rows(testCase.getMultiRowFields())) {
168189
try (
169190
BooleanVector noMasking = driverContext().blockFactory().newConstantBooleanVector(true, inputPage.getPositionCount())
@@ -187,6 +208,7 @@ private void aggregateGroupingSingleMode(Expression expression) {
187208
assumeFalse("Grouping aggregations must receive data to check results", pages.isEmpty());
188209

189210
try (var aggregator = groupingAggregator(expression, initialInputChannels(), AggregatorMode.SINGLE)) {
211+
assertAggregatorToString(aggregator);
190212
var groupCount = randomIntBetween(1, 1000);
191213
for (Page inputPage : pages) {
192214
processPageGrouping(aggregator, inputPage, groupCount);
@@ -482,4 +504,43 @@ private void processPageGrouping(GroupingAggregator aggregator, Page inputPage,
482504
}
483505
}
484506
}
507+
508+
private void assertAggregatorToString(Object aggregator) {
509+
if (optIntoToAggregatorToStringChecks() == false) {
510+
return;
511+
}
512+
String expectedStart = switch (aggregator) {
513+
case Aggregator a -> "Aggregator[aggregatorFunction=";
514+
case GroupingAggregator a -> "GroupingAggregator[aggregatorFunction=";
515+
default -> throw new UnsupportedOperationException("can't check toString for [" + aggregator.getClass() + "]");
516+
};
517+
String expectedEnd = switch (aggregator) {
518+
case Aggregator a -> "AggregatorFunction[channels=[0]], mode=SINGLE]";
519+
case GroupingAggregator a -> "GroupingAggregatorFunction[channels=[0]], mode=SINGLE]";
520+
default -> throw new UnsupportedOperationException("can't check toString for [" + aggregator.getClass() + "]");
521+
};
522+
523+
String toString = aggregator.toString();
524+
assertThat(toString, startsWith(expectedStart));
525+
assertThat(toString.substring(expectedStart.length(), toString.length() - expectedEnd.length()), testCase.evaluatorToString());
526+
assertThat(toString, endsWith(expectedEnd));
527+
}
528+
529+
protected boolean optIntoToAggregatorToStringChecks() {
530+
// TODO remove this when everyone has opted in
531+
return false;
532+
}
533+
534+
protected static String standardAggregatorName(String prefix, DataType type) {
535+
String typeName = switch (type) {
536+
case BOOLEAN -> "Boolean";
537+
case KEYWORD, TEXT, VERSION -> "BytesRef";
538+
case DOUBLE -> "Double";
539+
case INTEGER -> "Int";
540+
case IP -> "Ip";
541+
case DATETIME, DATE_NANOS, LONG, UNSIGNED_LONG -> "Long";
542+
default -> throw new UnsupportedOperationException("name for [" + type + "]");
543+
};
544+
return prefix + typeName;
545+
}
485546
}

x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/MaxTests.java

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ public static Iterable<Object[]> parameters() {
7373
List.of(DataType.INTEGER),
7474
() -> new TestCaseSupplier.TestCase(
7575
List.of(TestCaseSupplier.TypedData.multiRow(List.of(200), DataType.INTEGER, "field")),
76-
"Max[field=Attribute[channel=0]]",
76+
standardAggregatorName("Max", DataType.INTEGER),
7777
DataType.INTEGER,
7878
equalTo(200)
7979
)
@@ -82,7 +82,7 @@ public static Iterable<Object[]> parameters() {
8282
List.of(DataType.LONG),
8383
() -> new TestCaseSupplier.TestCase(
8484
List.of(TestCaseSupplier.TypedData.multiRow(List.of(200L), DataType.LONG, "field")),
85-
"Max[field=Attribute[channel=0]]",
85+
standardAggregatorName("Max", DataType.LONG),
8686
DataType.LONG,
8787
equalTo(200L)
8888
)
@@ -94,7 +94,7 @@ public static Iterable<Object[]> parameters() {
9494
TestCaseSupplier.TypedData.multiRow(List.of(new BigInteger("200")), DataType.UNSIGNED_LONG, "field")
9595
.withAppliesTo(unsignedLongAppliesTo)
9696
),
97-
"Max[field=Attribute[channel=0]]",
97+
standardAggregatorName("Max", DataType.UNSIGNED_LONG),
9898
DataType.UNSIGNED_LONG,
9999
equalTo(new BigInteger("200"))
100100
)
@@ -103,7 +103,7 @@ public static Iterable<Object[]> parameters() {
103103
List.of(DataType.DOUBLE),
104104
() -> new TestCaseSupplier.TestCase(
105105
List.of(TestCaseSupplier.TypedData.multiRow(List.of(200.), DataType.DOUBLE, "field")),
106-
"Max[field=Attribute[channel=0]]",
106+
standardAggregatorName("Max", DataType.DOUBLE),
107107
DataType.DOUBLE,
108108
equalTo(200.)
109109
)
@@ -112,7 +112,7 @@ public static Iterable<Object[]> parameters() {
112112
List.of(DataType.DATETIME),
113113
() -> new TestCaseSupplier.TestCase(
114114
List.of(TestCaseSupplier.TypedData.multiRow(List.of(200L), DataType.DATETIME, "field")),
115-
"Max[field=Attribute[channel=0]]",
115+
standardAggregatorName("Max", DataType.DATETIME),
116116
DataType.DATETIME,
117117
equalTo(200L)
118118
)
@@ -121,7 +121,7 @@ public static Iterable<Object[]> parameters() {
121121
List.of(DataType.DATE_NANOS),
122122
() -> new TestCaseSupplier.TestCase(
123123
List.of(TestCaseSupplier.TypedData.multiRow(List.of(200L), DataType.DATE_NANOS, "field")),
124-
"Max[field=Attribute[channel=0]]",
124+
standardAggregatorName("Max", DataType.DATE_NANOS),
125125
DataType.DATE_NANOS,
126126
equalTo(200L)
127127
)
@@ -130,7 +130,7 @@ public static Iterable<Object[]> parameters() {
130130
List.of(DataType.BOOLEAN),
131131
() -> new TestCaseSupplier.TestCase(
132132
List.of(TestCaseSupplier.TypedData.multiRow(List.of(true), DataType.BOOLEAN, "field")),
133-
"Max[field=Attribute[channel=0]]",
133+
standardAggregatorName("Max", DataType.BOOLEAN),
134134
DataType.BOOLEAN,
135135
equalTo(true)
136136
)
@@ -145,7 +145,7 @@ public static Iterable<Object[]> parameters() {
145145
"field"
146146
)
147147
),
148-
"Max[field=Attribute[channel=0]]",
148+
standardAggregatorName("Max", DataType.IP),
149149
DataType.IP,
150150
equalTo(new BytesRef(InetAddressPoint.encode(InetAddresses.forString("127.0.0.1"))))
151151
)
@@ -154,7 +154,7 @@ public static Iterable<Object[]> parameters() {
154154
var value = new BytesRef(randomAlphaOfLengthBetween(0, 50));
155155
return new TestCaseSupplier.TestCase(
156156
List.of(TestCaseSupplier.TypedData.multiRow(List.of(value), DataType.KEYWORD, "field")),
157-
"Max[field=Attribute[channel=0]]",
157+
standardAggregatorName("Max", DataType.KEYWORD),
158158
DataType.KEYWORD,
159159
equalTo(value)
160160
);
@@ -163,7 +163,7 @@ public static Iterable<Object[]> parameters() {
163163
var value = new BytesRef(randomAlphaOfLengthBetween(0, 50));
164164
return new TestCaseSupplier.TestCase(
165165
List.of(TestCaseSupplier.TypedData.multiRow(List.of(value), DataType.TEXT, "field")),
166-
"Max[field=Attribute[channel=0]]",
166+
standardAggregatorName("Max", DataType.TEXT),
167167
DataType.KEYWORD,
168168
equalTo(value)
169169
);
@@ -175,7 +175,7 @@ public static Iterable<Object[]> parameters() {
175175
.toBytesRef();
176176
return new TestCaseSupplier.TestCase(
177177
List.of(TestCaseSupplier.TypedData.multiRow(List.of(value), DataType.VERSION, "field")),
178-
"Max[field=Attribute[channel=0]]",
178+
standardAggregatorName("Max", DataType.VERSION),
179179
DataType.VERSION,
180180
equalTo(value)
181181
);
@@ -203,10 +203,15 @@ private static TestCaseSupplier makeSupplier(TestCaseSupplier.TypedDataSupplier
203203

204204
return new TestCaseSupplier.TestCase(
205205
List.of(fieldTypedData),
206-
"Max[field=Attribute[channel=0]]",
206+
standardAggregatorName("Max", fieldSupplier.type()),
207207
fieldSupplier.type(),
208208
equalTo(expected)
209209
);
210210
});
211211
}
212+
213+
@Override
214+
protected boolean optIntoToAggregatorToStringChecks() {
215+
return true;
216+
}
212217
}

x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/MedianTests.java

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ public static Iterable<Object[]> parameters() {
4747
List.of(DataType.INTEGER),
4848
() -> new TestCaseSupplier.TestCase(
4949
List.of(TestCaseSupplier.TypedData.multiRow(List.of(200), DataType.INTEGER, "number")),
50-
"Median[field=Attribute[channel=0]]",
50+
standardAggregatorName("Percentile", DataType.INTEGER),
5151
DataType.DOUBLE,
5252
equalTo(200.)
5353
)
@@ -56,7 +56,7 @@ public static Iterable<Object[]> parameters() {
5656
List.of(DataType.LONG),
5757
() -> new TestCaseSupplier.TestCase(
5858
List.of(TestCaseSupplier.TypedData.multiRow(List.of(200L), DataType.LONG, "number")),
59-
"Median[field=Attribute[channel=0]]",
59+
standardAggregatorName("Percentile", DataType.LONG),
6060
DataType.DOUBLE,
6161
equalTo(200.)
6262
)
@@ -65,7 +65,7 @@ public static Iterable<Object[]> parameters() {
6565
List.of(DataType.DOUBLE),
6666
() -> new TestCaseSupplier.TestCase(
6767
List.of(TestCaseSupplier.TypedData.multiRow(List.of(200.), DataType.DOUBLE, "number")),
68-
"Median[field=Attribute[channel=0]]",
68+
standardAggregatorName("Percentile", DataType.DOUBLE),
6969
DataType.DOUBLE,
7070
equalTo(200.)
7171
)
@@ -94,11 +94,16 @@ private static TestCaseSupplier makeSupplier(TestCaseSupplier.TypedDataSupplier
9494

9595
return new TestCaseSupplier.TestCase(
9696
List.of(fieldTypedData),
97-
"Median[number=Attribute[channel=0]]",
97+
standardAggregatorName("Percentile", fieldSupplier.type()),
9898
DataType.DOUBLE,
9999
equalTo(expected)
100100
);
101101
}
102102
});
103103
}
104+
105+
@Override
106+
protected boolean optIntoToAggregatorToStringChecks() {
107+
return true;
108+
}
104109
}

x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/MinTests.java

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ public static Iterable<Object[]> parameters() {
7373
List.of(DataType.INTEGER),
7474
() -> new TestCaseSupplier.TestCase(
7575
List.of(TestCaseSupplier.TypedData.multiRow(List.of(200), DataType.INTEGER, "field")),
76-
"Min[field=Attribute[channel=0]]",
76+
standardAggregatorName("Min", DataType.INTEGER),
7777
DataType.INTEGER,
7878
equalTo(200)
7979
)
@@ -82,7 +82,7 @@ public static Iterable<Object[]> parameters() {
8282
List.of(DataType.LONG),
8383
() -> new TestCaseSupplier.TestCase(
8484
List.of(TestCaseSupplier.TypedData.multiRow(List.of(200L), DataType.LONG, "field")),
85-
"Min[field=Attribute[channel=0]]",
85+
standardAggregatorName("Min", DataType.LONG),
8686
DataType.LONG,
8787
equalTo(200L)
8888
)
@@ -94,7 +94,7 @@ public static Iterable<Object[]> parameters() {
9494
TestCaseSupplier.TypedData.multiRow(List.of(new BigInteger("200")), DataType.UNSIGNED_LONG, "field")
9595
.withAppliesTo(unsignedLongAppliesTo)
9696
),
97-
"Max[field=Attribute[channel=0]]",
97+
standardAggregatorName("Min", DataType.UNSIGNED_LONG),
9898
DataType.UNSIGNED_LONG,
9999
equalTo(new BigInteger("200"))
100100
)
@@ -103,7 +103,7 @@ public static Iterable<Object[]> parameters() {
103103
List.of(DataType.DOUBLE),
104104
() -> new TestCaseSupplier.TestCase(
105105
List.of(TestCaseSupplier.TypedData.multiRow(List.of(200.), DataType.DOUBLE, "field")),
106-
"Min[field=Attribute[channel=0]]",
106+
standardAggregatorName("Min", DataType.DOUBLE),
107107
DataType.DOUBLE,
108108
equalTo(200.)
109109
)
@@ -112,7 +112,7 @@ public static Iterable<Object[]> parameters() {
112112
List.of(DataType.DATETIME),
113113
() -> new TestCaseSupplier.TestCase(
114114
List.of(TestCaseSupplier.TypedData.multiRow(List.of(200L), DataType.DATETIME, "field")),
115-
"Min[field=Attribute[channel=0]]",
115+
standardAggregatorName("Min", DataType.DATETIME),
116116
DataType.DATETIME,
117117
equalTo(200L)
118118
)
@@ -121,7 +121,7 @@ public static Iterable<Object[]> parameters() {
121121
List.of(DataType.DATE_NANOS),
122122
() -> new TestCaseSupplier.TestCase(
123123
List.of(TestCaseSupplier.TypedData.multiRow(List.of(200L), DataType.DATE_NANOS, "field")),
124-
"Min[field=Attribute[channel=0]]",
124+
standardAggregatorName("Min", DataType.DATE_NANOS),
125125
DataType.DATE_NANOS,
126126
equalTo(200L)
127127
)
@@ -130,7 +130,7 @@ public static Iterable<Object[]> parameters() {
130130
List.of(DataType.BOOLEAN),
131131
() -> new TestCaseSupplier.TestCase(
132132
List.of(TestCaseSupplier.TypedData.multiRow(List.of(true), DataType.BOOLEAN, "field")),
133-
"Min[field=Attribute[channel=0]]",
133+
standardAggregatorName("Min", DataType.BOOLEAN),
134134
DataType.BOOLEAN,
135135
equalTo(true)
136136
)
@@ -145,7 +145,7 @@ public static Iterable<Object[]> parameters() {
145145
"field"
146146
)
147147
),
148-
"Min[field=Attribute[channel=0]]",
148+
standardAggregatorName("Min", DataType.IP),
149149
DataType.IP,
150150
equalTo(new BytesRef(InetAddressPoint.encode(InetAddresses.forString("127.0.0.1"))))
151151
)
@@ -154,7 +154,7 @@ public static Iterable<Object[]> parameters() {
154154
var value = new BytesRef(randomAlphaOfLengthBetween(0, 50));
155155
return new TestCaseSupplier.TestCase(
156156
List.of(TestCaseSupplier.TypedData.multiRow(List.of(value), DataType.KEYWORD, "field")),
157-
"Min[field=Attribute[channel=0]]",
157+
standardAggregatorName("Min", DataType.KEYWORD),
158158
DataType.KEYWORD,
159159
equalTo(value)
160160
);
@@ -163,7 +163,7 @@ public static Iterable<Object[]> parameters() {
163163
var value = new BytesRef(randomAlphaOfLengthBetween(0, 50));
164164
return new TestCaseSupplier.TestCase(
165165
List.of(TestCaseSupplier.TypedData.multiRow(List.of(value), DataType.TEXT, "field")),
166-
"Min[field=Attribute[channel=0]]",
166+
standardAggregatorName("Min", DataType.TEXT),
167167
DataType.KEYWORD,
168168
equalTo(value)
169169
);
@@ -175,7 +175,7 @@ public static Iterable<Object[]> parameters() {
175175
.toBytesRef();
176176
return new TestCaseSupplier.TestCase(
177177
List.of(TestCaseSupplier.TypedData.multiRow(List.of(value), DataType.VERSION, "field")),
178-
"Min[field=Attribute[channel=0]]",
178+
standardAggregatorName("Min", DataType.VERSION),
179179
DataType.VERSION,
180180
equalTo(value)
181181
);
@@ -203,10 +203,15 @@ private static TestCaseSupplier makeSupplier(TestCaseSupplier.TypedDataSupplier
203203

204204
return new TestCaseSupplier.TestCase(
205205
List.of(fieldTypedData),
206-
"Min[field=Attribute[channel=0]]",
206+
standardAggregatorName("Min", fieldSupplier.type()),
207207
fieldSupplier.type(),
208208
equalTo(expected)
209209
);
210210
});
211211
}
212+
213+
@Override
214+
protected boolean optIntoToAggregatorToStringChecks() {
215+
return true;
216+
}
212217
}

0 commit comments

Comments
 (0)