From f470b22a303f43b0d6743cdf0522585f55452bd3 Mon Sep 17 00:00:00 2001 From: zhanghonggeng Date: Mon, 12 May 2025 02:50:39 +0000 Subject: [PATCH] [PHI] Optimize GPUScatterNdAdd --- paddle/phi/kernels/funcs/scatter.cu.h | 124 ++++++++++++++++++++++---- 1 file changed, 107 insertions(+), 17 deletions(-) diff --git a/paddle/phi/kernels/funcs/scatter.cu.h b/paddle/phi/kernels/funcs/scatter.cu.h index 164732bc7670d9..b69633d487ff5c 100644 --- a/paddle/phi/kernels/funcs/scatter.cu.h +++ b/paddle/phi/kernels/funcs/scatter.cu.h @@ -96,7 +96,7 @@ __global__ void ScatterCUDAKernel(const T* params, } } -template +template __global__ void ScatterNdCUDAKernel(const T* update, const IndexT* indices, T* output, @@ -104,11 +104,18 @@ __global__ void ScatterNdCUDAKernel(const T* update, size_t remain_size, size_t slice_size, size_t end_size) { - CUDA_KERNEL_LOOP_TYPE(i, remain_size * slice_size, int64_t) { - int64_t indices_i = i / slice_size; - int64_t slice_i = i - indices_i * slice_size; // offset inside the slice + int64_t total_size = remain_size * slice_size; + int64_t idx = (blockIdx.x * blockDim.x + threadIdx.x) * VecSize; + int64_t stride = blockDim.x * gridDim.x * VecSize; + +#pragma unroll + for (; idx < total_size; idx += stride) { + int64_t indices_i = idx / slice_size; + int64_t slice_i = idx % slice_size; int64_t gather_i = 0; int64_t temp = slice_size; + +#pragma unroll for (int64_t j = end_size - 1; j >= 0; --j) { IndexT index_value = indices[indices_i * end_size + j]; PADDLE_ENFORCE( @@ -128,8 +135,17 @@ __global__ void ScatterNdCUDAKernel(const T* update, gather_i += (index_value * temp); temp *= output_dims[j]; } + int64_t output_i = gather_i + slice_i; - phi::CudaAtomicAdd(output + output_i, *(update + i)); + + using VecType = kps::details::VectorType; + const VecType* src = reinterpret_cast(&update[idx]); + VecType* dst = reinterpret_cast(&output[output_i]); + +#pragma unroll + for (int k = 0; k < VecSize; ++k) { + phi::CudaAtomicAdd(&(dst->val[k]), src->val[k]); + } } } @@ -254,6 +270,72 @@ void GPUScatterGradForX(const phi::GPUContext& ctx, p_index, p_output, dst_dims[0], index_size, slice_size); } +template +void DispatchScatterNdKernel( + const phi::GPUContext& ctx, + const T* p_update, + const IndexT* p_index, + T* p_output, + const Dim& g_output_dims, + int64_t remain_numel, + int64_t slice_size, + int64_t end_size, + int vec_size, + const phi::backends::gpu::GpuLaunchConfig& config) { + if (vec_size == VecSize) { + auto stream = ctx.stream(); + ScatterNdCUDAKernel + <<>>( + p_update, + p_index, + p_output, + g_output_dims, + remain_numel, + slice_size, + end_size); + } else { + DispatchScatterNdKernel(ctx, + p_update, + p_index, + p_output, + g_output_dims, + remain_numel, + slice_size, + end_size, + vec_size, + config); + } +} + +template +void DispatchScatterNdKernel( + const phi::GPUContext& ctx, + const T* p_update, + const IndexT* p_index, + T* p_output, + const Dim& g_output_dims, + int64_t remain_numel, + int64_t slice_size, + int64_t end_size, + int vec_size, + const phi::backends::gpu::GpuLaunchConfig& config) { + if (vec_size == 1) { + auto stream = ctx.stream(); + ScatterNdCUDAKernel + <<>>( + p_update, + p_index, + p_output, + g_output_dims, + remain_numel, + slice_size, + end_size); + } else { + PADDLE_THROW(common::errors::Unimplemented( + "Unsupported vectorized size: %d", vec_size)); + } +} + template void GPUScatterNdAdd(const phi::GPUContext& ctx, const DenseTensor& update, @@ -286,19 +368,27 @@ void GPUScatterNdAdd(const phi::GPUContext& ctx, g_output_dims[i] = output_dims[i]; } - int block = 512; - int64_t n = slice_size * remain_numel; - dim3 grid = dim3((n + block - 1) / block); - phi::backends::gpu::LimitGridDim(ctx, &grid); + int vec_size = 4; + vec_size = std::min(phi::GetVectorizedSize(p_update), vec_size); + vec_size = std::min(phi::GetVectorizedSize(p_output), vec_size); + while (vec_size > 1 && slice_size % vec_size != 0) { + vec_size /= 2; + } - ScatterNdCUDAKernel - <<>>(p_update, - p_index, - p_output, - g_output_dims, - remain_numel, - slice_size, - end_size); + constexpr int loop_count = 4; + auto config = phi::backends::gpu::GetGpuLaunchConfig1D( + ctx, remain_numel * slice_size, vec_size * loop_count); + + DispatchScatterNdKernel(ctx, + p_update, + p_index, + p_output, + g_output_dims, + remain_numel, + slice_size, + end_size, + vec_size, + config); } } // namespace funcs