Skip to content

Commit 6428d9c

Browse files
[PHI] Optimize GPUScatterNdAdd (#72666)
1 parent 32d5f4b commit 6428d9c

File tree

1 file changed

+107
-17
lines changed

1 file changed

+107
-17
lines changed

paddle/phi/kernels/funcs/scatter.cu.h

+107-17
Original file line numberDiff line numberDiff line change
@@ -96,19 +96,26 @@ __global__ void ScatterCUDAKernel(const T* params,
9696
}
9797
}
9898

99-
template <typename T, typename IndexT = int>
99+
template <typename T, typename IndexT, int VecSize>
100100
__global__ void ScatterNdCUDAKernel(const T* update,
101101
const IndexT* indices,
102102
T* output,
103103
const Dim<DDim::kMaxRank> output_dims,
104104
size_t remain_size,
105105
size_t slice_size,
106106
size_t end_size) {
107-
CUDA_KERNEL_LOOP_TYPE(i, remain_size * slice_size, int64_t) {
108-
int64_t indices_i = i / slice_size;
109-
int64_t slice_i = i - indices_i * slice_size; // offset inside the slice
107+
int64_t total_size = remain_size * slice_size;
108+
int64_t idx = (blockIdx.x * blockDim.x + threadIdx.x) * VecSize;
109+
int64_t stride = blockDim.x * gridDim.x * VecSize;
110+
111+
#pragma unroll
112+
for (; idx < total_size; idx += stride) {
113+
int64_t indices_i = idx / slice_size;
114+
int64_t slice_i = idx % slice_size;
110115
int64_t gather_i = 0;
111116
int64_t temp = slice_size;
117+
118+
#pragma unroll
112119
for (int64_t j = end_size - 1; j >= 0; --j) {
113120
IndexT index_value = indices[indices_i * end_size + j];
114121
PADDLE_ENFORCE(
@@ -128,8 +135,17 @@ __global__ void ScatterNdCUDAKernel(const T* update,
128135
gather_i += (index_value * temp);
129136
temp *= output_dims[j];
130137
}
138+
131139
int64_t output_i = gather_i + slice_i;
132-
phi::CudaAtomicAdd(output + output_i, *(update + i));
140+
141+
using VecType = kps::details::VectorType<T, VecSize>;
142+
const VecType* src = reinterpret_cast<const VecType*>(&update[idx]);
143+
VecType* dst = reinterpret_cast<VecType*>(&output[output_i]);
144+
145+
#pragma unroll
146+
for (int k = 0; k < VecSize; ++k) {
147+
phi::CudaAtomicAdd(&(dst->val[k]), src->val[k]);
148+
}
133149
}
134150
}
135151

@@ -254,6 +270,72 @@ void GPUScatterGradForX(const phi::GPUContext& ctx,
254270
p_index, p_output, dst_dims[0], index_size, slice_size);
255271
}
256272

273+
template <typename T, typename IndexT, int VecSize>
274+
void DispatchScatterNdKernel(
275+
const phi::GPUContext& ctx,
276+
const T* p_update,
277+
const IndexT* p_index,
278+
T* p_output,
279+
const Dim<DDim::kMaxRank>& g_output_dims,
280+
int64_t remain_numel,
281+
int64_t slice_size,
282+
int64_t end_size,
283+
int vec_size,
284+
const phi::backends::gpu::GpuLaunchConfig& config) {
285+
if (vec_size == VecSize) {
286+
auto stream = ctx.stream();
287+
ScatterNdCUDAKernel<T, IndexT, VecSize>
288+
<<<config.block_per_grid, config.thread_per_block, 0, stream>>>(
289+
p_update,
290+
p_index,
291+
p_output,
292+
g_output_dims,
293+
remain_numel,
294+
slice_size,
295+
end_size);
296+
} else {
297+
DispatchScatterNdKernel<T, IndexT, VecSize / 2>(ctx,
298+
p_update,
299+
p_index,
300+
p_output,
301+
g_output_dims,
302+
remain_numel,
303+
slice_size,
304+
end_size,
305+
vec_size,
306+
config);
307+
}
308+
}
309+
310+
template <typename T, typename IndexT>
311+
void DispatchScatterNdKernel(
312+
const phi::GPUContext& ctx,
313+
const T* p_update,
314+
const IndexT* p_index,
315+
T* p_output,
316+
const Dim<DDim::kMaxRank>& g_output_dims,
317+
int64_t remain_numel,
318+
int64_t slice_size,
319+
int64_t end_size,
320+
int vec_size,
321+
const phi::backends::gpu::GpuLaunchConfig& config) {
322+
if (vec_size == 1) {
323+
auto stream = ctx.stream();
324+
ScatterNdCUDAKernel<T, IndexT, 1>
325+
<<<config.block_per_grid, config.thread_per_block, 0, stream>>>(
326+
p_update,
327+
p_index,
328+
p_output,
329+
g_output_dims,
330+
remain_numel,
331+
slice_size,
332+
end_size);
333+
} else {
334+
PADDLE_THROW(common::errors::Unimplemented(
335+
"Unsupported vectorized size: %d", vec_size));
336+
}
337+
}
338+
257339
template <typename T, typename IndexT = int>
258340
void GPUScatterNdAdd(const phi::GPUContext& ctx,
259341
const DenseTensor& update,
@@ -286,19 +368,27 @@ void GPUScatterNdAdd(const phi::GPUContext& ctx,
286368
g_output_dims[i] = output_dims[i];
287369
}
288370

289-
int block = 512;
290-
int64_t n = slice_size * remain_numel;
291-
dim3 grid = dim3((n + block - 1) / block);
292-
phi::backends::gpu::LimitGridDim(ctx, &grid);
371+
int vec_size = 4;
372+
vec_size = std::min(phi::GetVectorizedSize(p_update), vec_size);
373+
vec_size = std::min(phi::GetVectorizedSize(p_output), vec_size);
374+
while (vec_size > 1 && slice_size % vec_size != 0) {
375+
vec_size /= 2;
376+
}
293377

294-
ScatterNdCUDAKernel<T, IndexT>
295-
<<<grid, block, 0, ctx.stream()>>>(p_update,
296-
p_index,
297-
p_output,
298-
g_output_dims,
299-
remain_numel,
300-
slice_size,
301-
end_size);
378+
constexpr int loop_count = 4;
379+
auto config = phi::backends::gpu::GetGpuLaunchConfig1D(
380+
ctx, remain_numel * slice_size, vec_size * loop_count);
381+
382+
DispatchScatterNdKernel<T, IndexT, 4>(ctx,
383+
p_update,
384+
p_index,
385+
p_output,
386+
g_output_dims,
387+
remain_numel,
388+
slice_size,
389+
end_size,
390+
vec_size,
391+
config);
302392
}
303393

304394
} // namespace funcs

0 commit comments

Comments
 (0)