Skip to content

Commit 1519150

Browse files
committed
Handle retract elements in JSON array agg (#186)
1 parent 50c7189 commit 1519150

File tree

5 files changed

+118
-65
lines changed

5 files changed

+118
-65
lines changed

stdlib/stdlib-json/src/main/java/com/datasqrl/flinkrunner/stdlib/json/ArrayAgg.java renamed to stdlib/stdlib-json/src/main/java/com/datasqrl/flinkrunner/stdlib/json/ArrayAggAccumulator.java

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,24 +15,31 @@
1515
*/
1616
package com.datasqrl.flinkrunner.stdlib.json;
1717

18-
import java.util.List;
18+
import java.util.LinkedList;
1919
import lombok.AllArgsConstructor;
2020
import lombok.Getter;
2121
import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.JsonNode;
2222
import org.apache.flink.table.annotation.DataTypeHint;
2323

2424
@AllArgsConstructor
2525
@Getter
26-
public final class ArrayAgg {
26+
public final class ArrayAggAccumulator {
2727

2828
@DataTypeHint("RAW")
29-
private final List<JsonNode> objects;
29+
private final LinkedList<JsonNode> elements;
30+
31+
@DataTypeHint("RAW")
32+
private final LinkedList<JsonNode> retractElements;
3033

3134
public void add(JsonNode value) {
32-
objects.add(value);
35+
elements.add(value);
36+
}
37+
38+
public void addRetract(JsonNode value) {
39+
retractElements.add(value);
3340
}
3441

35-
public void remove(JsonNode value) {
36-
objects.remove(value);
42+
public boolean remove(JsonNode value) {
43+
return elements.remove(value);
3744
}
3845
}

stdlib/stdlib-json/src/main/java/com/datasqrl/flinkrunner/stdlib/json/ObjectAgg.java renamed to stdlib/stdlib-json/src/main/java/com/datasqrl/flinkrunner/stdlib/json/ObjectAggAccumulator.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323

2424
@AllArgsConstructor
2525
@Getter
26-
public final class ObjectAgg {
26+
public final class ObjectAggAccumulator {
2727

2828
@DataTypeHint("RAW")
2929
private final Map<String, JsonNode> objects;

stdlib/stdlib-json/src/main/java/com/datasqrl/flinkrunner/stdlib/json/jsonb_array_agg.java

Lines changed: 57 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -17,83 +17,101 @@
1717

1818
import com.datasqrl.flinkrunner.stdlib.utils.AutoRegisterSystemFunction;
1919
import com.google.auto.service.AutoService;
20-
import java.util.ArrayList;
21-
import lombok.SneakyThrows;
20+
import java.util.LinkedList;
21+
import java.util.List;
22+
import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.JsonNode;
2223
import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.ObjectMapper;
2324
import org.apache.flink.table.functions.AggregateFunction;
2425
import org.apache.flink.util.jackson.JacksonMapperFactory;
2526

2627
/** Aggregation function that aggregates JSON objects into a JSON array. */
2728
@AutoService(AutoRegisterSystemFunction.class)
28-
public class jsonb_array_agg extends AggregateFunction<FlinkJsonType, ArrayAgg>
29+
public class jsonb_array_agg extends AggregateFunction<FlinkJsonType, ArrayAggAccumulator>
2930
implements AutoRegisterSystemFunction {
3031

3132
private static final ObjectMapper mapper = JacksonMapperFactory.createObjectMapper();
3233

3334
@Override
34-
public ArrayAgg createAccumulator() {
35-
return new ArrayAgg(new ArrayList<>());
35+
public ArrayAggAccumulator createAccumulator() {
36+
return new ArrayAggAccumulator(new LinkedList<>(), new LinkedList<>());
3637
}
3738

38-
public void accumulate(ArrayAgg accumulator, String value) {
39-
accumulator.add(mapper.getNodeFactory().textNode(value));
39+
public void accumulate(ArrayAggAccumulator acc, String value) {
40+
acc.add(mapper.getNodeFactory().textNode(value));
4041
}
4142

42-
@SneakyThrows
43-
public void accumulate(ArrayAgg accumulator, FlinkJsonType value) {
44-
if (value != null) {
45-
accumulator.add(value.json);
46-
} else {
47-
accumulator.add(null);
48-
}
43+
public void accumulate(ArrayAggAccumulator acc, FlinkJsonType value) {
44+
acc.add(value == null ? null : value.json);
4945
}
5046

51-
public void accumulate(ArrayAgg accumulator, Double value) {
52-
accumulator.add(mapper.getNodeFactory().numberNode(value));
47+
public void accumulate(ArrayAggAccumulator acc, Double value) {
48+
acc.add(mapper.getNodeFactory().numberNode(value));
5349
}
5450

55-
public void accumulate(ArrayAgg accumulator, Long value) {
56-
accumulator.add(mapper.getNodeFactory().numberNode(value));
51+
public void accumulate(ArrayAggAccumulator acc, Long value) {
52+
acc.add(mapper.getNodeFactory().numberNode(value));
5753
}
5854

59-
public void accumulate(ArrayAgg accumulator, Integer value) {
60-
accumulator.add(mapper.getNodeFactory().numberNode(value));
55+
public void accumulate(ArrayAggAccumulator acc, Integer value) {
56+
acc.add(mapper.getNodeFactory().numberNode(value));
6157
}
6258

63-
public void retract(ArrayAgg accumulator, String value) {
64-
accumulator.remove(mapper.getNodeFactory().textNode(value));
59+
public void retract(ArrayAggAccumulator acc, String value) {
60+
var nodeVal = mapper.getNodeFactory().textNode(value);
61+
if (!acc.remove(nodeVal)) {
62+
acc.addRetract(nodeVal);
63+
}
6564
}
6665

67-
@SneakyThrows
68-
public void retract(ArrayAgg accumulator, FlinkJsonType value) {
69-
if (value != null) {
70-
accumulator.remove(value.json);
71-
} else {
72-
accumulator.remove(null);
66+
public void retract(ArrayAggAccumulator acc, FlinkJsonType value) {
67+
var finalVal = value == null ? null : value.json;
68+
if (!acc.remove(finalVal)) {
69+
acc.addRetract(finalVal);
7370
}
7471
}
7572

76-
public void retract(ArrayAgg accumulator, Double value) {
77-
accumulator.remove(mapper.getNodeFactory().numberNode(value));
73+
public void retract(ArrayAggAccumulator acc, Double value) {
74+
var nodeVal = mapper.getNodeFactory().numberNode(value);
75+
if (!acc.getElements().remove(nodeVal)) {
76+
acc.addRetract(nodeVal);
77+
}
7878
}
7979

80-
public void retract(ArrayAgg accumulator, Long value) {
81-
accumulator.remove(mapper.getNodeFactory().numberNode(value));
80+
public void retract(ArrayAggAccumulator acc, Long value) {
81+
var nodeVal = mapper.getNodeFactory().numberNode(value);
82+
if (!acc.getElements().remove(nodeVal)) {
83+
acc.addRetract(nodeVal);
84+
}
8285
}
8386

84-
public void retract(ArrayAgg accumulator, Integer value) {
85-
accumulator.remove(mapper.getNodeFactory().numberNode(value));
87+
public void retract(ArrayAggAccumulator acc, Integer value) {
88+
var nodeVal = mapper.getNodeFactory().numberNode(value);
89+
if (!acc.getElements().remove(nodeVal)) {
90+
acc.addRetract(nodeVal);
91+
}
8692
}
8793

88-
public void merge(ArrayAgg accumulator, java.lang.Iterable<ArrayAgg> iterable) {
89-
iterable.forEach(o -> accumulator.getObjects().addAll(o.getObjects()));
94+
public void merge(ArrayAggAccumulator acc, Iterable<ArrayAggAccumulator> iterable) {
95+
for (ArrayAggAccumulator otherAcc : iterable) {
96+
acc.getElements().addAll(otherAcc.getElements());
97+
acc.getRetractElements().addAll(otherAcc.getRetractElements());
98+
}
99+
100+
List<JsonNode> newRetractBuffer = new LinkedList<>();
101+
for (JsonNode elem : acc.getRetractElements()) {
102+
if (!acc.remove(elem)) {
103+
newRetractBuffer.add(elem);
104+
}
105+
}
106+
107+
acc.getRetractElements().clear();
108+
acc.getRetractElements().addAll(newRetractBuffer);
90109
}
91110

92111
@Override
93-
public FlinkJsonType getValue(ArrayAgg accumulator) {
94-
// Replacing var with explicit type declaration for Java 11 compatibility
112+
public FlinkJsonType getValue(ArrayAggAccumulator acc) {
95113
var arrayNode = mapper.createArrayNode();
96-
for (Object o : accumulator.getObjects()) {
114+
for (Object o : acc.getElements()) {
97115
if (o instanceof FlinkJsonType) {
98116
arrayNode.add(((FlinkJsonType) o).json);
99117
} else {

stdlib/stdlib-json/src/main/java/com/datasqrl/flinkrunner/stdlib/json/jsonb_object_agg.java

Lines changed: 22 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -36,76 +36,81 @@
3636
bridgedTo = FlinkJsonType.class,
3737
rawSerializer = FlinkJsonTypeSerializer.class))
3838
@AutoService(AutoRegisterSystemFunction.class)
39-
public class jsonb_object_agg extends AggregateFunction<Object, ObjectAgg>
39+
public class jsonb_object_agg extends AggregateFunction<Object, ObjectAggAccumulator>
4040
implements AutoRegisterSystemFunction {
4141

4242
private static final ObjectMapper mapper = JacksonMapperFactory.createObjectMapper();
4343

4444
@Override
45-
public ObjectAgg createAccumulator() {
46-
return new ObjectAgg(new LinkedHashMap<>());
45+
public ObjectAggAccumulator createAccumulator() {
46+
return new ObjectAggAccumulator(new LinkedHashMap<>());
4747
}
4848

49-
public void accumulate(ObjectAgg accumulator, String key, String value) {
49+
public void accumulate(ObjectAggAccumulator accumulator, String key, String value) {
5050
accumulateObject(accumulator, key, value);
5151
}
5252

5353
public void accumulate(
54-
ObjectAgg accumulator, String key, @DataTypeHint(inputGroup = InputGroup.ANY) Object value) {
54+
ObjectAggAccumulator accumulator,
55+
String key,
56+
@DataTypeHint(inputGroup = InputGroup.ANY) Object value) {
5557
if (value instanceof FlinkJsonType) {
5658
accumulateObject(accumulator, key, ((FlinkJsonType) value).getJson());
5759
} else {
5860
accumulator.add(key, mapper.getNodeFactory().pojoNode(value));
5961
}
6062
}
6163

62-
public void accumulate(ObjectAgg accumulator, String key, Double value) {
64+
public void accumulate(ObjectAggAccumulator accumulator, String key, Double value) {
6365
accumulateObject(accumulator, key, value);
6466
}
6567

66-
public void accumulate(ObjectAgg accumulator, String key, Long value) {
68+
public void accumulate(ObjectAggAccumulator accumulator, String key, Long value) {
6769
accumulateObject(accumulator, key, value);
6870
}
6971

70-
public void accumulate(ObjectAgg accumulator, String key, Integer value) {
72+
public void accumulate(ObjectAggAccumulator accumulator, String key, Integer value) {
7173
accumulateObject(accumulator, key, value);
7274
}
7375

74-
public void accumulateObject(ObjectAgg accumulator, String key, Object value) {
76+
public void accumulateObject(ObjectAggAccumulator accumulator, String key, Object value) {
7577
accumulator.add(key, mapper.getNodeFactory().pojoNode(value));
7678
}
7779

78-
public void retract(ObjectAgg accumulator, String key, String value) {
80+
public void retract(ObjectAggAccumulator accumulator, String key, String value) {
7981
retractObject(accumulator, key);
8082
}
8183

8284
public void retract(
83-
ObjectAgg accumulator, String key, @DataTypeHint(inputGroup = InputGroup.ANY) Object value) {
85+
ObjectAggAccumulator accumulator,
86+
String key,
87+
@DataTypeHint(inputGroup = InputGroup.ANY) Object value) {
8488
retractObject(accumulator, key);
8589
}
8690

87-
public void retract(ObjectAgg accumulator, String key, Double value) {
91+
public void retract(ObjectAggAccumulator accumulator, String key, Double value) {
8892
retractObject(accumulator, key);
8993
}
9094

91-
public void retract(ObjectAgg accumulator, String key, Long value) {
95+
public void retract(ObjectAggAccumulator accumulator, String key, Long value) {
9296
retractObject(accumulator, key);
9397
}
9498

95-
public void retract(ObjectAgg accumulator, String key, Integer value) {
99+
public void retract(ObjectAggAccumulator accumulator, String key, Integer value) {
96100
retractObject(accumulator, key);
97101
}
98102

99-
public void retractObject(ObjectAgg accumulator, String key) {
103+
public void retractObject(ObjectAggAccumulator accumulator, String key) {
100104
accumulator.remove(key);
101105
}
102106

103-
public void merge(ObjectAgg accumulator, java.lang.Iterable<ObjectAgg> iterable) {
107+
public void merge(
108+
ObjectAggAccumulator accumulator, java.lang.Iterable<ObjectAggAccumulator> iterable) {
104109
iterable.forEach(o -> accumulator.getObjects().putAll(o.getObjects()));
105110
}
106111

107112
@Override
108-
public FlinkJsonType getValue(ObjectAgg accumulator) {
113+
public FlinkJsonType getValue(ObjectAggAccumulator accumulator) {
109114
var objectNode = mapper.createObjectNode();
110115
accumulator.getObjects().forEach(objectNode::putPOJO);
111116
return new FlinkJsonType(objectNode);

stdlib/stdlib-json/src/test/java/com/datasqrl/flinkrunner/stdlib/json/JsonFunctionsTest.java

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import static org.assertj.core.api.Assertions.assertThat;
1919
import static org.assertj.core.api.Assertions.assertThatThrownBy;
2020

21+
import java.util.List;
2122
import lombok.SneakyThrows;
2223
import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.JsonNode;
2324
import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.ObjectMapper;
@@ -317,7 +318,7 @@ void testNullInput2() {
317318
}
318319

319320
@Nested
320-
class JsonArrayAggTest {
321+
class JsonArrayAggAccumulatorTest {
321322

322323
@Test
323324
void testAggregateJsonTypes() {
@@ -406,10 +407,32 @@ void testRetractNullFromNonExisting() {
406407
assertThat(result).isNotNull();
407408
assertThat(result.getJson().toString()).isEqualTo("[{\"key\":\"value1\"}]");
408409
}
410+
411+
@Test
412+
void testMergeWithProperRetract() {
413+
var acc = JsonFunctions.JSON_ARRAYAGG.createAccumulator();
414+
JsonFunctions.JSON_ARRAYAGG.accumulate(acc, 1);
415+
JsonFunctions.JSON_ARRAYAGG.accumulate(acc, 2);
416+
JsonFunctions.JSON_ARRAYAGG.accumulate(acc, 3);
417+
JsonFunctions.JSON_ARRAYAGG.retract(acc, 4);
418+
419+
var otherAcc = JsonFunctions.JSON_ARRAYAGG.createAccumulator();
420+
JsonFunctions.JSON_ARRAYAGG.accumulate(otherAcc, 1);
421+
JsonFunctions.JSON_ARRAYAGG.accumulate(otherAcc, 2);
422+
JsonFunctions.JSON_ARRAYAGG.accumulate(otherAcc, 3);
423+
JsonFunctions.JSON_ARRAYAGG.accumulate(otherAcc, 4);
424+
JsonFunctions.JSON_ARRAYAGG.accumulate(otherAcc, 5);
425+
JsonFunctions.JSON_ARRAYAGG.retract(otherAcc, 1);
426+
427+
JsonFunctions.JSON_ARRAYAGG.merge(acc, List.of(otherAcc));
428+
429+
var res = JsonFunctions.JSON_ARRAYAGG.getValue(acc);
430+
assertThat(res.getJson().toString()).isEqualTo("[1,2,3,2,3,5]");
431+
}
409432
}
410433

411434
@Nested
412-
class JsonObjectAggTest {
435+
class JsonObjectAggAccumulatorTest {
413436

414437
@Test
415438
void testAggregateJsonTypes() {

0 commit comments

Comments
 (0)