29
29
#include " paddle/phi/kernels/funcs/for_range.h"
30
30
31
31
namespace phi {
32
- template <typename T, typename MT>
32
+ template <typename T, typename TG, typename MT>
33
33
__global__ void AdamWKernelREG (MT beta1,
34
34
MT beta2,
35
35
MT epsilon,
@@ -42,7 +42,7 @@ __global__ void AdamWKernelREG(MT beta1,
42
42
const MT* moment2,
43
43
MT* moment2_out,
44
44
const MT* lr_,
45
- const T * grad,
45
+ const TG * grad,
46
46
const T* param,
47
47
T* param_out,
48
48
const MT* master_param,
@@ -78,7 +78,7 @@ __global__ void AdamWKernelREG(MT beta1,
78
78
}
79
79
}
80
80
81
- template <typename T, typename MT>
81
+ template <typename T, typename TG, typename MT>
82
82
__global__ void AdamWKernelMEM (MT beta1,
83
83
MT beta2,
84
84
MT epsilon,
@@ -91,7 +91,7 @@ __global__ void AdamWKernelMEM(MT beta1,
91
91
const MT* moment2,
92
92
MT* moment2_out,
93
93
const MT* lr_,
94
- const T * grad,
94
+ const TG * grad,
95
95
const T* param,
96
96
T* param_out,
97
97
const MT* master_param,
@@ -167,6 +167,8 @@ void AdamwDenseKernel(const Context& dev_ctx,
167
167
DenseTensor* master_param_outs) {
168
168
using MPDType = typename phi::dtype::MPTypeTrait<T>::Type;
169
169
170
+ const auto grad_type = grad.dtype ();
171
+
170
172
VLOG (4 ) << " use_global_beta_pow:" << use_global_beta_pow;
171
173
172
174
MPDType coeff_ = static_cast <MPDType>(coeff);
@@ -191,8 +193,10 @@ void AdamwDenseKernel(const Context& dev_ctx,
191
193
phi::Copy (dev_ctx, param, dev_ctx.GetPlace (), false , param_out);
192
194
phi::Copy (dev_ctx, moment1, dev_ctx.GetPlace (), false , moment1_out);
193
195
phi::Copy (dev_ctx, moment2, dev_ctx.GetPlace (), false , moment2_out);
194
- phi::Copy (dev_ctx, beta1_pow, beta1_pow.place (), false , beta1_pow_out);
195
- phi::Copy (dev_ctx, beta2_pow, beta2_pow.place (), false , beta2_pow_out);
196
+ if (!use_global_beta_pow) {
197
+ phi::Copy (dev_ctx, beta1_pow, beta1_pow.place (), false , beta1_pow_out);
198
+ phi::Copy (dev_ctx, beta2_pow, beta2_pow.place (), false , beta2_pow_out);
199
+ }
196
200
return ;
197
201
}
198
202
@@ -233,25 +237,49 @@ void AdamwDenseKernel(const Context& dev_ctx,
233
237
234
238
if (beta1_pow.place () == CPUPlace () && beta2_pow.place () == CPUPlace ()) {
235
239
// Compute with betapow in REG
236
- AdamWKernelREG<T, MPDType><<<blocks, threads, 0 , dev_ctx.stream()>>> (
237
- beta1_,
238
- beta2_,
239
- epsilon_,
240
- coeff_,
241
- lr_ratio_,
242
- *beta1_pow.data <MPDType>(),
243
- *beta2_pow.data <MPDType>(),
244
- moment1.data <MPDType>(),
245
- dev_ctx.template Alloc <MPDType>(moment1_out),
246
- moment2.data <MPDType>(),
247
- dev_ctx.template Alloc <MPDType>(moment2_out),
248
- learning_rate.data <MPDType>(),
249
- grad.data <T>(),
250
- param.data <T>(),
251
- dev_ctx.template Alloc <T>(param_out),
252
- master_in_data,
253
- master_out_data,
254
- param.numel ());
240
+ if (grad_type == phi::DataType::FLOAT32)
241
+ AdamWKernelREG<T, float , MPDType>
242
+ <<<blocks, threads, 0 , dev_ctx.stream()>>> (
243
+ beta1_,
244
+ beta2_,
245
+ epsilon_,
246
+ coeff_,
247
+ lr_ratio_,
248
+ *beta1_pow.data <MPDType>(),
249
+ *beta2_pow.data <MPDType>(),
250
+ moment1.data <MPDType>(),
251
+ dev_ctx.template Alloc <MPDType>(moment1_out),
252
+ moment2.data <MPDType>(),
253
+ dev_ctx.template Alloc <MPDType>(moment2_out),
254
+ learning_rate.data <MPDType>(),
255
+ grad.data <float >(),
256
+ param.data <T>(),
257
+ dev_ctx.template Alloc <T>(param_out),
258
+ master_in_data,
259
+ master_out_data,
260
+ param.numel ());
261
+
262
+ else
263
+
264
+ AdamWKernelREG<T, T, MPDType><<<blocks, threads, 0 , dev_ctx.stream()>>> (
265
+ beta1_,
266
+ beta2_,
267
+ epsilon_,
268
+ coeff_,
269
+ lr_ratio_,
270
+ *beta1_pow.data <MPDType>(),
271
+ *beta2_pow.data <MPDType>(),
272
+ moment1.data <MPDType>(),
273
+ dev_ctx.template Alloc <MPDType>(moment1_out),
274
+ moment2.data <MPDType>(),
275
+ dev_ctx.template Alloc <MPDType>(moment2_out),
276
+ learning_rate.data <MPDType>(),
277
+ grad.data <T>(),
278
+ param.data <T>(),
279
+ dev_ctx.template Alloc <T>(param_out),
280
+ master_in_data,
281
+ master_out_data,
282
+ param.numel ());
255
283
if (!use_global_beta_pow) {
256
284
// Cpu update
257
285
dev_ctx.template HostAlloc <MPDType>(beta1_pow_out)[0 ] =
@@ -260,28 +288,50 @@ void AdamwDenseKernel(const Context& dev_ctx,
260
288
beta2_ * beta2_pow.data <MPDType>()[0 ];
261
289
}
262
290
} else {
263
- AdamWKernelMEM<T, MPDType><<<blocks, threads, 0 , dev_ctx.stream()>>> (
264
- beta1_,
265
- beta2_,
266
- epsilon_,
267
- coeff_,
268
- lr_ratio_,
269
- beta1_pow.data <MPDType>(),
270
- beta2_pow.data <MPDType>(),
271
- moment1.data <MPDType>(),
272
- dev_ctx.template Alloc <MPDType>(moment1_out),
273
- moment2.data <MPDType>(),
274
- dev_ctx.template Alloc <MPDType>(moment2_out),
275
- learning_rate.data <MPDType>(),
276
- grad.data <T>(),
277
- param.data <T>(),
278
- dev_ctx.template Alloc <T>(param_out),
279
- master_in_data,
280
- master_out_data,
281
- param.numel ());
291
+ if (grad_type == phi::DataType::FLOAT32)
292
+ AdamWKernelMEM<T, float , MPDType>
293
+ <<<blocks, threads, 0 , dev_ctx.stream()>>> (
294
+ beta1_,
295
+ beta2_,
296
+ epsilon_,
297
+ coeff_,
298
+ lr_ratio_,
299
+ beta1_pow.data <MPDType>(),
300
+ beta2_pow.data <MPDType>(),
301
+ moment1.data <MPDType>(),
302
+ dev_ctx.template Alloc <MPDType>(moment1_out),
303
+ moment2.data <MPDType>(),
304
+ dev_ctx.template Alloc <MPDType>(moment2_out),
305
+ learning_rate.data <MPDType>(),
306
+ grad.data <float >(),
307
+ param.data <T>(),
308
+ dev_ctx.template Alloc <T>(param_out),
309
+ master_in_data,
310
+ master_out_data,
311
+ param.numel ());
312
+ else
313
+ AdamWKernelMEM<T, T, MPDType><<<blocks, threads, 0 , dev_ctx.stream()>>> (
314
+ beta1_,
315
+ beta2_,
316
+ epsilon_,
317
+ coeff_,
318
+ lr_ratio_,
319
+ beta1_pow.data <MPDType>(),
320
+ beta2_pow.data <MPDType>(),
321
+ moment1.data <MPDType>(),
322
+ dev_ctx.template Alloc <MPDType>(moment1_out),
323
+ moment2.data <MPDType>(),
324
+ dev_ctx.template Alloc <MPDType>(moment2_out),
325
+ learning_rate.data <MPDType>(),
326
+ grad.data <T>(),
327
+ param.data <T>(),
328
+ dev_ctx.template Alloc <T>(param_out),
329
+ master_in_data,
330
+ master_out_data,
331
+ param.numel ());
282
332
if (!use_global_beta_pow) {
283
333
// Update with gpu
284
- UpdateAdamWBetaPow<MPDType><<<1 , 32 , 0 , dev_ctx.stream()>>> (
334
+ UpdateAdamWBetaPow<MPDType><<<1 , 1 , 0 , dev_ctx.stream()>>> (
285
335
beta1_,
286
336
beta2_,
287
337
beta1_pow.data <MPDType>(),
@@ -300,9 +350,21 @@ PD_REGISTER_KERNEL(adamw,
300
350
phi::AdamwDenseKernel,
301
351
float ,
302
352
double ,
303
- phi::dtype::float16) {
353
+ phi::dtype::float16,
354
+ phi::dtype::bfloat16) {
304
355
// Skip beta1_pow, beta2_pow, skip_update data transform
305
356
kernel->InputAt (5 ).SetBackend (phi::Backend::ALL_BACKEND);
306
357
kernel->InputAt (6 ).SetBackend (phi::Backend::ALL_BACKEND);
307
358
kernel->InputAt (8 ).SetBackend (phi::Backend::ALL_BACKEND);
359
+
360
+ if (kernel_key.dtype () == phi::DataType::FLOAT16 ||
361
+ kernel_key.dtype () == phi::DataType::BFLOAT16) {
362
+ kernel->OutputAt (1 ).SetDataType (phi::DataType::FLOAT32);
363
+ kernel->OutputAt (2 ).SetDataType (phi::DataType::FLOAT32);
364
+ kernel->OutputAt (3 ).SetDataType (phi::DataType::FLOAT32);
365
+ kernel->OutputAt (4 ).SetDataType (phi::DataType::FLOAT32);
366
+ kernel->OutputAt (5 ).SetDataType (phi::DataType::FLOAT32);
367
+ }
368
+ kernel->OutputAt (3 ).SetBackend (phi::Backend::UNDEFINED);
369
+ kernel->OutputAt (4 ).SetBackend (phi::Backend::UNDEFINED);
308
370
}
0 commit comments