Skip to content

Commit c785512

Browse files
authored
broadcast_add kp performance optimization (#42097)
1 parent 81078a8 commit c785512

7 files changed

+880
-43
lines changed

paddle/phi/kernels/funcs/broadcast_function.h

Lines changed: 107 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,27 @@ __device__ __forceinline__ void LoadData(
242242
}
243243
}
244244

245+
template <typename T, int VecSize, int Rank, bool IsBoundary = false>
246+
__device__ __forceinline__ void LoadData(
247+
T *dst,
248+
const _ptr_ T *src,
249+
uint32_t block_offset,
250+
const kps::details::BroadcastConfig<Rank> &config,
251+
int numel,
252+
int num,
253+
int need_broadcast,
254+
int read_lens) {
255+
// numel : whole num of output
256+
// num: how many data will be deal with in this time
257+
if (need_broadcast) {
258+
kps::ReadDataBc<T, VecSize, 1, 1, Rank, IsBoundary>(
259+
dst, src, block_offset, config, numel, read_lens);
260+
} else {
261+
kps::ReadData<T, VecSize, 1, 1, IsBoundary>(
262+
dst, src + block_offset, num, read_lens);
263+
}
264+
}
265+
245266
template <typename InT,
246267
typename OutT,
247268
typename Functor,
@@ -258,20 +279,22 @@ __device__ void VectorizedBroadcastKernelImpl(
258279
const phi::Array<kps::details::BroadcastConfig<Rank>, Arity> &configs,
259280
int num,
260281
int block_offset,
282+
int read_lens,
261283
Functor func) {
262-
InT args[Arity][VecSize];
263-
ConditionalT<OutT, NumOuts> result[VecSize];
284+
__simd__ InT args[Arity][VecSize];
285+
__simd__ ConditionalT<OutT, NumOuts> result[VecSize];
264286

265287
#pragma unroll
266288
for (int i = 0; i < Arity; i++) {
267-
kps::Init<InT, VecSize>(args[i], static_cast<InT>(1.0f));
289+
kps::Init<InT, VecSize>(args[i], static_cast<InT>(1.0f), read_lens);
268290
LoadData<InT, VecSize, Rank, IsBoundary>(args[i],
269291
ins[i],
270292
block_offset,
271293
configs[i],
272294
numel,
273295
num,
274-
use_broadcast[i]);
296+
use_broadcast[i],
297+
read_lens);
275298
}
276299
constexpr bool kCallElementwiseAny =
277300
paddle::platform::FunctionTraits<Functor>::has_pointer_args;
@@ -281,10 +304,10 @@ __device__ void VectorizedBroadcastKernelImpl(
281304
Functor,
282305
Arity,
283306
kCallElementwiseAny>()(
284-
func, args, result);
285-
286-
phi::funcs::ElementwiseWriteDataCaller<OutT, VecSize, IsBoundary, NumOuts>()(
287-
outs, result, block_offset, num);
307+
func, args, result, read_lens);
308+
phi::funcs::
309+
ElementwiseWriteDataCallerBc<OutT, VecSize, IsBoundary, NumOuts>()(
310+
outs, result, block_offset, num, read_lens);
288311
}
289312

