Skip to content

Commit d849199

Browse files
author
Julien Ruaux
committed
feat: Added support for UPDATE statements
1 parent 4c1336b commit d849199

9 files changed

+162
-49
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ build/
66
!**/src/test/**/build/
77
hs_err*.log
88
test-output/
9+
*.dylib
910

1011
### STS ###
1112
.checkstyle

src/main/java/com/redis/trino/RediSearchMetadata.java

Lines changed: 40 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,10 @@
2626
import static com.google.common.base.Preconditions.checkState;
2727
import static com.google.common.base.Verify.verify;
2828
import static com.google.common.base.Verify.verifyNotNull;
29+
import static com.google.common.collect.ImmutableList.toImmutableList;
2930
import static com.google.common.collect.ImmutableSet.toImmutableSet;
3031
import static io.airlift.slice.SliceUtf8.getCodePointAt;
3132
import static io.trino.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT;
32-
import static io.trino.spi.StandardErrorCode.NOT_SUPPORTED;
33-
import static io.trino.spi.connector.RetryMode.NO_RETRIES;
3433
import static io.trino.spi.expression.StandardFunctions.LIKE_FUNCTION_NAME;
3534
import static java.util.Objects.requireNonNull;
3635

@@ -205,9 +204,7 @@ public void dropColumn(ConnectorSession session, ConnectorTableHandle tableHandl
205204
@Override
206205
public ConnectorOutputTableHandle beginCreateTable(ConnectorSession session, ConnectorTableMetadata tableMetadata,
207206
Optional<ConnectorTableLayout> layout, RetryMode retryMode) {
208-
if (retryMode != RetryMode.NO_RETRIES) {
209-
throw new TrinoException(StandardErrorCode.NOT_SUPPORTED, "This connector does not support query retries");
210-
}
207+
checkRetry(retryMode);
211208
List<RediSearchColumnHandle> columns = buildColumnHandles(tableMetadata);
212209

213210
rediSearchSession.createTable(tableMetadata.getTable(), columns);
@@ -218,6 +215,12 @@ public ConnectorOutputTableHandle beginCreateTable(ConnectorSession session, Con
218215
columns.stream().filter(c -> !c.isHidden()).collect(Collectors.toList()));
219216
}
220217

218+
private void checkRetry(RetryMode retryMode) {
219+
if (retryMode != RetryMode.NO_RETRIES) {
220+
throw new TrinoException(StandardErrorCode.NOT_SUPPORTED, "This connector does not support retries");
221+
}
222+
}
223+
221224
@Override
222225
public Optional<ConnectorOutputMetadata> finishCreateTable(ConnectorSession session,
223226
ConnectorOutputTableHandle tableHandle, Collection<Slice> fragments,
@@ -229,9 +232,7 @@ public Optional<ConnectorOutputMetadata> finishCreateTable(ConnectorSession sess
229232
@Override
230233
public ConnectorInsertTableHandle beginInsert(ConnectorSession session, ConnectorTableHandle tableHandle,
231234
List<ColumnHandle> insertedColumns, RetryMode retryMode) {
232-
if (retryMode != RetryMode.NO_RETRIES) {
233-
throw new TrinoException(StandardErrorCode.NOT_SUPPORTED, "This connector does not support query retries");
234-
}
235+
checkRetry(retryMode);
235236
RediSearchTableHandle table = (RediSearchTableHandle) tableHandle;
236237
List<RediSearchColumnHandle> columns = rediSearchSession.getTable(table.getSchemaTableName()).getColumns();
237238

@@ -255,14 +256,34 @@ public RediSearchColumnHandle getDeleteRowIdColumnHandle(ConnectorSession sessio
255256
@Override
256257
public RediSearchTableHandle beginDelete(ConnectorSession session, ConnectorTableHandle tableHandle,
257258
RetryMode retryMode) {
258-
if (retryMode != NO_RETRIES) {
259-
throw new TrinoException(NOT_SUPPORTED, "This connector does not support query retries");
260-
}
259+
checkRetry(retryMode);
261260
return (RediSearchTableHandle) tableHandle;
262261
}
263262

264263
@Override
265264
public void finishDelete(ConnectorSession session, ConnectorTableHandle tableHandle, Collection<Slice> fragments) {
265+
// Do nothing
266+
}
267+
268+
@Override
269+
public RediSearchColumnHandle getUpdateRowIdColumnHandle(ConnectorSession session, ConnectorTableHandle tableHandle,
270+
List<ColumnHandle> updatedColumns) {
271+
return RediSearchBuiltinField.ID.getColumnHandle();
272+
}
273+
274+
@Override
275+
public RediSearchTableHandle beginUpdate(ConnectorSession session, ConnectorTableHandle tableHandle,
276+
List<ColumnHandle> updatedColumns, RetryMode retryMode) {
277+
checkRetry(retryMode);
278+
RediSearchTableHandle table = (RediSearchTableHandle) tableHandle;
279+
return new RediSearchTableHandle(table.getType(), table.getSchemaTableName(), table.getConstraint(),
280+
table.getLimit(), table.getTermAggregations(), table.getMetricAggregations(), table.getWildcards(),
281+
updatedColumns.stream().map(RediSearchColumnHandle.class::cast).collect(toImmutableList()));
282+
}
283+
284+
@Override
285+
public void finishUpdate(ConnectorSession session, ConnectorTableHandle tableHandle, Collection<Slice> fragments) {
286+
// Do nothing
266287
}
267288

268289
@Override
@@ -285,9 +306,11 @@ public Optional<LimitApplicationResult<ConnectorTableHandle>> applyLimit(Connect
285306
return Optional.empty();
286307
}
287308

288-
return Optional.of(new LimitApplicationResult<>(new RediSearchTableHandle(handle.getType(),
289-
handle.getSchemaTableName(), handle.getConstraint(), OptionalLong.of(limit),
290-
handle.getTermAggregations(), handle.getMetricAggregations(), handle.getWildcards()), true, false));
309+
return Optional.of(new LimitApplicationResult<>(
310+
new RediSearchTableHandle(handle.getType(), handle.getSchemaTableName(), handle.getConstraint(),
311+
OptionalLong.of(limit), handle.getTermAggregations(), handle.getMetricAggregations(),
312+
handle.getWildcards(), handle.getUpdatedColumns()),
313+
true, false));
291314
}
292315

293316
@Override
@@ -350,7 +373,7 @@ public Optional<ConstraintApplicationResult<ConnectorTableHandle>> applyFilter(C
350373
}
351374

352375
handle = new RediSearchTableHandle(handle.getType(), handle.getSchemaTableName(), newDomain, handle.getLimit(),
353-
handle.getTermAggregations(), handle.getMetricAggregations(), newWildcards);
376+
handle.getTermAggregations(), handle.getMetricAggregations(), newWildcards, handle.getUpdatedColumns());
354377

355378
return Optional.of(new ConstraintApplicationResult<>(handle, TupleDomain.withColumnDomains(unsupported),
356379
newExpression, false));
@@ -476,7 +499,8 @@ public Optional<AggregationApplicationResult<ConnectorTableHandle>> applyAggrega
476499
return Optional.empty();
477500
}
478501
RediSearchTableHandle tableHandle = new RediSearchTableHandle(Type.AGGREGATE, table.getSchemaTableName(),
479-
table.getConstraint(), table.getLimit(), terms.build(), aggregationList, table.getWildcards());
502+
table.getConstraint(), table.getLimit(), terms.build(), aggregationList, table.getWildcards(),
503+
table.getUpdatedColumns());
480504
return Optional.of(new AggregationApplicationResult<>(tableHandle, projections.build(),
481505
resultAssignments.build(), Map.of(), false));
482506
}

