diff --git a/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/AbstractContainerOptions.java b/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/AbstractContainerOptions.java index 81f4eb3f2..7f8c583f7 100644 --- a/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/AbstractContainerOptions.java +++ b/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/AbstractContainerOptions.java @@ -53,6 +53,8 @@ public abstract class AbstractContainerOptions, private final BackPressureMode backPressureMode; + private final BackPressureHandlerFactory backPressureHandlerFactory; + private final ListenerMode listenerMode; private final MessagingMessageConverter messageConverter; @@ -84,6 +86,7 @@ protected AbstractContainerOptions(Builder builder) { this.listenerShutdownTimeout = builder.listenerShutdownTimeout; this.acknowledgementShutdownTimeout = builder.acknowledgementShutdownTimeout; this.backPressureMode = builder.backPressureMode; + this.backPressureHandlerFactory = builder.backPressureHandlerFactory; this.listenerMode = builder.listenerMode; this.messageConverter = builder.messageConverter; this.acknowledgementMode = builder.acknowledgementMode; @@ -154,6 +157,11 @@ public BackPressureMode getBackPressureMode() { return this.backPressureMode; } + @Override + public BackPressureHandlerFactory getBackPressureHandlerFactory() { + return this.backPressureHandlerFactory; + } + @Override public ListenerMode getListenerMode() { return this.listenerMode; @@ -214,6 +222,8 @@ protected abstract static class Builder, private static final BackPressureMode DEFAULT_THROUGHPUT_CONFIGURATION = BackPressureMode.AUTO; + private static final BackPressureHandlerFactory DEFAULT_BACKPRESSURE_FACTORY = BackPressureHandlerFactory::semaphoreBackPressureHandler; + private static final ListenerMode DEFAULT_MESSAGE_DELIVERY_STRATEGY = ListenerMode.SINGLE_MESSAGE; private static final MessagingMessageConverter DEFAULT_MESSAGE_CONVERTER = new SqsMessagingMessageConverter(); @@ -234,6 +244,8 @@ protected abstract static class Builder, private BackPressureMode backPressureMode = DEFAULT_THROUGHPUT_CONFIGURATION; + private BackPressureHandlerFactory backPressureHandlerFactory = DEFAULT_BACKPRESSURE_FACTORY; + private Duration listenerShutdownTimeout = DEFAULT_LISTENER_SHUTDOWN_TIMEOUT; private Duration acknowledgementShutdownTimeout = DEFAULT_ACKNOWLEDGEMENT_SHUTDOWN_TIMEOUT; @@ -272,6 +284,7 @@ protected Builder(AbstractContainerOptions options) { this.listenerShutdownTimeout = options.listenerShutdownTimeout; this.acknowledgementShutdownTimeout = options.acknowledgementShutdownTimeout; this.backPressureMode = options.backPressureMode; + this.backPressureHandlerFactory = options.backPressureHandlerFactory; this.listenerMode = options.listenerMode; this.messageConverter = options.messageConverter; this.acknowledgementMode = options.acknowledgementMode; @@ -364,6 +377,12 @@ public B backPressureMode(BackPressureMode backPressureMode) { return self(); } + @Override + public B backPressureHandlerFactory(BackPressureHandlerFactory backPressureHandlerFactory) { + this.backPressureHandlerFactory = backPressureHandlerFactory; + return self(); + } + @Override public B acknowledgementInterval(Duration acknowledgementInterval) { Assert.notNull(acknowledgementInterval, "acknowledgementInterval cannot be null"); diff --git a/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/AbstractPipelineMessageListenerContainer.java b/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/AbstractPipelineMessageListenerContainer.java index 6808f647a..96a3292a2 100644 --- a/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/AbstractPipelineMessageListenerContainer.java +++ b/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/AbstractPipelineMessageListenerContainer.java @@ -225,10 +225,9 @@ private TaskExecutor validateCustomExecutor(TaskExecutor taskExecutor) { } protected BackPressureHandler createBackPressureHandler() { - return SemaphoreBackPressureHandler.builder().batchSize(getContainerOptions().getMaxMessagesPerPoll()) - .totalPermits(getContainerOptions().getMaxConcurrentMessages()) - .acquireTimeout(getContainerOptions().getMaxDelayBetweenPolls()) - .throughputConfiguration(getContainerOptions().getBackPressureMode()).build(); + O containerOptions = getContainerOptions(); + BackPressureHandlerFactory factory = containerOptions.getBackPressureHandlerFactory(); + return factory.createBackPressureHandler(containerOptions); } protected TaskExecutor createSourcesTaskExecutor() { diff --git a/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/BackPressureHandler.java b/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/BackPressureHandler.java index 1d76d6589..55e5a25f0 100644 --- a/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/BackPressureHandler.java +++ b/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/BackPressureHandler.java @@ -29,7 +29,7 @@ public interface BackPressureHandler { /** - * Request a number of permits. Each obtained permit allows the + * Requests a number of permits. Each obtained permit allows the * {@link io.awspring.cloud.sqs.listener.source.MessageSource} to retrieve one message. * @param amount the amount of permits to request. * @return the amount of permits obtained. @@ -37,12 +37,40 @@ public interface BackPressureHandler { */ int request(int amount) throws InterruptedException; + /** + * Releases the specified amount of permits for processed messages. Each message that has been processed should + * release one permit, whether processing was successful or not. + *

+ * This method can is called in the following use cases: + *

    + *
  • {@link ReleaseReason#LIMITED}: permits were not used because another BackPressureHandler has a lower permits + * limit and the difference in permits needs to be returned.
  • + *
  • {@link ReleaseReason#NONE_FETCHED}: none of the permits were actually used because no messages were retrieved + * from SQS. Permits need to be returned.
  • + *
  • {@link ReleaseReason#PARTIAL_FETCH}: some of the permits were used (some messages were retrieved from SQS). + * The unused ones need to be returned. The amount to be returned might be {@literal 0}, in which case it means all + * the permits will be used as the same number of messages were fetched from SQS.
  • + *
  • {@link ReleaseReason#PROCESSED}: a message processing finished, successfully or not.
  • + *
+ * @param amount the amount of permits to release. + * @param reason the reason why the permits were released. + */ + default void release(int amount, ReleaseReason reason) { + release(amount); + } + /** * Release the specified amount of permits. Each message that has been processed should release one permit, whether * processing was successful or not. * @param amount the amount of permits to release. + * + * @deprecated This method is deprecated and will not be called by the Spring Cloud AWS SQS listener anymore. + * Implement {@link #release(int, ReleaseReason)} instead. */ - void release(int amount); + @Deprecated + default void release(int amount) { + release(amount, ReleaseReason.PROCESSED); + } /** * Attempts to acquire all permits up to the specified timeout. If successful, means all permits were returned and @@ -52,4 +80,24 @@ public interface BackPressureHandler { */ boolean drain(Duration timeout); + enum ReleaseReason { + /** + * Permits were not used because another BackPressureHandler has a lower permits limit and the difference need + * to be aligned across all handlers. + */ + LIMITED, + /** + * No messages were retrieved from SQS, so all permits need to be returned. + */ + NONE_FETCHED, + /** + * Some messages were fetched from SQS. Unused permits need to be returned. + */ + PARTIAL_FETCH, + /** + * The processing of one or more messages finished, successfully or not. + */ + PROCESSED; + } + } diff --git a/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/BackPressureHandlerFactory.java b/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/BackPressureHandlerFactory.java new file mode 100644 index 000000000..a72bd12f1 --- /dev/null +++ b/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/BackPressureHandlerFactory.java @@ -0,0 +1,184 @@ +/* + * Copyright 2013-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.awspring.cloud.sqs.listener; + +import java.time.Duration; +import java.util.ArrayList; +import java.util.List; + +/** + * A factory for creating {@link BackPressureHandler} for managing queue consumption backpressure. Implementations can + * configure each the {@link BackPressureHandler} according to its strategies, using the provided + * {@link ContainerOptions}. + *

+ * Spring Cloud AWS provides the following {@link BackPressureHandler} implementations: + *

    + *
  • {@link ConcurrencyLimiterBlockingBackPressureHandler}: Limits the maximum number of messages that can be + * processed concurrently by the application.
  • + *
  • {@link ThroughputBackPressureHandler}: Adapts the throughput dynamically between high and low modes in order to + * reduce SQS pull costs when few messages are coming in.
  • + *
  • {@link CompositeBackPressureHandler}: Allows combining multiple {@link BackPressureHandler} together and ensures + * they cooperate.
  • + *
+ *

+ * Below are a few examples of how common use cases can be achieved. Keep in mind you can always create your own + * {@link BackPressureHandler} implementation and if needed combine it with the provided ones thanks to the + * {@link CompositeBackPressureHandler}. + * + *

A BackPressureHandler limiting the max concurrency with high throughput

+ * + *
{@code
+ * containerOptionsBuilder.backPressureHandlerFactory(containerOptions -> {
+ * 		return ConcurrencyLimiterBlockingBackPressureHandler.builder()
+ * 			.batchSize(containerOptions.getMaxMessagesPerPoll())
+ * 			.totalPermits(containerOptions.getMaxConcurrentMessages())
+ * 			.acquireTimeout(containerOptions.getMaxDelayBetweenPolls())
+ * 			.throughputConfiguration(BackPressureMode.FIXED_HIGH_THROUGHPUT)
+ * 			.build()
+ * }}
+ * + *

A BackPressureHandler limiting the max concurrency with dynamic throughput

+ * + *
{@code
+ * containerOptionsBuilder.backPressureHandlerFactory(containerOptions -> {
+ * 		int batchSize = containerOptions.getMaxMessagesPerPoll();
+ * 		var concurrencyLimiterBlockingBackPressureHandler = ConcurrencyLimiterBlockingBackPressureHandler.builder()
+ * 			.batchSize(batchSize)
+ * 			.totalPermits(containerOptions.getMaxConcurrentMessages())
+ * 			.acquireTimeout(containerOptions.getMaxDelayBetweenPolls())
+ * 			.throughputConfiguration(BackPressureMode.AUTO)
+ * 			.build()
+ * 		var throughputBackPressureHandler = ThroughputBackPressureHandler.builder()
+ * 			.batchSize(batchSize)
+ * 			.build();
+ * 		return new CompositeBackPressureHandler(List.of(
+ * 				concurrencyLimiterBlockingBackPressureHandler,
+ * 				throughputBackPressureHandler
+ * 			),
+ * 			batchSize,
+ * 			standbyLimitPollingInterval
+ * 		);
+ * }}
+ */ +public interface BackPressureHandlerFactory { + + /** + * Creates a new {@link BackPressureHandler} instance based on the provided {@link ContainerOptions}. + *

+ * NOTE: it is important for the factory to always return a new instance as otherwise it might + * result in a BackPressureHandler internal resources (counters, semaphores, ...) to be shared by multiple + * containers which is very likely not the desired behavior. + * + * @param containerOptions the container options to use for creating the BackPressureHandler. + * @return the created BackPressureHandler + */ + BackPressureHandler createBackPressureHandler(ContainerOptions containerOptions); + + /** + * Creates a new {@link SemaphoreBackPressureHandler} instance based on the provided {@link ContainerOptions}. + * + * @param options the container options. + * @return the created SemaphoreBackPressureHandler. + */ + static BatchAwareBackPressureHandler semaphoreBackPressureHandler(ContainerOptions options) { + return SemaphoreBackPressureHandler.builder().batchSize(options.getMaxMessagesPerPoll()) + .totalPermits(options.getMaxConcurrentMessages()).acquireTimeout(options.getMaxDelayBetweenPolls()) + .throughputConfiguration(options.getBackPressureMode()).build(); + } + + /** + * Creates a new {@link BackPressureHandler} instance based on the provided {@link ContainerOptions} combining a + * {@link ConcurrencyLimiterBlockingBackPressureHandler}, a {@link ThroughputBackPressureHandler} and a + * {@link FullBatchBackPressureHandler}. The exact combination of depends on the given {@link ContainerOptions}. + * + * @param options the container options. + * @param maxIdleWaitTime the maximum amount of time to wait for a permit to be released in case no permits were + * obtained. + * @return the created SemaphoreBackPressureHandler. + */ + static BatchAwareBackPressureHandler adaptativeThroughputBackPressureHandler(ContainerOptions options, + Duration maxIdleWaitTime) { + BackPressureMode backPressureMode = options.getBackPressureMode(); + + var concurrencyLimiterBlockingBackPressureHandler = concurrencyLimiterBackPressureHandler(options); + if (backPressureMode == BackPressureMode.FIXED_HIGH_THROUGHPUT) { + return concurrencyLimiterBlockingBackPressureHandler; + } + var backPressureHandlers = new ArrayList(); + backPressureHandlers.add(concurrencyLimiterBlockingBackPressureHandler); + + // The ThroughputBackPressureHandler should run second in the chain as it is non-blocking. + // Running it first would result in more polls as it would potentially limit the + // ConcurrencyLimiterBlockingBackPressureHandler to a lower amount of requested permits + // which means the ConcurrencyLimiterBlockingBackPressureHandler blocking behavior would + // not be optimally leveraged. + if (backPressureMode == BackPressureMode.AUTO + || backPressureMode == BackPressureMode.ALWAYS_POLL_MAX_MESSAGES) { + backPressureHandlers.add(throughputBackPressureHandler(options)); + } + + // The FullBatchBackPressureHandler should run last in the chain to ensure that a full batch is requested or not + if (backPressureMode == BackPressureMode.ALWAYS_POLL_MAX_MESSAGES) { + backPressureHandlers.add(fullBatchBackPressureHandler(options)); + } + return compositeBackPressureHandler(options, maxIdleWaitTime, backPressureHandlers); + } + + /** + * Creates a new {@link ConcurrencyLimiterBlockingBackPressureHandler} instance based on the provided + * {@link ContainerOptions}. + * + * @param options the container options. + * @return the created ConcurrencyLimiterBlockingBackPressureHandler. + */ + static CompositeBackPressureHandler compositeBackPressureHandler(ContainerOptions options, + Duration maxIdleWaitTime, List backPressureHandlers) { + return new CompositeBackPressureHandler(List.copyOf(backPressureHandlers), options.getMaxMessagesPerPoll(), + maxIdleWaitTime); + } + + /** + * Creates a new {@link ConcurrencyLimiterBlockingBackPressureHandler} instance based on the provided + * {@link ContainerOptions}. + * @param options the container options. + * @return the created ConcurrencyLimiterBlockingBackPressureHandler. + */ + static ConcurrencyLimiterBlockingBackPressureHandler concurrencyLimiterBackPressureHandler( + ContainerOptions options) { + return ConcurrencyLimiterBlockingBackPressureHandler.builder().batchSize(options.getMaxMessagesPerPoll()) + .totalPermits(options.getMaxConcurrentMessages()).throughputConfiguration(options.getBackPressureMode()) + .acquireTimeout(options.getMaxDelayBetweenPolls()).build(); + } + + /** + * Creates a new {@link ThroughputBackPressureHandler} instance based on the provided {@link ContainerOptions}. + * @param options the container options. + * @return the created ThroughputBackPressureHandler. + */ + static ThroughputBackPressureHandler throughputBackPressureHandler(ContainerOptions options) { + return ThroughputBackPressureHandler.builder().batchSize(options.getMaxMessagesPerPoll()) + .totalPermits(options.getMaxConcurrentMessages()).build(); + } + + /** + * Creates a new {@link FullBatchBackPressureHandler} instance based on the provided {@link ContainerOptions}. + * @param options the container options. + * @return the created FullBatchBackPressureHandler. + */ + static FullBatchBackPressureHandler fullBatchBackPressureHandler(ContainerOptions options) { + return FullBatchBackPressureHandler.builder().batchSize(options.getMaxMessagesPerPoll()).build(); + } +} diff --git a/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/BatchAwareBackPressureHandler.java b/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/BatchAwareBackPressureHandler.java index 51e12e0a0..661b7731b 100644 --- a/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/BatchAwareBackPressureHandler.java +++ b/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/BatchAwareBackPressureHandler.java @@ -35,13 +35,34 @@ public interface BatchAwareBackPressureHandler extends BackPressureHandler { * Release a batch of permits. This has the semantics of letting the {@link BackPressureHandler} know that all * permits from a batch are being released, in opposition to {@link #release(int)} in which any number of permits * can be specified. + * + * @deprecated This method is deprecated and will not be called by the Spring Cloud AWS SQS listener anymore. + * Implement {@link BackPressureHandler#release(int, ReleaseReason)} instead. */ - void releaseBatch(); + @Deprecated + default void releaseBatch() { + release(getBatchSize(), ReleaseReason.NONE_FETCHED); + } + + @Override + default void release(int amount, ReleaseReason reason) { + if (amount == getBatchSize() && reason == ReleaseReason.NONE_FETCHED) { + releaseBatch(); + } + else { + release(amount); + } + } /** * Return the configured batch size for this handler. * @return the batch size. + * + * @deprecated This method is deprecated and will not be used by the Spring Cloud AWS SQS listener anymore. */ - int getBatchSize(); + @Deprecated + default int getBatchSize() { + return 0; + } } diff --git a/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/CompositeBackPressureHandler.java b/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/CompositeBackPressureHandler.java new file mode 100644 index 000000000..a53722f17 --- /dev/null +++ b/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/CompositeBackPressureHandler.java @@ -0,0 +1,171 @@ +/* + * Copyright 2013-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.awspring.cloud.sqs.listener; + +import java.time.Duration; +import java.time.Instant; +import java.util.List; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.locks.Condition; +import java.util.concurrent.locks.ReentrantLock; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Composite {@link BackPressureHandler} implementation that delegates the back-pressure handling to a list of + * {@link BackPressureHandler}s. + *

+ * This class is used to combine multiple back-pressure handlers into a single one. It allows for more complex + * back-pressure handling strategies by combining different implementations. + *

+ * The order in which the back-pressure handlers are registered in the {@link CompositeBackPressureHandler} is important + * as it will affect the blocking and limiting behaviour of the back-pressure handling. + *

+ * When {@link #request(int amount)} is called, the first back-pressure handler in the list is called with + * {@code amount} as the requested amount of permits. The returned amount of permits (which is less than or equal to the + * initial amount) is then passed to the next back-pressure handler in the list. This process of reducing the amount to + * request for the next handlers in the chain is called "limiting". This process continues until all back-pressure + * handlers have been called or {@literal 0} permits has been returned. + *

+ * Once the final amount of available permits have been computed, unused acquired permits on back-pressure handlers (due + * to later limiting happening in the chain) are released. + *

+ * If no permits were obtained, the {@link #request(int)} method will wait up to {@code noPermitsReturnedWaitTimeout} + * for a release of permits before returning. + */ +public class CompositeBackPressureHandler implements BatchAwareBackPressureHandler, IdentifiableContainerComponent { + + private static final Logger logger = LoggerFactory.getLogger(CompositeBackPressureHandler.class); + + private final List backPressureHandlers; + + private final int batchSize; + + private final ReentrantLock noPermitsReturnedWaitLock = new ReentrantLock(); + + private final Condition permitsReleasedCondition = noPermitsReturnedWaitLock.newCondition(); + + private final Duration noPermitsReturnedWaitTimeout; + + private String id; + + public CompositeBackPressureHandler(List backPressureHandlers, int batchSize, + Duration noPermitsReturnedWaitTimeout) { + this.backPressureHandlers = backPressureHandlers; + this.batchSize = batchSize; + this.noPermitsReturnedWaitTimeout = noPermitsReturnedWaitTimeout; + } + + @Override + public void setId(String id) { + this.id = id; + backPressureHandlers.stream().filter(IdentifiableContainerComponent.class::isInstance) + .map(IdentifiableContainerComponent.class::cast) + .forEach(bph -> bph.setId(bph.getClass().getSimpleName() + "-" + id)); + } + + @Override + public String getId() { + return id; + } + + @Override + public int requestBatch() throws InterruptedException { + return request(batchSize); + } + + @Override + public int request(int amount) throws InterruptedException { + logger.debug("[{}] Requesting {} permits", this.id, amount); + int obtained = amount; + int[] obtainedPerBph = new int[backPressureHandlers.size()]; + for (int i = 0; i < backPressureHandlers.size() && obtained > 0; i++) { + obtainedPerBph[i] = backPressureHandlers.get(i).request(obtained); + obtained = Math.min(obtained, obtainedPerBph[i]); + } + for (int i = 0; i < backPressureHandlers.size(); i++) { + int obtainedForBph = obtainedPerBph[i]; + if (obtainedForBph > obtained) { + backPressureHandlers.get(i).release(obtainedForBph - obtained, ReleaseReason.LIMITED); + } + } + if (obtained == 0) { + waitForPermitsToBeReleased(); + } + logger.debug("[{}] Obtained {} permits ({} requested)", this.id, obtained, amount); + return obtained; + } + + @Override + public void release(int amount, ReleaseReason reason) { + logger.debug("[{}] Releasing {} permits ({})", this.id, amount, reason); + for (BackPressureHandler handler : backPressureHandlers) { + handler.release(amount, reason); + } + if (amount > 0) { + signalPermitsWereReleased(); + } + } + + /** + * Waits for permits to be released up to {@link #noPermitsReturnedWaitTimeout}. If no permits were released within + * the configured {@link #noPermitsReturnedWaitTimeout}, returns immediately. This allows {@link #request(int)} to + * return {@code 0} permits and will trigger another round of back-pressure handling. + * + * @throws InterruptedException if the Thread is interrupted while waiting for permits. + */ + @SuppressWarnings({ "java:S899" // we are not interested in the await return value here + }) + private void waitForPermitsToBeReleased() throws InterruptedException { + noPermitsReturnedWaitLock.lock(); + try { + logger.trace("[{}] No permits were obtained, waiting for a release up to {}", this.id, + noPermitsReturnedWaitTimeout); + permitsReleasedCondition.await(noPermitsReturnedWaitTimeout.toMillis(), TimeUnit.MILLISECONDS); + } + finally { + noPermitsReturnedWaitLock.unlock(); + } + } + + private void signalPermitsWereReleased() { + noPermitsReturnedWaitLock.lock(); + try { + permitsReleasedCondition.signal(); + } + finally { + noPermitsReturnedWaitLock.unlock(); + } + } + + @Override + public boolean drain(Duration timeout) { + logger.debug("[{}] Draining back-pressure handlers initiated", this.id); + boolean result = true; + Instant start = Instant.now(); + for (BackPressureHandler handler : backPressureHandlers) { + Duration remainingTimeout = maxDuration(timeout.minus(Duration.between(start, Instant.now())), + Duration.ZERO); + result &= handler.drain(remainingTimeout); + } + logger.debug("[{}] Draining back-pressure handlers completed", this.id); + return result; + } + + private static Duration maxDuration(Duration first, Duration second) { + return first.compareTo(second) > 0 ? first : second; + } +} diff --git a/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/ConcurrencyLimiterBlockingBackPressureHandler.java b/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/ConcurrencyLimiterBlockingBackPressureHandler.java new file mode 100644 index 000000000..51129f183 --- /dev/null +++ b/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/ConcurrencyLimiterBlockingBackPressureHandler.java @@ -0,0 +1,160 @@ +/* + * Copyright 2013-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.awspring.cloud.sqs.listener; + +import io.awspring.cloud.sqs.listener.source.PollingMessageSource; +import java.time.Duration; +import java.util.Arrays; +import java.util.concurrent.Semaphore; +import java.util.concurrent.TimeUnit; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.util.Assert; + +/** + * Blocking {@link BackPressureHandler} implementation that uses a {@link Semaphore} for handling the number of + * concurrent messages being processed. + * + * @see PollingMessageSource + */ +public class ConcurrencyLimiterBlockingBackPressureHandler + implements BatchAwareBackPressureHandler, IdentifiableContainerComponent { + + private static final Logger logger = LoggerFactory.getLogger(ConcurrencyLimiterBlockingBackPressureHandler.class); + + private final Semaphore semaphore; + + private final int batchSize; + + private final int totalPermits; + + private final Duration acquireTimeout; + + private String id = getClass().getSimpleName(); + + private ConcurrencyLimiterBlockingBackPressureHandler(Builder builder) { + this.batchSize = builder.batchSize; + this.totalPermits = builder.totalPermits; + this.acquireTimeout = builder.acquireTimeout; + logger.debug( + "ConcurrencyLimiterBlockingBackPressureHandler created with configuration " + + "totalPermits: {}, batchSize: {}, acquireTimeout: {}", + this.totalPermits, this.batchSize, this.acquireTimeout); + this.semaphore = new Semaphore(totalPermits); + } + + public static Builder builder() { + return new Builder(); + } + + @Override + public void setId(String id) { + this.id = id; + } + + @Override + public String getId() { + return this.id; + } + + @Override + public int requestBatch() throws InterruptedException { + return request(this.batchSize); + } + + @Override + public int request(int amount) throws InterruptedException { + int acquiredPermits = tryAcquire(amount, this.acquireTimeout); + if (acquiredPermits > 0) { + return acquiredPermits; + } + int availablePermits = Math.min(this.semaphore.availablePermits(), amount); + if (availablePermits > 0) { + return tryAcquire(availablePermits, this.acquireTimeout); + } + return 0; + } + + private int tryAcquire(int amount, Duration duration) throws InterruptedException { + if (this.semaphore.tryAcquire(amount, duration.toMillis(), TimeUnit.MILLISECONDS)) { + logger.debug("[{}] Acquired {} permits ({} / {} available)", this.id, amount, + this.semaphore.availablePermits(), this.totalPermits); + return amount; + } + return 0; + } + + @Override + public void release(int amount, ReleaseReason reason) { + this.semaphore.release(amount); + logger.debug("[{}] Released {} permits ({}) ({} / {} available)", this.id, amount, reason, + this.semaphore.availablePermits(), this.totalPermits); + } + + @Override + public boolean drain(Duration timeout) { + logger.debug("[{}] Waiting for up to {} for approx. {} permits to be released", this.id, timeout, + this.totalPermits - this.semaphore.availablePermits()); + try { + return tryAcquire(this.totalPermits, timeout) > 0; + } + catch (InterruptedException e) { + Thread.currentThread().interrupt(); + logger.debug("[{}] Draining interrupted", this.id); + return false; + } + } + + public static class Builder { + + private int batchSize; + + private int totalPermits; + + private Duration acquireTimeout; + + private BackPressureMode backPressureMode; + + public Builder batchSize(int batchSize) { + this.batchSize = batchSize; + return this; + } + + public Builder totalPermits(int totalPermits) { + this.totalPermits = totalPermits; + return this; + } + + public Builder acquireTimeout(Duration acquireTimeout) { + this.acquireTimeout = acquireTimeout; + return this; + } + + public Builder throughputConfiguration(BackPressureMode backPressureConfiguration) { + this.backPressureMode = backPressureConfiguration; + return this; + } + + public ConcurrencyLimiterBlockingBackPressureHandler build() { + Assert.noNullElements( + Arrays.asList(this.batchSize, this.totalPermits, this.acquireTimeout, this.backPressureMode), + "Missing configuration"); + Assert.isTrue(this.batchSize > 0, "The batch size must be greater than 0"); + Assert.isTrue(this.totalPermits >= this.batchSize, "Total permits must be greater than the batch size"); + return new ConcurrencyLimiterBlockingBackPressureHandler(this); + } + } +} diff --git a/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/ContainerOptions.java b/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/ContainerOptions.java index ad7313cf6..8e7006bfb 100644 --- a/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/ContainerOptions.java +++ b/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/ContainerOptions.java @@ -59,7 +59,7 @@ public interface ContainerOptions, B extends Co boolean isAutoStartup(); /** - * Set the maximum time the polling thread should wait for a full batch of permits to be available before trying to + * Sets the maximum time the polling thread should wait for a full batch of permits to be available before trying to * acquire a partial batch if so configured. A poll is only actually executed if at least one permit is available. * Default is 10 seconds. * @@ -127,6 +127,12 @@ default BackOffPolicy getPollBackOffPolicy() { */ BackPressureMode getBackPressureMode(); + /** + * Return the a {@link BackPressureHandlerFactory} to create a {@link BackPressureHandler} for this container. + * @return the BackPressureHandlerFactory. + */ + BackPressureHandlerFactory getBackPressureHandlerFactory(); + /** * Return the {@link ListenerMode} mode for this container. * @return the listener mode. diff --git a/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/ContainerOptionsBuilder.java b/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/ContainerOptionsBuilder.java index 9d03b7964..9ae2e32b2 100644 --- a/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/ContainerOptionsBuilder.java +++ b/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/ContainerOptionsBuilder.java @@ -145,6 +145,16 @@ default B pollBackOffPolicy(BackOffPolicy pollBackOffPolicy) { */ B backPressureMode(BackPressureMode backPressureMode); + /** + * Sets the {@link BackPressureHandlerFactory} for this container. Default is + * {@code AbstractContainerOptions.DEFAULT_BACKPRESSURE_FACTORY} which results in a default + * {@link SemaphoreBackPressureHandler} to be instantiated. + * + * @param backPressureHandlerFactory the BackPressureHandler supplier. + * @return this instance. + */ + B backPressureHandlerFactory(BackPressureHandlerFactory backPressureHandlerFactory); + /** * Set the maximum interval between acknowledgements for batch acknowledgements. The default depends on the specific * {@link ContainerComponentFactory} implementation. diff --git a/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/FullBatchBackPressureHandler.java b/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/FullBatchBackPressureHandler.java new file mode 100644 index 000000000..aa83921ab --- /dev/null +++ b/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/FullBatchBackPressureHandler.java @@ -0,0 +1,100 @@ +/* + * Copyright 2013-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.awspring.cloud.sqs.listener; + +import io.awspring.cloud.sqs.listener.source.PollingMessageSource; +import java.time.Duration; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.util.Assert; + +/** + * Non-blocking {@link BackPressureHandler} implementation that ensures the exact batch size is requested. + *

+ * If the amount of permits being requested is not equal to the batch size, permits will be limited to {@literal 0}. For + * this limiting mechanism to work, the {@link FullBatchBackPressureHandler} must be used in combination with another + * {@link BackPressureHandler} and be the last one in the chain of the {@link CompositeBackPressureHandler} + * + * @see PollingMessageSource + */ +public class FullBatchBackPressureHandler implements BatchAwareBackPressureHandler, IdentifiableContainerComponent { + + private static final Logger logger = LoggerFactory.getLogger(FullBatchBackPressureHandler.class); + + private final int batchSize; + + private String id = getClass().getSimpleName(); + + private FullBatchBackPressureHandler(Builder builder) { + this.batchSize = builder.batchSize; + logger.debug("FullBatchBackPressureHandler created with configuration: batchSize: {}", this.batchSize); + } + + public static Builder builder() { + return new Builder(); + } + + @Override + public void setId(String id) { + this.id = id; + } + + @Override + public String getId() { + return this.id; + } + + @Override + public int requestBatch() throws InterruptedException { + return request(this.batchSize); + } + + @Override + public int request(int amount) throws InterruptedException { + if (amount == batchSize) { + return amount; + } + logger.warn("[{}] Could not acquire a full batch ({} / {}), cancelling current poll", this.id, amount, + this.batchSize); + return 0; + } + + @Override + public void release(int amount, ReleaseReason reason) { + // NO-OP + } + + @Override + public boolean drain(Duration timeout) { + return true; + } + + public static class Builder { + + private int batchSize; + + public Builder batchSize(int batchSize) { + this.batchSize = batchSize; + return this; + } + + public FullBatchBackPressureHandler build() { + Assert.notNull(this.batchSize, "Missing configuration for batch size"); + Assert.isTrue(this.batchSize > 0, "The batch size must be greater than 0"); + return new FullBatchBackPressureHandler(this); + } + } +} diff --git a/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/ThroughputBackPressureHandler.java b/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/ThroughputBackPressureHandler.java new file mode 100644 index 000000000..ec2525ef4 --- /dev/null +++ b/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/ThroughputBackPressureHandler.java @@ -0,0 +1,178 @@ +/* + * Copyright 2013-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.awspring.cloud.sqs.listener; + +import io.awspring.cloud.sqs.listener.source.PollingMessageSource; +import java.time.Duration; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.util.Assert; + +/** + * Non-blocking {@link BackPressureHandler} implementation that uses a switch between high and low throughput modes. + *

+ * Throughput modes + *

    + *
  • In low-throughput mode, a single batch can be requested at a time. The number of permits that will be delivered + * is adjusted so that the number of in flight messages will not exceed the batch size.
  • + *
  • In high-throughput mode, multiple batches can be requested at a time. The number of permits that will be + * delivered is adjusted so that the number of in flight messages will not exceed the maximum number of concurrent + * messages. Note that for a single poll the maximum number of permits that will be delivered will not exceed the batch + * size.
  • + *
+ *

+ * Throughput mode switch: The initial throughput mode is the low-throughput mode. If some messages are + * fetched, then the throughput mode is switched to high-throughput mode. If no messages are returned fetched by a poll, + * the throughput mode is switched back to low-throughput mode. + *

+ * This {@link BackPressureHandler} is designed to be used in combination with another {@link BackPressureHandler} like + * the {@link ConcurrencyLimiterBlockingBackPressureHandler} that will handle the maximum concurrency level within the + * application in a blocking way. + * + * @see PollingMessageSource + */ +public class ThroughputBackPressureHandler implements BatchAwareBackPressureHandler, IdentifiableContainerComponent { + + private static final Logger logger = LoggerFactory.getLogger(ThroughputBackPressureHandler.class); + + private final int batchSize; + private final int maxConcurrentMessages; + + private final AtomicReference currentThroughputMode = new AtomicReference<>( + CurrentThroughputMode.LOW); + + private final AtomicInteger inFlightRequests = new AtomicInteger(0); + + private final AtomicBoolean drained = new AtomicBoolean(false); + + private String id = getClass().getSimpleName(); + + private ThroughputBackPressureHandler(Builder builder) { + this.batchSize = builder.batchSize; + this.maxConcurrentMessages = builder.maxConcurrentMessages; + logger.debug("ThroughputBackPressureHandler created with batchSize {}", this.batchSize); + } + + public static Builder builder() { + return new Builder(); + } + + @Override + public void setId(String id) { + this.id = id; + } + + @Override + public String getId() { + return this.id; + } + + @Override + public int requestBatch() throws InterruptedException { + return request(this.batchSize); + } + + @Override + public int request(int amount) throws InterruptedException { + if (drained.get()) { + return 0; + } + int amountCappedAtBatchSize = Math.min(amount, this.batchSize); + int permits; + int inFlight = inFlightRequests.get(); + if (CurrentThroughputMode.LOW == this.currentThroughputMode.get()) { + // In low-throughput mode, we only acquire one batch at a time, + // so we need to limit the available permits to the batchSize - inFlight messages. + permits = Math.max(0, Math.min(amountCappedAtBatchSize, this.batchSize - inFlight)); + logger.debug("[{}] Acquired {} permits (low-throughput mode), requested: {}, in flight: {}", this.id, + permits, amount, inFlight); + } + else { + // In high-throughput mode, we can acquire more permits than the batch size, + // but we need to limit the available permits to the maxConcurrentMessages - inFlight messages. + permits = Math.max(0, Math.min(amountCappedAtBatchSize, this.maxConcurrentMessages - inFlight)); + logger.debug("[{}] Acquired {} permits (high-throughput mode), requested: {}, in flight: {}", this.id, + permits, amount, inFlight); + } + inFlightRequests.addAndGet(permits); + return permits; + } + + @Override + public void release(int amount, ReleaseReason reason) { + if (drained.get()) { + return; + } + logger.debug("[{}] Releasing {} permits ({})", this.id, amount, reason); + inFlightRequests.addAndGet(-amount); + switch (reason) { + case NONE_FETCHED -> updateThroughputMode(CurrentThroughputMode.HIGH, CurrentThroughputMode.LOW); + case PARTIAL_FETCH -> updateThroughputMode(CurrentThroughputMode.LOW, CurrentThroughputMode.HIGH); + case LIMITED, PROCESSED -> { + // No need to switch throughput mode + } + } + } + + private void updateThroughputMode(CurrentThroughputMode currentTarget, CurrentThroughputMode newTarget) { + if (this.currentThroughputMode.compareAndSet(currentTarget, newTarget)) { + logger.debug("[{}] throughput mode updated to {}", this.id, newTarget); + } + } + + @Override + public boolean drain(Duration timeout) { + logger.debug("[{}] Draining", this.id); + drained.set(true); + return true; + } + + private enum CurrentThroughputMode { + + HIGH, + + LOW; + + } + + public static class Builder { + + private int batchSize; + private int maxConcurrentMessages; + + public Builder batchSize(int batchSize) { + this.batchSize = batchSize; + return this; + } + + public Builder totalPermits(int maxConcurrentMessages) { + this.maxConcurrentMessages = maxConcurrentMessages; + return this; + } + + public ThroughputBackPressureHandler build() { + Assert.notNull(this.batchSize, "Missing batchSize configuration"); + Assert.isTrue(this.batchSize > 0, "batch size must be greater than 0"); + Assert.notNull(this.maxConcurrentMessages, "Missing maxConcurrentMessages configuration"); + Assert.notNull(this.maxConcurrentMessages >= this.batchSize, + "maxConcurrentMessages must be greater than or equal to batchSize"); + return new ThroughputBackPressureHandler(this); + } + } +} diff --git a/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/source/AbstractPollingMessageSource.java b/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/source/AbstractPollingMessageSource.java index e71dc4319..9041cd9c8 100644 --- a/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/source/AbstractPollingMessageSource.java +++ b/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/source/AbstractPollingMessageSource.java @@ -17,6 +17,7 @@ import io.awspring.cloud.sqs.ConfigUtils; import io.awspring.cloud.sqs.listener.BackPressureHandler; +import io.awspring.cloud.sqs.listener.BackPressureHandler.ReleaseReason; import io.awspring.cloud.sqs.listener.BatchAwareBackPressureHandler; import io.awspring.cloud.sqs.listener.ContainerOptions; import io.awspring.cloud.sqs.listener.IdentifiableContainerComponent; @@ -214,7 +215,7 @@ private void pollAndEmitMessages() { if (!isRunning()) { logger.debug("MessageSource was stopped after permits where acquired. Returning {} permits", acquiredPermits); - this.backPressureHandler.release(acquiredPermits); + this.backPressureHandler.release(acquiredPermits, ReleaseReason.NONE_FETCHED); continue; } // @formatter:off @@ -252,15 +253,12 @@ private void handlePollBackOff() { protected abstract CompletableFuture> doPollForMessages(int messagesToRequest); public Collection> releaseUnusedPermits(int permits, Collection> msgs) { - if (msgs.isEmpty() && permits == this.backPressureHandler.getBatchSize()) { - this.backPressureHandler.releaseBatch(); - logger.trace("Released batch of unused permits for queue {}", this.pollingEndpointName); - } - else { - int permitsToRelease = permits - msgs.size(); - this.backPressureHandler.release(permitsToRelease); - logger.trace("Released {} unused permits for queue {}", permitsToRelease, this.pollingEndpointName); - } + int polledMessages = msgs.size(); + int permitsToRelease = permits - polledMessages; + ReleaseReason releaseReason = polledMessages == 0 ? ReleaseReason.NONE_FETCHED : ReleaseReason.PARTIAL_FETCH; + this.backPressureHandler.release(permitsToRelease, releaseReason); + logger.trace("Released {} unused ({}) permits for queue {} (messages polled {})", permitsToRelease, + releaseReason, this.pollingEndpointName, polledMessages); return msgs; } @@ -285,7 +283,7 @@ protected AcknowledgementCallback getAcknowledgementCallback() { private void releaseBackPressure() { logger.debug("Releasing permit for queue {}", this.pollingEndpointName); - this.backPressureHandler.release(1); + this.backPressureHandler.release(1, ReleaseReason.PROCESSED); } private Void handleSinkException(Throwable t) { diff --git a/spring-cloud-aws-sqs/src/test/java/io/awspring/cloud/sqs/integration/SqsBackPressureIntegrationTests.java b/spring-cloud-aws-sqs/src/test/java/io/awspring/cloud/sqs/integration/SqsBackPressureIntegrationTests.java new file mode 100644 index 000000000..8fc1fec03 --- /dev/null +++ b/spring-cloud-aws-sqs/src/test/java/io/awspring/cloud/sqs/integration/SqsBackPressureIntegrationTests.java @@ -0,0 +1,528 @@ +/* + * Copyright 2013-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.awspring.cloud.sqs.integration; + +import static org.assertj.core.api.Assertions.assertThat; + +import io.awspring.cloud.sqs.config.SqsBootstrapConfiguration; +import io.awspring.cloud.sqs.listener.*; +import io.awspring.cloud.sqs.operations.SqsTemplate; +import java.nio.charset.StandardCharsets; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.StandardOpenOption; +import java.time.Duration; +import java.time.Instant; +import java.util.List; +import java.util.Queue; +import java.util.Random; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.Semaphore; +import java.util.concurrent.ThreadLocalRandom; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.IntUnaryOperator; +import java.util.stream.Collectors; +import java.util.stream.IntStream; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.CsvSource; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.test.context.SpringBootTest; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.context.annotation.Import; +import org.springframework.messaging.Message; +import org.springframework.messaging.support.MessageBuilder; + +/** + * Integration tests for SQS containers back pressure management. + * + * @author Loïc Rouchon + */ +@SpringBootTest +class SqsBackPressureIntegrationTests extends BaseSqsIntegrationTest { + + private static final Logger logger = LoggerFactory.getLogger(SqsBackPressureIntegrationTests.class); + + @Autowired + SqsTemplate sqsTemplate; + + static final class NonBlockingExternalConcurrencyLimiterBackPressureHandler implements BackPressureHandler { + private final AtomicInteger limit; + private final AtomicInteger inFlight = new AtomicInteger(0); + private final AtomicBoolean draining = new AtomicBoolean(false); + + NonBlockingExternalConcurrencyLimiterBackPressureHandler(int max) { + limit = new AtomicInteger(max); + } + + public void setLimit(int value) { + logger.info("adjusting limit from {} to {}", limit.get(), value); + limit.set(value); + } + + @Override + public int request(int amount) { + if (draining.get()) { + return 0; + } + int permits = Math.max(0, Math.min(limit.get() - inFlight.get(), amount)); + inFlight.addAndGet(permits); + return permits; + } + + @Override + public void release(int amount, ReleaseReason reason) { + inFlight.addAndGet(-amount); + } + + @Override + public boolean drain(Duration timeout) { + Duration drainingTimeout = Duration.ofSeconds(10L); + Duration drainingPollingIntervalCheck = Duration.ofMillis(50L); + draining.set(true); + limit.set(0); + Instant start = Instant.now(); + while (Duration.between(start, Instant.now()).compareTo(drainingTimeout) < 0) { + if (inFlight.get() == 0) { + return true; + } + sleep(drainingPollingIntervalCheck.toMillis()); + } + return false; + } + } + + @ParameterizedTest + @CsvSource({ "2,2", "4,4", "5,5", "20,5" }) + void staticBackPressureLimitShouldCapQueueProcessingCapacity(int staticLimit, int expectedMaxConcurrentRequests) + throws Exception { + AtomicInteger concurrentRequest = new AtomicInteger(); + AtomicInteger maxConcurrentRequest = new AtomicInteger(); + NonBlockingExternalConcurrencyLimiterBackPressureHandler limiter = new NonBlockingExternalConcurrencyLimiterBackPressureHandler( + staticLimit); + String queueName = "BACK_PRESSURE_LIMITER_QUEUE_NAME_STATIC_LIMIT_" + staticLimit; + IntStream.range(0, 10).forEach(index -> { + List> messages = create10Messages("staticBackPressureLimit" + staticLimit); + sqsTemplate.sendMany(queueName, messages); + }); + logger.debug("Sent 100 messages to queue {}", queueName); + var latch = new CountDownLatch(100); + var container = SqsMessageListenerContainer.builder().sqsAsyncClient(BaseSqsIntegrationTest.createAsyncClient()) + .queueNames(queueName) + .configure( + options -> options.maxMessagesPerPoll(5).maxConcurrentMessages(5) + .backPressureMode(BackPressureMode.AUTO).maxDelayBetweenPolls(Duration.ofSeconds(1)) + .pollTimeout(Duration.ofSeconds(1)) + .backPressureHandlerFactory(containerOptions -> BackPressureHandlerFactory + .compositeBackPressureHandler(containerOptions, Duration.ofMillis(50L), + List.of(limiter, BackPressureHandlerFactory + .concurrencyLimiterBackPressureHandler(containerOptions))))) + .messageListener(msg -> { + int concurrentRqs = concurrentRequest.incrementAndGet(); + maxConcurrentRequest.updateAndGet(max -> Math.max(max, concurrentRqs)); + sleep(50L); + logger.debug("concurrent rq {}, max concurrent rq {}, latch count {}", concurrentRequest.get(), + maxConcurrentRequest.get(), latch.getCount()); + latch.countDown(); + concurrentRequest.decrementAndGet(); + }).build(); + container.start(); + assertThat(latch.await(10, TimeUnit.SECONDS)).isTrue(); + assertThat(maxConcurrentRequest.get()).isEqualTo(expectedMaxConcurrentRequests); + container.stop(); + } + + @Test + void zeroBackPressureLimitShouldStopQueueProcessing() throws Exception { + AtomicInteger concurrentRequest = new AtomicInteger(); + AtomicInteger maxConcurrentRequest = new AtomicInteger(); + NonBlockingExternalConcurrencyLimiterBackPressureHandler limiter = new NonBlockingExternalConcurrencyLimiterBackPressureHandler( + 0); + String queueName = "BACK_PRESSURE_LIMITER_QUEUE_NAME_STATIC_LIMIT_0"; + IntStream.range(0, 10).forEach(index -> { + List> messages = create10Messages("staticBackPressureLimit0"); + sqsTemplate.sendMany(queueName, messages); + }); + logger.debug("Sent 100 messages to queue {}", queueName); + var latch = new CountDownLatch(100); + var container = SqsMessageListenerContainer.builder().sqsAsyncClient(BaseSqsIntegrationTest.createAsyncClient()) + .queueNames(queueName) + .configure( + options -> options.maxMessagesPerPoll(5).maxConcurrentMessages(5) + .backPressureMode(BackPressureMode.AUTO).maxDelayBetweenPolls(Duration.ofSeconds(1)) + .pollTimeout(Duration.ofSeconds(1)) + .backPressureHandlerFactory(containerOptions -> BackPressureHandlerFactory + .compositeBackPressureHandler(containerOptions, Duration.ofMillis(50L), + List.of(limiter, BackPressureHandlerFactory + .concurrencyLimiterBackPressureHandler(containerOptions))))) + .messageListener(msg -> { + int concurrentRqs = concurrentRequest.incrementAndGet(); + maxConcurrentRequest.updateAndGet(max -> Math.max(max, concurrentRqs)); + sleep(50L); + logger.debug("concurrent rq {}, max concurrent rq {}, latch count {}", concurrentRequest.get(), + maxConcurrentRequest.get(), latch.getCount()); + latch.countDown(); + concurrentRequest.decrementAndGet(); + }).build(); + container.start(); + assertThat(latch.await(2, TimeUnit.SECONDS)).isFalse(); + assertThat(maxConcurrentRequest.get()).isZero(); + assertThat(latch.getCount()).isEqualTo(100L); + container.stop(); + } + + @Test + void changeInBackPressureLimitShouldAdaptQueueProcessingCapacity() throws Exception { + AtomicInteger concurrentRequest = new AtomicInteger(); + AtomicInteger maxConcurrentRequest = new AtomicInteger(); + NonBlockingExternalConcurrencyLimiterBackPressureHandler limiter = new NonBlockingExternalConcurrencyLimiterBackPressureHandler( + 5); + String queueName = "BACK_PRESSURE_LIMITER_QUEUE_NAME_SYNC_ADAPTIVE_LIMIT"; + int nbMessages = 280; + IntStream.range(0, nbMessages / 10).forEach(index -> { + List> messages = create10Messages("syncAdaptiveBackPressureLimit"); + sqsTemplate.sendMany(queueName, messages); + }); + logger.debug("Sent {} messages to queue {}", nbMessages, queueName); + var latch = new CountDownLatch(nbMessages); + var controlSemaphore = new Semaphore(0); + var advanceSemaphore = new Semaphore(0); + var processingFailed = new AtomicBoolean(false); + var isDraining = new AtomicBoolean(false); + var container = SqsMessageListenerContainer.builder().sqsAsyncClient(BaseSqsIntegrationTest.createAsyncClient()) + .queueNames(queueName) + .configure( + options -> options.maxMessagesPerPoll(5).maxConcurrentMessages(5) + .backPressureMode(BackPressureMode.AUTO).maxDelayBetweenPolls(Duration.ofSeconds(1)) + .pollTimeout(Duration.ofSeconds(1)) + .backPressureHandlerFactory(containerOptions -> BackPressureHandlerFactory + .compositeBackPressureHandler(containerOptions, Duration.ofMillis(50L), + List.of(limiter, BackPressureHandlerFactory + .concurrencyLimiterBackPressureHandler(containerOptions))))) + .messageListener(msg -> { + try { + if (!controlSemaphore.tryAcquire(5, TimeUnit.SECONDS) && !isDraining.get()) { + processingFailed.set(true); + throw new IllegalStateException("Failed to wait for control semaphore"); + } + } + catch (InterruptedException e) { + if (!isDraining.get()) { + processingFailed.set(true); + Thread.currentThread().interrupt(); + throw new RuntimeException(e); + } + } + int concurrentRqs = concurrentRequest.incrementAndGet(); + maxConcurrentRequest.updateAndGet(max -> Math.max(max, concurrentRqs)); + latch.countDown(); + logger.debug("concurrent rq {}, max concurrent rq {}, latch count {}", concurrentRequest.get(), + maxConcurrentRequest.get(), latch.getCount()); + sleep(10L); + concurrentRequest.decrementAndGet(); + advanceSemaphore.release(); + }).build(); + class Controller { + private final Semaphore advanceSemaphore; + private final Semaphore controlSemaphore; + private final NonBlockingExternalConcurrencyLimiterBackPressureHandler limiter; + private final AtomicInteger maxConcurrentRequest; + private final AtomicBoolean processingFailed; + + Controller(Semaphore advanceSemaphore, Semaphore controlSemaphore, + NonBlockingExternalConcurrencyLimiterBackPressureHandler limiter, + AtomicInteger maxConcurrentRequest, AtomicBoolean processingFailed) { + this.advanceSemaphore = advanceSemaphore; + this.controlSemaphore = controlSemaphore; + this.limiter = limiter; + this.maxConcurrentRequest = maxConcurrentRequest; + this.processingFailed = processingFailed; + } + + public void updateLimit(int newLimit) { + limiter.setLimit(newLimit); + } + + void updateLimitAndWaitForReset(int newLimit) throws InterruptedException { + updateLimit(newLimit); + int atLeastTwoPollingCycles = 2 * 5; + controlSemaphore.release(atLeastTwoPollingCycles); + waitForAdvance(atLeastTwoPollingCycles); + maxConcurrentRequest.set(0); + } + + void advance(int permits) { + controlSemaphore.release(permits); + } + + void waitForAdvance(int permits) throws InterruptedException { + assertThat(advanceSemaphore.tryAcquire(permits, 5, TimeUnit.SECONDS)) + .withFailMessage(() -> "Waiting for %d permits timed out. Only %d permits available" + .formatted(permits, advanceSemaphore.availablePermits())) + .isTrue(); + assertThat(processingFailed.get()).isFalse(); + } + } + var controller = new Controller(advanceSemaphore, controlSemaphore, limiter, maxConcurrentRequest, + processingFailed); + try { + container.start(); + + controller.advance(50); + controller.waitForAdvance(50); + // not limiting queue processing capacity + assertThat(controller.maxConcurrentRequest.get()).isEqualTo(5); + controller.updateLimitAndWaitForReset(2); + controller.advance(50); + + controller.waitForAdvance(50); + // limiting queue processing capacity + assertThat(controller.maxConcurrentRequest.get()).isEqualTo(2); + controller.updateLimitAndWaitForReset(7); + controller.advance(50); + + controller.waitForAdvance(50); + // not limiting queue processing capacity + assertThat(controller.maxConcurrentRequest.get()).isEqualTo(5); + controller.updateLimitAndWaitForReset(3); + controller.advance(50); + sleep(10L); + limiter.setLimit(1); + sleep(10L); + limiter.setLimit(2); + sleep(10L); + limiter.setLimit(3); + + controller.waitForAdvance(50); + assertThat(controller.maxConcurrentRequest.get()).isEqualTo(3); + // stopping processing of the queue + controller.updateLimit(0); + controller.advance(50); + assertThat(advanceSemaphore.tryAcquire(10, 5, TimeUnit.SECONDS)) + .withFailMessage("Acquiring semaphore should have timed out as limit was set to 0").isFalse(); + + // resume queue processing + controller.updateLimit(6); + + controller.waitForAdvance(50); + assertThat(latch.await(10, TimeUnit.SECONDS)).isTrue(); + assertThat(controller.maxConcurrentRequest.get()).isEqualTo(5); + assertThat(processingFailed.get()).isFalse(); + } + finally { + isDraining.set(true); + container.stop(); + } + } + + static class EventsCsvWriter { + private final Queue events = new ConcurrentLinkedQueue<>(List.of("event,time,value")); + + void registerEvent(String event, int value) { + events.add("%s,%s,%d".formatted(event, Instant.now(), value)); + } + + void write(Path path) throws Exception { + Files.writeString(path, String.join("\n", events), StandardCharsets.UTF_8, StandardOpenOption.CREATE, + StandardOpenOption.TRUNCATE_EXISTING); + } + } + + static class StatisticsBphDecorator implements BatchAwareBackPressureHandler, IdentifiableContainerComponent { + private final BatchAwareBackPressureHandler delegate; + private final EventsCsvWriter eventCsv; + private String id; + + StatisticsBphDecorator(BatchAwareBackPressureHandler delegate, EventsCsvWriter eventsCsvWriter) { + this.delegate = delegate; + this.eventCsv = eventsCsvWriter; + } + + @Override + public int requestBatch() throws InterruptedException { + int permits = delegate.requestBatch(); + if (permits > 0) { + eventCsv.registerEvent("obtained_permits", permits); + } + return permits; + } + + @Override + public int request(int amount) throws InterruptedException { + int permits = delegate.request(amount); + if (permits > 0) { + eventCsv.registerEvent("obtained_permits", permits); + } + return permits; + } + + @Override + public void release(int amount, ReleaseReason reason) { + if (amount > 0) { + eventCsv.registerEvent("release_" + reason, amount); + } + delegate.release(amount, reason); + } + + @Override + public boolean drain(Duration timeout) { + eventCsv.registerEvent("drain", 1); + return delegate.drain(timeout); + } + + @Override + public void setId(String id) { + this.id = id; + if (delegate instanceof IdentifiableContainerComponent icc) { + icc.setId("delegate-" + id); + } + } + + @Override + public String getId() { + return id; + } + } + + /** + * This test simulates a progressive change in the back pressure limit. Unlike + * {@link #changeInBackPressureLimitShouldAdaptQueueProcessingCapacity()}, this test does not block message + * consumption while updating the limit. + *

+ * The limit is updated in a loop until all messages are consumed. The update follows a triangle wave pattern with a + * minimum of 0, a maximum of 15, and a period of 30 iterations. After each update of the limit, the test waits up + * to 10ms and samples the maximum number of concurrent messages that were processed since the update. This number + * can be higher than the defined limit during the adaptation period of the decreasing limit wave. For the + * increasing limit wave, it is usually lower due to the adaptation delay. In both cases, the maximum number of + * concurrent messages being processed rapidly converges toward the defined limit. + *

+ * The test passes if the sum of the sampled maximum number of concurrently processed messages is lower than the sum + * of the limits at those points in time. + */ + @Test + void unsynchronizedChangesInBackPressureLimitShouldAdaptQueueProcessingCapacity() throws Exception { + AtomicInteger concurrentRequest = new AtomicInteger(); + AtomicInteger maxConcurrentRequest = new AtomicInteger(); + NonBlockingExternalConcurrencyLimiterBackPressureHandler limiter = new NonBlockingExternalConcurrencyLimiterBackPressureHandler( + 0); + String queueName = "REACTIVE_BACK_PRESSURE_LIMITER_QUEUE_NAME_ADAPTIVE_LIMIT"; + int nbMessages = 1000; + Semaphore advanceSemaphore = new Semaphore(0); + IntStream.range(0, nbMessages / 10).forEach(index -> { + List> messages = create10Messages("reactAdaptiveBackPressureLimit"); + sqsTemplate.sendMany(queueName, messages); + }); + logger.debug("Sent {} messages to queue {}", nbMessages, queueName); + var latch = new CountDownLatch(nbMessages); + EventsCsvWriter eventsCsvWriter = new EventsCsvWriter(); + var container = SqsMessageListenerContainer.builder().sqsAsyncClient(BaseSqsIntegrationTest.createAsyncClient()) + .queueNames(queueName) + .configure( + options -> options.maxMessagesPerPoll(10).maxConcurrentMessages(10) + .backPressureMode(BackPressureMode.AUTO).maxDelayBetweenPolls(Duration.ofSeconds(1)) + .pollTimeout(Duration.ofSeconds(1)) + .backPressureHandlerFactory(containerOptions -> new StatisticsBphDecorator( + BackPressureHandlerFactory.compositeBackPressureHandler(containerOptions, + Duration.ofMillis(50L), + List.of(limiter, BackPressureHandlerFactory + .concurrencyLimiterBackPressureHandler(containerOptions))), + eventsCsvWriter))) + .messageListener(msg -> { + int currentConcurrentRq = concurrentRequest.incrementAndGet(); + maxConcurrentRequest.updateAndGet(max -> Math.max(max, currentConcurrentRq)); + sleep(ThreadLocalRandom.current().nextInt(10)); + latch.countDown(); + logger.debug("concurrent rq {}, max concurrent rq {}, latch count {}", concurrentRequest.get(), + maxConcurrentRequest.get(), latch.getCount()); + concurrentRequest.decrementAndGet(); + advanceSemaphore.release(); + }).build(); + IntUnaryOperator progressiveLimitChange = (int x) -> { + int period = 30; + int halfPeriod = period / 2; + if (x % period < halfPeriod) { + return (x % halfPeriod); + } + else { + return (halfPeriod - (x % halfPeriod)); + } + }; + try { + container.start(); + Random random = new Random(); + int limitsSum = 0; + int maxConcurrentRqSum = 0; + int changeLimitCount = 0; + while (latch.getCount() > 0 && changeLimitCount < nbMessages) { + changeLimitCount++; + int limit = progressiveLimitChange.applyAsInt(changeLimitCount); + int expectedMax = Math.min(10, limit); + limiter.setLimit(limit); + maxConcurrentRequest.set(0); + sleep(random.nextInt(20)); + int actualLimit = Math.min(10, limit); + int max = maxConcurrentRequest.get(); + if (max > 0) { + // Ignore iterations where nothing was polled (messages consumption slower than iteration) + limitsSum += actualLimit; + maxConcurrentRqSum += max; + } + eventsCsvWriter.registerEvent("max_concurrent_rq", max); + eventsCsvWriter.registerEvent("concurrent_rq", concurrentRequest.get()); + eventsCsvWriter.registerEvent("limit", limit); + eventsCsvWriter.registerEvent("in_flight", limiter.inFlight.get()); + eventsCsvWriter.registerEvent("expected_max", expectedMax); + eventsCsvWriter.registerEvent("max_minus_expected_max", max - expectedMax); + } + eventsCsvWriter.write(Path.of("target/stats-%s.csv".formatted(queueName))); + assertThat(maxConcurrentRqSum).isLessThanOrEqualTo(limitsSum); + assertThat(latch.await(10, TimeUnit.SECONDS)).isTrue(); + } + finally { + container.stop(); + } + } + + private static void sleep(long millis) { + try { + Thread.sleep(millis); + } + catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + } + + private List> create10Messages(String testName) { + return IntStream.range(0, 10).mapToObj(index -> testName + "-payload-" + index) + .map(payload -> MessageBuilder.withPayload(payload).build()).collect(Collectors.toList()); + } + + @Import(SqsBootstrapConfiguration.class) + @Configuration + static class SQSConfiguration { + + @Bean + SqsTemplate sqsTemplate() { + return SqsTemplate.builder().sqsAsyncClient(BaseSqsIntegrationTest.createAsyncClient()).build(); + } + } +} diff --git a/spring-cloud-aws-sqs/src/test/java/io/awspring/cloud/sqs/integration/SqsIntegrationTests.java b/spring-cloud-aws-sqs/src/test/java/io/awspring/cloud/sqs/integration/SqsIntegrationTests.java index 50bded839..76a7a65f7 100644 --- a/spring-cloud-aws-sqs/src/test/java/io/awspring/cloud/sqs/integration/SqsIntegrationTests.java +++ b/spring-cloud-aws-sqs/src/test/java/io/awspring/cloud/sqs/integration/SqsIntegrationTests.java @@ -269,6 +269,7 @@ void manuallyCreatesInactiveContainer() throws Exception { logger.debug("Sent message to queue {} with messageBody {}", MANUALLY_CREATE_INACTIVE_CONTAINER_QUEUE_NAME, messageBody); assertThat(latchContainer.manuallyInactiveCreatedContainerLatch.await(10, TimeUnit.SECONDS)).isTrue(); + inactiveMessageListenerContainer.stop(); } // @formatter:off diff --git a/spring-cloud-aws-sqs/src/test/java/io/awspring/cloud/sqs/listener/source/AbstractPollingMessageSourceTests.java b/spring-cloud-aws-sqs/src/test/java/io/awspring/cloud/sqs/listener/source/AbstractPollingMessageSourceTests.java index b03b308c6..df3b5a1bc 100644 --- a/spring-cloud-aws-sqs/src/test/java/io/awspring/cloud/sqs/listener/source/AbstractPollingMessageSourceTests.java +++ b/spring-cloud-aws-sqs/src/test/java/io/awspring/cloud/sqs/listener/source/AbstractPollingMessageSourceTests.java @@ -23,27 +23,17 @@ import static org.mockito.Mockito.times; import io.awspring.cloud.sqs.MessageExecutionThreadFactory; -import io.awspring.cloud.sqs.listener.BackPressureMode; -import io.awspring.cloud.sqs.listener.SemaphoreBackPressureHandler; -import io.awspring.cloud.sqs.listener.SqsContainerOptions; +import io.awspring.cloud.sqs.listener.*; import io.awspring.cloud.sqs.listener.acknowledgement.AcknowledgementCallback; import io.awspring.cloud.sqs.listener.acknowledgement.AcknowledgementProcessor; import io.awspring.cloud.sqs.support.converter.MessageConversionContext; import io.awspring.cloud.sqs.support.converter.SqsMessagingMessageConverter; import java.time.Duration; -import java.util.Collection; -import java.util.Collections; -import java.util.List; -import java.util.UUID; -import java.util.concurrent.CompletableFuture; -import java.util.concurrent.CountDownLatch; -import java.util.concurrent.ExecutorService; -import java.util.concurrent.Executors; -import java.util.concurrent.Semaphore; -import java.util.concurrent.ThreadFactory; -import java.util.concurrent.TimeUnit; +import java.util.*; +import java.util.concurrent.*; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; import java.util.stream.IntStream; import org.assertj.core.api.InstanceOfAssertFactories; import org.awaitility.Awaitility; @@ -68,10 +58,12 @@ class AbstractPollingMessageSourceTests { @Test void shouldAcquireAndReleaseFullPermits() { String testName = "shouldAcquireAndReleaseFullPermits"; + SqsContainerOptions options = SqsContainerOptions.builder().maxMessagesPerPoll(10).maxConcurrentMessages(10) + .backPressureMode(BackPressureMode.ALWAYS_POLL_MAX_MESSAGES) + .maxDelayBetweenPolls(Duration.ofMillis(200)).build(); + BackPressureHandler backPressureHandler = BackPressureHandlerFactory + .adaptativeThroughputBackPressureHandler(options, Duration.ofMillis(100L)); - SemaphoreBackPressureHandler backPressureHandler = SemaphoreBackPressureHandler.builder() - .acquireTimeout(Duration.ofMillis(200)).batchSize(10).totalPermits(10) - .throughputConfiguration(BackPressureMode.ALWAYS_POLL_MAX_MESSAGES).build(); ExecutorService threadPool = Executors.newCachedThreadPool(); CountDownLatch pollingCounter = new CountDownLatch(3); CountDownLatch processingCounter = new CountDownLatch(1); @@ -80,8 +72,6 @@ void shouldAcquireAndReleaseFullPermits() { private final AtomicBoolean hasReceived = new AtomicBoolean(false); - private final AtomicBoolean hasMadeSecondPoll = new AtomicBoolean(false); - @Override protected CompletableFuture> doPollForMessages(int messagesToRequest) { return CompletableFuture.supplyAsync(() -> { @@ -90,21 +80,6 @@ protected CompletableFuture> doPollForMessages(int messagesT assertThat(messagesToRequest).isEqualTo(10); assertAvailablePermits(backPressureHandler, 0); boolean firstPoll = hasReceived.compareAndSet(false, true); - if (firstPoll) { - logger.debug("First poll"); - // No permits released yet, should be TM low - assertThroughputMode(backPressureHandler, "low"); - } - else if (hasMadeSecondPoll.compareAndSet(false, true)) { - logger.debug("Second poll"); - // Permits returned, should be high - assertThroughputMode(backPressureHandler, "high"); - } - else { - logger.debug("Third poll"); - // Already returned full permits, should be low - assertThroughputMode(backPressureHandler, "low"); - } return firstPoll ? (Collection) List.of(Message.builder() .messageId(UUID.randomUUID().toString()).body("message").build()) @@ -130,7 +105,7 @@ else if (hasMadeSecondPoll.compareAndSet(false, true)) { }); source.setId(testName + " source"); - source.configure(SqsContainerOptions.builder().build()); + source.configure(options); source.setTaskExecutor(createTaskExecutor(testName)); source.setAcknowledgementProcessor(getNoOpsAcknowledgementProcessor()); source.start(); @@ -138,14 +113,101 @@ else if (hasMadeSecondPoll.compareAndSet(false, true)) { assertThat(doAwait(processingCounter)).isTrue(); } + @Test + void shouldAdaptThroughputMode() { + String testName = "shouldAdaptThroughputMode"; + SqsContainerOptions options = SqsContainerOptions.builder().maxMessagesPerPoll(10).maxConcurrentMessages(10) + .backPressureMode(BackPressureMode.ALWAYS_POLL_MAX_MESSAGES) + .maxDelayBetweenPolls(Duration.ofMillis(150)).build(); + BackPressureHandler backPressureHandler = BackPressureHandlerFactory + .adaptativeThroughputBackPressureHandler(options, Duration.ofMillis(100L)); + + ExecutorService threadPool = Executors.newCachedThreadPool(); + CountDownLatch pollingCounter = new CountDownLatch(3); + CountDownLatch processingCounter = new CountDownLatch(1); + Collection errors = new ConcurrentLinkedQueue<>(); + + AbstractPollingMessageSource source = new AbstractPollingMessageSource<>() { + + private final AtomicInteger pollAttemptCounter = new AtomicInteger(0); + + @Override + protected CompletableFuture> doPollForMessages(int messagesToRequest) { + return CompletableFuture.supplyAsync(() -> { + try { + int pollAttempt = pollAttemptCounter.incrementAndGet(); + logger.warn("Poll attempt {}", pollAttempt); + if (pollAttempt == 1) { + // Initial poll; throughput mode should be low + assertThroughputMode(backPressureHandler, "low"); + // Since no permits were acquired yet, should be 10 + assertThat(messagesToRequest).isEqualTo(10); + return (Collection) List.of( + Message.builder().messageId(UUID.randomUUID().toString()).body("message").build()); + } + else if (pollAttempt == 2) { + // Messages returned in the previous poll; throughput mode should be high + assertThroughputMode(backPressureHandler, "high"); + // Since throughput mode is high, should be 10 + assertThat(messagesToRequest).isEqualTo(10); + return Collections. emptyList(); + } + else { + // No Messages returned in the previous poll; throughput mode should be low + assertThroughputMode(backPressureHandler, "low"); + return Collections. emptyList(); + } + } + catch (Throwable t) { + logger.error("Error (not expecting it)", t); + errors.add(t); + throw new RuntimeException(t); + } + }, threadPool).whenComplete((v, t) -> { + if (t == null) { + logger.warn("Polling succeeded", t); + pollingCounter.countDown(); + } + else { + logger.warn("Polling failed with error", t); + errors.add(t); + } + }); + } + }; + + source.setBackPressureHandler(backPressureHandler); + source.setMessageSink((msgs, context) -> { + msgs.forEach(msg -> context.runBackPressureReleaseCallback()); + return CompletableFuture.runAsync(processingCounter::countDown); + }); + + source.setId(testName + " source"); + source.configure(options); + source.setTaskExecutor(createTaskExecutor(testName)); + source.setAcknowledgementProcessor(getNoOpsAcknowledgementProcessor()); + try { + source.start(); + assertThat(doAwait(pollingCounter)).isTrue(); + assertThat(doAwait(processingCounter)).isTrue(); + assertThat(errors).isEmpty(); + } + finally { + source.stop(); + threadPool.shutdownNow(); + } + } + private static final AtomicInteger testCounter = new AtomicInteger(); @Test void shouldAcquireAndReleasePartialPermits() { String testName = "shouldAcquireAndReleasePartialPermits"; - SemaphoreBackPressureHandler backPressureHandler = SemaphoreBackPressureHandler.builder() - .acquireTimeout(Duration.ofMillis(150)).batchSize(10).totalPermits(10) - .throughputConfiguration(BackPressureMode.AUTO).build(); + SqsContainerOptions options = SqsContainerOptions.builder().maxMessagesPerPoll(10).maxConcurrentMessages(10) + .backPressureMode(BackPressureMode.AUTO).maxDelayBetweenPolls(Duration.ofMillis(150)).build(); + BackPressureHandler backPressureHandler = BackPressureHandlerFactory + .adaptativeThroughputBackPressureHandler(options, Duration.ofMillis(200L)); + ExecutorService threadPool = Executors .newCachedThreadPool(new MessageExecutionThreadFactory("test " + testCounter.incrementAndGet())); CountDownLatch pollingCounter = new CountDownLatch(4); @@ -155,60 +217,34 @@ void shouldAcquireAndReleasePartialPermits() { AbstractPollingMessageSource source = new AbstractPollingMessageSource<>() { - private final AtomicBoolean hasReceived = new AtomicBoolean(false); - - private final AtomicBoolean hasAcquired9 = new AtomicBoolean(false); - - private final AtomicBoolean hasMadeThirdPoll = new AtomicBoolean(false); + private final AtomicInteger pollAttemptCounter = new AtomicInteger(0); @Override protected CompletableFuture> doPollForMessages(int messagesToRequest) { return CompletableFuture.supplyAsync(() -> { try { - // Give it some time between returning empty and polling again - // doSleep(100); - - // Will only be true the first time it sets hasReceived to true - boolean shouldReturnMessage = hasReceived.compareAndSet(false, true); - if (shouldReturnMessage) { + int pollAttempt = pollAttemptCounter.incrementAndGet(); + if (pollAttempt == 1) { // First poll, should have 10 logger.debug("First poll - should request 10 messages"); assertThat(messagesToRequest).isEqualTo(10); - assertAvailablePermits(backPressureHandler, 0); - // No permits have been released yet - assertThroughputMode(backPressureHandler, "low"); + Message message = Message.builder().messageId(UUID.randomUUID().toString()).body("message") + .build(); + return (Collection) List.of(message); } - else if (hasAcquired9.compareAndSet(false, true)) { + else if (pollAttempt == 2) { // Second poll, should have 9 logger.debug("Second poll - should request 9 messages"); assertThat(messagesToRequest).isEqualTo(9); - assertAvailablePermitsLessThanOrEqualTo(backPressureHandler, 1); - // Has released 9 permits, should be TM HIGH - assertThroughputMode(backPressureHandler, "high"); processingLatch.countDown(); // Release processing now + return Collections. emptyList(); } else { - boolean thirdPoll = hasMadeThirdPoll.compareAndSet(false, true); // Third poll or later, should have 10 again - logger.debug("Third poll - should request 10 messages"); + logger.debug("Third (or later) poll - should request 10 messages"); assertThat(messagesToRequest).isEqualTo(10); - assertAvailablePermits(backPressureHandler, 0); - if (thirdPoll) { - // Hasn't yet returned a full batch, should be TM High - assertThroughputMode(backPressureHandler, "high"); - } - else { - // Has returned all permits in third poll - assertThroughputMode(backPressureHandler, "low"); - } - } - if (shouldReturnMessage) { - logger.debug("shouldReturnMessage, returning one message"); - return (Collection) List.of( - Message.builder().messageId(UUID.randomUUID().toString()).body("message").build()); + return Collections. emptyList(); } - logger.debug("should not return message, returning empty list"); - return Collections. emptyList(); } catch (Error e) { hasThrownError.set(true); @@ -228,7 +264,7 @@ else if (hasAcquired9.compareAndSet(false, true)) { return CompletableFuture.completedFuture(null).thenRun(processingCounter::countDown); }); source.setId(testName + " source"); - source.configure(SqsContainerOptions.builder().build()); + source.configure(options); source.setTaskExecutor(createTaskExecutor(testName)); source.setAcknowledgementProcessor(getNoOpsAcknowledgementProcessor()); source.start(); @@ -236,19 +272,14 @@ else if (hasAcquired9.compareAndSet(false, true)) { assertThat(doAwait(pollingCounter)).isTrue(); source.stop(); assertThat(hasThrownError.get()).isFalse(); + threadPool.shutdownNow(); } @Test void shouldReleasePermitsOnConversionErrors() { String testName = "shouldReleasePermitsOnConversionErrors"; - SemaphoreBackPressureHandler backPressureHandler = SemaphoreBackPressureHandler.builder() - .acquireTimeout(Duration.ofMillis(150)).batchSize(10).totalPermits(10) - .throughputConfiguration(BackPressureMode.AUTO).build(); AtomicInteger convertedMessages = new AtomicInteger(0); - AtomicInteger messagesInSink = new AtomicInteger(0); - AtomicBoolean hasFailed = new AtomicBoolean(false); - var converter = new SqsMessagingMessageConverter() { @Override public org.springframework.messaging.Message toMessagingMessage(Message source, @@ -262,6 +293,15 @@ public org.springframework.messaging.Message toMessagingMessage(Message sourc } }; + SqsContainerOptions options = SqsContainerOptions.builder().maxMessagesPerPoll(10).maxConcurrentMessages(10) + .backPressureMode(BackPressureMode.ALWAYS_POLL_MAX_MESSAGES) + .maxDelayBetweenPolls(Duration.ofMillis(150)).messageConverter(converter).build(); + BackPressureHandler backPressureHandler = BackPressureHandlerFactory + .adaptativeThroughputBackPressureHandler(options, Duration.ofMillis(100L)); + + AtomicInteger messagesInSink = new AtomicInteger(0); + AtomicBoolean hasFailed = new AtomicBoolean(false); + AbstractPollingMessageSource source = new AbstractPollingMessageSource<>() { @Override @@ -288,7 +328,7 @@ private Collection create10Messages() { return CompletableFuture.completedFuture(null); }); source.setId(testName + " source"); - source.configure(SqsContainerOptions.builder().messageConverter(converter).build()); + source.configure(options); source.setPollingEndpointName("shouldReleasePermitsOnConversionErrors-queue"); source.setTaskExecutor(createTaskExecutor(testName)); source.setAcknowledgementProcessor(getNoOpsAcknowledgementProcessor()); @@ -301,12 +341,17 @@ private Collection create10Messages() { @Test void shouldBackOffIfPollingThrowsAnError() { - var testName = "shouldBackOffIfPollingThrowsAnError"; - var backPressureHandler = SemaphoreBackPressureHandler.builder().acquireTimeout(Duration.ofMillis(200)) - .batchSize(10).totalPermits(40).throughputConfiguration(BackPressureMode.ALWAYS_POLL_MAX_MESSAGES) - .build(); + var policy = mock(BackOffPolicy.class); + var backOffContext = mock(BackOffContext.class); + given(policy.start(null)).willReturn(backOffContext); + SqsContainerOptions options = SqsContainerOptions.builder().maxMessagesPerPoll(10).maxConcurrentMessages(40) + .backPressureMode(BackPressureMode.ALWAYS_POLL_MAX_MESSAGES) + .maxDelayBetweenPolls(Duration.ofMillis(200)).pollBackOffPolicy(policy).build(); + BackPressureHandler backPressureHandler = BackPressureHandlerFactory + .adaptativeThroughputBackPressureHandler(options, Duration.ofMillis(100L)); + var currentPoll = new AtomicInteger(0); var waitThirdPollLatch = new CountDownLatch(4); @@ -333,14 +378,10 @@ else if (currentPoll.compareAndSet(2, 3)) { } }; - var policy = mock(BackOffPolicy.class); - var backOffContext = mock(BackOffContext.class); - given(policy.start(null)).willReturn(backOffContext); - source.setBackPressureHandler(backPressureHandler); source.setMessageSink((msgs, context) -> CompletableFuture.completedFuture(null)); source.setId(testName + " source"); - source.configure(SqsContainerOptions.builder().pollBackOffPolicy(policy).build()); + source.configure(options); source.setTaskExecutor(createTaskExecutor(testName)); source.setAcknowledgementProcessor(getNoOpsAcknowledgementProcessor()); @@ -363,24 +404,45 @@ private static boolean doAwait(CountDownLatch processingLatch) { } } - private void assertThroughputMode(SemaphoreBackPressureHandler backPressureHandler, String expectedThroughputMode) { - assertThat(ReflectionTestUtils.getField(backPressureHandler, "currentThroughputMode")) - .extracting(Object::toString).extracting(String::toLowerCase) + private void assertThroughputMode(BackPressureHandler backPressureHandler, String expectedThroughputMode) { + var bph = extractBackPressureHandler(backPressureHandler, ThroughputBackPressureHandler.class); + assertThat(getThroughputModeValue(bph, "currentThroughputMode")) .isEqualTo(expectedThroughputMode.toLowerCase()); } - private void assertAvailablePermits(SemaphoreBackPressureHandler backPressureHandler, int expectedPermits) { - assertThat(ReflectionTestUtils.getField(backPressureHandler, "semaphore")).asInstanceOf(type(Semaphore.class)) + private static String getThroughputModeValue(ThroughputBackPressureHandler bph, String targetThroughputMode) { + return ((AtomicReference) ReflectionTestUtils.getField(bph, targetThroughputMode)).get().toString() + .toLowerCase(Locale.ROOT); + } + + private void assertAvailablePermits(BackPressureHandler backPressureHandler, int expectedPermits) { + var bph = extractBackPressureHandler(backPressureHandler, ConcurrencyLimiterBlockingBackPressureHandler.class); + assertThat(ReflectionTestUtils.getField(bph, "semaphore")).asInstanceOf(type(Semaphore.class)) .extracting(Semaphore::availablePermits).isEqualTo(expectedPermits); } - private void assertAvailablePermitsLessThanOrEqualTo(SemaphoreBackPressureHandler backPressureHandler, + private void assertAvailablePermitsLessThanOrEqualTo(BackPressureHandler backPressureHandler, int maxExpectedPermits) { - assertThat(ReflectionTestUtils.getField(backPressureHandler, "semaphore")).asInstanceOf(type(Semaphore.class)) + var bph = extractBackPressureHandler(backPressureHandler, ConcurrencyLimiterBlockingBackPressureHandler.class); + assertThat(ReflectionTestUtils.getField(bph, "semaphore")).asInstanceOf(type(Semaphore.class)) .extracting(Semaphore::availablePermits).asInstanceOf(InstanceOfAssertFactories.INTEGER) .isLessThanOrEqualTo(maxExpectedPermits); } + private T extractBackPressureHandler(BackPressureHandler bph, Class type) { + if (type.isInstance(bph)) { + return type.cast(bph); + } + if (bph instanceof CompositeBackPressureHandler cbph) { + List backPressureHandlers = (List) ReflectionTestUtils + .getField(cbph, "backPressureHandlers"); + return extractBackPressureHandler( + backPressureHandlers.stream().filter(type::isInstance).map(type::cast).findFirst().orElseThrow(), + type); + } + throw new NoSuchElementException("%s not found in %s".formatted(type.getSimpleName(), bph)); + } + // Used to slow down tests while developing private void doSleep(int time) { try { diff --git a/spring-cloud-aws-sqs/src/test/java/io/awspring/cloud/sqs/listener/source/SemaphoreBackPressureHandlerAbstractPollingMessageSourceTests.java b/spring-cloud-aws-sqs/src/test/java/io/awspring/cloud/sqs/listener/source/SemaphoreBackPressureHandlerAbstractPollingMessageSourceTests.java new file mode 100644 index 000000000..94cb76959 --- /dev/null +++ b/spring-cloud-aws-sqs/src/test/java/io/awspring/cloud/sqs/listener/source/SemaphoreBackPressureHandlerAbstractPollingMessageSourceTests.java @@ -0,0 +1,445 @@ +/* + * Copyright 2013-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.awspring.cloud.sqs.listener.source; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.InstanceOfAssertFactories.type; +import static org.mockito.BDDMockito.given; +import static org.mockito.BDDMockito.then; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; + +import io.awspring.cloud.sqs.MessageExecutionThreadFactory; +import io.awspring.cloud.sqs.listener.*; +import io.awspring.cloud.sqs.listener.acknowledgement.AcknowledgementCallback; +import io.awspring.cloud.sqs.listener.acknowledgement.AcknowledgementProcessor; +import io.awspring.cloud.sqs.support.converter.MessageConversionContext; +import io.awspring.cloud.sqs.support.converter.SqsMessagingMessageConverter; +import java.time.Duration; +import java.util.Collection; +import java.util.Collections; +import java.util.List; +import java.util.UUID; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Semaphore; +import java.util.concurrent.ThreadFactory; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.stream.IntStream; +import org.assertj.core.api.InstanceOfAssertFactories; +import org.awaitility.Awaitility; +import org.junit.jupiter.api.Test; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.core.task.TaskExecutor; +import org.springframework.lang.Nullable; +import org.springframework.retry.backoff.BackOffContext; +import org.springframework.retry.backoff.BackOffPolicy; +import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor; +import org.springframework.test.util.ReflectionTestUtils; +import software.amazon.awssdk.services.sqs.model.Message; + +/** + * @author Tomaz Fernandes + */ +class SemaphoreBackPressureHandlerAbstractPollingMessageSourceTests { + + private static final Logger logger = LoggerFactory.getLogger(AbstractPollingMessageSourceTests.class); + + @Test + void shouldAcquireAndReleaseFullPermits() { + String testName = "shouldAcquireAndReleaseFullPermits"; + BackPressureHandler backPressureHandler = BackPressureHandlerFactory + .semaphoreBackPressureHandler(SqsContainerOptions.builder().maxMessagesPerPoll(10) + .maxConcurrentMessages(10).backPressureMode(BackPressureMode.ALWAYS_POLL_MAX_MESSAGES) + .maxDelayBetweenPolls(Duration.ofMillis(200)).build()); + + ExecutorService threadPool = Executors.newCachedThreadPool(); + CountDownLatch pollingCounter = new CountDownLatch(3); + CountDownLatch processingCounter = new CountDownLatch(1); + + AbstractPollingMessageSource source = new AbstractPollingMessageSource<>() { + + private final AtomicBoolean hasReceived = new AtomicBoolean(false); + + private final AtomicBoolean hasMadeSecondPoll = new AtomicBoolean(false); + + @Override + protected CompletableFuture> doPollForMessages(int messagesToRequest) { + return CompletableFuture.supplyAsync(() -> { + try { + // Since BackPressureMode.ALWAYS_POLL_MAX_MESSAGES, should always be 10. + assertThat(messagesToRequest).isEqualTo(10); + assertAvailablePermits(backPressureHandler, 0); + boolean firstPoll = hasReceived.compareAndSet(false, true); + if (firstPoll) { + logger.debug("First poll"); + // No permits released yet, should be TM low + assertThroughputMode(backPressureHandler, "low"); + } + else if (hasMadeSecondPoll.compareAndSet(false, true)) { + logger.debug("Second poll"); + // Permits returned, should be high + assertThroughputMode(backPressureHandler, "high"); + } + else { + logger.debug("Third poll"); + // Already returned full permits, should be low + assertThroughputMode(backPressureHandler, "low"); + } + return firstPoll + ? (Collection) List.of(Message.builder() + .messageId(UUID.randomUUID().toString()).body("message").build()) + : Collections. emptyList(); + } + catch (Throwable t) { + logger.error("Error", t); + throw new RuntimeException(t); + } + }, threadPool).whenComplete((v, t) -> { + if (t == null) { + pollingCounter.countDown(); + } + }); + } + }; + + source.setBackPressureHandler(backPressureHandler); + source.setMessageSink((msgs, context) -> { + assertAvailablePermits(backPressureHandler, 9); + msgs.forEach(msg -> context.runBackPressureReleaseCallback()); + return CompletableFuture.runAsync(processingCounter::countDown); + }); + + source.setId(testName + " source"); + source.configure(SqsContainerOptions.builder().build()); + source.setTaskExecutor(createTaskExecutor(testName)); + source.setAcknowledgementProcessor(getNoOpsAcknowledgementProcessor()); + source.start(); + assertThat(doAwait(pollingCounter)).isTrue(); + assertThat(doAwait(processingCounter)).isTrue(); + } + + private static final AtomicInteger testCounter = new AtomicInteger(); + + @Test + void shouldAcquireAndReleasePartialPermits() { + String testName = "shouldAcquireAndReleasePartialPermits"; + BackPressureHandler backPressureHandler = BackPressureHandlerFactory.semaphoreBackPressureHandler( + SqsContainerOptions.builder().maxMessagesPerPoll(10).maxConcurrentMessages(10) + .backPressureMode(BackPressureMode.AUTO).maxDelayBetweenPolls(Duration.ofMillis(150)).build()); + + ExecutorService threadPool = Executors + .newCachedThreadPool(new MessageExecutionThreadFactory("test " + testCounter.incrementAndGet())); + CountDownLatch pollingCounter = new CountDownLatch(4); + CountDownLatch processingCounter = new CountDownLatch(1); + CountDownLatch processingLatch = new CountDownLatch(1); + AtomicBoolean hasThrownError = new AtomicBoolean(false); + + AbstractPollingMessageSource source = new AbstractPollingMessageSource<>() { + + private final AtomicBoolean hasReceived = new AtomicBoolean(false); + + private final AtomicBoolean hasAcquired9 = new AtomicBoolean(false); + + private final AtomicBoolean hasMadeThirdPoll = new AtomicBoolean(false); + + @Override + protected CompletableFuture> doPollForMessages(int messagesToRequest) { + return CompletableFuture.supplyAsync(() -> { + try { + // Give it some time between returning empty and polling again + // doSleep(100); + + // Will only be true the first time it sets hasReceived to true + boolean shouldReturnMessage = hasReceived.compareAndSet(false, true); + if (shouldReturnMessage) { + // First poll, should have 10 + logger.debug("First poll - should request 10 messages"); + assertThat(messagesToRequest).isEqualTo(10); + assertAvailablePermits(backPressureHandler, 0); + // No permits have been released yet + assertThroughputMode(backPressureHandler, "low"); + } + else if (hasAcquired9.compareAndSet(false, true)) { + // Second poll, should have 9 + logger.debug("Second poll - should request 9 messages"); + assertThat(messagesToRequest).isEqualTo(9); + assertAvailablePermitsLessThanOrEqualTo(backPressureHandler, 1); + // Has released 9 permits, should be TM HIGH + assertThroughputMode(backPressureHandler, "high"); + processingLatch.countDown(); // Release processing now + } + else { + boolean thirdPoll = hasMadeThirdPoll.compareAndSet(false, true); + // Third poll or later, should have 10 again + logger.debug("Third poll - should request 10 messages"); + assertThat(messagesToRequest).isEqualTo(10); + assertAvailablePermits(backPressureHandler, 0); + if (thirdPoll) { + // Hasn't yet returned a full batch, should be TM High + assertThroughputMode(backPressureHandler, "high"); + } + else { + // Has returned all permits in third poll + assertThroughputMode(backPressureHandler, "low"); + } + } + if (shouldReturnMessage) { + logger.debug("shouldReturnMessage, returning one message"); + return (Collection) List.of( + Message.builder().messageId(UUID.randomUUID().toString()).body("message").build()); + } + logger.debug("should not return message, returning empty list"); + return Collections. emptyList(); + } + catch (Error e) { + hasThrownError.set(true); + throw new RuntimeException("Error polling for messages", e); + } + }, threadPool).whenComplete((v, t) -> pollingCounter.countDown()); + } + }; + + source.setBackPressureHandler(backPressureHandler); + source.setMessageSink((msgs, context) -> { + logger.debug("Processing {} messages", msgs.size()); + assertAvailablePermits(backPressureHandler, 9); + assertThat(doAwait(processingLatch)).isTrue(); + logger.debug("Finished processing {} messages", msgs.size()); + msgs.forEach(msg -> context.runBackPressureReleaseCallback()); + return CompletableFuture.completedFuture(null).thenRun(processingCounter::countDown); + }); + source.setId(testName + " source"); + source.configure(SqsContainerOptions.builder().build()); + source.setTaskExecutor(createTaskExecutor(testName)); + source.setAcknowledgementProcessor(getNoOpsAcknowledgementProcessor()); + source.start(); + assertThat(doAwait(processingCounter)).isTrue(); + assertThat(doAwait(pollingCounter)).isTrue(); + source.stop(); + assertThat(hasThrownError.get()).isFalse(); + } + + @Test + void shouldReleasePermitsOnConversionErrors() { + String testName = "shouldReleasePermitsOnConversionErrors"; + BackPressureHandler backPressureHandler = BackPressureHandlerFactory.semaphoreBackPressureHandler( + SqsContainerOptions.builder().maxMessagesPerPoll(10).maxConcurrentMessages(10) + .backPressureMode(BackPressureMode.AUTO).maxDelayBetweenPolls(Duration.ofMillis(150)).build()); + + AtomicInteger convertedMessages = new AtomicInteger(0); + AtomicInteger messagesInSink = new AtomicInteger(0); + AtomicBoolean hasFailed = new AtomicBoolean(false); + + var converter = new SqsMessagingMessageConverter() { + @Override + public org.springframework.messaging.Message toMessagingMessage(Message source, + @Nullable MessageConversionContext context) { + var converted = convertedMessages.incrementAndGet(); + logger.trace("Messages converted: {}", converted); + if (converted % 9 == 0) { + throw new RuntimeException("Expected error"); + } + return super.toMessagingMessage(source, context); + } + }; + + AbstractPollingMessageSource source = new AbstractPollingMessageSource<>() { + + @Override + protected CompletableFuture> doPollForMessages(int messagesToRequest) { + if (messagesToRequest != 10) { + logger.error("Expected 10 messages to requesst, received {}", messagesToRequest); + hasFailed.set(true); + } + return convertedMessages.get() < 30 ? CompletableFuture.completedFuture(create10Messages()) + : CompletableFuture.completedFuture(List.of()); + } + + private Collection create10Messages() { + return IntStream.range(0, 10).mapToObj( + index -> Message.builder().messageId(UUID.randomUUID().toString()).body("test-message").build()) + .toList(); + } + }; + + source.setBackPressureHandler(backPressureHandler); + source.setMessageSink((msgs, context) -> { + msgs.forEach(message -> messagesInSink.incrementAndGet()); + msgs.forEach(msg -> context.runBackPressureReleaseCallback()); + return CompletableFuture.completedFuture(null); + }); + source.setId(testName + " source"); + source.configure(SqsContainerOptions.builder().messageConverter(converter).build()); + source.setPollingEndpointName("shouldReleasePermitsOnConversionErrors-queue"); + source.setTaskExecutor(createTaskExecutor(testName)); + source.setAcknowledgementProcessor(getNoOpsAcknowledgementProcessor()); + source.start(); + Awaitility.waitAtMost(Duration.ofSeconds(10)).until(() -> convertedMessages.get() == 30); + assertThat(hasFailed).isFalse(); + assertThat(messagesInSink).hasValue(27); + source.stop(); + } + + @Test + void shouldBackOffIfPollingThrowsAnError() { + var testName = "shouldBackOffIfPollingThrowsAnError"; + BackPressureHandler backPressureHandler = BackPressureHandlerFactory + .semaphoreBackPressureHandler(SqsContainerOptions.builder().maxMessagesPerPoll(10) + .maxConcurrentMessages(40).backPressureMode(BackPressureMode.ALWAYS_POLL_MAX_MESSAGES) + .maxDelayBetweenPolls(Duration.ofMillis(200)).build()); + + var currentPoll = new AtomicInteger(0); + var waitThirdPollLatch = new CountDownLatch(4); + + AbstractPollingMessageSource source = new AbstractPollingMessageSource<>() { + @Override + protected CompletableFuture> doPollForMessages(int messagesToRequest) { + waitThirdPollLatch.countDown(); + if (currentPoll.compareAndSet(0, 1)) { + logger.debug("First poll - returning empty list"); + return CompletableFuture.completedFuture(List.of()); + } + else if (currentPoll.compareAndSet(1, 2)) { + logger.debug("Second poll - returning error"); + return CompletableFuture.failedFuture(new RuntimeException("Expected exception on second poll")); + } + else if (currentPoll.compareAndSet(2, 3)) { + logger.debug("Third poll - returning error"); + return CompletableFuture.failedFuture(new RuntimeException("Expected exception on third poll")); + } + else { + logger.debug("Fourth poll - returning empty list"); + return CompletableFuture.completedFuture(List.of()); + } + } + }; + + var policy = mock(BackOffPolicy.class); + var backOffContext = mock(BackOffContext.class); + given(policy.start(null)).willReturn(backOffContext); + + source.setBackPressureHandler(backPressureHandler); + source.setMessageSink((msgs, context) -> CompletableFuture.completedFuture(null)); + source.setId(testName + " source"); + source.configure(SqsContainerOptions.builder().pollBackOffPolicy(policy).build()); + + source.setTaskExecutor(createTaskExecutor(testName)); + source.setAcknowledgementProcessor(getNoOpsAcknowledgementProcessor()); + source.start(); + + doAwait(waitThirdPollLatch); + + then(policy).should().start(null); + then(policy).should(times(2)).backOff(backOffContext); + + } + + private static boolean doAwait(CountDownLatch processingLatch) { + try { + return processingLatch.await(4, TimeUnit.SECONDS); + } + catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new RuntimeException("Interrupted while waiting for latch", e); + } + } + + private void assertThroughputMode(BackPressureHandler backPressureHandler, String expectedThroughputMode) { + assertThat(ReflectionTestUtils.getField(backPressureHandler, "currentThroughputMode")) + .extracting(Object::toString).extracting(String::toLowerCase) + .isEqualTo(expectedThroughputMode.toLowerCase()); + } + + private void assertAvailablePermits(BackPressureHandler backPressureHandler, int expectedPermits) { + assertThat(ReflectionTestUtils.getField(backPressureHandler, "semaphore")).asInstanceOf(type(Semaphore.class)) + .extracting(Semaphore::availablePermits).isEqualTo(expectedPermits); + } + + private void assertAvailablePermitsLessThanOrEqualTo(BackPressureHandler backPressureHandler, + int maxExpectedPermits) { + assertThat(ReflectionTestUtils.getField(backPressureHandler, "semaphore")).asInstanceOf(type(Semaphore.class)) + .extracting(Semaphore::availablePermits).asInstanceOf(InstanceOfAssertFactories.INTEGER) + .isLessThanOrEqualTo(maxExpectedPermits); + } + + // Used to slow down tests while developing + private void doSleep(int time) { + try { + Thread.sleep(time); + } + catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new RuntimeException(e); + } + } + + protected TaskExecutor createTaskExecutor(String testName) { + ThreadPoolTaskExecutor executor = new ThreadPoolTaskExecutor(); + int poolSize = 10; + executor.setMaxPoolSize(poolSize); + executor.setCorePoolSize(10); + executor.setQueueCapacity(poolSize); + executor.setAllowCoreThreadTimeOut(true); + executor.setThreadFactory(createThreadFactory(testName)); + executor.afterPropertiesSet(); + return executor; + } + + protected ThreadFactory createThreadFactory(String testName) { + MessageExecutionThreadFactory threadFactory = new MessageExecutionThreadFactory(); + threadFactory.setThreadNamePrefix(testName + "-thread" + "-"); + return threadFactory; + } + + private AcknowledgementProcessor getNoOpsAcknowledgementProcessor() { + return new AcknowledgementProcessor<>() { + @Override + public AcknowledgementCallback getAcknowledgementCallback() { + return new AcknowledgementCallback<>() { + }; + } + + @Override + public void setId(String id) { + } + + @Override + public String getId() { + return "test processor"; + } + + @Override + public void start() { + } + + @Override + public void stop() { + } + + @Override + public boolean isRunning() { + return false; + } + }; + } + +}