Add Semaphore Support for cp.async
loads (Non-TMA Load Patterns)
#97
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This PR introduces semaphore support for non-TMA
load_async
operations by leveraging the PTX instructioncp.async.mbarrier.arrive.noinc.shared::cta.b64
. The change aims to simplify producer-consumer kernels with non-standard load patterns that cannot be completed by the TMA.Background and Motivation
Working with @DanFu09, I developed sparse matmul kernels that required using
cp.async
instead of TMA because of our unique memory layout. Currently, producer-consumer kernels force the producer to callcp.async.wait_all
and manually signal the semaphore (e.g. FFTConv kernel). Our tests show that manually waiting on a semaphore withcp.async.wait_all
plus an explicitarrive(bar)
is over 200 TFLOPS slower than allowingcp.async
to automatically signal the semaphore.Note on Semaphores:
The PTX instruction
cp.async.mbarrier.arrive.noinc.shared::cta.b64
ensures that once all non-committedcp.async
operations from the current thread finish, that thread automatically arrives at the semaphore. Until then, it can work on other tasks. For example, whenwarpgroup::load_async
is called with a semaphore, the expected arrival count is128
(32 threads per warp * 4 warps). Detailed explanations are provided in the updated library comments.What's New
load_async
operations can now automatically work with semaphores by accepting an optional semaphore parameter.