Skip to content

Commit ba9a22d

Browse files
Xrekishaojiewang
andauthored
Add bfloat16 support for several operators and apis. (#52696)
* Cherry-pick the register of bfloat16 for amp_kernel, pull request #45541. * Cherry-pick the master_grad support of adamw, pull request #51141. * add bf16 for some ops in static mode (#51582) * Add bfloat16 support for some api in static mode. * Fix codestyle. * Revert the change of layer_function_generator.py. --------- Co-authored-by: Shaojie WANG <wsjmessi@163.com>
1 parent 95c3d61 commit ba9a22d

File tree

9 files changed

+781
-441
lines changed

9 files changed

+781
-441
lines changed

paddle/phi/kernels/gpu/adamw_kernel.cu

+108-46
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
#include "paddle/phi/kernels/funcs/for_range.h"
3030

3131
namespace phi {
32-
template <typename T, typename MT>
32+
template <typename T, typename TG, typename MT>
3333
__global__ void AdamWKernelREG(MT beta1,
3434
MT beta2,
3535
MT epsilon,
@@ -42,7 +42,7 @@ __global__ void AdamWKernelREG(MT beta1,
4242
const MT* moment2,
4343
MT* moment2_out,
4444
const MT* lr_,
45-
const T* grad,
45+
const TG* grad,
4646
const T* param,
4747
T* param_out,
4848
const MT* master_param,
@@ -78,7 +78,7 @@ __global__ void AdamWKernelREG(MT beta1,
7878
}
7979
}
8080

81-
template <typename T, typename MT>
81+
template <typename T, typename TG, typename MT>
8282
__global__ void AdamWKernelMEM(MT beta1,
8383
MT beta2,
8484
MT epsilon,
@@ -91,7 +91,7 @@ __global__ void AdamWKernelMEM(MT beta1,
9191
const MT* moment2,
9292
MT* moment2_out,
9393
const MT* lr_,
94-
const T* grad,
94+
const TG* grad,
9595
const T* param,
9696
T* param_out,
9797
const MT* master_param,
@@ -167,6 +167,8 @@ void AdamwDenseKernel(const Context& dev_ctx,
167167
DenseTensor* master_param_outs) {
168168
using MPDType = typename phi::dtype::MPTypeTrait<T>::Type;
169169

170+
const auto grad_type = grad.dtype();
171+
170172
VLOG(4) << "use_global_beta_pow:" << use_global_beta_pow;
171173

172174
MPDType coeff_ = static_cast<MPDType>(coeff);
@@ -191,8 +193,10 @@ void AdamwDenseKernel(const Context& dev_ctx,
191193
phi::Copy(dev_ctx, param, dev_ctx.GetPlace(), false, param_out);
192194
phi::Copy(dev_ctx, moment1, dev_ctx.GetPlace(), false, moment1_out);
193195
phi::Copy(dev_ctx, moment2, dev_ctx.GetPlace(), false, moment2_out);
194-
phi::Copy(dev_ctx, beta1_pow, beta1_pow.place(), false, beta1_pow_out);
195-
phi::Copy(dev_ctx, beta2_pow, beta2_pow.place(), false, beta2_pow_out);
196+
if (!use_global_beta_pow) {
197+
phi::Copy(dev_ctx, beta1_pow, beta1_pow.place(), false, beta1_pow_out);
198+
phi::Copy(dev_ctx, beta2_pow, beta2_pow.place(), false, beta2_pow_out);
199+
}
196200
return;
197201
}
198202

@@ -233,25 +237,49 @@ void AdamwDenseKernel(const Context& dev_ctx,
233237

234238
if (beta1_pow.place() == CPUPlace() && beta2_pow.place() == CPUPlace()) {
235239
// Compute with betapow in REG
236-
AdamWKernelREG<T, MPDType><<<blocks, threads, 0, dev_ctx.stream()>>>(
237-
beta1_,
238-
beta2_,
239-
epsilon_,
240-
coeff_,
241-
lr_ratio_,
242-
*beta1_pow.data<MPDType>(),
243-
*beta2_pow.data<MPDType>(),
244-
moment1.data<MPDType>(),
245-
dev_ctx.template Alloc<MPDType>(moment1_out),
246-
moment2.data<MPDType>(),
247-
dev_ctx.template Alloc<MPDType>(moment2_out),
248-
learning_rate.data<MPDType>(),
249-
grad.data<T>(),
250-
param.data<T>(),
251-
dev_ctx.template Alloc<T>(param_out),
252-
master_in_data,
253-
master_out_data,
254-
param.numel());
240+
if (grad_type == phi::DataType::FLOAT32)
241+
AdamWKernelREG<T, float, MPDType>
242+
<<<blocks, threads, 0, dev_ctx.stream()>>>(
243+
beta1_,
244+
beta2_,
245+
epsilon_,
246+
coeff_,
247+
lr_ratio_,
248+
*beta1_pow.data<MPDType>(),
249+
*beta2_pow.data<MPDType>(),
250+
moment1.data<MPDType>(),
251+
dev_ctx.template Alloc<MPDType>(moment1_out),
252+
moment2.data<MPDType>(),
253+
dev_ctx.template Alloc<MPDType>(moment2_out),
254+
learning_rate.data<MPDType>(),
255+
grad.data<float>(),
256+
param.data<T>(),
257+
dev_ctx.template Alloc<T>(param_out),
258+
master_in_data,
259+
master_out_data,
260+
param.numel());
261+
262+
else
263+
264+
AdamWKernelREG<T, T, MPDType><<<blocks, threads, 0, dev_ctx.stream()>>>(
265+
beta1_,
266+
beta2_,
267+
epsilon_,
268+
coeff_,
269+
lr_ratio_,
270+
*beta1_pow.data<MPDType>(),
271+
*beta2_pow.data<MPDType>(),
272+
moment1.data<MPDType>(),
273+
dev_ctx.template Alloc<MPDType>(moment1_out),
274+
moment2.data<MPDType>(),
275+
dev_ctx.template Alloc<MPDType>(moment2_out),
276+
learning_rate.data<MPDType>(),
277+
grad.data<T>(),
278+
param.data<T>(),
279+
dev_ctx.template Alloc<T>(param_out),
280+
master_in_data,
281+
master_out_data,
282+
param.numel());
255283
if (!use_global_beta_pow) {
256284
// Cpu update
257285
dev_ctx.template HostAlloc<MPDType>(beta1_pow_out)[0] =
@@ -260,28 +288,50 @@ void AdamwDenseKernel(const Context& dev_ctx,
260288
beta2_ * beta2_pow.data<MPDType>()[0];
261289
}
262290
} else {
263-
AdamWKernelMEM<T, MPDType><<<blocks, threads, 0, dev_ctx.stream()>>>(
264-
beta1_,
265-
beta2_,
266-
epsilon_,
267-
coeff_,
268-
lr_ratio_,
269-
beta1_pow.data<MPDType>(),
270-
beta2_pow.data<MPDType>(),
271-
moment1.data<MPDType>(),
272-
dev_ctx.template Alloc<MPDType>(moment1_out),
273-
moment2.data<MPDType>(),
274-
dev_ctx.template Alloc<MPDType>(moment2_out),
275-
learning_rate.data<MPDType>(),
276-
grad.data<T>(),
277-
param.data<T>(),
278-
dev_ctx.template Alloc<T>(param_out),
279-
master_in_data,
280-
master_out_data,
281-
param.numel());
291+
if (grad_type == phi::DataType::FLOAT32)
292+
AdamWKernelMEM<T, float, MPDType>
293+
<<<blocks, threads, 0, dev_ctx.stream()>>>(
294+
beta1_,
295+
beta2_,
296+
epsilon_,
297+
coeff_,
298+
lr_ratio_,
299+
beta1_pow.data<MPDType>(),
300+
beta2_pow.data<MPDType>(),
301+
moment1.data<MPDType>(),
302+
dev_ctx.template Alloc<MPDType>(moment1_out),
303+
moment2.data<MPDType>(),
304+
dev_ctx.template Alloc<MPDType>(moment2_out),
305+
learning_rate.data<MPDType>(),
306+
grad.data<float>(),
307+
param.data<T>(),
308+
dev_ctx.template Alloc<T>(param_out),
309+
master_in_data,
310+
master_out_data,
311+
param.numel());
312+
else
313+
AdamWKernelMEM<T, T, MPDType><<<blocks, threads, 0, dev_ctx.stream()>>>(
314+
beta1_,
315+
beta2_,
316+
epsilon_,
317+
coeff_,
318+
lr_ratio_,
319+
beta1_pow.data<MPDType>(),
320+
beta2_pow.data<MPDType>(),
321+
moment1.data<MPDType>(),
322+
dev_ctx.template Alloc<MPDType>(moment1_out),
323+
moment2.data<MPDType>(),
324+
dev_ctx.template Alloc<MPDType>(moment2_out),
325+
learning_rate.data<MPDType>(),
326+
grad.data<T>(),
327+
param.data<T>(),
328+
dev_ctx.template Alloc<T>(param_out),
329+
master_in_data,
330+
master_out_data,
331+
param.numel());
282332
if (!use_global_beta_pow) {
283333
// Update with gpu
284-
UpdateAdamWBetaPow<MPDType><<<1, 32, 0, dev_ctx.stream()>>>(
334+
UpdateAdamWBetaPow<MPDType><<<1, 1, 0, dev_ctx.stream()>>>(
285335
beta1_,
286336
beta2_,
287337
beta1_pow.data<MPDType>(),
@@ -300,9 +350,21 @@ PD_REGISTER_KERNEL(adamw,
300350
phi::AdamwDenseKernel,
301351
float,
302352
double,
303-
phi::dtype::float16) {
353+
phi::dtype::float16,
354+
phi::dtype::bfloat16) {
304355
// Skip beta1_pow, beta2_pow, skip_update data transform
305356
kernel->InputAt(5).SetBackend(phi::Backend::ALL_BACKEND);
306357
kernel->InputAt(6).SetBackend(phi::Backend::ALL_BACKEND);
307358
kernel->InputAt(8).SetBackend(phi::Backend::ALL_BACKEND);
359+
360+
if (kernel_key.dtype() == phi::DataType::FLOAT16 ||
361+
kernel_key.dtype() == phi::DataType::BFLOAT16) {
362+
kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT32);
363+
kernel->OutputAt(2).SetDataType(phi::DataType::FLOAT32);
364+
kernel->OutputAt(3).SetDataType(phi::DataType::FLOAT32);
365+
kernel->OutputAt(4).SetDataType(phi::DataType::FLOAT32);
366+
kernel->OutputAt(5).SetDataType(phi::DataType::FLOAT32);
367+
}
368+
kernel->OutputAt(3).SetBackend(phi::Backend::UNDEFINED);
369+
kernel->OutputAt(4).SetBackend(phi::Backend::UNDEFINED);
308370
}

paddle/phi/kernels/gpu/amp_kernel.cu

+4-2
Original file line numberDiff line numberDiff line change
@@ -357,14 +357,16 @@ PD_REGISTER_KERNEL(check_finite_and_unscale,
357357
phi::CheckFiniteAndUnscaleKernel,
358358
float,
359359
double,
360-
phi::dtype::float16) {}
360+
phi::dtype::float16,
361+
phi::dtype::bfloat16) {}
361362

362363
PD_REGISTER_KERNEL(update_loss_scaling,
363364
GPU,
364365
ALL_LAYOUT,
365366
phi::UpdateLossScalingKernel,
366367
float,
367368
double,
368-
phi::dtype::float16) {
369+
phi::dtype::float16,
370+
phi::dtype::bfloat16) {
369371
kernel->InputAt(1).SetBackend(phi::Backend::ALL_BACKEND);
370372
}

paddle/phi/kernels/gpu/matmul_grad_kernel.cu

+1
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ PD_REGISTER_KERNEL(matmul_with_flatten_grad,
5555
phi::MatmulWithFlattenGradKernel,
5656
float,
5757
double,
58+
phi::dtype::bfloat16,
5859
phi::dtype::float16) {}
5960

6061
PD_REGISTER_KERNEL(matmul_with_flatten_double_grad,

paddle/phi/kernels/gpu/matmul_kernel.cu

+1
Original file line numberDiff line numberDiff line change
@@ -36,4 +36,5 @@ PD_REGISTER_KERNEL(matmul_with_flatten,
3636
phi::MatmulWithFlattenKernel,
3737
float,
3838
double,
39+
phi::dtype::bfloat16,
3940
phi::dtype::float16) {}

0 commit comments

Comments
 (0)