290313
template <typename InT,
@@ -302,9 +325,10 @@ __global__ void VectorizedBroadcastKernel(
302325
phi::Array<kps::details::BroadcastConfig<Rank>, Arity> configs,
303326
int main_offset,
304327
int tail_tid,
328+
int read_lens,
305329
Functor func) {
306-
int block_offset = BLOCK_ID_X * BLOCK_NUM_X * VecSize;
307-
int stride = BLOCK_NUM_X * GRID_NUM_X * VecSize;
330+
int block_offset = BLOCK_ID_X * BLOCK_NUM_X * read_lens;
331+
int stride = BLOCK_NUM_X * GRID_NUM_X * read_lens;
308332

309333
#ifdef PADDLE_WITH_XPU_KP
310334
for (; block_offset < main_offset; block_offset += stride) {
@@ -320,8 +344,9 @@ __global__ void VectorizedBroadcastKernel(
320344
use_broadcast,
321345
numel,
322346
configs,
323-
BLOCK_NUM_X * VecSize,
347+
BLOCK_NUM_X * read_lens,
324348
block_offset,
349+
read_lens,
325350
func);
326351
}
327352
int num = numel - block_offset;
@@ -333,8 +358,15 @@ __global__ void VectorizedBroadcastKernel(
333358
NumOuts,
334359
VecSize,
335360
Rank,
336-
true>(
337-
ins, outs, use_broadcast, numel, configs, num, block_offset, func);
361+
true>(ins,
362+
outs,
363+
use_broadcast,
364+
numel,
365+
configs,
366+
num,
367+
block_offset,
368+
read_lens,
369+
func);
338370
}
339371
#else
340372
if (block_offset < main_offset) {
@@ -352,6 +384,7 @@ __global__ void VectorizedBroadcastKernel(
352384
configs,
353385
BLOCK_NUM_X * VecSize,
354386
block_offset,
387+
read_lens,
355388
func);
356389
} else {
357390
VectorizedBroadcastKernelImpl<InT,
@@ -361,8 +394,15 @@ __global__ void VectorizedBroadcastKernel(
361394
NumOuts,
362395
VecSize,
363396
Rank,
364-
true>(
365-
ins, outs, use_broadcast, numel, configs, tail_tid, block_offset, func);
397+
true>(ins,
398+
outs,
399+
use_broadcast,
400+
numel,
401+
configs,
402+
tail_tid,
403+
block_offset,
404+
read_lens,
405+
func);
366406
}
367407
#endif
368408
}
@@ -392,35 +432,70 @@ void LaunchBroadcastKernel(const KPDevice &ctx,
392432
for (int i = 0; i < Arity; i++) {
393433
use_broadcast[i] = (ins[i]->numel() != numel);
394434
ins_data[i] = (const _ptr_ InT *)(ins[i]->data<InT>());
435+
#ifdef PADDLE_WITH_XPU_KP
436+
if (i == 0) {
437+
configs[i] = kps::details::BroadcastConfig<Rank>(merge_dims.out_dims,
438+
merge_dims.in_dims[0],
439+
merge_dims.in_dims[1],
440+
merge_dims.dim_size);
441+
} else if (i == 1) {
442+
configs[i] = kps::details::BroadcastConfig<Rank>(merge_dims.out_dims,
443+
merge_dims.in_dims[1],
444+
merge_dims.in_dims[0],
445+
merge_dims.dim_size);
446+
}
447+
#else
395448
if (use_broadcast[i]) {
396449
// get the broadcast config,
397450
// if data shape is[m, n], then you should set data_dim = {n, m}
398451
// eg: out's shape [3, 45, 1]. then out_dims = {1, 45, 3}
399452
configs[i] = kps::details::BroadcastConfig<Rank>(
400453
merge_dims.out_dims, merge_dims.in_dims[i], merge_dims.dim_size);
401454
}
455+
#endif
402456
}
403457

404458
#ifdef PADDLE_WITH_XPU_KP
405459
const int threads = 64;
406460
const int blocks = 8;
407-
int main_offset = (numel / (VecSize * threads)) * VecSize * threads;
408-
int tail_tid = numel % (VecSize * threads);
461+
int read_lens = configs[0].buf_len;
462+
int main_offset = (numel / (read_lens * threads)) * read_lens * threads;
463+
int tail_tid = numel % (read_lens * threads);
409464
auto stream = ctx.x_context()->xpu_stream;
410-
VectorizedBroadcastKernel<InT,
411-
OutT,
412-
Functor,
413-
Arity,
414-
NumOuts,
415-
VecSize,
416-
Rank><<<blocks, threads, stream>>>(ins_data,
417-
outs_data,
418-
use_broadcast,
419-
numel,
420-
configs,
421-
main_offset,
422-
tail_tid,
423-
func);
465+
if (configs[0].cmp_type != kps::details::OptType::CanNotOptimize) {
466+
main_offset = numel;
467+
VectorizedBroadcastKernel<InT,
468+
OutT,
469+
Functor,
470+
Arity,
471+
NumOuts,
472+
512,
473+
Rank><<<blocks, threads, stream>>>(ins_data,
474+
outs_data,
475+
use_broadcast,
476+
numel,
477+
configs,
478+
main_offset,
479+
tail_tid,
480+
read_lens,
481+
func);
482+
} else {
483+
VectorizedBroadcastKernel<InT,
484+
OutT,
485+
Functor,
486+
Arity,
487+
NumOuts,
488+
256,
489+
Rank><<<blocks, threads, stream>>>(ins_data,
490+
outs_data,
491+
use_broadcast,
492+
numel,
493+
configs,
494+
main_offset,
495+
tail_tid,
496+
read_lens,
497+
func);
498+
}
424499
#else
425500
const int threads = 256;
426501
int blocks = ((numel + VecSize - 1) / VecSize + threads - 1) / threads;
@@ -440,6 +515,7 @@ void LaunchBroadcastKernel(const KPDevice &ctx,
440515
configs,
441516
main_offset,
442517
tail_tid,
518+
VecSize,
443519
func);
444520
#endif
445521
}

paddle/phi/kernels/funcs/elementwise_base.h

Lines changed: 49 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -577,14 +577,16 @@ template <typename InT,
577577
struct ElementwisePrimitiveCaller {
578578
__device__ inline void operator()(Functor func,
579579
InT (*args)[VecSize],
580-
OutT *result);
580+
OutT *result,
581+
int read_lens);
581582
};
582583