src/main/java/com/redis/trino/RediSearchPageSink.java

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
import com.google.common.primitives.Shorts;
5353
import com.google.common.primitives.SignedBytes;
5454
import com.redis.lettucemod.api.StatefulRedisModulesConnection;
55+
import com.redis.lettucemod.api.async.RedisModulesAsyncCommands;
5556
import com.redis.lettucemod.search.CreateOptions;
5657
import com.redis.lettucemod.search.CreateOptions.DataType;
5758
import com.redis.lettucemod.search.IndexInfo;
@@ -82,6 +83,7 @@
8283

8384
public class RediSearchPageSink implements ConnectorPageSink {
8485

86+
private static final String KEY_SEPARATOR = ":";
8587
private final RediSearchSession session;
8688
private final SchemaTableName schemaTableName;
8789
private final List<RediSearchColumnHandle> columns;
@@ -96,22 +98,24 @@ public RediSearchPageSink(RediSearchSession rediSearchSession, SchemaTableName s
9698

9799
@Override
98100
public CompletableFuture<?> appendPage(Page page) {
99-
String prefix = prefix().orElse(schemaTableName.getTableName());
101+
String prefix = prefix().orElse(schemaTableName.getTableName() + KEY_SEPARATOR);
100102
StatefulRedisModulesConnection<String, String> connection = session.getConnection();
101103
connection.setAutoFlushCommands(false);
104+
RedisModulesAsyncCommands<String, String> commands = connection.async();
102105
List<RedisFuture<?>> futures = new ArrayList<>();
103106
for (int position = 0; position < page.getPositionCount(); position++) {
107+
String key = prefix + factory.create().toString();
104108
Map<String, String> map = new HashMap<>();
105-
String key = prefix + ":" + factory.create().toString();
106109
for (int channel = 0; channel < page.getChannelCount(); channel++) {
107110
RediSearchColumnHandle column = columns.get(channel);
108111
Block block = page.getBlock(channel);
109112
if (block.isNull(position)) {
110113
continue;
111114
}
112-
map.put(column.getName(), getObjectValue(columns.get(channel).getType(), block, position));
115+
String value = value(column.getType(), block, position);
116+
map.put(column.getName(), value);
113117
}
114-
RedisFuture<Long> future = connection.async().hset(key, map);
118+
RedisFuture<Long> future = commands.hset(key, map);
115119
futures.add(future);
116120
}
117121
connection.flushCommands();
@@ -136,16 +140,16 @@ private Optional<String> prefix() {
136140
if (prefix.equals("*")) {
137141
return Optional.empty();
138142
}
139-
if (prefix.endsWith(":")) {
140-
return Optional.of(prefix.substring(0, prefix.length() - 1));
143+
if (prefix.endsWith(KEY_SEPARATOR)) {
144+
return Optional.of(prefix);
141145
}
142-
return Optional.of(prefix);
146+
return Optional.of(prefix + KEY_SEPARATOR);
143147
} catch (Exception e) {
144148
return Optional.empty();
145149
}
146150
}
147151

148-
private String getObjectValue(Type type, Block block, int position) {
152+
public static String value(Type type, Block block, int position) {
149153
if (type.equals(BooleanType.BOOLEAN)) {
150154
return String.valueOf(type.getBoolean(block, position));
151155
}
@@ -205,5 +209,6 @@ public CompletableFuture<Collection<Slice>> finish() {
205209

206210
@Override
207211
public void abort() {
212+
// Do nothing
208213
}
209214
}

src/main/java/com/redis/trino/RediSearchPageSinkProvider.java

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
*/
2424
package com.redis.trino;
2525

26+
import java.util.List;
27+
2628
import javax.inject.Inject;
2729

2830
import io.trino.spi.connector.ConnectorInsertTableHandle;
@@ -32,27 +34,32 @@
3234
import io.trino.spi.connector.ConnectorPageSinkProvider;
3335
import io.trino.spi.connector.ConnectorSession;
3436
import io.trino.spi.connector.ConnectorTransactionHandle;
37+
import io.trino.spi.connector.SchemaTableName;
3538

3639
public class RediSearchPageSinkProvider implements ConnectorPageSinkProvider {
3740

38-
private final RediSearchSession rediSearchSession;
41+
private final RediSearchSession session;
3942

4043
@Inject
4144
public RediSearchPageSinkProvider(RediSearchSession rediSearchSession) {
42-
this.rediSearchSession = rediSearchSession;
45+
this.session = rediSearchSession;
4346
}
4447

4548
@Override
4649
public ConnectorPageSink createPageSink(ConnectorTransactionHandle transactionHandle, ConnectorSession session,
4750
ConnectorOutputTableHandle outputTableHandle, ConnectorPageSinkId pageSinkId) {
4851
RediSearchOutputTableHandle handle = (RediSearchOutputTableHandle) outputTableHandle;
49-
return new RediSearchPageSink(rediSearchSession, handle.getSchemaTableName(), handle.getColumns());
52+
return pageSink(handle.getSchemaTableName(), handle.getColumns());
5053
}
5154

5255
@Override
5356
public ConnectorPageSink createPageSink(ConnectorTransactionHandle transactionHandle, ConnectorSession session,
5457
ConnectorInsertTableHandle insertTableHandle, ConnectorPageSinkId pageSinkId) {
5558
RediSearchInsertTableHandle handle = (RediSearchInsertTableHandle) insertTableHandle;
56-
return new RediSearchPageSink(rediSearchSession, handle.getSchemaTableName(), handle.getColumns());
59+
return pageSink(handle.getSchemaTableName(), handle.getColumns());
60+
}
61+
62+
private RediSearchPageSink pageSink(SchemaTableName schemaTableName, List<RediSearchColumnHandle> columns) {
63+
return new RediSearchPageSink(session, schemaTableName, columns);
5764
}
5865
}

src/main/java/com/redis/trino/RediSearchPageSource.java

Lines changed: 45 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,30 +30,38 @@
3030
import java.util.ArrayList;
3131
import java.util.Collection;
3232
import java.util.Collections;
33+
import java.util.HashMap;
3334
import java.util.Iterator;
3435
import java.util.List;
36+
import java.util.Map;
3537
import java.util.concurrent.CompletableFuture;
3638
import java.util.stream.Collectors;
3739

3840
import com.fasterxml.jackson.core.JsonFactory;
3941
import com.fasterxml.jackson.core.JsonGenerator;
42+
import com.redis.lettucemod.api.StatefulRedisModulesConnection;
43+
import com.redis.lettucemod.api.async.RedisModulesAsyncCommands;
4044
import com.redis.lettucemod.search.Document;
4145

4246
import io.airlift.slice.Slice;
4347
import io.airlift.slice.SliceOutput;
48+
import io.lettuce.core.LettuceFutures;
49+
import io.lettuce.core.RedisFuture;
4450
import io.trino.spi.Page;
4551
import io.trino.spi.PageBuilder;
4652
import io.trino.spi.block.Block;
4753
import io.trino.spi.block.BlockBuilder;
4854
import io.trino.spi.connector.UpdatablePageSource;
4955
import io.trino.spi.type.Type;
56+
import io.trino.spi.type.VarcharType;
5057

5158
public class RediSearchPageSource implements UpdatablePageSource {
5259

5360
private static final int ROWS_PER_REQUEST = 1024;
5461

5562
private final RediSearchPageSourceResultWriter writer = new RediSearchPageSourceResultWriter();
5663
private final RediSearchSession session;
64+
private final RediSearchTableHandle table;
5765
private final Iterator<Document<String, String>> cursor;
5866
private final String[] columnNames;
5967
private final List<Type> columnTypes;
@@ -66,9 +74,10 @@ public class RediSearchPageSource implements UpdatablePageSource {
6674
public RediSearchPageSource(RediSearchSession session, RediSearchTableHandle table,
6775
List<RediSearchColumnHandle> columns) {
6876
this.session = session;
77+
this.table = table;
6978
this.columnNames = columns.stream().map(RediSearchColumnHandle::getName).toArray(String[]::new);
7079
this.columnTypes = columns.stream().map(RediSearchColumnHandle::getType)
71-
.collect(Collectors.toUnmodifiableList());
80+
.collect(Collectors.toList());
7281
this.cursor = session.search(table, columnNames).iterator();
7382
this.currentDoc = null;
7483
this.pageBuilder = new PageBuilder(columnTypes);
@@ -127,14 +136,45 @@ public Page getNextPage() {
127136
@Override
128137
public void deleteRows(Block rowIds) {
129138
List<String> docIds = new ArrayList<>(rowIds.getPositionCount());
130-
for (int i = 0; i < rowIds.getPositionCount(); i++) {
131-
int len = rowIds.getSliceLength(i);
132-
Slice slice = rowIds.getSlice(i, 0, len);
133-
docIds.add(slice.toStringUtf8());
139+
for (int position = 0; position < rowIds.getPositionCount(); position++) {
140+
docIds.add(VarcharType.VARCHAR.getSlice(rowIds, position).toStringUtf8());
134141
}
135142
session.deleteDocs(docIds);
136143
}
137144

145+
@Override
146+
public void updateRows(Page page, List<Integer> columnValueAndRowIdChannels) {
147+
int rowIdChannel = columnValueAndRowIdChannels.get(columnValueAndRowIdChannels.size() - 1);
148+
List<Integer> columnChannelMapping = columnValueAndRowIdChannels.subList(0,
149+
columnValueAndRowIdChannels.size() - 1);
150+
StatefulRedisModulesConnection<String, String> connection = session.getConnection();
151+
connection.setAutoFlushCommands(false);
152+
RedisModulesAsyncCommands<String, String> commands = connection.async();
153+
List<RedisFuture<?>> futures = new ArrayList<>();
154+
for (int position = 0; position < page.getPositionCount(); position++) {
155+
Block rowIdBlock = page.getBlock(rowIdChannel);
156+
if (rowIdBlock.isNull(position)) {
157+
continue;
158+
}
159+
String key = VarcharType.VARCHAR.getSlice(rowIdBlock, position).toStringUtf8();
160+
Map<String, String> map = new HashMap<>();
161+
for (int channel = 0; channel < columnChannelMapping.size(); channel++) {
162+
RediSearchColumnHandle column = table.getUpdatedColumns().get(columnChannelMapping.get(channel));
163+
Block block = page.getBlock(channel);
164+
if (block.isNull(position)) {
165+
continue;
166+
}
167+
String value = RediSearchPageSink.value(column.getType(), block, position);
168+
map.put(column.getName(), value);
169+
}
170+
RedisFuture<Long> future = commands.hset(key, map);
171+
futures.add(future);
172+
}
173+
connection.flushCommands();
174+
LettuceFutures.awaitAll(connection.getTimeout(), futures.toArray(new RedisFuture[0]));
175+
connection.setAutoFlushCommands(true);
176+
}
177+
138178
private String currentValue(String columnName) {
139179
if (RediSearchBuiltinField.isBuiltinColumn(columnName)) {
140180
if (RediSearchBuiltinField.ID.getName().equals(columnName)) {

src/main/java/com/redis/trino/RediSearchQueryBuilder.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -242,9 +242,9 @@ public Optional<Group> group(RediSearchTableHandle table) {
242242
List<RediSearchAggregation> aggregates = table.getMetricAggregations();
243243
List<String> groupFields = new ArrayList<>();
244244
if (terms != null && !terms.isEmpty()) {
245-
groupFields = terms.stream().map(RediSearchAggregationTerm::getTerm).collect(Collectors.toUnmodifiableList());
245+
groupFields = terms.stream().map(RediSearchAggregationTerm::getTerm).collect(Collectors.toList());
246246
}
247-
List<Reducer> reducers = aggregates.stream().map(this::reducer).collect(Collectors.toUnmodifiableList());
247+
List<Reducer> reducers = aggregates.stream().map(this::reducer).collect(Collectors.toList());
248248
if (reducers.isEmpty()) {
249249
return Optional.empty();
250250
}

src/main/java/com/redis/trino/RediSearchSession.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ public void createTable(SchemaTableName schemaTableName, List<RediSearchColumnHa
185185
String index = index(schemaTableName);
186186
if (!connection.sync().ftList().contains(index)) {
187187
List<Field<String>> fields = columns.stream().filter(c -> !c.getName().equals("_id"))
188-
.map(c -> buildField(c.getName(), c.getType())).collect(Collectors.toUnmodifiableList());
188+
.map(c -> buildField(c.getName(), c.getType())).collect(Collectors.toList());
189189
CreateOptions.Builder<String, String> options = CreateOptions.<String, String>builder();
190190
options.prefix(index + ":");
191191
connection.sync().ftCreate(index, options.build(), fields.toArray(Field[]::new));
@@ -308,7 +308,7 @@ public AggregateWithCursorResults<String> aggregate(RediSearchTableHandle table)
308308
AggregateWithCursorResults<String> results = connection.sync().ftAggregate(aggregation.getIndex(),
309309
aggregation.getQuery(), aggregation.getCursorOptions(), aggregation.getOptions());
310310
List<AggregateOperation<?, ?>> groupBys = aggregation.getOptions().getOperations().stream()
311-
.filter(o -> o.getType() == AggregateOperation.Type.GROUP).collect(Collectors.toUnmodifiableList());
311+
.filter(o -> o.getType() == AggregateOperation.Type.GROUP).collect(Collectors.toList());
312312
if (results.isEmpty() && !groupBys.isEmpty()) {
313313
Group groupBy = (Group) groupBys.get(0);
314314
Optional<String> as = groupBy.getReducers()[0].getAs();

0 commit comments

Comments
 (0)