@@ -96,19 +96,26 @@ __global__ void ScatterCUDAKernel(const T* params,
96
96
}
97
97
}
98
98
99
- template <typename T, typename IndexT = int >
99
+ template <typename T, typename IndexT, int VecSize >
100
100
__global__ void ScatterNdCUDAKernel (const T* update,
101
101
const IndexT* indices,
102
102
T* output,
103
103
const Dim<DDim::kMaxRank > output_dims,
104
104
size_t remain_size,
105
105
size_t slice_size,
106
106
size_t end_size) {
107
- CUDA_KERNEL_LOOP_TYPE (i, remain_size * slice_size, int64_t ) {
108
- int64_t indices_i = i / slice_size;
109
- int64_t slice_i = i - indices_i * slice_size; // offset inside the slice
107
+ int64_t total_size = remain_size * slice_size;
108
+ int64_t idx = (blockIdx.x * blockDim.x + threadIdx.x ) * VecSize;
109
+ int64_t stride = blockDim.x * gridDim.x * VecSize;
110
+
111
+ #pragma unroll
112
+ for (; idx < total_size; idx += stride) {
113
+ int64_t indices_i = idx / slice_size;
114
+ int64_t slice_i = idx % slice_size;
110
115
int64_t gather_i = 0 ;
111
116
int64_t temp = slice_size;
117
+
118
+ #pragma unroll
112
119
for (int64_t j = end_size - 1 ; j >= 0 ; --j) {
113
120
IndexT index_value = indices[indices_i * end_size + j];
114
121
PADDLE_ENFORCE (
@@ -128,8 +135,17 @@ __global__ void ScatterNdCUDAKernel(const T* update,
128
135
gather_i += (index_value * temp);
129
136
temp *= output_dims[j];
130
137
}
138
+
131
139
int64_t output_i = gather_i + slice_i;
132
- phi::CudaAtomicAdd (output + output_i, *(update + i));
140
+
141
+ using VecType = kps::details::VectorType<T, VecSize>;
142
+ const VecType* src = reinterpret_cast <const VecType*>(&update[idx]);
143
+ VecType* dst = reinterpret_cast <VecType*>(&output[output_i]);
144
+
145
+ #pragma unroll
146
+ for (int k = 0 ; k < VecSize; ++k) {
147
+ phi::CudaAtomicAdd (&(dst->val [k]), src->val [k]);
148
+ }
133
149
}
134
150
}
135
151
@@ -254,6 +270,72 @@ void GPUScatterGradForX(const phi::GPUContext& ctx,
254
270
p_index, p_output, dst_dims[0 ], index_size, slice_size);
255
271
}
256
272
273
+ template <typename T, typename IndexT, int VecSize>
274
+ void DispatchScatterNdKernel (
275
+ const phi::GPUContext& ctx,
276
+ const T* p_update,
277
+ const IndexT* p_index,
278
+ T* p_output,
279
+ const Dim<DDim::kMaxRank >& g_output_dims,
280
+ int64_t remain_numel,
281
+ int64_t slice_size,
282
+ int64_t end_size,
283
+ int vec_size,
284
+ const phi::backends::gpu::GpuLaunchConfig& config) {
285
+ if (vec_size == VecSize) {
286
+ auto stream = ctx.stream ();
287
+ ScatterNdCUDAKernel<T, IndexT, VecSize>
288
+ <<<config.block_per_grid , config.thread_per_block , 0 , stream>>>(
289
+ p_update,
290
+ p_index,
291
+ p_output,
292
+ g_output_dims,
293
+ remain_numel,
294
+ slice_size,
295
+ end_size);
296
+ } else {
297
+ DispatchScatterNdKernel<T, IndexT, VecSize / 2 >(ctx,
298
+ p_update,
299
+ p_index,
300
+ p_output,
301
+ g_output_dims,
302
+ remain_numel,
303
+ slice_size,
304
+ end_size,
305
+ vec_size,
306
+ config);
307
+ }
308
+ }
309
+
310
+ template <typename T, typename IndexT>
311
+ void DispatchScatterNdKernel (
312
+ const phi::GPUContext& ctx,
313
+ const T* p_update,
314
+ const IndexT* p_index,
315
+ T* p_output,
316
+ const Dim<DDim::kMaxRank >& g_output_dims,
317
+ int64_t remain_numel,
318
+ int64_t slice_size,
319
+ int64_t end_size,
320
+ int vec_size,
321
+ const phi::backends::gpu::GpuLaunchConfig& config) {
322
+ if (vec_size == 1 ) {
323
+ auto stream = ctx.stream ();
324
+ ScatterNdCUDAKernel<T, IndexT, 1 >
325
+ <<<config.block_per_grid , config.thread_per_block , 0 , stream>>>(
326
+ p_update,
327
+ p_index,
328
+ p_output,
329
+ g_output_dims,
330
+ remain_numel,
331
+ slice_size,
332
+ end_size);
333
+ } else {
334
+ PADDLE_THROW (common::errors::Unimplemented (
335
+ " Unsupported vectorized size: %d" , vec_size));
336
+ }
337
+ }
338
+
257
339
template <typename T, typename IndexT = int >
258
340
void GPUScatterNdAdd (const phi::GPUContext& ctx,
259
341
const DenseTensor& update,
@@ -286,19 +368,27 @@ void GPUScatterNdAdd(const phi::GPUContext& ctx,
286
368
g_output_dims[i] = output_dims[i];
287
369
}
288
370
289
- int block = 512 ;
290
- int64_t n = slice_size * remain_numel;
291
- dim3 grid = dim3 ((n + block - 1 ) / block);
292
- phi::backends::gpu::LimitGridDim (ctx, &grid);
371
+ int vec_size = 4 ;
372
+ vec_size = std::min (phi::GetVectorizedSize (p_update), vec_size);
373
+ vec_size = std::min (phi::GetVectorizedSize (p_output), vec_size);
374
+ while (vec_size > 1 && slice_size % vec_size != 0 ) {
375
+ vec_size /= 2 ;
376
+ }
293
377
294
- ScatterNdCUDAKernel<T, IndexT>
295
- <<<grid, block, 0 , ctx.stream ()>>>(p_update,
296
- p_index,
297
- p_output,
298
- g_output_dims,
299
- remain_numel,
300
- slice_size,
301
- end_size);
378
+ constexpr int loop_count = 4 ;
379
+ auto config = phi::backends::gpu::GetGpuLaunchConfig1D (
380
+ ctx, remain_numel * slice_size, vec_size * loop_count);
381
+
382
+ DispatchScatterNdKernel<T, IndexT, 4 >(ctx,
383
+ p_update,
384
+ p_index,
385
+ p_output,
386
+ g_output_dims,
387
+ remain_numel,
388
+ slice_size,
389
+ end_size,
390
+ vec_size,
391
+ config);
302
392
}
303
393
304
394
} // namespace funcs
0 commit comments