diff --git a/pom.xml b/pom.xml index d6b5bc6d6e..f0af3d6791 100644 --- a/pom.xml +++ b/pom.xml @@ -348,6 +348,14 @@ src/test/java/redis/clients/jedis/commands/jedis/ClusterStreamsCommandsTest.java src/test/java/redis/clients/jedis/commands/jedis/PooledStreamsCommandsTest.java src/test/java/redis/clients/jedis/resps/StreamEntryDeletionResultTest.java + **/Maintenance*.java + **/Push*.java + **/Rebind*.java + src/test/java/redis/clients/jedis/upgrade/*.java + src/test/java/redis/clients/jedis/util/server/*.java + **/TimeoutOptions.java + **/*Handler.java + **/ConnectionTestHelper.java diff --git a/src/main/java/redis/clients/jedis/AbstractListenerHandler.java b/src/main/java/redis/clients/jedis/AbstractListenerHandler.java new file mode 100644 index 0000000000..6bb0cb4f79 --- /dev/null +++ b/src/main/java/redis/clients/jedis/AbstractListenerHandler.java @@ -0,0 +1,25 @@ +package redis.clients.jedis; + +import java.util.Collection; +import java.util.List; +import java.util.concurrent.CopyOnWriteArrayList; + +abstract class AbstractListenerHandler implements ListenerHandler { + private final List listeners = new CopyOnWriteArrayList<>(); + + public void addListener(T listener) { + listeners.add(listener); + } + + public void removeListener(T listener) { + listeners.remove(listener); + } + + public void removeAllListeners() { + listeners.clear(); + } + + public Collection getListeners() { + return listeners; + } +} diff --git a/src/main/java/redis/clients/jedis/Connection.java b/src/main/java/redis/clients/jedis/Connection.java index de473d0b8e..de4ccbbe0b 100644 --- a/src/main/java/redis/clients/jedis/Connection.java +++ b/src/main/java/redis/clients/jedis/Connection.java @@ -1,9 +1,11 @@ package redis.clients.jedis; +import static redis.clients.jedis.PushConsumerChain.PROPAGATE_ALL_HANDLER; import static redis.clients.jedis.util.SafeEncoder.encode; import java.io.Closeable; import java.io.IOException; +import java.lang.ref.WeakReference; import java.net.Socket; import java.net.SocketAddress; import java.net.SocketException; @@ -16,9 +18,12 @@ import java.util.function.Supplier; import java.util.concurrent.atomic.AtomicReference; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import redis.clients.jedis.Protocol.Command; import redis.clients.jedis.Protocol.Keyword; import redis.clients.jedis.annots.Experimental; +import redis.clients.jedis.annots.VisibleForTesting; import redis.clients.jedis.args.ClientAttributeOption; import redis.clients.jedis.args.Rawable; import redis.clients.jedis.authentication.AuthXManager; @@ -28,10 +33,13 @@ import redis.clients.jedis.exceptions.JedisException; import redis.clients.jedis.exceptions.JedisValidationException; import redis.clients.jedis.util.IOUtils; +import redis.clients.jedis.util.NumberUtils; import redis.clients.jedis.util.RedisInputStream; import redis.clients.jedis.util.RedisOutputStream; +import redis.clients.jedis.util.SafeEncoder; public class Connection implements Closeable { + public static Logger logger = LoggerFactory.getLogger(Connection.class); private ConnectionPool memberOf; protected RedisProtocol protocol; @@ -39,6 +47,9 @@ public class Connection implements Closeable { private Socket socket; private RedisOutputStream outputStream; private RedisInputStream inputStream; + private boolean relaxedTimeoutEnabled = false; + private int relaxedTimeout = NumberUtils.safeToInt(TimeoutOptions.DISABLED_TIMEOUT.toMillis()); + private int relaxedBlockingTimeout = NumberUtils.safeToInt(TimeoutOptions.DISABLED_TIMEOUT.toMillis()); private int soTimeout = 0; private int infiniteSoTimeout = 0; private boolean broken = false; @@ -48,7 +59,11 @@ public class Connection implements Closeable { protected String version; private AtomicReference currentCredentials = new AtomicReference<>(null); private AuthXManager authXManager; + private boolean isBlocking = false; + private boolean isRelaxed = false; + private boolean rebindRequested = false; + protected PushConsumerChain pushConsumer; public Connection() { this(Protocol.DEFAULT_HOST, Protocol.DEFAULT_PORT); } @@ -68,15 +83,62 @@ public Connection(final HostAndPort hostAndPort, final JedisClientConfig clientC public Connection(final JedisSocketFactory socketFactory) { this.socketFactory = socketFactory; this.authXManager = null; + + initPushConsumers(null); } public Connection(final JedisSocketFactory socketFactory, JedisClientConfig clientConfig) { this.socketFactory = socketFactory; this.soTimeout = clientConfig.getSocketTimeoutMillis(); this.infiniteSoTimeout = clientConfig.getBlockingSocketTimeoutMillis(); + this.relaxedTimeout = NumberUtils.safeToInt(clientConfig.getTimeoutOptions().getRelaxedTimeout().toMillis()); + this.relaxedBlockingTimeout = NumberUtils.safeToInt(clientConfig.getTimeoutOptions().getRelaxedBlockingTimeout().toMillis()); + this.relaxedTimeoutEnabled = TimeoutOptions.isRelaxedTimeoutEnabled(relaxedTimeout) || + TimeoutOptions.isRelaxedTimeoutEnabled(relaxedBlockingTimeout); + initPushConsumers(clientConfig); initializeFromClientConfig(clientConfig); } + + protected void initPushConsumers(JedisClientConfig config) { + /* + * Default consumers to process push messages. + * Marks all @{link PushMessage}s as processed, except for pub/sub. + * Pub/sub messages are propagated to the client. + */ + this.pushConsumer = PushConsumerChain.of( + PushConsumerChain.PUBSUB_ONLY_HANDLER + ); + + if (config != null) { + + /* + * Add consumer to handle server maintenance events. + * Maintenance events are propagated to the registered {@link MaintenanceEventListener}s. + */ + MaintenanceEventHandler maintenanceEventHandler = config.getMaintenanceEventHandler(); + if (maintenanceEventHandler != null) { + this.pushConsumer.add(new MaintenanceEventConsumer(maintenanceEventHandler)); + + if (config.isProactiveRebindEnabled()) { + maintenanceEventHandler.addListener(new ConnectionRebindHandler()); + } + + if (TimeoutOptions.isRelaxedTimeoutEnabled(config.getTimeoutOptions().getRelaxedTimeout())) { + maintenanceEventHandler.addListener(new AdaptiveTimeoutHandler(Connection.this)); + } + } + + /* + * Add consumer to notify registered {@link PushListener}s. + */ + PushHandler pushHandler = config.getPushHandler(); + if (pushHandler != null) { + this.pushConsumer.add(new ListenerNotificationConsumer(pushHandler)); + } + } + } + @Override public String toString() { return getClass().getSimpleName() + "{" + socketFactory + "}"; @@ -152,7 +214,8 @@ public void setTimeoutInfinite() { public void rollbackTimeout() { try { - socket.setSoTimeout(this.soTimeout); + int timeout = getDesiredTimeout(); + socket.setSoTimeout(timeout); } catch (SocketException ex) { setBroken(); throw new JedisConnectionException(ex); @@ -175,9 +238,11 @@ public T executeCommand(final CommandObject commandObject) { return commandObject.getBuilder().build(getOne()); } else { try { + isBlocking = true; setTimeoutInfinite(); return commandObject.getBuilder().build(getOne()); } finally { + isBlocking = false; rollbackTimeout(); } } @@ -261,7 +326,7 @@ public void close() { if (this.memberOf != null) { ConnectionPool pool = this.memberOf; this.memberOf = null; - if (isBroken()) { + if (isBroken() || isRebindRequested()) { pool.returnBrokenResource(this); } else { pool.returnResource(this); @@ -271,6 +336,10 @@ public void close() { } } + private boolean isRebindRequested() { + return rebindRequested; + } + /** * Close the socket and disconnect the server. */ @@ -303,7 +372,7 @@ public void setBroken() { public String getStatusCodeReply() { flush(); - final byte[] resp = (byte[]) readProtocolWithCheckingBroken(); + final byte[] resp = (byte[]) readProtocolWithCheckingBroken(pushConsumer); if (null == resp) { return null; } else { @@ -322,12 +391,12 @@ public String getBulkReply() { public byte[] getBinaryBulkReply() { flush(); - return (byte[]) readProtocolWithCheckingBroken(); + return (byte[]) readProtocolWithCheckingBroken(pushConsumer); } public Long getIntegerReply() { flush(); - return (Long) readProtocolWithCheckingBroken(); + return (Long) readProtocolWithCheckingBroken(pushConsumer); } public List getMultiBulkReply() { @@ -337,7 +406,7 @@ public List getMultiBulkReply() { @SuppressWarnings("unchecked") public List getBinaryMultiBulkReply() { flush(); - return (List) readProtocolWithCheckingBroken(); + return (List) readProtocolWithCheckingBroken(pushConsumer); } /** @@ -346,28 +415,28 @@ public List getBinaryMultiBulkReply() { @Deprecated @SuppressWarnings("unchecked") public List getUnflushedObjectMultiBulkReply() { - return (List) readProtocolWithCheckingBroken(); + return (List) readProtocolWithCheckingBroken(pushConsumer); } @SuppressWarnings("unchecked") public Object getUnflushedObject() { - return readProtocolWithCheckingBroken(); + return readProtocolWithCheckingBroken(pushConsumer); } public List getObjectMultiBulkReply() { flush(); - return (List) readProtocolWithCheckingBroken(); + return (List) readProtocolWithCheckingBroken(pushConsumer); } @SuppressWarnings("unchecked") public List getIntegerMultiBulkReply() { flush(); - return (List) readProtocolWithCheckingBroken(); + return (List) readProtocolWithCheckingBroken(pushConsumer); } public Object getOne() { flush(); - return readProtocolWithCheckingBroken(); + return readProtocolWithCheckingBroken(pushConsumer); } protected void flush() { @@ -380,21 +449,39 @@ protected void flush() { } @Experimental - protected Object protocolRead(RedisInputStream is) { - return Protocol.read(is); + protected Object protocolRead(RedisInputStream is, PushConsumer handler) { + return Protocol.read(is, handler); } @Experimental protected void protocolReadPushes(RedisInputStream is) { } + protected Object readProtocolWithCheckingBroken(PushConsumer handler) { + if (broken) { + throw new JedisConnectionException("Attempting to read from a broken connection."); + } + + try { + return protocolRead(inputStream, handler); + } catch (JedisConnectionException exc) { + broken = true; + throw exc; + } + } + + /** + * @deprecated Use {@link #readProtocolWithCheckingBroken(PushConsumer)} + * @return + */ + @Deprecated protected Object readProtocolWithCheckingBroken() { if (broken) { throw new JedisConnectionException("Attempting to read from a broken connection."); } try { - return protocolRead(inputStream); + return protocolRead(inputStream, PROPAGATE_ALL_HANDLER); } catch (JedisConnectionException exc) { broken = true; throw exc; @@ -424,7 +511,7 @@ public List getMany(final int count) { final List responses = new ArrayList<>(count); for (int i = 0; i < count; i++) { try { - responses.add(readProtocolWithCheckingBroken()); + responses.add(readProtocolWithCheckingBroken(pushConsumer)); } catch (JedisDataException e) { responses.add(e); } @@ -614,4 +701,243 @@ protected boolean isTokenBasedAuthenticationEnabled() { protected AuthXManager getAuthXManager() { return authXManager; } + + @Experimental + @VisibleForTesting + PushConsumerChain getPushConsumer() { + return this.pushConsumer; + } + + @Experimental + public boolean isRelaxedTimeoutActive() { + return isRelaxed; + } + + /** + * Calculate the desired timeout based on current state (blocking/non-blocking and relaxed/normal). + * When relaxed timeouts are enabled, use configured relaxed timeout if available, otherwise fallback to default timeout. + */ + private int getDesiredTimeout() { + if (!isRelaxed) { + if (!isBlocking) { + return soTimeout; + } else { + return infiniteSoTimeout; + } + } else { + if (!isBlocking) { + return TimeoutOptions.isRelaxedTimeoutEnabled(relaxedTimeout) ? relaxedTimeout : soTimeout; + } else { + return TimeoutOptions.isRelaxedTimeoutEnabled(relaxedBlockingTimeout) ? relaxedBlockingTimeout : infiniteSoTimeout; + } + } + } + + @Experimental + public void relaxTimeouts() { + if (!relaxedTimeoutEnabled) { + return; + } + + if (!isRelaxed) { + isRelaxed = true; + try { + if (isConnected()) { + socket.setSoTimeout(getDesiredTimeout()); + } + } catch (SocketException ex) { + setBroken(); + throw new JedisConnectionException(ex); + } + } + } + + @Experimental + public void disableRelaxedTimeout() { + if (isRelaxed) { + isRelaxed = false; + try { + if (isConnected()) { + socket.setSoTimeout(getDesiredTimeout()); + } + } catch (SocketException ex) { + setBroken(); + throw new JedisConnectionException(ex); + } + } + } + + /** + * Push consumer that delegates to a {@link PushHandler} for listener notification. + */ + private static class ListenerNotificationConsumer implements PushConsumer { + private final PushHandler pushHandler; + + public ListenerNotificationConsumer(PushHandler pushHandler) { + this.pushHandler = pushHandler; + } + + @Override + public void accept(PushConsumerContext context) { + if (pushHandler != null) { + notifyListeners(context.getMessage()); + } + } + + private void notifyListeners(PushMessage pushMessage) { + try { + pushHandler.getListeners().forEach(pushListener -> { + try { + pushListener.onPush(pushMessage); + } catch (Exception e) { + // ignore + } + }); + } catch (Exception e) { + // Log notification failures + } + } + } + + + private static class MaintenanceEventConsumer implements PushConsumer { + private final MaintenanceEventHandler eventHandler; + + public MaintenanceEventConsumer(MaintenanceEventHandler eventHandler) { + this.eventHandler = eventHandler; + } + + @Override + public void accept(PushConsumerContext context) { + PushMessage message = context.getMessage(); + + switch ( message.getType()) { + case "MOVING": + onMoving(message); + break; + case "MIGRATING": + onMigrating(); + break; + case "MIGRATED": + onMigrated(); + break; + case "FAILING_OVER": + onFailOver(); + break; + case "FAILED_OVER": + onFailedOver(); + break; + } + } + private void onMoving(PushMessage message) { + HostAndPort rebindTarget = getRebindTarget(message); + eventHandler.getListeners().forEach(listener -> listener.onRebind(rebindTarget)); + } + + private void onMigrating() { + eventHandler.getListeners().forEach(MaintenanceEventListener::onMigrating); + } + + private void onMigrated() { + eventHandler.getListeners().forEach(MaintenanceEventListener::onMigrated); + } + + private void onFailOver() { + eventHandler.getListeners().forEach(MaintenanceEventListener::onFailOver); + } + + private void onFailedOver() { + eventHandler.getListeners().forEach(MaintenanceEventListener::onFailedOver); + } + + private HostAndPort getRebindTarget(PushMessage message) { + // Extract domain/ip and port from the message + // MOVING push message format: ["MOVING", slot, "host:port"] + List content = message.getContent(); + + if (content.size() < 3) { + logger.warn("MOVING push message is malformed: {}", message); + return null; + } + + Object addressObject = content.get(2); // Get the 3rd element (index 2) + if (!(addressObject instanceof byte[])) { + logger.warn("Invalid re-bind message format, expected 3rd element to be a byte[], got {}", + addressObject.getClass()); + return null; + } + + try { + String addressAndPort = SafeEncoder.encode((byte[]) addressObject); + String[] parts = addressAndPort.split(":"); + if (parts.length != 2) { + logger.warn("Invalid re-bind message format, expected 'host:port', got {}", + addressAndPort); + return null; + } + + String address = parts[0]; + int port = Integer.parseInt(parts[1]); + return new HostAndPort(address, port); + } catch (Exception e) { + logger.warn("Error parsing re-bind target from message: {}", message, e); + return null; + } + } + } + + private class ConnectionRebindHandler implements MaintenanceEventListener { + public void onRebind(HostAndPort target) { + rebindRequested = true; + } + } + + private static class AdaptiveTimeoutHandler implements MaintenanceEventListener { + + private final WeakReference connectionRef; + + /** + * Creates a new maintenance listener for the specified connection. + * + * @param connection The connection to manage timeouts for + */ + public AdaptiveTimeoutHandler(Connection connection) { + this.connectionRef = new WeakReference<>(connection); + } + + public void onMigrating() { + Connection connection = connectionRef.get(); + if (connection != null) { + connection.relaxTimeouts(); + } + } + + public void onMigrated() { + Connection connection = connectionRef.get(); + if (connection != null) { + connection.disableRelaxedTimeout(); + } + } + + public void onFailOver() { + Connection connection = connectionRef.get(); + if (connection != null) { + connection.relaxTimeouts(); + } + } + + public void onFailedOver() { + Connection connection = connectionRef.get(); + if (connection != null) { + connection.disableRelaxedTimeout(); + } + } + + public void onRebind(HostAndPort target) { + Connection connection = connectionRef.get(); + if (connection != null) { + connection.relaxTimeouts(); + } + } + } } diff --git a/src/main/java/redis/clients/jedis/ConnectionFactory.java b/src/main/java/redis/clients/jedis/ConnectionFactory.java index 7440417152..d24139c79b 100644 --- a/src/main/java/redis/clients/jedis/ConnectionFactory.java +++ b/src/main/java/redis/clients/jedis/ConnectionFactory.java @@ -19,7 +19,7 @@ /** * PoolableObjectFactory custom impl. */ -public class ConnectionFactory implements PooledObjectFactory { +public class ConnectionFactory implements PooledObjectFactory , RebindAware { private static final Logger logger = LoggerFactory.getLogger(ConnectionFactory.class); @@ -140,4 +140,18 @@ private void reAuthenticate(Connection jedis) throws Exception { throw e; } } + + + @Override + public void rebind(HostAndPort newHostAndPort) { + // TODO : extract interface from DefaultJedisSocketFactory so that we can support custom socket factories + if (!(jedisSocketFactory instanceof DefaultJedisSocketFactory)) { + throw new IllegalStateException("Rebind not supported for custom JedisSocketFactory implementations"); + } + + DefaultJedisSocketFactory factory = (DefaultJedisSocketFactory) jedisSocketFactory; + logger.debug("Rebinding to {}", newHostAndPort); + factory.updateHostAndPort(newHostAndPort); + } + } diff --git a/src/main/java/redis/clients/jedis/ConnectionPool.java b/src/main/java/redis/clients/jedis/ConnectionPool.java index 2ae1401081..ab7c375308 100644 --- a/src/main/java/redis/clients/jedis/ConnectionPool.java +++ b/src/main/java/redis/clients/jedis/ConnectionPool.java @@ -3,12 +3,16 @@ import org.apache.commons.pool2.PooledObjectFactory; import org.apache.commons.pool2.impl.GenericObjectPoolConfig; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import redis.clients.jedis.annots.Experimental; import redis.clients.jedis.authentication.AuthXManager; import redis.clients.jedis.csc.Cache; import redis.clients.jedis.exceptions.JedisException; import redis.clients.jedis.util.Pool; +import java.util.concurrent.atomic.AtomicReference; + public class ConnectionPool extends Pool { private AuthXManager authXManager; @@ -16,6 +20,7 @@ public class ConnectionPool extends Pool { public ConnectionPool(HostAndPort hostAndPort, JedisClientConfig clientConfig) { this(new ConnectionFactory(hostAndPort, clientConfig)); attachAuthenticationListener(clientConfig.getAuthXManager()); + attachRebindHandler(clientConfig, (ConnectionFactory) this.getFactory()); } @Experimental @@ -23,6 +28,7 @@ public ConnectionPool(HostAndPort hostAndPort, JedisClientConfig clientConfig, Cache clientSideCache) { this(new ConnectionFactory(hostAndPort, clientConfig, clientSideCache)); attachAuthenticationListener(clientConfig.getAuthXManager()); + attachRebindHandler(clientConfig, (ConnectionFactory) this.getFactory()); } public ConnectionPool(PooledObjectFactory factory) { @@ -33,6 +39,7 @@ public ConnectionPool(HostAndPort hostAndPort, JedisClientConfig clientConfig, GenericObjectPoolConfig poolConfig) { this(new ConnectionFactory(hostAndPort, clientConfig), poolConfig); attachAuthenticationListener(clientConfig.getAuthXManager()); + attachRebindHandler(clientConfig, (ConnectionFactory) this.getFactory()); } @Experimental @@ -40,6 +47,7 @@ public ConnectionPool(HostAndPort hostAndPort, JedisClientConfig clientConfig, Cache clientSideCache, GenericObjectPoolConfig poolConfig) { this(new ConnectionFactory(hostAndPort, clientConfig, clientSideCache), poolConfig); attachAuthenticationListener(clientConfig.getAuthXManager()); + attachRebindHandler(clientConfig, (ConnectionFactory) this.getFactory()); } public ConnectionPool(PooledObjectFactory factory, @@ -78,4 +86,35 @@ private void attachAuthenticationListener(AuthXManager authXManager) { }); } } + + private void attachRebindHandler(JedisClientConfig clientConfig, ConnectionFactory factory) { + if (clientConfig.isProactiveRebindEnabled()) { + RebindHandler rebindHandler = new RebindHandler(this, factory); + clientConfig.getMaintenanceEventHandler().addListener(rebindHandler); + } + } + + private static class RebindHandler implements MaintenanceEventListener { + private final ConnectionPool pool; + private final ConnectionFactory factory; + private final AtomicReference rebindTarget = new AtomicReference<>(); + + public RebindHandler(ConnectionPool pool, ConnectionFactory factory) { + this.pool = pool; + this.factory = factory; + } + + @Override + public void onRebind(HostAndPort target) { + if (target == null) { + return; + } + + HostAndPort previous = rebindTarget.getAndSet(target); + if (!target.equals(previous)) { + this.factory.rebind(target); + this.pool.clear(); + } + } + } } diff --git a/src/main/java/redis/clients/jedis/DefaultJedisClientConfig.java b/src/main/java/redis/clients/jedis/DefaultJedisClientConfig.java index 25a4737ec0..0c39aa67df 100644 --- a/src/main/java/redis/clients/jedis/DefaultJedisClientConfig.java +++ b/src/main/java/redis/clients/jedis/DefaultJedisClientConfig.java @@ -33,6 +33,14 @@ public final class DefaultJedisClientConfig implements JedisClientConfig { private final AuthXManager authXManager; + private final TimeoutOptions timeoutOptions; + + private final boolean proactiveRebindEnabled; + + private final PushHandler pushHandler; + + private final MaintenanceEventHandler maintenanceEventHandler; + private DefaultJedisClientConfig(DefaultJedisClientConfig.Builder builder) { this.redisProtocol = builder.redisProtocol; this.connectionTimeoutMillis = builder.connectionTimeoutMillis; @@ -50,6 +58,19 @@ private DefaultJedisClientConfig(DefaultJedisClientConfig.Builder builder) { this.clientSetInfoConfig = builder.clientSetInfoConfig; this.readOnlyForRedisClusterReplicas = builder.readOnlyForRedisClusterReplicas; this.authXManager = builder.authXManager; + this.timeoutOptions = builder.timeoutOptions; + this.proactiveRebindEnabled = builder.proactiveRebindEnabled; + this.pushHandler = builder.pushHandler; + + if ((builder.proactiveRebindEnabled + || TimeoutOptions.isRelaxedTimeoutEnabled(builder.timeoutOptions.getRelaxedTimeout()) + || TimeoutOptions.isRelaxedTimeoutEnabled(builder.timeoutOptions.getRelaxedBlockingTimeout())) + && builder.maintenanceEventHandler == null) { + // Proactive rebind or relaxed timeouts require a maintenance event handler + this.maintenanceEventHandler = new MaintenanceEventHandlerImpl(); + } else { + this.maintenanceEventHandler = builder.maintenanceEventHandler; + } } @Override @@ -143,6 +164,27 @@ public boolean isReadOnlyForRedisClusterReplicas() { return readOnlyForRedisClusterReplicas; } + @Override + public TimeoutOptions getTimeoutOptions() { + return timeoutOptions; + } + + @Override + public boolean isProactiveRebindEnabled() { + return proactiveRebindEnabled; + } + + @Override + public PushHandler getPushHandler() { + return pushHandler; + } + + + @Override + public MaintenanceEventHandler getMaintenanceEventHandler() { + return maintenanceEventHandler; + } + public static Builder builder() { return new Builder(); } @@ -175,6 +217,14 @@ public static class Builder { private AuthXManager authXManager = null; + private TimeoutOptions timeoutOptions = TimeoutOptions.create(); + + private boolean proactiveRebindEnabled = false; + + private PushHandler pushHandler = null; + + private MaintenanceEventHandler maintenanceEventHandler = null; + private Builder() { } @@ -297,6 +347,26 @@ public Builder authXManager(AuthXManager authXManager) { return this; } + public Builder timeoutOptions(TimeoutOptions timeoutOptions) { + this.timeoutOptions = timeoutOptions; + return this; + } + + public Builder proactiveRebindEnabled(boolean proactiveRebindEnabled) { + this.proactiveRebindEnabled = proactiveRebindEnabled; + return this; + } + + public Builder pushHandler(PushHandler pushHandler) { + this.pushHandler = pushHandler; + return this; + } + + public Builder maintenanceEventHandler(MaintenanceEventHandler maintenanceEventHandler) { + this.maintenanceEventHandler = maintenanceEventHandler; + return this; + } + public Builder from(JedisClientConfig instance) { this.redisProtocol = instance.getRedisProtocol(); this.connectionTimeoutMillis = instance.getConnectionTimeoutMillis(); @@ -314,6 +384,10 @@ public Builder from(JedisClientConfig instance) { this.clientSetInfoConfig = instance.getClientSetInfoConfig(); this.readOnlyForRedisClusterReplicas = instance.isReadOnlyForRedisClusterReplicas(); this.authXManager = instance.getAuthXManager(); + this.timeoutOptions = instance.getTimeoutOptions(); + this.proactiveRebindEnabled = instance.isProactiveRebindEnabled(); + this.pushHandler = instance.getPushHandler(); + this.maintenanceEventHandler = instance.getMaintenanceEventHandler(); return this; } } @@ -375,6 +449,10 @@ public static DefaultJedisClientConfig copyConfig(JedisClientConfig copy) { } builder.authXManager(copy.getAuthXManager()); + builder.timeoutOptions(copy.getTimeoutOptions()); + if (copy.isProactiveRebindEnabled()) { + builder.proactiveRebindEnabled(true); + } return builder.build(); } diff --git a/src/main/java/redis/clients/jedis/JedisClientConfig.java b/src/main/java/redis/clients/jedis/JedisClientConfig.java index ce7fd82de4..ab3039bc00 100644 --- a/src/main/java/redis/clients/jedis/JedisClientConfig.java +++ b/src/main/java/redis/clients/jedis/JedisClientConfig.java @@ -115,4 +115,32 @@ default boolean isReadOnlyForRedisClusterReplicas() { default ClientSetInfoConfig getClientSetInfoConfig() { return ClientSetInfoConfig.DEFAULT; } + + default TimeoutOptions getTimeoutOptions() { + return TimeoutOptions.create(); + } + + /** + * Configure whether the driver should listen for server events that indicate the current endpoint is being re-bound. + * When enabled, the proactive re-bind will help with the connection handover and reduce the number of failed commands. + * This feature requires the server to support proactive re-binds. + * Enabling this feature requires also setting a {@link #getMaintenanceEventHandler() maintenance event handler} + * + * Defaults to {@code false}. + */ + default boolean isProactiveRebindEnabled() { + return false; + } + + default PushHandler getPushHandler() { + return null; + } + + /** + * @return The event handler to use for server maintenance events. + */ + default MaintenanceEventHandler getMaintenanceEventHandler(){ + return null; + } + } diff --git a/src/main/java/redis/clients/jedis/ListenerHandler.java b/src/main/java/redis/clients/jedis/ListenerHandler.java new file mode 100644 index 0000000000..bf8bb437de --- /dev/null +++ b/src/main/java/redis/clients/jedis/ListenerHandler.java @@ -0,0 +1,14 @@ +package redis.clients.jedis; + +import java.util.Collection; + +public interface ListenerHandler { + void addListener(T listener); + + void removeListener(T listener); + + void removeAllListeners(); + + Collection getListeners(); + +} \ No newline at end of file diff --git a/src/main/java/redis/clients/jedis/MaintenanceEventHandler.java b/src/main/java/redis/clients/jedis/MaintenanceEventHandler.java new file mode 100644 index 0000000000..60c57cfc8e --- /dev/null +++ b/src/main/java/redis/clients/jedis/MaintenanceEventHandler.java @@ -0,0 +1,5 @@ +package redis.clients.jedis; + +public interface MaintenanceEventHandler extends ListenerHandler { + +} \ No newline at end of file diff --git a/src/main/java/redis/clients/jedis/MaintenanceEventHandlerImpl.java b/src/main/java/redis/clients/jedis/MaintenanceEventHandlerImpl.java new file mode 100644 index 0000000000..87d132ca01 --- /dev/null +++ b/src/main/java/redis/clients/jedis/MaintenanceEventHandlerImpl.java @@ -0,0 +1,5 @@ +package redis.clients.jedis; + +public class MaintenanceEventHandlerImpl extends AbstractListenerHandler + implements MaintenanceEventHandler { +} diff --git a/src/main/java/redis/clients/jedis/MaintenanceEventListener.java b/src/main/java/redis/clients/jedis/MaintenanceEventListener.java new file mode 100644 index 0000000000..8b917b65bb --- /dev/null +++ b/src/main/java/redis/clients/jedis/MaintenanceEventListener.java @@ -0,0 +1,19 @@ +package redis.clients.jedis; + +public interface MaintenanceEventListener { + + default void onMigrating() { + }; + + default void onMigrated() { + }; + + default void onFailOver() { + }; + + default void onFailedOver() { + }; + + default void onRebind(HostAndPort target) { + }; +} diff --git a/src/main/java/redis/clients/jedis/Protocol.java b/src/main/java/redis/clients/jedis/Protocol.java index 226702cd9f..e4238e5dfd 100644 --- a/src/main/java/redis/clients/jedis/Protocol.java +++ b/src/main/java/redis/clients/jedis/Protocol.java @@ -19,6 +19,8 @@ import redis.clients.jedis.util.RedisOutputStream; import redis.clients.jedis.util.SafeEncoder; +import static redis.clients.jedis.PushConsumerChain.PROPAGATE_ALL_HANDLER; + public final class Protocol { public static final String DEFAULT_HOST = "127.0.0.1"; @@ -127,7 +129,7 @@ private static String[] parseTargetHostAndSlot(String clusterRedirectResponse) { return response; } - private static Object process(final RedisInputStream is) { + private static Object process(final RedisInputStream is, PushConsumer pushConsumer) { final byte b = is.readByte(); // System.out.println("BYTE: " + (char) b); switch (b) { @@ -153,7 +155,8 @@ private static Object process(final RedisInputStream is) { case TILDE_BYTE: // TODO: return processMultiBulkReply(is); case GREATER_THAN_BYTE: - return processMultiBulkReply(is); + // return processMultiBulkReply(is) + return processPush(is, pushConsumer); case MINUS_BYTE: processError(is); return null; @@ -193,7 +196,7 @@ private static List processMultiBulkReply(final RedisInputStream is) { final List ret = new ArrayList<>(num); for (int i = 0; i < num; i++) { try { - ret.add(process(is)); + ret.add(process(is, null)); } catch (JedisDataException e) { ret.add(e); } @@ -211,22 +214,49 @@ private static List processMapKeyValueReply(final RedisInputStream is) default: final List ret = new ArrayList<>(num); for (int i = 0; i < num; i++) { - ret.add(new KeyValue(process(is), process(is))); + ret.add(new KeyValue(process(is, null), process(is,null))); } return ret; } } public static Object read(final RedisInputStream is) { - return process(is); + // for backward compatibility propagate all push events to application + Object reply = process(is, PROPAGATE_ALL_HANDLER); + + if (reply != null & reply instanceof PushConsumerContext) { + PushConsumerContext context = (PushConsumerContext) reply; + if (!context.isForwardToClient()) { + return null; + } + return context.getMessage().getContent(); + } + + return reply; } @Experimental - public static Object read(final RedisInputStream is, final Cache cache) { - Object unhandledPush = readPushes(is, cache, false); - return unhandledPush == null ? process(is) : unhandledPush; + public static Object read(final RedisInputStream is, PushConsumer pushConsumer) { + // read until we have a non-push event, + // or push-event is not handled and need to be propagated to application + Object reply; + do { + reply = process(is, pushConsumer); + + } while (isPush(reply) && !((PushConsumerContext) reply).isForwardToClient()); + + if (isPush(reply)) { + return ((PushConsumerContext) reply).getMessage().getContent(); + } + + return reply; + } + + private static boolean isPush(Object reply) { + return reply instanceof PushConsumerContext; } + // TODO : Refactor to use PushHandler @Experimental public static Object readPushes(final RedisInputStream is, final Cache cache, boolean onlyPendingBuffer) { @@ -247,6 +277,13 @@ public static Object readPushes(final RedisInputStream is, final Cache cache, return unhandledPush; } + private static PushConsumerContext processPush(final RedisInputStream is, PushConsumer handler) { + List list = processMultiBulkReply(is); + PushConsumerContext context = new PushConsumerContext(new PushMessage(list)); + handler.accept(context); + return context; + } + private static Object processPush(final RedisInputStream is, Cache cache) { is.readByte(); List list = processMultiBulkReply(is); diff --git a/src/main/java/redis/clients/jedis/PushConsumer.java b/src/main/java/redis/clients/jedis/PushConsumer.java new file mode 100644 index 0000000000..24eefd7186 --- /dev/null +++ b/src/main/java/redis/clients/jedis/PushConsumer.java @@ -0,0 +1,19 @@ +package redis.clients.jedis; + +import redis.clients.jedis.annots.Internal; + +@Internal +@FunctionalInterface +public interface PushConsumer { + + /** + * Handle a push message. + *

+ * Messages are not processed by default. Handlers should update the context's processed flag to + * true if they have processed the message. + *

+ * @param context The context of the message to respond to. + */ + void accept(PushConsumerContext context); + +} diff --git a/src/main/java/redis/clients/jedis/PushConsumerChain.java b/src/main/java/redis/clients/jedis/PushConsumerChain.java new file mode 100644 index 0000000000..e84d883c54 --- /dev/null +++ b/src/main/java/redis/clients/jedis/PushConsumerChain.java @@ -0,0 +1,144 @@ +package redis.clients.jedis; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import redis.clients.jedis.annots.Internal; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashSet; +import java.util.List; +import java.util.Set; + +/** + * A chain of PushHandlers that processes events in order. + *

+ * Uses a context object for tracking the processed state. + *

+ */ +@Internal +public final class PushConsumerChain implements PushConsumer { + /** + * Handler that allows all push events to be propagated to the client. + */ + public static final PushConsumer PROPAGATE_ALL_HANDLER = (context) -> { + // mark as not-processed, always propagate + context.setForwardToClient(true); + }; + /** + * Handler that allows only pub/sub related events to be propagated to the client + *

+ * Marks non-pub/sub events as processed, preventing their propagation. + *

+ */ + public static final PushConsumer PUBSUB_ONLY_HANDLER = new PushConsumer() { + final Set pubSubCommands = new HashSet<>(); + + { + pubSubCommands.add("message"); + pubSubCommands.add("pmessage"); + pubSubCommands.add("smessage"); + pubSubCommands.add("subscribe"); + pubSubCommands.add("ssubscribe"); + pubSubCommands.add("psubscribe"); + pubSubCommands.add("unsubscribe"); + pubSubCommands.add("sunsubscribe"); + pubSubCommands.add("punsubscribe"); + } + + @Override + public void accept(PushConsumerContext context) { + if (pubSubCommands.contains(context.getMessage().getType())) { + // Ensure pub/sub events are propagated to application + context.setForwardToClient(true); + } + } + }; + private static final Logger log = LoggerFactory.getLogger(PushConsumerChain.class); + private final List consumers; + + /** + * Create a new empty handler chain. + */ + public PushConsumerChain() { + this.consumers = new ArrayList<>(); + } + + /** + * Create a chain with the specified handlers. + * @param consumers The handlers to add to the chain + */ + public PushConsumerChain(PushConsumer... consumers) { + this.consumers = new ArrayList<>(Arrays.asList(consumers)); + } + + /** + * Create a chain with the specified handlers. + * @param handlers The handlers to add to the chain + * @return A new handler chain with the specified handlers + */ + public static PushConsumerChain of(PushConsumer... handlers) { + return new PushConsumerChain(handlers); + } + + /** + * Add a handler to the end of the chain. + * @param handler The handler to add + * @return this chain for method chaining + */ + public PushConsumerChain add(PushConsumer handler) { + if (handler != null) { + consumers.add(handler); + } + return this; + } + + /** + * Insert a handler at the specified position. + * @param index The position to insert at (0-based) + * @param handler The handler to insert + * @return this chain for method chaining + */ + public PushConsumerChain insert(int index, PushConsumer handler) { + if (handler != null) { + consumers.add(index, handler); + } + return this; + } + + /** + * Remove a handler from the chain. + * @param handler The handler to remove + * @return true if the handler was removed + */ + public boolean remove(PushConsumer handler) { + return consumers.remove(handler); + } + + /** + * Get the number of handlers in the chain. + * @return The number of handlers + */ + public int size() { + return consumers.size(); + } + + /** + * Clear all handlers from the chain. + */ + public void clear() { + consumers.clear(); + } + + @Override + public void accept(PushConsumerContext context) { + if (consumers.isEmpty()) { + return; + } + + for (PushConsumer handler : consumers) { + handler.accept(context); + } + } + +} \ No newline at end of file diff --git a/src/main/java/redis/clients/jedis/PushConsumerContext.java b/src/main/java/redis/clients/jedis/PushConsumerContext.java new file mode 100644 index 0000000000..df429ce7a7 --- /dev/null +++ b/src/main/java/redis/clients/jedis/PushConsumerContext.java @@ -0,0 +1,27 @@ +package redis.clients.jedis; + +import redis.clients.jedis.annots.Internal; + +@Internal +public class PushConsumerContext { + private final PushMessage message; + + private boolean forwardToClient = false; + + public PushConsumerContext(PushMessage message) { + this.message = message; + } + + public PushMessage getMessage() { + return message; + } + + public boolean isForwardToClient() { + return forwardToClient; + } + + public void setForwardToClient(boolean forwardToClient) { + this.forwardToClient = forwardToClient; + } + +} diff --git a/src/main/java/redis/clients/jedis/PushHandelrImpl.java b/src/main/java/redis/clients/jedis/PushHandelrImpl.java new file mode 100644 index 0000000000..38ab0000c8 --- /dev/null +++ b/src/main/java/redis/clients/jedis/PushHandelrImpl.java @@ -0,0 +1,4 @@ +package redis.clients.jedis; + +class PushHandlerImpl extends AbstractListenerHandler implements PushHandler { +} diff --git a/src/main/java/redis/clients/jedis/PushHandler.java b/src/main/java/redis/clients/jedis/PushHandler.java new file mode 100644 index 0000000000..43de22816f --- /dev/null +++ b/src/main/java/redis/clients/jedis/PushHandler.java @@ -0,0 +1,47 @@ +package redis.clients.jedis; + +import java.util.Collection; +import java.util.Collections; + +/** + * A handler object that provides access to {@link PushListener}s. + * @author Ivo Gaydajiev + * @since 6.1 + */ +public interface PushHandler extends ListenerHandler { + + /** + * A no-operation implementation of PushHandler that doesn't maintain any listeners + *

+ * All operations are no-ops and getPushListeners() returns an empty list. + *

+ */ + PushHandler NOOP = new NoOpPushHandler(); + +} + +final class NoOpPushHandler implements PushHandler { + + NoOpPushHandler() { + } + + @Override + public void addListener(PushListener listener) { + // No-op + } + + @Override + public void removeListener(PushListener listener) { + // No-op + } + + @Override + public void removeAllListeners() { + // No-op + } + + @Override + public Collection getListeners() { + return Collections.emptyList(); + } +} \ No newline at end of file diff --git a/src/main/java/redis/clients/jedis/PushListener.java b/src/main/java/redis/clients/jedis/PushListener.java new file mode 100644 index 0000000000..26ca46b6bd --- /dev/null +++ b/src/main/java/redis/clients/jedis/PushListener.java @@ -0,0 +1,15 @@ +package redis.clients.jedis; + +@FunctionalInterface +public interface PushListener { + + /** + * Interface to be implemented by push message listeners that are interested in listening to + * {@link PushMessage}. Requires Redis 6+ using RESP3. + * @author Ivo Gaydajiev + * @since 6.1 + * @see PushMessage + */ + void onPush(PushMessage push); + +} diff --git a/src/main/java/redis/clients/jedis/PushMessage.java b/src/main/java/redis/clients/jedis/PushMessage.java new file mode 100644 index 0000000000..48c0364c41 --- /dev/null +++ b/src/main/java/redis/clients/jedis/PushMessage.java @@ -0,0 +1,25 @@ +package redis.clients.jedis; + +import redis.clients.jedis.util.SafeEncoder; + +import java.util.List; + +public class PushMessage { + String type; + List content; + + public PushMessage(List content) { + this.content = content; + if (content.size() > 0) { + type = SafeEncoder.encode((byte[]) content.get(0)); + } + } + + public String getType() { + return type; + } + + public List getContent() { + return content; + } +} \ No newline at end of file diff --git a/src/main/java/redis/clients/jedis/RebindAware.java b/src/main/java/redis/clients/jedis/RebindAware.java new file mode 100644 index 0000000000..89c6e9fd00 --- /dev/null +++ b/src/main/java/redis/clients/jedis/RebindAware.java @@ -0,0 +1,28 @@ +package redis.clients.jedis; + +import redis.clients.jedis.annots.Experimental; + +/** + * Interface for components that support rebinding to a new host and port. + *

+ * Implementations of this interface can be notified when a Redis server sends a MOVING notification + * during maintenance events. This interface can be implemented by various components such as: - + * Connection pools - Socket factories - Connection providers - Any component that manages + * connections to Redis servers + *

+ */ +@Experimental +public interface RebindAware { + + /** + * Notifies the component that a re-bind to a new host and port is scheduled. + *

+ * This is called when a MOVING notification is received. Components that implement this interface + * should update their internal state to reflect the new host and port, and return true if the + * re-bind was accepted. Components might decide to reject the re-bind request if they are not in + * a state to support it. + *

+ * @param newHostAndPort The new host and port to use for new connections + */ + void rebind(HostAndPort newHostAndPort); +} \ No newline at end of file diff --git a/src/main/java/redis/clients/jedis/TimeoutOptions.java b/src/main/java/redis/clients/jedis/TimeoutOptions.java new file mode 100644 index 0000000000..ce91490986 --- /dev/null +++ b/src/main/java/redis/clients/jedis/TimeoutOptions.java @@ -0,0 +1,116 @@ +package redis.clients.jedis; + +import redis.clients.jedis.util.JedisAsserts; + +import java.time.Duration; + +public class TimeoutOptions { + + private static final int DISABLED_TIMEOUT_MS = -1; + + public static final Duration DISABLED_TIMEOUT = Duration.ofMillis(DISABLED_TIMEOUT_MS); + + public static final Duration DEFAULT_RELAXED_TIMEOUT = DISABLED_TIMEOUT; + + public static final Duration DEFAULT_RELAXED_BLOCKING_TIMEOUT = DISABLED_TIMEOUT; + + private final Duration relaxedTimeout; + + private final Duration relaxedBlockingTimeout; + + private TimeoutOptions(Duration relaxedTimeout, Duration relaxedBlockingTimeout) { + this.relaxedTimeout = relaxedTimeout; + this.relaxedBlockingTimeout = relaxedBlockingTimeout; + } + + public static boolean isRelaxedTimeoutEnabled(Duration relaxedTimeout) { + return relaxedTimeout != null && !relaxedTimeout.equals(DISABLED_TIMEOUT); + } + + public static boolean isRelaxedTimeoutEnabled(int relaxedTimeout) { + return relaxedTimeout != DISABLED_TIMEOUT_MS; + } + + /** + * @return the {@link Duration} to relax timeouts proactively, {@link #DISABLED_TIMEOUT} if + * disabled. + */ + public Duration getRelaxedTimeout() { + return relaxedTimeout; + } + + /** + * @return the {@link Duration} to relax timeouts proactively for blocking commands, + * {@link #DISABLED_TIMEOUT} if disabled. + */ + public Duration getRelaxedBlockingTimeout() { + return relaxedBlockingTimeout; + } + + /** + * Returns a new {@link TimeoutOptions.Builder} to construct {@link TimeoutOptions}. + * @return a new {@link TimeoutOptions.Builder} to construct {@link TimeoutOptions}. + */ + public static Builder builder() { + return new Builder(); + } + + /** + * Create a new instance of {@link TimeoutOptions} with default settings. + * @return a new instance of {@link TimeoutOptions} with default settings. + */ + public static TimeoutOptions create() { + return builder().build(); + } + + public static class Builder { + private Duration relaxedTimeout = DEFAULT_RELAXED_TIMEOUT; + private Duration relaxedBlockingTimeout = DEFAULT_RELAXED_BLOCKING_TIMEOUT; + + /** + * Enable proactive timeout relaxing. Disabled by default, see {@link #DEFAULT_RELAXED_TIMEOUT}. + *

+ * If the Redis server supports this, and the client is set up to use it , the client would + * listen to notifications that the current endpoint is about to go down (as part of some + * maintenance activity, for example). In such cases, the driver could extend the existing + * timeout settings for newly issued commands, or such that are in flight, to make sure they do + * not time out during this process. + *

+ * @param duration {@link Duration} to relax timeouts proactively, must not be {@code null}. + * @return {@code this} + */ + public Builder proactiveTimeoutsRelaxing(Duration duration) { + JedisAsserts.notNull(duration, "Duration must not be null"); + + this.relaxedTimeout = duration; + return this; + } + + /** + * Enable proactive timeout relaxing for blocking commands. Disabled by default, see + * {@link #DEFAULT_RELAXED_BLOCKING_TIMEOUT}. + *

+ * If the Redis server supports this, and the client is set up to use it, the client would + * listen to notifications that the current endpoint is about to go down (as part of some + * maintenance activity, for example). In such cases, the driver could extend the existing + * timeout settings for blocking commands that are in flight, to make sure they do not time out + * during this process. If not configured, the infinite timeout for blocking commands will be + * preserved. + *

+ * @param duration {@link Duration} to relax timeouts proactively for blocking commands, must + * not be {@code null}. + * @return {@code this} + */ + public Builder proactiveBlockingTimeoutsRelaxing(Duration duration) { + JedisAsserts.notNull(duration, "Duration must not be null"); + + this.relaxedBlockingTimeout = duration; + return this; + } + + public TimeoutOptions build() { + return new TimeoutOptions(relaxedTimeout, relaxedBlockingTimeout); + } + } + +} diff --git a/src/main/java/redis/clients/jedis/csc/CacheConnection.java b/src/main/java/redis/clients/jedis/csc/CacheConnection.java index f157d95a94..c891eda4ff 100644 --- a/src/main/java/redis/clients/jedis/csc/CacheConnection.java +++ b/src/main/java/redis/clients/jedis/csc/CacheConnection.java @@ -1,5 +1,6 @@ package redis.clients.jedis.csc; +import java.util.List; import java.util.Objects; import java.util.concurrent.locks.ReentrantLock; @@ -8,6 +9,8 @@ import redis.clients.jedis.JedisClientConfig; import redis.clients.jedis.JedisSocketFactory; import redis.clients.jedis.Protocol; +import redis.clients.jedis.PushConsumer; +import redis.clients.jedis.PushConsumerContext; import redis.clients.jedis.RedisProtocol; import redis.clients.jedis.exceptions.JedisException; import redis.clients.jedis.util.RedisInputStream; @@ -19,6 +22,21 @@ public class CacheConnection extends Connection { private static final String REDIS = "redis"; private static final String MIN_REDIS_VERSION = "7.4"; + private static class PushInvalidateConsumer implements PushConsumer { + private final Cache cache; + public PushInvalidateConsumer(Cache cache) { + this.cache = cache; + } + + @Override + public void accept(PushConsumerContext event) { + if (event.getMessage().getType().equals("invalidate")) { + cache.deleteByRedisKeys((List) event.getMessage().getContent().get(1)); + event.setForwardToClient(false); + } + } + } + public CacheConnection(final JedisSocketFactory socketFactory, JedisClientConfig clientConfig, Cache cache) { super(socketFactory, clientConfig); @@ -33,6 +51,7 @@ public CacheConnection(final JedisSocketFactory socketFactory, JedisClientConfig } } this.cache = Objects.requireNonNull(cache); + initializeClientSideCache(); } @@ -43,10 +62,11 @@ protected void initializeFromClientConfig(JedisClientConfig config) { } @Override - protected Object protocolRead(RedisInputStream inputStream) { + protected Object protocolRead(RedisInputStream inputStream, PushConsumer consumer) { lock.lock(); try { - return Protocol.read(inputStream, cache); + // return Protocol.read(inputStream, cache); + return Protocol.read(inputStream, pushConsumer); } finally { lock.unlock(); } @@ -102,6 +122,7 @@ public Cache getCache() { } private void initializeClientSideCache() { + this.pushConsumer.add(new PushInvalidateConsumer(cache)); sendCommand(Protocol.Command.CLIENT, "TRACKING", "ON"); String reply = getStatusCodeReply(); if (!"OK".equals(reply)) { diff --git a/src/main/java/redis/clients/jedis/util/NumberUtils.java b/src/main/java/redis/clients/jedis/util/NumberUtils.java new file mode 100644 index 0000000000..5d3d92e374 --- /dev/null +++ b/src/main/java/redis/clients/jedis/util/NumberUtils.java @@ -0,0 +1,13 @@ +package redis.clients.jedis.util; + +public final class NumberUtils { + + public static int safeToInt(long millis) { + if (millis > Integer.MAX_VALUE) { + return Integer.MAX_VALUE; + } + + return (int) millis; + } + +} diff --git a/src/test/java/redis/clients/jedis/ConnectionTestHelper.java b/src/test/java/redis/clients/jedis/ConnectionTestHelper.java new file mode 100644 index 0000000000..dbf4a15584 --- /dev/null +++ b/src/test/java/redis/clients/jedis/ConnectionTestHelper.java @@ -0,0 +1,7 @@ +package redis.clients.jedis; + +public class ConnectionTestHelper { + public static HostAndPort getHostAndPort(Connection connection) { + return connection.getHostAndPort(); + } +} diff --git a/src/test/java/redis/clients/jedis/JedisPubSubBaseTest.java b/src/test/java/redis/clients/jedis/JedisPubSubBaseTest.java index 98b0907735..1c3d58025e 100644 --- a/src/test/java/redis/clients/jedis/JedisPubSubBaseTest.java +++ b/src/test/java/redis/clients/jedis/JedisPubSubBaseTest.java @@ -10,6 +10,7 @@ import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.fail; +import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; import static redis.clients.jedis.Protocol.ResponseKeyword.MESSAGE; diff --git a/src/test/java/redis/clients/jedis/JedisShardedPubSubBaseTest.java b/src/test/java/redis/clients/jedis/JedisShardedPubSubBaseTest.java index 6803d44e96..af22522869 100644 --- a/src/test/java/redis/clients/jedis/JedisShardedPubSubBaseTest.java +++ b/src/test/java/redis/clients/jedis/JedisShardedPubSubBaseTest.java @@ -9,6 +9,7 @@ import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.fail; +import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; import static redis.clients.jedis.Protocol.ResponseKeyword.SMESSAGE; diff --git a/src/test/java/redis/clients/jedis/ProtocolTest.java b/src/test/java/redis/clients/jedis/ProtocolTest.java index e9891b3a93..26a0982f16 100644 --- a/src/test/java/redis/clients/jedis/ProtocolTest.java +++ b/src/test/java/redis/clients/jedis/ProtocolTest.java @@ -4,6 +4,7 @@ import redis.clients.jedis.util.FragmentedByteArrayInputStream; import static org.junit.jupiter.api.Assertions.assertArrayEquals; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertInstanceOf; import static org.junit.jupiter.api.Assertions.assertNull; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.fail; @@ -139,4 +140,117 @@ public void busyReply() { } fail("Expected a JedisBusyException to be thrown."); } + + @Test + public void readPushEventsAreNotPropagatedAsReadOutputIfProcessed() { + // Create a mock push listener + final List receivedMessages = new ArrayList<>(); + PushConsumer handler = pushContext -> { + receivedMessages.add(pushContext.getMessage()); + pushContext.setForwardToClient(false); + }; + + // Create a stream with a push message followed by a regular response + byte[] data = (">2\r\n$10\r\ninvalidate\r\n*1\r\n$3\r\nfoo\r\n+OK\r\n").getBytes(); + RedisInputStream is = new RedisInputStream(new ByteArrayInputStream(data)); + + // Read the response, which should process the push message first + Object response = Protocol.read(is, handler); + + // Verify the response + assertArrayEquals(SafeEncoder.encode("OK"), (byte[]) response); + + // Verify the push message was received + assertEquals(1, receivedMessages.size()); + PushMessage pushMessage = receivedMessages.get(0); + assertEquals(2, pushMessage.getContent().size()); + assertEquals("invalidate", pushMessage.getType()); + assertArrayEquals(SafeEncoder.encode("invalidate"), (byte[]) pushMessage.getContent().get(0)); + + // The second element should be a list with one element "foo" + assertInstanceOf(List.class, pushMessage.getContent().get(1)); + List keys = (List) pushMessage.getContent().get(1); + assertEquals(1, keys.size()); + assertArrayEquals(SafeEncoder.encode("foo"), (byte[]) keys.get(0)); + } + + @Test + public void readMultiplePushEventsAreNotPropagatedAsReadOutputIfProcessed() { + // Create a mock push listener + final List receivedMessages = new ArrayList<>(); + PushConsumer handler = pushContext -> { receivedMessages.add(pushContext.getMessage()); pushContext.setForwardToClient(false); }; + + + // Create a stream with multiple push messages followed by a regular response + byte[] data = ( + ">2\r\n$10\r\ninvalidate\r\n*1\r\n$3\r\nfoo\r\n" + + ">2\r\n$10\r\ninvalidate\r\n*1\r\n$3\r\nbar\r\n" + + ">2\r\n$7\r\nmessage\r\n$5\r\nhello\r\n" + + ":123\r\n" + ).getBytes(); + RedisInputStream is = new RedisInputStream(new ByteArrayInputStream(data)); + + // Read the response, which should process all push messages first + Object response = Protocol.read(is, handler); + + // Verify the response + assertEquals(123L, response); + + // Verify all push messages were received + assertEquals(3, receivedMessages.size()); + + // First push message (invalidate foo) + PushMessage pushMessage1 = receivedMessages.get(0); + assertArrayEquals(SafeEncoder.encode("invalidate"), (byte[]) pushMessage1.getContent().get(0)); + List keys1 = (List) pushMessage1.getContent().get(1); + assertArrayEquals(SafeEncoder.encode("foo"), (byte[]) keys1.get(0)); + + // Second push message (invalidate bar) + PushMessage pushMessage2 = receivedMessages.get(1); + assertArrayEquals(SafeEncoder.encode("invalidate"), (byte[]) pushMessage2.getContent().get(0)); + List keys2 = (List) pushMessage2.getContent().get(1); + assertArrayEquals(SafeEncoder.encode("bar"), (byte[]) keys2.get(0)); + + // Third push message (message hello) + PushMessage pushMessage3 = receivedMessages.get(2); + assertArrayEquals(SafeEncoder.encode("message"), (byte[]) pushMessage3.getContent().get(0)); + assertArrayEquals(SafeEncoder.encode("hello"), (byte[]) pushMessage3.getContent().get(1)); + } + + @Test + public void readPushEventsArePropagateAsReadOutputIfNotProcessed() { + // Create a mock push listener + final List receivedMessages = new ArrayList<>(); + PushConsumer handler = pushContext -> { + receivedMessages.add(pushContext.getMessage()); + pushContext.setForwardToClient(true); + }; + + // Create a stream with a push message followed by a regular response + byte[] data = (">2\r\n$10\r\ninvalidate\r\n*1\r\n$3\r\nfoo\r\n+OK\r\n").getBytes(); + RedisInputStream is = new RedisInputStream(new ByteArrayInputStream(data)); + + // Read the response, which should return + // - invoke the push handler with the push message + // - propagate the push message as the read output since it was not processed + Object pushMessage = Protocol.read(is, handler); + + // Verify the push message is propagated as the read output + assertInstanceOf(ArrayList.class, pushMessage); + assertArrayEquals(SafeEncoder.encode("invalidate"), (byte[]) ((ArrayList) pushMessage).get(0)); + + // Verify the handler receives the push message + assertEquals(1, receivedMessages.size()); + PushMessage push = receivedMessages.get(0); + assertEquals(2, push.getContent().size()); + assertEquals("invalidate", push.getType()); + assertArrayEquals(SafeEncoder.encode("invalidate"), (byte[]) push.getContent().get(0)); + + + // Second read should return the command response itself + Object commandResponse = Protocol.read(is, handler); + + // Verify the response + assertArrayEquals(SafeEncoder.encode("OK"), (byte[]) commandResponse); + } } diff --git a/src/test/java/redis/clients/jedis/PushMessageNotificationTest.java b/src/test/java/redis/clients/jedis/PushMessageNotificationTest.java new file mode 100644 index 0000000000..ff3204738d --- /dev/null +++ b/src/test/java/redis/clients/jedis/PushMessageNotificationTest.java @@ -0,0 +1,363 @@ +package redis.clients.jedis; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; + +import io.redis.test.annotations.SinceRedisVersion; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; +import redis.clients.jedis.Protocol.Command; +import redis.clients.jedis.util.RedisVersionCondition; + +/** + * Tests for Redis RESP3 push notifications functionality. + */ +@SinceRedisVersion("6.0.0") +public class PushMessageNotificationTest { + + private static final EndpointConfig endpoint = HostAndPorts.getRedisEndpoint("standalone0"); + + @RegisterExtension + public RedisVersionCondition versionCondition = new RedisVersionCondition(endpoint); + + private Connection connection; + private UnifiedJedis unifiedJedis; + private final String testKey = "tracking:test:key"; + private final String initialValue = "initial"; + private final String modifiedValue = "modified"; + + @BeforeEach + public void setUp() { + // Nothing to set up by default - connections are created in each test + } + + @AfterEach + public void tearDown() { + if (connection != null) { + connection.close(); + connection = null; + } + + if (unifiedJedis != null) { + try { + unifiedJedis.sendCommand(Command.CLIENT, "TRACKING", "OFF"); + } catch (Exception e) { + // Ignore exceptions during cleanup + } + unifiedJedis.close(); + unifiedJedis = null; + } + } + + /** + * Helper method to modify a key using a separate connection to trigger invalidation. + * @param key The key to modify + * @param value The new value to set + */ + private void triggerKeyInvalidation(String key, String value) { + try (Jedis modifierClient = new Jedis(endpoint.getHostAndPort(), + endpoint.getClientConfigBuilder().protocol(RedisProtocol.RESP3).build())) { + modifierClient.set(key, value); + } + } + + /** + * Helper method to enable client tracking on a connection. + * @param connection The connection on which to enable tracking + */ + private void enableClientTracking(Connection connection) { + connection.sendCommand(Command.CLIENT, "TRACKING", "ON"); + assertEquals("OK", connection.getStatusCodeReply()); + } + + @Test + public void testConnectionResp3PushNotifications() { + connection = new Connection(endpoint.getHostAndPort(), + endpoint.getClientConfigBuilder().protocol(RedisProtocol.RESP3).build()); + connection.connect(); + + // Enable client tracking + enableClientTracking(connection); + + // Set initial value + CommandArguments comArgs = new CommandArguments(Command.SET); + CommandObject set = new CommandObject<>(comArgs.key(testKey).add(initialValue), + BuilderFactory.STRING); + String setResult = connection.executeCommand(set); + assertEquals("OK", setResult); + + // Get the key to track it + CommandObject get = new CommandObject<>(new CommandArguments(Command.GET).key(testKey), + BuilderFactory.STRING); + String getResponse = connection.executeCommand(get); + assertEquals(initialValue, getResponse); + + // Modify the key from another connection to trigger invalidation + triggerKeyInvalidation(testKey, modifiedValue); + + // Send PING and expect to receive invalidation message first, then PONG + CommandObject ping = new CommandObject<>(new CommandArguments(Command.PING), + BuilderFactory.STRING); + String pingResponse = connection.executeCommand(ping); + assertEquals("PONG", pingResponse); + } + + @Test + public void testUnifiedJedisResp3PushNotifications() { + unifiedJedis = new UnifiedJedis(endpoint.getHostAndPort(), + endpoint.getClientConfigBuilder().protocol(RedisProtocol.RESP3).build()); + + // Enable client tracking + unifiedJedis.sendCommand(Command.CLIENT, "TRACKING", "ON"); + + // Set initial value + unifiedJedis.set(testKey, initialValue); + + // Get the key to track it + String getResponse = unifiedJedis.get(testKey); + assertEquals(initialValue, getResponse); + + // Modify the key from another connection to trigger invalidation + triggerKeyInvalidation(testKey, modifiedValue); + + // Send PING command + String pingResponse = unifiedJedis.ping(); + // Next reply should be PONG + assertEquals("PONG", pingResponse); + } + + @Test + public void testUnifiedJedisCustomPushListener() { + List receivedMessages = new ArrayList<>(); + PushHandlerImpl pushHandler = new PushHandlerImpl(); + pushHandler.addListener(receivedMessages::add); + + DefaultJedisClientConfig clientConfig = endpoint.getClientConfigBuilder() + .pushHandler(pushHandler).protocol(RedisProtocol.RESP3).build(); + + unifiedJedis = new UnifiedJedis(endpoint.getHostAndPort(), clientConfig); + + // Enable client tracking + unifiedJedis.sendCommand(Command.CLIENT, "TRACKING", "ON"); + + // Set initial value + unifiedJedis.set(testKey, initialValue); + + // Get the key to track it + assertEquals(initialValue, unifiedJedis.get(testKey)); + + // Modify the key from another connection to trigger invalidation + triggerKeyInvalidation(testKey, modifiedValue); + + // Send PING command + String pingResponse = unifiedJedis.ping(); + // Next reply should be PONG + assertEquals("PONG", pingResponse); + assertEquals(1, receivedMessages.size()); + assertEquals("invalidate", receivedMessages.get(0).getType()); + } + + @Test + public void testJedisCustomPushListener() { + List receivedMessages = new ArrayList<>(); + PushHandlerImpl pushHandler = new PushHandlerImpl(); + pushHandler.addListener(receivedMessages::add); + + DefaultJedisClientConfig clientConfig = endpoint.getClientConfigBuilder() + .pushHandler(pushHandler).protocol(RedisProtocol.RESP3).build(); + + Jedis jedis = new Jedis(endpoint.getHostAndPort(), clientConfig); + + // Enable client tracking + jedis.sendCommand(Command.CLIENT, "TRACKING", "ON"); + + // Set initial value + jedis.set(testKey, initialValue); + + // Get the key to track it + assertEquals(initialValue, jedis.get(testKey)); + + // Modify the key from another connection to trigger invalidation + triggerKeyInvalidation(testKey, modifiedValue); + + // Send PING command + String pingResponse = jedis.ping(); + // Next reply should be PONG + assertEquals("PONG", pingResponse); + assertEquals(1, receivedMessages.size()); + assertEquals("invalidate", receivedMessages.get(0).getType()); + + // Clean up + jedis.close(); + } + + @Test + public void testConnectionResp3PushNotificationsWithCustomListener() { + // Create a list to store received push messages + List receivedMessages = new ArrayList<>(); + + // Create a custom push listener + PushConsumer listener = pushContext -> { + receivedMessages.add(pushContext.getMessage()); + }; + + // Create connection with RESP3 protocol + connection = new Connection(endpoint.getHostAndPort(), + endpoint.getClientConfigBuilder().protocol(RedisProtocol.RESP3).build()); + connection.connect(); + + // Set the push listener + connection.getPushConsumer().add(listener); + + // Enable client tracking + enableClientTracking(connection); + + // Set and get a key to track it + CommandArguments setArgs = new CommandArguments(Command.SET); + CommandObject setCmd = new CommandObject<>(setArgs.key(testKey).add(initialValue), + BuilderFactory.STRING); + connection.executeCommand(setCmd); + + CommandObject getCmd = new CommandObject<>( + new CommandArguments(Command.GET).key(testKey), BuilderFactory.STRING); + connection.executeCommand(getCmd); + + // Modify the key from another connection to trigger invalidation + triggerKeyInvalidation(testKey, modifiedValue); + + // Send a command to trigger processing of any pending push messages + CommandObject pingCmd = new CommandObject<>(new CommandArguments(Command.PING), + BuilderFactory.STRING); + String pingResponse = connection.executeCommand(pingCmd); + assertEquals("PONG", pingResponse); + + // Verify we received at least one push message + assertTrue(!receivedMessages.isEmpty(), "Should have received at least one push message"); + + // Verify the message is an invalidation message + PushMessage pushMessage = receivedMessages.get(0); + assertNotNull(pushMessage); + assertEquals("invalidate", pushMessage.getType()); + } + + @ParameterizedTest + @MethodSource("redis.clients.jedis.commands.CommandsTestsParameters#respVersions") + public void testUnifiedJedisPubSubWithResp3PushNotifications(RedisProtocol protocol) + throws InterruptedException { + // Create a UnifiedJedis instance with RESP3 protocol for subscribing + unifiedJedis = new UnifiedJedis(endpoint.getHostAndPort(), + endpoint.getClientConfigBuilder().protocol(protocol).build()); + + // Enable client tracking to generate push notifications + unifiedJedis.sendCommand(Command.CLIENT, "TRACKING", "ON"); + + // Set initial value to track + unifiedJedis.set(testKey, initialValue); + + // Get the key to track it + String getResponse = unifiedJedis.get(testKey); + assertEquals(initialValue, getResponse); + + // Create a list to store received pub/sub messages + final List receivedMessages = new ArrayList<>(); + + // Create an atomic counter to track received messages + final AtomicInteger messageCounter = new AtomicInteger(0); + + // Create a latch to signal when subscription is ready + final CountDownLatch subscriptionLatch = new CountDownLatch(1); + + // Create a JedisPubSub instance to handle pub/sub messages + JedisPubSub pubSub = new JedisPubSub() { + @Override + public void onMessage(String channel, String message) { + System.out.println("onMessage from " + channel + " : " + message); + receivedMessages.add(message); + + // If we've received both messages, unsubscribe + if (messageCounter.incrementAndGet() == 2) { + this.unsubscribe("test-channel"); + } + } + + @Override + public void onUnsubscribe(String channel, int subscribedChannels) { + // Signal that subscription is ready + System.out.println("Unsubscribed from " + channel); + } + + @Override + public void onSubscribe(String channel, int subscribedChannels) { + // Signal that subscription is ready + subscriptionLatch.countDown(); + } + }; + + // Start a thread to handle the subscription + Thread subscriberThread = new Thread(() -> { + unifiedJedis.subscribe(pubSub, "test-channel"); + }); + + // Start the subscriber thread + subscriberThread.start(); + + // Start a thread to publish messages and trigger key invalidation + Thread publisherThread = new Thread(() -> { + try (UnifiedJedis publisher = new UnifiedJedis(endpoint.getHostAndPort(), + endpoint.getClientConfigBuilder().protocol(RedisProtocol.RESP3).build())) { + + // Wait for subscription to be ready + try { + if (!subscriptionLatch.await(5, TimeUnit.SECONDS)) { + System.err.println("Timed out waiting for subscription to be ready"); + return; + } + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + return; + } + + // Publish a message + publisher.publish("test-channel", "test-message-1"); + + // Trigger key invalidation to generate a push notification + triggerKeyInvalidation(testKey, modifiedValue); + + // Publish another message + publisher.publish("test-channel", "test-message-2"); + } catch (Exception e) { + e.printStackTrace(); + } + }); + + // Start the publisher thread + publisherThread.start(); + + // Wait for the subscriber thread to complete (it will complete when unsubscribe is called) + subscriberThread.join(); + + // Wait for the publisher thread to complete + publisherThread.join(); + + // Verify that we received both pub/sub messages + assertEquals(2, receivedMessages.size(), "Should have received both pub/sub messages"); + assertEquals("test-message-1", receivedMessages.get(0)); + assertEquals("test-message-2", receivedMessages.get(1)); + + // Send a PING command to process any pending push messages + String pingResponse = unifiedJedis.ping(); + assertEquals("PONG", pingResponse); + } +} diff --git a/src/test/java/redis/clients/jedis/commands/jedis/PublishSubscribeCommandsTest.java b/src/test/java/redis/clients/jedis/commands/jedis/PublishSubscribeCommandsTest.java index 8d52908d27..fc698ad40e 100644 --- a/src/test/java/redis/clients/jedis/commands/jedis/PublishSubscribeCommandsTest.java +++ b/src/test/java/redis/clients/jedis/commands/jedis/PublishSubscribeCommandsTest.java @@ -22,7 +22,6 @@ import java.util.concurrent.atomic.AtomicBoolean; import java.util.function.Consumer; -import org.hamcrest.Matchers; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Timeout; import org.junit.jupiter.params.ParameterizedClass; diff --git a/src/test/java/redis/clients/jedis/commands/jedis/TransactionCommandsTest.java b/src/test/java/redis/clients/jedis/commands/jedis/TransactionCommandsTest.java index 7ab7589fb7..09fdd7e7d5 100644 --- a/src/test/java/redis/clients/jedis/commands/jedis/TransactionCommandsTest.java +++ b/src/test/java/redis/clients/jedis/commands/jedis/TransactionCommandsTest.java @@ -30,11 +30,13 @@ import redis.clients.jedis.Jedis; import redis.clients.jedis.Protocol; +import redis.clients.jedis.PushConsumer; import redis.clients.jedis.RedisProtocol; import redis.clients.jedis.Response; import redis.clients.jedis.Transaction; import redis.clients.jedis.exceptions.JedisConnectionException; import redis.clients.jedis.exceptions.JedisDataException; +import redis.clients.jedis.util.RedisInputStream; import redis.clients.jedis.util.SafeEncoder; @ParameterizedClass @@ -176,7 +178,7 @@ public void discardFail() { trans.set("b", "b"); try (MockedStatic protocol = Mockito.mockStatic(Protocol.class)) { - protocol.when(() -> Protocol.read(any())).thenThrow(JedisConnectionException.class); + protocol.when(() -> Protocol.read(any(RedisInputStream.class), any(PushConsumer.class))).thenThrow(JedisConnectionException.class); trans.discard(); fail("Should get mocked JedisConnectionException."); @@ -196,7 +198,7 @@ public void execFail() { trans.set("b", "b"); try (MockedStatic protocol = Mockito.mockStatic(Protocol.class)) { - protocol.when(() -> Protocol.read(any())).thenThrow(JedisConnectionException.class); + protocol.when(() -> Protocol.read(any(RedisInputStream.class), any(PushConsumer.class))).thenThrow(JedisConnectionException.class); trans.exec(); fail("Should get mocked JedisConnectionException."); diff --git a/src/test/java/redis/clients/jedis/csc/ClientSideCacheTestBase.java b/src/test/java/redis/clients/jedis/csc/ClientSideCacheTestBase.java index 7e13d98da3..808095d018 100644 --- a/src/test/java/redis/clients/jedis/csc/ClientSideCacheTestBase.java +++ b/src/test/java/redis/clients/jedis/csc/ClientSideCacheTestBase.java @@ -6,10 +6,12 @@ import org.apache.commons.pool2.impl.GenericObjectPoolConfig; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Tag; import org.junit.jupiter.api.extension.RegisterExtension; import redis.clients.jedis.*; import redis.clients.jedis.util.RedisVersionCondition; +@Tag("ClientSideCache") @SinceRedisVersion(value = "7.4.0", message = "Jedis client-side caching is only supported with Redis 7.4 or later.") public abstract class ClientSideCacheTestBase { diff --git a/src/test/java/redis/clients/jedis/csc/UnifiedJedisClientSideCacheTestBase.java b/src/test/java/redis/clients/jedis/csc/UnifiedJedisClientSideCacheTestBase.java index fa4043799e..d70ac0e082 100644 --- a/src/test/java/redis/clients/jedis/csc/UnifiedJedisClientSideCacheTestBase.java +++ b/src/test/java/redis/clients/jedis/csc/UnifiedJedisClientSideCacheTestBase.java @@ -15,10 +15,12 @@ import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Tag; import org.junit.jupiter.api.Test; import redis.clients.jedis.JedisPubSub; import redis.clients.jedis.UnifiedJedis; +@Tag("ClientSideCache") public abstract class UnifiedJedisClientSideCacheTestBase { protected UnifiedJedis control; @@ -73,8 +75,7 @@ public void flushAll() { control.set("foo", "bar"); assertEquals("bar", jedis.get("foo")); control.flushAll(); - await().atMost(5, TimeUnit.SECONDS).pollInterval(50, TimeUnit.MILLISECONDS) - .untilAsserted(() -> assertNull(jedis.get("foo"))); + await().untilAsserted(() -> assertNull(jedis.get("foo"))); } } diff --git a/src/test/java/redis/clients/jedis/upgrade/ConnectionAdaptiveTimeoutTest.java b/src/test/java/redis/clients/jedis/upgrade/ConnectionAdaptiveTimeoutTest.java new file mode 100644 index 0000000000..06dbf0f430 --- /dev/null +++ b/src/test/java/redis/clients/jedis/upgrade/ConnectionAdaptiveTimeoutTest.java @@ -0,0 +1,329 @@ +package redis.clients.jedis.upgrade; + +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Tag; +import org.junit.jupiter.api.Test; +import org.mockito.Mockito; +import redis.clients.jedis.CommandObjects; +import redis.clients.jedis.Connection; +import redis.clients.jedis.DefaultJedisClientConfig; +import redis.clients.jedis.HostAndPort; +import redis.clients.jedis.MaintenanceEventHandler; +import redis.clients.jedis.MaintenanceEventHandlerImpl; +import redis.clients.jedis.MaintenanceEventListener; +import redis.clients.jedis.RedisProtocol; +import redis.clients.jedis.TimeoutOptions; +import redis.clients.jedis.util.ReflectionTestUtils; +import redis.clients.jedis.util.server.CommandHandler; +import redis.clients.jedis.util.server.RespResponse; +import redis.clients.jedis.util.server.TcpMockServer; + +import java.io.IOException; +import java.net.Socket; +import java.net.SocketException; +import java.time.Duration; +import java.util.concurrent.CountDownLatch; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.ArgumentMatchers.anyList; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doAnswer; + +/** + * Test that connection adaptive timeout works as expected. Usses a mock TCP server to send + * Maintenance push messages to the client in controllable manner. + */ +@Tag("upgrade") +public class ConnectionAdaptiveTimeoutTest { + + private final int originalTimeoutMs = 2000; + private final Duration relaxedTimeout = Duration.ofSeconds(10); + private final Duration relaxedBlockingTimeout = Duration.ofSeconds(15); + private final CommandObjects commandObjects = new CommandObjects(); + private final CommandHandler mockHandler = Mockito.mock(CommandHandler.class); + private TcpMockServer mockServer; + private Connection connection; + + @BeforeEach + public void setUp() throws IOException { + // Start the mock TCP server + mockServer = new TcpMockServer(); + mockServer.setCommandHandler(mockHandler); + mockServer.start(); + + // Create client configuration with relaxed timeout and maintenance event handler + TimeoutOptions timeoutOptions = TimeoutOptions.builder() + .proactiveTimeoutsRelaxing(relaxedTimeout) + .proactiveBlockingTimeoutsRelaxing(relaxedBlockingTimeout).build(); + + MaintenanceEventHandler maintenanceEventHandler = new MaintenanceEventHandlerImpl(); + + MaintenanceEventListener testListener = new MaintenanceEventListener() { + @Override + public void onMigrating() { + System.out.println("MIGRATING"); + } + }; + maintenanceEventHandler.addListener(testListener); + + DefaultJedisClientConfig clientConfig = DefaultJedisClientConfig.builder() + .socketTimeoutMillis(originalTimeoutMs).timeoutOptions(timeoutOptions) + .maintenanceEventHandler(maintenanceEventHandler).protocol(RedisProtocol.RESP3).build(); + + // Create connection to the mock server + HostAndPort hostAndPort = new HostAndPort("localhost", mockServer.getPort()); + connection = new Connection(hostAndPort, clientConfig); + } + + @AfterEach + public void tearDown() throws IOException { + if (connection != null && connection.isConnected()) { + connection.close(); + } + if (mockServer != null) { + mockServer.stop(); + } + } + + @Test + public void testMigratingPushMessage() throws SocketException { + Socket socket = ReflectionTestUtils.getField(connection, "socket"); + + assertTrue(connection.isConnected()); + assertEquals(originalTimeoutMs, connection.getSoTimeout()); + assertEquals(originalTimeoutMs, socket.getSoTimeout()); + + // First send MIGRATING to activate relaxed timeout + mockServer.sendMigratingPushToAll(); + assertTrue(connection.ping()); + assertTrue(connection.isRelaxedTimeoutActive()); + assertEquals(relaxedTimeout.toMillis(), socket.getSoTimeout()); + + mockServer.sendMigratedPushToAll(); + assertTrue(connection.ping()); + assertFalse(connection.isRelaxedTimeoutActive()); + assertEquals(originalTimeoutMs, socket.getSoTimeout()); + } + + @Test + public void testFailoverPushMessage() throws SocketException { + Socket socket = ReflectionTestUtils.getField(connection, "socket"); + + assertTrue(connection.isConnected()); + assertEquals(originalTimeoutMs, connection.getSoTimeout()); + assertEquals(originalTimeoutMs, socket.getSoTimeout()); + + // First send MIGRATING to activate relaxed timeout + mockServer.sendFailingOverPushToAll(); + assertTrue(connection.ping()); + assertTrue(connection.isRelaxedTimeoutActive()); + assertEquals(relaxedTimeout.toMillis(), socket.getSoTimeout()); + + mockServer.sendFailedOverPushToAll(); + assertTrue(connection.ping()); + assertFalse(connection.isRelaxedTimeoutActive()); + assertEquals(originalTimeoutMs, socket.getSoTimeout()); + } + + @Test + public void testDisabledTimeoutRelaxationDoesNotApplyRelaxedTimeout() throws Exception { + // Create a connection with disabled timeout relaxation + Connection disabledConnection = createConnectionWithDisabledTimeoutRelaxation(); + Socket disabledSocket = ReflectionTestUtils.getField(disabledConnection, "socket"); + + try { + assertTrue(disabledConnection.isConnected()); + assertEquals(originalTimeoutMs, disabledConnection.getSoTimeout()); + assertEquals(originalTimeoutMs, disabledSocket.getSoTimeout()); + + // Verify that relaxed timeout is disabled + assertFalse(disabledConnection.isRelaxedTimeoutActive()); + + // Send MIGRATING push message - should NOT activate relaxed timeout + mockServer.sendMigratingPushToAll(); + + assertTrue(disabledConnection.ping()); + + // Verify that relaxed timeout was NOT activated + assertFalse(disabledConnection.isRelaxedTimeoutActive()); + assertEquals(originalTimeoutMs, disabledSocket.getSoTimeout()); + + // Send FAILING_OVER push message - should also NOT activate relaxed timeout + mockServer.sendFailingOverPushToAll(); + + assertTrue(disabledConnection.ping()); + + // Verify that relaxed timeout is still NOT activated + assertFalse(disabledConnection.isRelaxedTimeoutActive()); + assertEquals(originalTimeoutMs, disabledSocket.getSoTimeout()); + + // Send MIGRATED and FAILED_OVER messages - timeout should remain unchanged + mockServer.sendMigratedPushToAll(); + mockServer.sendFailedOverPushToAll(); + + assertTrue(disabledConnection.ping()); + assertFalse(disabledConnection.isRelaxedTimeoutActive()); + assertEquals(originalTimeoutMs, disabledSocket.getSoTimeout()); + + } finally { + if (disabledConnection.isConnected()) { + disabledConnection.close(); + } + } + } + + @Test + public void testManualRelaxTimeoutsCallWithDisabledTimeoutRelaxation() throws Exception { + // Create a connection with disabled timeout relaxation + Connection disabledConnection = createConnectionWithDisabledTimeoutRelaxation(); + Socket disabledSocket = ReflectionTestUtils.getField(disabledConnection, "socket"); + + try { + assertTrue(disabledConnection.isConnected()); + assertEquals(originalTimeoutMs, disabledSocket.getSoTimeout()); + + // Manually call relaxTimeouts() - should have no effect when disabled + disabledConnection.relaxTimeouts(); + + // Relaxed timeout should fallback to original timeout, if relaxed timeout is disabled + assertFalse(disabledConnection.isRelaxedTimeoutActive()); + assertEquals(originalTimeoutMs, disabledSocket.getSoTimeout()); + + // Verify connection still works + assertTrue(disabledConnection.ping()); + + } finally { + if (disabledConnection.isConnected()) { + disabledConnection.close(); + } + } + } + + @Test + public void testDefaultTimeoutOptionsDisablesRelaxedTimeout() throws Exception { + // Create a connection with null timeout options + Connection defaultTimeoutConnection = createConnectionWithDefaultTimeoutOptions(); + Socket nullTimeoutSocket = ReflectionTestUtils.getField(defaultTimeoutConnection, "socket"); + + try { + assertTrue(defaultTimeoutConnection.isConnected()); + assertEquals(originalTimeoutMs, nullTimeoutSocket.getSoTimeout()); + + // Verify that relaxed timeout is disabled + assertFalse(defaultTimeoutConnection.isRelaxedTimeoutActive()); + + // Send maintenance push messages - should NOT activate relaxed timeout + mockServer.sendMigratingPushToAll(); + + assertTrue(defaultTimeoutConnection.ping()); + + // Relaxed timeout's are disabled by default + assertFalse(defaultTimeoutConnection.isRelaxedTimeoutActive()); + assertEquals(originalTimeoutMs, nullTimeoutSocket.getSoTimeout()); + + // Manual call should also have no effect + defaultTimeoutConnection.relaxTimeouts(); + assertFalse(defaultTimeoutConnection.isRelaxedTimeoutActive()); + assertEquals(originalTimeoutMs, nullTimeoutSocket.getSoTimeout()); + + } finally { + if (defaultTimeoutConnection.isConnected()) { + defaultTimeoutConnection.close(); + } + } + } + + @Test + public void testRelaxedBlockingTimeoutAppliedDuringBlockingCommand() + throws IOException, InterruptedException { + + // Verify initial timeout + Socket socket = ReflectionTestUtils.getField(connection, "socket"); + assertEquals(originalTimeoutMs, socket.getSoTimeout()); + + CountDownLatch blpopLatch = new CountDownLatch(1); + CountDownLatch blpopLatchAfter = new CountDownLatch(1); + doAnswer(invocation -> { + blpopLatch.countDown(); + return RespResponse.arrayOfBulkStrings("popped-item"); + }).when(mockHandler).handleCommand(eq("BLPOP"), anyList(), anyString()); + + // Send MIGRATING push notification which should trigger relaxTimeouts() + mockServer.sendMigratingPushToAll(); + + Thread t1 = new Thread(() -> { + connection.executeCommand(commandObjects.blpop(5, "test:blpop:key")); + blpopLatchAfter.countDown(); + }); + t1.start(); + + // Verify that relaxed blocking timeout was applied + blpopLatch.await(); + assertTrue(connection.isRelaxedTimeoutActive(), + "Relaxed timeout should be active during blocking command"); + assertEquals((int) relaxedBlockingTimeout.toMillis(), socket.getSoTimeout(), + "Socket timeout should be relaxed blocking timeout during blocking command"); + + blpopLatchAfter.await(); + assertTrue(connection.isRelaxedTimeoutActive(), + "Relaxed timeout should be still active after blocking command"); + assertEquals(relaxedTimeout.toMillis(), socket.getSoTimeout(), + "Socket timeout should be restored to relaxed timeout for non blocking command"); + + // Send MIGRATED push notification to disable relaxed timeout + mockServer.sendMigratedPushToAll(); + connection.executeCommand(commandObjects.ping()); + + assertFalse(connection.isRelaxedTimeoutActive(), + "Relaxed timeout should be disabled after MIGRATED"); + assertEquals(originalTimeoutMs, socket.getSoTimeout(), + "Socket timeout should be restored to original timeout"); + } + + /** + * Helper method to create a connection with disabled timeout relaxation. + */ + private Connection createConnectionWithDisabledTimeoutRelaxation() { + // Create configuration with disabled timeout relaxation + TimeoutOptions disabledTimeoutOptions = TimeoutOptions.builder() + .proactiveTimeoutsRelaxing(TimeoutOptions.DISABLED_TIMEOUT) + .proactiveBlockingTimeoutsRelaxing(TimeoutOptions.DISABLED_TIMEOUT).build(); + + MaintenanceEventHandler maintenanceEventHandler = new MaintenanceEventHandlerImpl(); + + DefaultJedisClientConfig clientConfig = DefaultJedisClientConfig.builder() + .socketTimeoutMillis(originalTimeoutMs).timeoutOptions(disabledTimeoutOptions) + .maintenanceEventHandler(maintenanceEventHandler).protocol(RedisProtocol.RESP3).build(); + + // Create connection to the mock server + HostAndPort hostAndPort = new HostAndPort("localhost", mockServer.getPort()); + Connection disabledConnection = new Connection(hostAndPort, clientConfig); + disabledConnection.connect(); + + return disabledConnection; + } + + /** + * Helper method to create a connection with null timeout options. + */ + private Connection createConnectionWithDefaultTimeoutOptions() { + MaintenanceEventHandler maintenanceEventHandler = new MaintenanceEventHandlerImpl(); + + DefaultJedisClientConfig clientConfig = DefaultJedisClientConfig.builder() + .socketTimeoutMillis(originalTimeoutMs) + // Note: not setting timeoutOptions, so it will be null + .maintenanceEventHandler(maintenanceEventHandler).protocol(RedisProtocol.RESP3).build(); + + // Create connection to the mock server + HostAndPort hostAndPort = new HostAndPort("localhost", mockServer.getPort()); + Connection nullTimeoutConnection = new Connection(hostAndPort, clientConfig); + nullTimeoutConnection.connect(); + + return nullTimeoutConnection; + } + +} diff --git a/src/test/java/redis/clients/jedis/upgrade/UnifiedJedisProactiveRebindTest.java b/src/test/java/redis/clients/jedis/upgrade/UnifiedJedisProactiveRebindTest.java new file mode 100644 index 0000000000..3c0ec7277b --- /dev/null +++ b/src/test/java/redis/clients/jedis/upgrade/UnifiedJedisProactiveRebindTest.java @@ -0,0 +1,268 @@ +package redis.clients.jedis.upgrade; + +import static org.awaitility.Awaitility.await; +import static org.junit.jupiter.api.Assertions.*; + +import java.io.IOException; +import java.time.Duration; + +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Tag; +import org.junit.jupiter.api.Test; + +import redis.clients.jedis.Connection; +import redis.clients.jedis.ConnectionPoolConfig; +import redis.clients.jedis.ConnectionTestHelper; +import redis.clients.jedis.DefaultJedisClientConfig; +import redis.clients.jedis.HostAndPort; +import redis.clients.jedis.JedisPooled; +import redis.clients.jedis.RedisProtocol; +import redis.clients.jedis.util.Pool; +import redis.clients.jedis.util.server.TcpMockServer; + +/** + * Test that UnifiedJedis proactively rebinds to new target when receiving MOVING notifications. + * Uses mock TCP servers to simulate Redis cluster slot migration scenarios. + */ +@Tag("upgrade") +public class UnifiedJedisProactiveRebindTest { + + private TcpMockServer mockServer1; + private TcpMockServer mockServer2; + + private final int socketTimeoutMs = 5000; + + DefaultJedisClientConfig clientConfig = DefaultJedisClientConfig.builder() + .socketTimeoutMillis(socketTimeoutMs).protocol(RedisProtocol.RESP3) + .proactiveRebindEnabled(true) // Enable proactive rebinding + .build(); + + HostAndPort server1Address; + HostAndPort server2Address; + + ConnectionPoolConfig connectionPoolConfig; + + @BeforeEach + public void setUp() throws IOException { + // Start tcpmockedserver1 + mockServer1 = new TcpMockServer(); + mockServer1.start(); + + // Start tcpmockedserver2 + mockServer2 = new TcpMockServer(); + mockServer2.start(); + + server1Address = new HostAndPort("localhost", mockServer1.getPort()); + server2Address = new HostAndPort("localhost", mockServer2.getPort()); + + connectionPoolConfig = new ConnectionPoolConfig(); + + System.out.println("MockServer1 started on port: " + mockServer1.getPort()); + System.out.println("MockServer2 started on port: " + mockServer2.getPort()); + } + + @AfterEach + public void tearDown() throws IOException { + + if (mockServer1 != null) { + mockServer1.stop(); + } + if (mockServer2 != null) { + mockServer2.stop(); + } + } + + @Test + public void testProactiveRebind() throws Exception { + // 1. Create UnifiedJedis client and connect it to mockedserver1 + try (JedisPooled unifiedJedis = new JedisPooled(connectionPoolConfig, server1Address, + clientConfig)) { + + // 1. Perform a PING command to initiate a connection + String response1 = unifiedJedis.ping(); + assertEquals("PONG", response1); + + // Verify initial connection to server1 + assertEquals(1, mockServer1.getConnectedClientCount()); + assertEquals(0, mockServer2.getConnectedClientCount()); + + // 2. Send MOVING notification on server1 -> MOVING 30 localhost:port2 + mockServer1.sendPushMessageToAll("MOVING", "30", server2Address.toString()); + + // 3. Perform PING command + // This should trigger read of the MOVING notification and rebind to server2 + // the ping command itself should be executed against server1 + // the used connection should be closed after the ping command is executed + String response2 = unifiedJedis.ping(); + assertEquals("PONG", response2); + + // drop connection to server1 + mockServer1.stop(); + + // Verify initial connection to server1 + assertEquals(0, mockServer1.getConnectedClientCount()); + assertEquals(0, mockServer2.getConnectedClientCount()); + + // 4. Perform PING command + // Folowup ping command should be executed against server2 + + String response3 = unifiedJedis.ping(); + assertEquals("PONG", response3); + + // Verify that connection has moved to server2 + assertEquals(0, mockServer1.getConnectedClientCount()); + assertEquals(1, mockServer2.getConnectedClientCount()); + } + } + + @Test + public void testActiveConnectionShouldBeDisposedOnRebind() { + // 1. Create UnifiedJedis client and connect it to mockedserver1 + try (JedisPooled unifiedJedis = new JedisPooled(connectionPoolConfig, server1Address, + clientConfig)) { + Pool pool = unifiedJedis.getPool(); + + // 1. Test setup - 1 active connection, 0 idle connection + Connection activeConnection = unifiedJedis.getPool().getResource(); + assertEquals(1, pool.getNumActive()); + assertEquals(0, pool.getDestroyedCount()); + assertEquals(0, pool.getNumIdle()); + assertEquals(1, mockServer1.getConnectedClientCount()); + assertEquals(0, mockServer2.getConnectedClientCount()); + + // 2. Send MOVING notification on server1 -> MOVING 30 localhost:port2 + mockServer1.sendPushMessageToAll("MOVING", "30", server2Address.toString()); + + // 3. Active connection should be still usable until closed and returned to the pools + assertTrue(activeConnection.ping()); + + // 4. When closed connection should be disposed and not returned to the pool + activeConnection.close(); + assertEquals(1, pool.getDestroyedCount()); + assertEquals(0, pool.getNumActive()); + + // 5. Wait for connection to be closed on server1 + await().pollDelay(Duration.ofMillis(1)).timeout(Duration.ofMillis(10)) + .until(() -> mockServer1.getConnectedClientCount() == 0); + assertEquals(0, mockServer1.getConnectedClientCount()); + assertEquals(0, mockServer2.getConnectedClientCount()); + + // 6. Next command should create a new connection to server2 + String response2 = unifiedJedis.ping(); + assertEquals("PONG", response2); + assertEquals(0, mockServer1.getConnectedClientCount()); + assertEquals(1, mockServer2.getConnectedClientCount()); + } + } + + @Test + public void testIdleConnectionShouldBeDisposedOnRebind() { + + try (JedisPooled unifiedJedis = new JedisPooled(connectionPoolConfig, server1Address, + clientConfig)) { + Pool pool = unifiedJedis.getPool(); + + // 1. Test setup - 1 active connection, 1 idle connection + Connection activeConnection = unifiedJedis.getPool().getResource(); + Connection idleConnection = unifiedJedis.getPool().getResource(); + idleConnection.close(); + + assertEquals(1, pool.getNumActive()); + assertEquals(1, pool.getNumIdle()); + assertEquals(0, pool.getDestroyedCount()); + assertEquals(2, mockServer1.getConnectedClientCount()); + assertEquals(0, mockServer2.getConnectedClientCount()); + + // 2. Send MOVING notification on server1 -> MOVING 30 localhost:port2 + String server2Address = "localhost:" + mockServer2.getPort(); + mockServer1.sendPushMessageToAll("MOVING", "30", server2Address); + + // 3. perform a command on active connection to trigger rebind + assertTrue(activeConnection.ping()); + + // 4. All IDLE connection's should be closed & disposed + assertEquals(0, pool.getNumIdle()); + assertEquals(1, pool.getNumActive()); + + // 5. Wait for connection to be closed on server1 + await().pollDelay(Duration.ofMillis(1)).timeout(Duration.ofMillis(10)) + .until(() -> mockServer1.getConnectedClientCount() == 1); + assertEquals(1, mockServer1.getConnectedClientCount()); + assertEquals(0, mockServer2.getConnectedClientCount()); + + // 6. Next command should create a new connection to server2 + String response2 = unifiedJedis.ping(); + assertEquals("PONG", response2); + assertEquals(1, mockServer1.getConnectedClientCount()); + assertEquals(1, mockServer2.getConnectedClientCount()); + } + } + + @Test + public void testNewPoolConnectionsCreatedAgainstMovingTarget() { + // Create UnifiedJedis with connection pooling enabled + try (JedisPooled unifiedJedis = new JedisPooled(connectionPoolConfig, server1Address, + clientConfig)) { + + // 1. Test setup - 1 active connection + Connection activeConnection = unifiedJedis.getPool().getResource(); + + // Verify initial connection to server1 + assertEquals(1, mockServer1.getConnectedClientCount()); + assertEquals(0, mockServer2.getConnectedClientCount()); + + // 2. Send MOVING notification on server1 -> MOVING 30 localhost:port2 + mockServer1.sendPushMessageToAll("MOVING", "30", server2Address.toString()); + + // 3. perform a command on active connection to trigger rebind + assertTrue(activeConnection.ping()); + + // 4. Initiate a new connection from the pool + Connection newConnection = unifiedJedis.getPool().getResource(); + assertTrue(newConnection.ping()); + + // Verify that new connections are being created against server2 + assertEquals(server2Address, ConnectionTestHelper.getHostAndPort(newConnection)); + assertEquals(1, mockServer2.getConnectedClientCount()); + } + } + + @Test + public void testPoolConnectionsWithProactiveRebindDisabled() { + // Verify that with proactive rebind disabled, connections stay on original server + DefaultJedisClientConfig clientConfig = DefaultJedisClientConfig.builder() + .from(this.clientConfig).proactiveRebindEnabled(false).build(); + try (JedisPooled unifiedJedis = new JedisPooled(connectionPoolConfig, server1Address, + clientConfig)) { + Pool pool = unifiedJedis.getPool(); + + // 1. Test setup - 1 active connection, 1 idle connection + Connection activeConnection = unifiedJedis.getPool().getResource(); + Connection idleConnection = unifiedJedis.getPool().getResource(); + idleConnection.close(); + + // Verify initial connection to server1 + assertEquals(1, pool.getNumActive()); + assertEquals(1, pool.getNumIdle()); + assertEquals(0, pool.getDestroyedCount()); + assertEquals(2, mockServer1.getConnectedClientCount()); + assertEquals(0, mockServer2.getConnectedClientCount()); + + // 2. Send MOVING notification on server1 -> MOVING 30 localhost:port2 + mockServer1.sendPushMessageToAll("MOVING", "30", server2Address.toString()); + + // 3. Perform PING command + // This should trigger read of the MOVING notification processing + assertTrue(activeConnection.ping()); + + // Verify initial connection to server1 + assertEquals(1, pool.getNumActive()); + assertEquals(1, pool.getNumIdle()); + assertEquals(0, pool.getDestroyedCount()); + assertEquals(2, mockServer1.getConnectedClientCount()); + assertEquals(0, mockServer2.getConnectedClientCount()); + } + } + +} diff --git a/src/test/java/redis/clients/jedis/util/ReflectionTestUtils.java b/src/test/java/redis/clients/jedis/util/ReflectionTestUtils.java new file mode 100644 index 0000000000..c4a1790543 --- /dev/null +++ b/src/test/java/redis/clients/jedis/util/ReflectionTestUtils.java @@ -0,0 +1,60 @@ +package redis.clients.jedis.util; + +import java.lang.reflect.Field; +import java.lang.reflect.Modifier; + +public class ReflectionTestUtils { + + public static T getField(Object targetObject, String name) { + + Class targetClass = targetObject.getClass(); + + Field field = findField(targetClass, name); + + makeAccessible(field); + + try { + return (T) field.get(targetObject); + } catch (IllegalAccessException ex) { + throw new IllegalStateException( + "Unexpected reflection exception - " + ex.getClass().getName() + ": " + ex.getMessage()); + } + } + + public static void setField(Object targetObject, String name, Object value) { + + Class targetClass = targetObject.getClass(); + + Field field = findField(targetClass, name); + + makeAccessible(field); + + try { + field.set(targetObject, value); + } catch (IllegalAccessException ex) { + throw new IllegalStateException( + "Unexpected reflection exception - " + ex.getClass().getName() + ": " + ex.getMessage()); + } + } + + public static Field findField(Class clazz, String name) { + Class searchType = clazz; + while (Object.class != searchType && searchType != null) { + Field[] fields = searchType.getDeclaredFields(); + for (Field field : fields) { + if (name.equals(field.getName())) { + return field; + } + } + searchType = searchType.getSuperclass(); + } + return null; + } + + private static void makeAccessible(Field field) { + if ((!Modifier.isPublic(field.getModifiers()) || !Modifier.isPublic(field.getDeclaringClass().getModifiers()) + || Modifier.isFinal(field.getModifiers())) && !field.isAccessible()) { + field.setAccessible(true); + } + } +} diff --git a/src/test/java/redis/clients/jedis/util/server/CommandHandler.java b/src/test/java/redis/clients/jedis/util/server/CommandHandler.java new file mode 100644 index 0000000000..2070a02134 --- /dev/null +++ b/src/test/java/redis/clients/jedis/util/server/CommandHandler.java @@ -0,0 +1,20 @@ +package redis.clients.jedis.util.server; + +import java.util.List; + +/** + * Interface for handling custom Redis commands in TcpMockServer. This can be easily mocked with + * Mockito for testing purposes. + */ +public interface CommandHandler { + + /** + * Handle a Redis command and return a response. + * @param command The Redis command (case-insensitive) + * @param args The command arguments (excluding the command name) + * @param clientId The client identifier + * @return A RESP response string, or null to use default handling + */ + String handleCommand(String command, List args, String clientId); + +} diff --git a/src/test/java/redis/clients/jedis/util/server/RespResponse.java b/src/test/java/redis/clients/jedis/util/server/RespResponse.java new file mode 100644 index 0000000000..c1bf5e8e65 --- /dev/null +++ b/src/test/java/redis/clients/jedis/util/server/RespResponse.java @@ -0,0 +1,159 @@ +package redis.clients.jedis.util.server; + +import java.util.List; + +/** + * Utility class for building RESP (Redis Serialization Protocol) responses. This makes it easier to + * construct proper Redis protocol responses for testing. + */ +public class RespResponse { + + /** + * Create a simple string response (+OK, +PONG, etc.). + * @param value The string value + * @return A RESP simple string response + */ + public static String simpleString(String value) { + return "+" + value + "\r\n"; + } + + /** + * Create an error response (-ERR message). + * @param message The error message + * @return A RESP error response + */ + public static String error(String message) { + return "-" + message + "\r\n"; + } + + /** + * Create an integer response (:123). + * @param value The integer value + * @return A RESP integer response + */ + public static String integer(long value) { + return ":" + value + "\r\n"; + } + + /** + * Create a bulk string response ($5\r\nhello\r\n). + * @param value The string value (null for null bulk string) + * @return A RESP bulk string response + */ + public static String bulkString(String value) { + if (value == null) { + return "$-1\r\n"; + } + byte[] bytes = value.getBytes(); + return "$" + bytes.length + "\r\n" + value + "\r\n"; + } + + /** + * Create a bulk string response from byte array. + * @param bytes The byte array (null for null bulk string) + * @return A RESP bulk string response + */ + public static String bulkString(byte[] bytes) { + if (bytes == null) { + return "$-1\r\n"; + } + return "$" + bytes.length + "\r\n" + new String(bytes) + "\r\n"; + } + + /** + * Create an array response (*2\r\n$5\r\nhello\r\n$5\r\nworld\r\n). + * @param elements The array elements as RESP strings + * @return A RESP array response + */ + public static String array(String... elements) { + if (elements == null) { + return "*-1\r\n"; + } + + StringBuilder sb = new StringBuilder(); + sb.append("*").append(elements.length).append("\r\n"); + + for (String element : elements) { + sb.append(element); + } + + return sb.toString(); + } + + /** + * Create an array response from a list of strings as bulk strings. + * @param strings The list of strings + * @return A RESP array response with bulk string elements + */ + public static String arrayOfBulkStrings(List strings) { + if (strings == null) { + return "*-1\r\n"; + } + + String[] elements = new String[strings.size()]; + for (int i = 0; i < strings.size(); i++) { + elements[i] = bulkString(strings.get(i)); + } + + return array(elements); + } + + /** + * Create an array response from string values (automatically converts to bulk strings). + * @param values The string values + * @return A RESP array response with bulk string elements + */ + public static String arrayOfBulkStrings(String... values) { + if (values == null) { + return "*-1\r\n"; + } + + String[] elements = new String[values.length]; + for (int i = 0; i < values.length; i++) { + elements[i] = bulkString(values[i]); + } + + return array(elements); + } + + /** + * Create an empty array response (*0\r\n). + * @return A RESP empty array response + */ + public static String emptyArray() { + return "*0\r\n"; + } + + /** + * Create a null array response (*-1\r\n). + * @return A RESP null array response + */ + public static String nullArray() { + return "*-1\r\n"; + } + + /** + * Create a null bulk string response ($-1\r\n). + * @return A RESP null bulk string response + */ + public static String nullBulkString() { + return "$-1\r\n"; + } + + /** + * Create an OK response (+OK\r\n). + * @return A RESP OK response + */ + public static String ok() { + return simpleString("OK"); + } + + /** + * Create a PONG response (+PONG\r\n). + * @return A RESP PONG response + */ + public static String pong() { + return simpleString("PONG"); + } + +} diff --git a/src/test/java/redis/clients/jedis/util/server/TcpMockServer.java b/src/test/java/redis/clients/jedis/util/server/TcpMockServer.java new file mode 100644 index 0000000000..30a4224764 --- /dev/null +++ b/src/test/java/redis/clients/jedis/util/server/TcpMockServer.java @@ -0,0 +1,354 @@ +package redis.clients.jedis.util.server; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import redis.clients.jedis.Protocol; +import redis.clients.jedis.util.RedisInputStream; +import redis.clients.jedis.util.RedisOutputStream; +import redis.clients.jedis.util.SafeEncoder; + +import java.io.IOException; +import java.io.OutputStream; +import java.net.ServerSocket; +import java.net.Socket; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.atomic.AtomicBoolean; + +/** + * A simple TCP mock server for testing Redis push notifications and timeout behavior. This server + * can accept connections and send predefined responses including push messages. + */ +public class TcpMockServer { + private final AtomicBoolean running = new AtomicBoolean(false); + private final ExecutorService executor = Executors.newCachedThreadPool(); + private final Map connectedClients = new ConcurrentHashMap<>(); + Logger logger = LoggerFactory.getLogger(TcpMockServer.class); + private ServerSocket serverSocket; + private int port; + private CommandHandler commandHandler; + + /** + * Start the server on an available port + */ + public void start() throws IOException { + start(0); // Use any available port + } + + /** + * Start the server on a specific port + */ + public void start(int port) throws IOException { + serverSocket = new ServerSocket(port); + this.port = serverSocket.getLocalPort(); + running.set(true); + + executor.submit(() -> { + while (running.get() && !serverSocket.isClosed()) { + try { + Socket clientSocket = serverSocket.accept(); + executor.submit(new ClientHandler(clientSocket)); + } catch (IOException e) { + if (running.get()) { + logger.error("Error accepting client connection: " + e.getMessage()); + } + } + } + }); + } + + /** + * Stop the server and close all active connections + */ + public void stop() throws IOException { + running.set(false); + + // Close all active client connections first + closeAllActiveConnections(); + + // Close the server socket + if (serverSocket != null && !serverSocket.isClosed()) { + serverSocket.close(); + } + executor.shutdownNow(); + } + + /** + * Get the port the server is running on + */ + public int getPort() { + return port; + } + + /** + * Check if the server is running + */ + public boolean isRunning() { + return running.get() && serverSocket != null && !serverSocket.isClosed(); + } + + /** + * Get the number of connected clients + */ + public int getConnectedClientCount() { + return connectedClients.size(); + } + + /** + * Generic method to send a push message to all connected clients. + * @param pushType the type of push message (e.g., "MIGRATING", "MIGRATED") + * @param args optional arguments for the push message + */ + public void sendPushMessageToAll(String pushType, String... args) { + connectedClients.values().forEach(client -> client.sendPushMessage(pushType, args)); + } + + /** + * Send a MIGRATING push message to all connected clients + */ + public void sendMigratingPushToAll() { + sendPushMessageToAll("MIGRATING", "30"); // Default slot 30 + } + + /** + * Send a MIGRATED push message to all connected clients + */ + public void sendMigratedPushToAll() { + sendPushMessageToAll("MIGRATED"); + } + + /** + * Send a FAILING_OVER push message to all connected clients + */ + public void sendFailingOverPushToAll() { + sendPushMessageToAll("FAILING_OVER", "30"); // Default slot 30 + } + + /** + * Send a FAILED_OVER push message to all connected clients + */ + public void sendFailedOverPushToAll() { + sendPushMessageToAll("FAILED_OVER"); + } + + public void sendMovingPushToAll(String targetHost) { + sendPushMessageToAll("MOVING", "30", targetHost); + } + + /** + * Get the current command handler. + * @return The current command handler, or null if none is set + */ + public CommandHandler getCommandHandler() { + return commandHandler; + } + + /** + * Set a custom command handler for processing Redis commands. + * @param commandHandler The command handler to use, or null to use only built-in handlers + */ + public void setCommandHandler(CommandHandler commandHandler) { + this.commandHandler = commandHandler; + } + + /** + * Close all active client connections + */ + private void closeAllActiveConnections() { + // Create a copy of the values to avoid ConcurrentModificationException + java.util.List clientsToClose = new java.util.ArrayList<>( + connectedClients.values()); + + for (ClientHandler client : clientsToClose) { + try { + client.forceClose(); + } catch (Exception e) { + logger.error("Error closing client connection: " + e.getMessage()); + } + } + + // Clear the map + connectedClients.clear(); + } + + /** + * Client handler for each connection + */ + private class ClientHandler implements Runnable { + private final Socket clientSocket; + private final String clientId; + private RedisOutputStream outputStream; + private volatile boolean connected = true; + + public ClientHandler(Socket clientSocket) { + this.clientSocket = clientSocket; + this.clientId = clientSocket.getRemoteSocketAddress().toString(); + } + + @Override + public void run() { + try (RedisInputStream rin = new RedisInputStream(clientSocket.getInputStream()); + RedisOutputStream out = new RedisOutputStream(clientSocket.getOutputStream())) { + + this.outputStream = out; + connectedClients.put(clientId, this); + + Object input; + while (connected && !clientSocket.isClosed()) { + try { + input = Protocol.read(rin); + if (input == null) { + connected = false; + break; + } + + List cmdArgs = (List) input; + String cmdString = SafeEncoder.encode((byte[]) cmdArgs.get(0)); + + // Convert arguments to strings (excluding command name) + List args = new java.util.ArrayList<>(); + for (int i = 1; i < cmdArgs.size(); i++) { + args.add(SafeEncoder.encode((byte[]) cmdArgs.get(i))); + } + + // Try custom handler first + String customResponse = null; + if (commandHandler != null) { + customResponse = commandHandler.handleCommand(cmdString, args, clientId); + } + + if (customResponse != null) { + out.write(customResponse.getBytes()); + out.flush(); + } else { + // Handle with default built-in handlers + handleBuiltinCommand(cmdString, out); + } + } catch (IOException e) { + logger.debug("Client " + clientId + " disconnected: " + e.getMessage()); + connected = false; + break; + } catch (Exception e) { + logger.debug("Client " + clientId + " connection error: " + e.getMessage()); + connected = false; + break; + } + } + } catch (IOException e) { + logger.error("Error handling client: " + e.getMessage()); + } finally { + cleanup(); + } + } + + private void sendHelloResponse(OutputStream out) throws IOException { + // RESP3 HELLO response + String response = "%7\r\n" + "$6\r\nserver\r\n$5\r\nredis\r\n" + + "$7\r\nversion\r\n$5\r\n7.0.0\r\n" + "$5\r\nproto\r\n:3\r\n" + "$2\r\nid\r\n:1\r\n" + + "$4\r\nmode\r\n$10\r\nstandalone\r\n" + "$4\r\nrole\r\n$6\r\nmaster\r\n" + + "$7\r\nmodules\r\n*0\r\n"; + out.write(response.getBytes()); + out.flush(); + } + + private void sendPongResponse(OutputStream out) throws IOException { + String response = "+PONG\r\n"; + out.write(response.getBytes()); + out.flush(); + } + + private void sendOkResponse(OutputStream out) throws IOException { + String response = "+OK\r\n"; + out.write(response.getBytes()); + out.flush(); + } + + /** + * Handle a command with built-in handlers. + */ + private void handleBuiltinCommand(String cmdString, OutputStream out) throws IOException { + if (cmdString.equalsIgnoreCase("HELLO")) { + sendHelloResponse(out); + } else if (cmdString.contains("PING")) { + sendPongResponse(out); + } else if (cmdString.contains("CLIENT")) { + sendOkResponse(out); + } else { + throw new RuntimeException("Unknown command: " + cmdString); + } + } + + /** + * Clean up client resources and remove from connected clients map + */ + private void cleanup() { + connected = false; + connectedClients.remove(clientId); + outputStream = null; + + try { + if (clientSocket != null && !clientSocket.isClosed()) { + clientSocket.close(); + } + } catch (IOException e) { + logger.error("Error closing client socket during cleanup: " + e.getMessage()); + } + } + + /** + * Generic method to send a push message to this client. + * @param pushType the type of push message (e.g., "MIGRATING", "MIGRATED") + * @param args optional arguments for the push message + */ + public void sendPushMessage(String pushType, String... args) { + try { + StringBuilder pushMessage = new StringBuilder(); + + // Calculate total number of elements (push type + arguments) + int elementCount = 1 + args.length; + pushMessage.append(">").append(elementCount).append("\r\n"); + + // Add push type + pushMessage.append("$").append(pushType.length()).append("\r\n").append(pushType) + .append("\r\n"); + + // Add arguments + for (String arg : args) { + pushMessage.append("$").append(arg.length()).append("\r\n").append(arg).append("\r\n"); + } + + outputStream.write(pushMessage.toString().getBytes()); + outputStream.flush(); + + } catch (IOException e) { + logger.error("Error sending " + pushType + " push to " + clientId + + " (client disconnected): " + e.getMessage()); + cleanup(); + } + } + + /** + * Force close this client connection (used when server is shutting down) + */ + public void forceClose() { + connected = false; + + try { + if (clientSocket != null && !clientSocket.isClosed()) { + clientSocket.close(); + } + } catch (IOException e) { + logger.error("Error force closing client socket: " + e.getMessage()); + } + + // Remove from connected clients map + connectedClients.remove(clientId); + outputStream = null; + } + + } + +}