583584
template <typename InT, typename OutT, int VecSize, typename Functor, int Arity>
584585
struct ElementwisePrimitiveCaller<InT, OutT, VecSize, Functor, Arity, true> {
585586
__device__ inline void operator()(Functor func,
586587
InT (*args)[VecSize],
587-
OutT *result) {
588+
OutT *result,
589+
int read_lens) {
588590
kps::ElementwiseAny<InT, OutT, VecSize, 1, 1, Arity, Functor>(
589591
result, args, func);
590592
}
@@ -594,7 +596,8 @@ template <typename InT, typename OutT, int VecSize, typename Functor>
594596
struct ElementwisePrimitiveCaller<InT, OutT, VecSize, Functor, 0, false> {
595597
__device__ inline void operator()(Functor func,
596598
InT (*args)[VecSize],
597-
OutT *result) {
599+
OutT *result,
600+
int read_lens) {
598601
kps::ElementwiseConstant<InT, OutT, VecSize, 1, 1, Functor>(result, func);
599602
}
600603
};
@@ -603,7 +606,8 @@ template <typename InT, typename OutT, int VecSize, typename Functor>
603606
struct ElementwisePrimitiveCaller<InT, OutT, VecSize, Functor, 1, false> {
604607
__device__ inline void operator()(Functor func,
605608
InT (*args)[VecSize],
606-
OutT *result) {
609+
OutT *result,
610+
int read_lens) {
607611
kps::ElementwiseUnary<InT, OutT, VecSize, 1, 1, Functor>(
608612
result, args[0], func);
609613
}
@@ -613,17 +617,19 @@ template <typename InT, typename OutT, int VecSize, typename Functor>
613617
struct ElementwisePrimitiveCaller<InT, OutT, VecSize, Functor, 2, false> {
614618
__device__ inline void operator()(Functor func,
615619
InT (*args)[VecSize],
616-
OutT *result) {
620+
OutT *result,
621+
int read_lens) {
617622
kps::ElementwiseBinary<InT, OutT, VecSize, 1, 1, Functor>(
618-
result, args[0], args[1], func);
623+
result, args[0], args[1], func, read_lens);
619624
}
620625
};
621626

622627
template <typename InT, typename OutT, int VecSize, typename Functor>
623628
struct ElementwisePrimitiveCaller<InT, OutT, VecSize, Functor, 3, false> {
624629
__device__ inline void operator()(Functor func,
625630
InT (*args)[VecSize],
626-
OutT *result) {
631+
OutT *result,
632+
int read_lens) {
627633
kps::ElementwiseTernary<InT, OutT, VecSize, 1, 1, Functor>(
628634
result, args[0], args[1], args[2], func);
629635
}
@@ -696,6 +702,42 @@ struct ElementwiseWriteDataCaller<OutT, VecSize, IsBoundary, 1> {
696702
}
697703
};
698704

705+
template <typename OutT, int VecSize, bool IsBoundary, int NumOuts>
706+
struct ElementwiseWriteDataCallerBc {
707+
__device__ __forceinline__ void operator()(
708+
phi::Array<_ptr_ OutT *, NumOuts> outs,
709+
ConditionalT<OutT, NumOuts> src[VecSize],
710+
int block_offset,
711+
int num,
712+
int read_lens) {
713+
OutT dst[NumOuts][VecSize];
714+
#pragma unroll
715+
for (int i = 0; i < read_lens; ++i) {
716+
#pragma unroll
717+
for (int j = 0; j < NumOuts; ++j) {
718+
dst[j][i] = (src[i])[j];
719+
}
720+
}
721+
#pragma unroll
722+
for (int i = 0; i < NumOuts; ++i) {
723+
kps::WriteData<OutT, VecSize, 1, 1, IsBoundary>(
724+
outs[i] + block_offset, dst[i], num, read_lens);
725+
}
726+
}
727+
};
728+
729+
template <typename OutT, int VecSize, bool IsBoundary>
730+
struct ElementwiseWriteDataCallerBc<OutT, VecSize, IsBoundary, 1> {
731+
__device__ __forceinline__ void operator()(phi::Array<_ptr_ OutT *, 1> outs,
732+
OutT src[VecSize],
733+
int block_offset,
734+
int num,
735+
int read_lens) {
736+
kps::WriteData<OutT, VecSize, 1, 1, IsBoundary>(
737+
outs[0] + block_offset, src, num, read_lens);
738+
}
739+
};
740+
699741
template <typename OutT,
700742
typename Functor,
701743
int Arity,

paddle/phi/kernels/primitive/compute_primitives.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,20 @@ __device__ __forceinline__ void ElementwiseBinary(OutT* out,
271271
}
272272
}
273273

