Skip to content

[PHI] Optimize GPUScatterNdAdd #72666

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 14, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 107 additions & 17 deletions paddle/phi/kernels/funcs/scatter.cu.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,19 +96,26 @@ __global__ void ScatterCUDAKernel(const T* params,
}
}

template <typename T, typename IndexT = int>
template <typename T, typename IndexT, int VecSize>
__global__ void ScatterNdCUDAKernel(const T* update,
const IndexT* indices,
T* output,
const Dim<DDim::kMaxRank> output_dims,
size_t remain_size,
size_t slice_size,
size_t end_size) {
CUDA_KERNEL_LOOP_TYPE(i, remain_size * slice_size, int64_t) {
int64_t indices_i = i / slice_size;
int64_t slice_i = i - indices_i * slice_size; // offset inside the slice
int64_t total_size = remain_size * slice_size;
int64_t idx = (blockIdx.x * blockDim.x + threadIdx.x) * VecSize;
int64_t stride = blockDim.x * gridDim.x * VecSize;

#pragma unroll
for (; idx < total_size; idx += stride) {
int64_t indices_i = idx / slice_size;
int64_t slice_i = idx % slice_size;
int64_t gather_i = 0;
int64_t temp = slice_size;

#pragma unroll
for (int64_t j = end_size - 1; j >= 0; --j) {
IndexT index_value = indices[indices_i * end_size + j];
PADDLE_ENFORCE(
Expand All @@ -128,8 +135,17 @@ __global__ void ScatterNdCUDAKernel(const T* update,
gather_i += (index_value * temp);
temp *= output_dims[j];
}

int64_t output_i = gather_i + slice_i;
phi::CudaAtomicAdd(output + output_i, *(update + i));

using VecType = kps::details::VectorType<T, VecSize>;
const VecType* src = reinterpret_cast<const VecType*>(&update[idx]);
VecType* dst = reinterpret_cast<VecType*>(&output[output_i]);

#pragma unroll
for (int k = 0; k < VecSize; ++k) {
phi::CudaAtomicAdd(&(dst->val[k]), src->val[k]);
}
}
}

Expand Down Expand Up @@ -254,6 +270,72 @@ void GPUScatterGradForX(const phi::GPUContext& ctx,
p_index, p_output, dst_dims[0], index_size, slice_size);
}

template <typename T, typename IndexT, int VecSize>
void DispatchScatterNdKernel(
const phi::GPUContext& ctx,
const T* p_update,
const IndexT* p_index,
T* p_output,
const Dim<DDim::kMaxRank>& g_output_dims,
int64_t remain_numel,
int64_t slice_size,
int64_t end_size,
int vec_size,
const phi::backends::gpu::GpuLaunchConfig& config) {
if (vec_size == VecSize) {
auto stream = ctx.stream();
ScatterNdCUDAKernel<T, IndexT, VecSize>
<<<config.block_per_grid, config.thread_per_block, 0, stream>>>(
p_update,
p_index,
p_output,
g_output_dims,
remain_numel,
slice_size,
end_size);
} else {
DispatchScatterNdKernel<T, IndexT, VecSize / 2>(ctx,
p_update,
p_index,
p_output,
g_output_dims,
remain_numel,
slice_size,
end_size,
vec_size,
config);
}
}

template <typename T, typename IndexT>
void DispatchScatterNdKernel(
const phi::GPUContext& ctx,
const T* p_update,
const IndexT* p_index,
T* p_output,
const Dim<DDim::kMaxRank>& g_output_dims,
int64_t remain_numel,
int64_t slice_size,
int64_t end_size,
int vec_size,
const phi::backends::gpu::GpuLaunchConfig& config) {
if (vec_size == 1) {
auto stream = ctx.stream();
ScatterNdCUDAKernel<T, IndexT, 1>
<<<config.block_per_grid, config.thread_per_block, 0, stream>>>(
p_update,
p_index,
p_output,
g_output_dims,
remain_numel,
slice_size,
end_size);
} else {
PADDLE_THROW(common::errors::Unimplemented(
"Unsupported vectorized size: %d", vec_size));
}
}

template <typename T, typename IndexT = int>
void GPUScatterNdAdd(const phi::GPUContext& ctx,
const DenseTensor& update,
Expand Down Expand Up @@ -286,19 +368,27 @@ void GPUScatterNdAdd(const phi::GPUContext& ctx,
g_output_dims[i] = output_dims[i];
}

int block = 512;
int64_t n = slice_size * remain_numel;
dim3 grid = dim3((n + block - 1) / block);
phi::backends::gpu::LimitGridDim(ctx, &grid);
int vec_size = 4;
vec_size = std::min(phi::GetVectorizedSize(p_update), vec_size);
vec_size = std::min(phi::GetVectorizedSize(p_output), vec_size);
while (vec_size > 1 && slice_size % vec_size != 0) {
vec_size /= 2;
}

ScatterNdCUDAKernel<T, IndexT>
<<<grid, block, 0, ctx.stream()>>>(p_update,
p_index,
p_output,
g_output_dims,
remain_numel,
slice_size,
end_size);
constexpr int loop_count = 4;
auto config = phi::backends::gpu::GetGpuLaunchConfig1D(
ctx, remain_numel * slice_size, vec_size * loop_count);

DispatchScatterNdKernel<T, IndexT, 4>(ctx,
p_update,
p_index,
p_output,
g_output_dims,
remain_numel,
slice_size,
end_size,
vec_size,
config);
}

} // namespace funcs
Expand Down