Skip to content

Commit 01eeba5

Browse files
authored
[AMP OP&Test] Support fp16/bf16 for cumsum (#51694)
* add fp16 unittest * support bf16 and add unittest * fix according to review
1 parent 9c238d2 commit 01eeba5

File tree

3 files changed

+248
-221
lines changed

3 files changed

+248
-221
lines changed

paddle/phi/kernels/gpu/cum_grad_kernel.cu

+3-1
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ namespace cub = hipcub;
2929

3030
#include "paddle/phi/backends/gpu/gpu_context.h"
3131
#include "paddle/phi/common/amp_type_traits.h"
32+
#include "paddle/phi/common/bfloat16.h"
3233
#include "paddle/phi/common/float16.h"
3334
#include "paddle/phi/core/hostdevice.h"
3435
#include "paddle/phi/core/kernel_registry.h"
@@ -82,5 +83,6 @@ PD_REGISTER_KERNEL(cumsum_grad,
8283
int16_t,
8384
int,
8485
int64_t,
85-
phi::dtype::float16) {}
86+
phi::dtype::float16,
87+
phi::dtype::bfloat16) {}
8688
#endif

paddle/phi/kernels/gpu/cum_kernel.cu

+15-2
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ namespace cub = hipcub;
2828

2929
#include "paddle/phi/backends/gpu/gpu_context.h"
3030
#include "paddle/phi/common/amp_type_traits.h"
31+
#include "paddle/phi/common/bfloat16.h"
3132
#include "paddle/phi/common/float16.h"
3233
#include "paddle/phi/core/hostdevice.h"
3334
#include "paddle/phi/core/kernel_registry.h"
@@ -217,7 +218,8 @@ __global__ void BlockScanKernel(T* d_out,
217218
}
218219

219220
template <typename Context, typename T>
220-
typename std::enable_if<!std::is_same<T, phi::dtype::float16>::value>::type
221+
typename std::enable_if<!std::is_same<T, phi::dtype::float16>::value &&
222+
!std::is_same<T, phi::dtype::bfloat16>::value>::type
221223
ThrustCumsumKernel(const Context& dev_ctx,
222224
const T* in_data,
223225
T* out_data,
@@ -261,6 +263,15 @@ ThrustCumsumKernel(const Context& dev_ctx,
261263
bool reverse,
262264
bool exclusive) {}
263265

266+
template <typename Context, typename T>
267+
typename std::enable_if<std::is_same<T, phi::dtype::bfloat16>::value>::type
268+
ThrustCumsumKernel(const Context& dev_ctx,
269+
const phi::dtype::bfloat16* in_data,
270+
phi::dtype::bfloat16* out_data,
271+
int64_t size,
272+
bool reverse,
273+
bool exclusive) {}
274+
264275
template <typename T, typename Context, typename Op>
265276
void ScanKernel(const Context& dev_ctx,
266277
const DenseTensor& x,
@@ -301,6 +312,7 @@ void ScanKernel(const Context& dev_ctx,
301312
// Use thrust for parallel acceleration when the input size is equal to the
302313
// length of the ‘axis’ dimension.
303314
if (!std::is_same<T, phi::dtype::float16>::value &&
315+
!std::is_same<T, phi::dtype::bfloat16>::value &&
304316
std::is_same<Op, cub::Sum>::value && size == out_dims[axis]) {
305317
ThrustCumsumKernel<Context, T>(
306318
dev_ctx, in_data, out_data, size, reverse, exclusive);
@@ -440,7 +452,8 @@ PD_REGISTER_KERNEL(cumsum,
440452
int16_t,
441453
int,
442454
int64_t,
443-
phi::dtype::float16) {}
455+
phi::dtype::float16,
456+
phi::dtype::bfloat16) {}
444457

445458
PD_REGISTER_KERNEL(logcumsumexp,
446459
GPU,

0 commit comments

Comments
 (0)