274+
template <typename InT,
275+
typename OutT,
276+
int NX,
277+
int NY,
278+
int BlockSize,
279+
class OpFunc>
280+
__device__ __forceinline__ void ElementwiseBinary(
281+
OutT* out, const InT* in1, const InT* in2, OpFunc compute, int read_lens) {
282+
#pragma unroll
283+
for (int idx = 0; idx < NX * NY; ++idx) {
284+
out[idx] = static_cast<OutT>(compute(in1[idx], in2[idx]));
285+
}
286+
}
287+
274288
/**
275289
* @brief Ternary calculation according to OpFunc. Shape of input and output
276290
* are the same.

paddle/phi/kernels/primitive/compute_primitives_xpu2.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include "xpu/kernel/cluster_header.h"
1818
#include "xpu/kernel/debug.h"
1919
#include "xpu/kernel/math.h"
20+
#include "xpu/kernel/simd_header.h"
2021

2122
namespace phi {
2223
namespace kps {
@@ -158,6 +159,19 @@ __device__ __forceinline__ void ElementwiseBinary(OutT* out,
158159
}
159160
}
160161

162+
template <typename InT,
163+
typename OutT,
164+
int NX,
165+
int NY,
166+
int BlockSize,
167+
class OpFunc>
168+
__device__ __forceinline__ void ElementwiseBinary(
169+
OutT* out, const InT* in1, const InT* in2, OpFunc compute, int read_lens) {
170+
for (int idx = 0; idx < read_lens; ++idx) {
171+
out[idx] = static_cast<OutT>(compute(in1[idx], in2[idx]));
172+
}
173+
}
174+
161175
/**
162176
* @brief Ternary calculation according to OpFunc. Shape of input and output
163177
* are the same.

0 commit comments

Comments
 (0)