@@ -28,6 +28,7 @@ namespace cub = hipcub;
28
28
29
29
#include " paddle/phi/backends/gpu/gpu_context.h"
30
30
#include " paddle/phi/common/amp_type_traits.h"
31
+ #include " paddle/phi/common/bfloat16.h"
31
32
#include " paddle/phi/common/float16.h"
32
33
#include " paddle/phi/core/hostdevice.h"
33
34
#include " paddle/phi/core/kernel_registry.h"
@@ -217,7 +218,8 @@ __global__ void BlockScanKernel(T* d_out,
217
218
}
218
219
219
220
template <typename Context, typename T>
220
- typename std::enable_if<!std::is_same<T, phi::dtype::float16>::value>::type
221
+ typename std::enable_if<!std::is_same<T, phi::dtype::float16>::value &&
222
+ !std::is_same<T, phi::dtype::bfloat16>::value>::type
221
223
ThrustCumsumKernel (const Context& dev_ctx,
222
224
const T* in_data,
223
225
T* out_data,
@@ -261,6 +263,15 @@ ThrustCumsumKernel(const Context& dev_ctx,
261
263
bool reverse,
262
264
bool exclusive) {}
263
265
266
+ template <typename Context, typename T>
267
+ typename std::enable_if<std::is_same<T, phi::dtype::bfloat16>::value>::type
268
+ ThrustCumsumKernel (const Context& dev_ctx,
269
+ const phi::dtype::bfloat16* in_data,
270
+ phi::dtype::bfloat16* out_data,
271
+ int64_t size,
272
+ bool reverse,
273
+ bool exclusive) {}
274
+
264
275
template <typename T, typename Context, typename Op>
265
276
void ScanKernel (const Context& dev_ctx,
266
277
const DenseTensor& x,
@@ -301,6 +312,7 @@ void ScanKernel(const Context& dev_ctx,
301
312
// Use thrust for parallel acceleration when the input size is equal to the
302
313
// length of the ‘axis’ dimension.
303
314
if (!std::is_same<T, phi::dtype::float16>::value &&
315
+ !std::is_same<T, phi::dtype::bfloat16>::value &&
304
316
std::is_same<Op, cub::Sum>::value && size == out_dims[axis]) {
305
317
ThrustCumsumKernel<Context, T>(
306
318
dev_ctx, in_data, out_data, size, reverse, exclusive);
@@ -440,7 +452,8 @@ PD_REGISTER_KERNEL(cumsum,
440
452
int16_t ,
441
453
int ,
442
454
int64_t ,
443
- phi::dtype::float16) {}
455
+ phi::dtype::float16,
456
+ phi::dtype::bfloat16) {}
444
457
445
458
PD_REGISTER_KERNEL (logcumsumexp,
446
459
GPU,
0 commit comments