Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 26 additions & 33 deletions fbgemm_gpu/src/sparse_ops/sparse_async_batched_cumsum.cu
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,10 @@ template <
uint32_t nthreads_per_block,
typename = std::enable_if_t<std::is_integral<val_t>::value>>
__global__ __launch_bounds__(kMaxThreads) void _batched_complete_cumsum_kernel(
const at::PackedTensorAccessor64<val_t, 2, at::RestrictPtrTraits> values,
const pta::PackedTensorAccessor64<val_t, 2, at::RestrictPtrTraits> values,
const uint32_t len,
const uint32_t items_per_thread,
at::PackedTensorAccessor64<val_t, 2, at::RestrictPtrTraits> out) {
pta::PackedTensorAccessor64<val_t, 2, at::RestrictPtrTraits> out) {
using BlockScan = cub::BlockScan<val_t, nthreads_per_block>;
__shared__ typename BlockScan::TempStorage temp_storage;

Expand All @@ -79,6 +79,18 @@ __global__ __launch_bounds__(kMaxThreads) void _batched_complete_cumsum_kernel(
}
}

#define BATCHED_COMPLETE_CUMSUM_KERNEL(NTHREADS_PER_BLOCK) \
FBGEMM_LAUNCH_KERNEL( \
(_batched_complete_cumsum_kernel<val_t, NTHREADS_PER_BLOCK>), \
B, \
NTHREADS_PER_BLOCK, \
0, \
at::cuda::getCurrentCUDAStream(), \
PTA_B(values, val_t, 2, 64), \
len, \
items_per_thread, \
PTA_B(cumsum, val_t, 2, 64));

at::Tensor asynchronous_batched_complete_cumsum_gpu(const at::Tensor& values) {
at::cuda::OptionalCUDAGuard device_guard;
device_guard.set_index(values.get_device());
Expand All @@ -102,48 +114,29 @@ at::Tensor asynchronous_batched_complete_cumsum_gpu(const at::Tensor& values) {
AT_DISPATCH_INTEGRAL_TYPES(
values.scalar_type(), "batched_complete_cumsum_cuda_input1", [&] {
using val_t = scalar_t;

if (nthreads_per_block == 64) {
_batched_complete_cumsum_kernel<val_t, 64>
<<<B, 64, 0, at::cuda::getCurrentCUDAStream()>>>(
values.packed_accessor64<val_t, 2, at::RestrictPtrTraits>(),
len,
items_per_thread,
cumsum.packed_accessor64<val_t, 2, at::RestrictPtrTraits>());
BATCHED_COMPLETE_CUMSUM_KERNEL(64);

} else if (nthreads_per_block == 128) {
_batched_complete_cumsum_kernel<val_t, 128>
<<<B, 128, 0, at::cuda::getCurrentCUDAStream()>>>(
values.packed_accessor64<val_t, 2, at::RestrictPtrTraits>(),
len,
items_per_thread,
cumsum.packed_accessor64<val_t, 2, at::RestrictPtrTraits>());
BATCHED_COMPLETE_CUMSUM_KERNEL(128);

} else if (nthreads_per_block == 256) {
_batched_complete_cumsum_kernel<val_t, 256>
<<<B, 256, 0, at::cuda::getCurrentCUDAStream()>>>(
values.packed_accessor64<val_t, 2, at::RestrictPtrTraits>(),
len,
items_per_thread,
cumsum.packed_accessor64<val_t, 2, at::RestrictPtrTraits>());
BATCHED_COMPLETE_CUMSUM_KERNEL(256);

} else if (nthreads_per_block == 512) {
_batched_complete_cumsum_kernel<val_t, 512>
<<<B, 512, 0, at::cuda::getCurrentCUDAStream()>>>(
values.packed_accessor64<val_t, 2, at::RestrictPtrTraits>(),
len,
items_per_thread,
cumsum.packed_accessor64<val_t, 2, at::RestrictPtrTraits>());
BATCHED_COMPLETE_CUMSUM_KERNEL(512);

} else {
_batched_complete_cumsum_kernel<val_t, 1024>
<<<B, 1024, 0, at::cuda::getCurrentCUDAStream()>>>(
values.packed_accessor64<val_t, 2, at::RestrictPtrTraits>(),
len,
items_per_thread,
cumsum.packed_accessor64<val_t, 2, at::RestrictPtrTraits>());
BATCHED_COMPLETE_CUMSUM_KERNEL(1024);
}
C10_CUDA_KERNEL_LAUNCH_CHECK();
});

return cumsum;
}

#undef BATCHED_COMPLETE_CUMSUM_KERNEL

} // namespace fbgemm_gpu

FBGEMM_OP_DISPATCH(
Expand Down
Loading