@@ -242,6 +242,27 @@ __device__ __forceinline__ void LoadData(
242
242
}
243
243
}
244
244
245
+ template <typename T, int VecSize, int Rank, bool IsBoundary = false >
246
+ __device__ __forceinline__ void LoadData (
247
+ T *dst,
248
+ const _ptr_ T *src,
249
+ uint32_t block_offset,
250
+ const kps::details::BroadcastConfig<Rank> &config,
251
+ int numel,
252
+ int num,
253
+ int need_broadcast,
254
+ int read_lens) {
255
+ // numel : whole num of output
256
+ // num: how many data will be deal with in this time
257
+ if (need_broadcast) {
258
+ kps::ReadDataBc<T, VecSize, 1 , 1 , Rank, IsBoundary>(
259
+ dst, src, block_offset, config, numel, read_lens);
260
+ } else {
261
+ kps::ReadData<T, VecSize, 1 , 1 , IsBoundary>(
262
+ dst, src + block_offset, num, read_lens);
263
+ }
264
+ }
265
+
245
266
template <typename InT,
246
267
typename OutT,
247
268
typename Functor,
@@ -258,20 +279,22 @@ __device__ void VectorizedBroadcastKernelImpl(
258
279
const phi::Array<kps::details::BroadcastConfig<Rank>, Arity> &configs,
259
280
int num,
260
281
int block_offset,
282
+ int read_lens,
261
283
Functor func) {
262
- InT args[Arity][VecSize];
263
- ConditionalT<OutT, NumOuts> result[VecSize];
284
+ __simd__ InT args[Arity][VecSize];
285
+ __simd__ ConditionalT<OutT, NumOuts> result[VecSize];
264
286
265
287
#pragma unroll
266
288
for (int i = 0 ; i < Arity; i++) {
267
- kps::Init<InT, VecSize>(args[i], static_cast <InT>(1 .0f ));
289
+ kps::Init<InT, VecSize>(args[i], static_cast <InT>(1 .0f ), read_lens );
268
290
LoadData<InT, VecSize, Rank, IsBoundary>(args[i],
269
291
ins[i],
270
292
block_offset,
271
293
configs[i],
272
294
numel,
273
295
num,
274
- use_broadcast[i]);
296
+ use_broadcast[i],
297
+ read_lens);
275
298
}
276
299
constexpr bool kCallElementwiseAny =
277
300
paddle::platform::FunctionTraits<Functor>::has_pointer_args;
@@ -281,10 +304,10 @@ __device__ void VectorizedBroadcastKernelImpl(
281
304
Functor,
282
305
Arity,
283
306
kCallElementwiseAny >()(
284
- func, args, result);
285
-
286
- phi::funcs::ElementwiseWriteDataCaller <OutT, VecSize, IsBoundary, NumOuts>()(
287
- outs, result, block_offset, num);
307
+ func, args, result, read_lens );
308
+ phi::funcs::
309
+ ElementwiseWriteDataCallerBc <OutT, VecSize, IsBoundary, NumOuts>()(
310
+ outs, result, block_offset, num, read_lens );
288
311
}
289
312
290
313
template <typename InT,
@@ -302,9 +325,10 @@ __global__ void VectorizedBroadcastKernel(
302
325
phi::Array<kps::details::BroadcastConfig<Rank>, Arity> configs,
303
326
int main_offset,
304
327
int tail_tid,
328
+ int read_lens,
305
329
Functor func) {
306
- int block_offset = BLOCK_ID_X * BLOCK_NUM_X * VecSize ;
307
- int stride = BLOCK_NUM_X * GRID_NUM_X * VecSize ;
330
+ int block_offset = BLOCK_ID_X * BLOCK_NUM_X * read_lens ;
331
+ int stride = BLOCK_NUM_X * GRID_NUM_X * read_lens ;
308
332
309
333
#ifdef PADDLE_WITH_XPU_KP
310
334
for (; block_offset < main_offset; block_offset += stride) {
@@ -320,8 +344,9 @@ __global__ void VectorizedBroadcastKernel(
320
344
use_broadcast,
321
345
numel,
322
346
configs,
323
- BLOCK_NUM_X * VecSize ,
347
+ BLOCK_NUM_X * read_lens ,
324
348
block_offset,
349
+ read_lens,
325
350
func);
326
351
}
327
352
int num = numel - block_offset;
@@ -333,8 +358,15 @@ __global__ void VectorizedBroadcastKernel(
333
358
NumOuts,
334
359
VecSize,
335
360
Rank,
336
- true >(
337
- ins, outs, use_broadcast, numel, configs, num, block_offset, func);
361
+ true >(ins,
362
+ outs,
363
+ use_broadcast,
364
+ numel,
365
+ configs,
366
+ num,
367
+ block_offset,
368
+ read_lens,
369
+ func);
338
370
}
339
371
#else
340
372
if (block_offset < main_offset) {
@@ -352,6 +384,7 @@ __global__ void VectorizedBroadcastKernel(
352
384
configs,
353
385
BLOCK_NUM_X * VecSize,
354
386
block_offset,
387
+ read_lens,
355
388
func);
356
389
} else {
357
390
VectorizedBroadcastKernelImpl<InT,
@@ -361,8 +394,15 @@ __global__ void VectorizedBroadcastKernel(
361
394
NumOuts,
362
395
VecSize,
363
396
Rank,
364
- true >(
365
- ins, outs, use_broadcast, numel, configs, tail_tid, block_offset, func);
397
+ true >(ins,
398
+ outs,
399
+ use_broadcast,
400
+ numel,
401
+ configs,
402
+ tail_tid,
403
+ block_offset,
404
+ read_lens,
405
+ func);
366
406
}
367
407
#endif
368
408
}
@@ -392,35 +432,70 @@ void LaunchBroadcastKernel(const KPDevice &ctx,
392
432
for (int i = 0 ; i < Arity; i++) {
393
433
use_broadcast[i] = (ins[i]->numel () != numel);
394
434
ins_data[i] = (const _ptr_ InT *)(ins[i]->data <InT>());
435
+ #ifdef PADDLE_WITH_XPU_KP
436
+ if (i == 0 ) {
437
+ configs[i] = kps::details::BroadcastConfig<Rank>(merge_dims.out_dims ,
438
+ merge_dims.in_dims [0 ],
439
+ merge_dims.in_dims [1 ],
440
+ merge_dims.dim_size );
441
+ } else if (i == 1 ) {
442
+ configs[i] = kps::details::BroadcastConfig<Rank>(merge_dims.out_dims ,
443
+ merge_dims.in_dims [1 ],
444
+ merge_dims.in_dims [0 ],
445
+ merge_dims.dim_size );
446
+ }
447
+ #else
395
448
if (use_broadcast[i]) {
396
449
// get the broadcast config,
397
450
// if data shape is[m, n], then you should set data_dim = {n, m}
398
451
// eg: out's shape [3, 45, 1]. then out_dims = {1, 45, 3}
399
452
configs[i] = kps::details::BroadcastConfig<Rank>(
400
453
merge_dims.out_dims , merge_dims.in_dims [i], merge_dims.dim_size );
401
454
}
455
+ #endif
402
456
}
403
457
404
458
#ifdef PADDLE_WITH_XPU_KP
405
459
const int threads = 64 ;
406
460
const int blocks = 8 ;
407
- int main_offset = (numel / (VecSize * threads)) * VecSize * threads;
408
- int tail_tid = numel % (VecSize * threads);
461
+ int read_lens = configs[0 ].buf_len ;
462
+ int main_offset = (numel / (read_lens * threads)) * read_lens * threads;
463
+ int tail_tid = numel % (read_lens * threads);
409
464
auto stream = ctx.x_context ()->xpu_stream ;
410
- VectorizedBroadcastKernel<InT,
411
- OutT,
412
- Functor,
413
- Arity,
414
- NumOuts,
415
- VecSize,
416
- Rank><<<blocks, threads, stream>>>(ins_data,
417
- outs_data,
418
- use_broadcast,
419
- numel,
420
- configs,
421
- main_offset,
422
- tail_tid,
423
- func);
465
+ if (configs[0 ].cmp_type != kps::details::OptType::CanNotOptimize) {
466
+ main_offset = numel;
467
+ VectorizedBroadcastKernel<InT,
468
+ OutT,
469
+ Functor,
470
+ Arity,
471
+ NumOuts,
472
+ 512 ,
473
+ Rank><<<blocks, threads, stream>>>(ins_data,
474
+ outs_data,
475
+ use_broadcast,
476
+ numel,
477
+ configs,
478
+ main_offset,
479
+ tail_tid,
480
+ read_lens,
481
+ func);
482
+ } else {
483
+ VectorizedBroadcastKernel<InT,
484
+ OutT,
485
+ Functor,
486
+ Arity,
487
+ NumOuts,
488
+ 256 ,
489
+ Rank><<<blocks, threads, stream>>>(ins_data,
490
+ outs_data,
491
+ use_broadcast,
492
+ numel,
493
+ configs,
494
+ main_offset,
495
+ tail_tid,
496
+ read_lens,
497
+ func);
498
+ }
424
499
#else
425
500
const int threads = 256 ;
426
501
int blocks = ((numel + VecSize - 1 ) / VecSize + threads - 1 ) / threads;
@@ -440,6 +515,7 @@ void LaunchBroadcastKernel(const KPDevice &ctx,
440
515
configs,
441
516
main_offset,
442
517
tail_tid,
518
+ VecSize,
443
519
func);
444
520
#endif
445
521
}
0 commit comments