@@ -231,29 +231,23 @@ void AdamWKernel(const Context& dev_ctx,
231
231
VLOG (3 ) << " Skip update" << skip_update_ << " , With decay: " << with_decay;
232
232
233
233
if (!skip_update_ && with_decay) {
234
- if (master_param.is_initialized ()) {
235
- PADDLE_THROW (
236
- phi::errors::Unimplemented (" Master Param is not supported on MLU" ));
237
- } else {
238
- // update param with decay coeff: mul(-1 * lr, coeff * param) + param
239
- MLUCnnlTensorDesc lr_desc (learning_rate);
240
- MLUCnnlTensorDesc param_desc (param);
241
- MLUCnnlOpTensorDesc mul_op_desc (
242
- CNNL_OP_TENSOR_MUL, ToCnnlDataType<T>(), CNNL_NOT_PROPAGATE_NAN);
234
+ MLUCnnlTensorDesc lr_desc (learning_rate);
235
+ MLUCnnlTensorDesc param_desc (param);
236
+ MLUCnnlOpTensorDesc mul_op_desc (
237
+ CNNL_OP_TENSOR_MUL, ToCnnlDataType<T>(), CNNL_NOT_PROPAGATE_NAN);
243
238
244
- MLUCnnl::OpTensor (dev_ctx,
245
- mul_op_desc.get (),
246
- lr_desc.get (),
247
- GetBasePtr (&learning_rate),
248
- param_desc.get (),
249
- GetBasePtr (¶m),
250
- param_desc.get (),
251
- const_cast <void *>(GetBasePtr (¶m)),
252
- ToCnnlDataType<T>(),
253
- /* alpha1*/ -1 .f ,
254
- /* alpha2*/ coeff,
255
- /* beta*/ 1 .f );
256
- }
239
+ MLUCnnl::OpTensor (dev_ctx,
240
+ mul_op_desc.get (),
241
+ lr_desc.get (),
242
+ GetBasePtr (&learning_rate),
243
+ param_desc.get (),
244
+ GetBasePtr (¶m),
245
+ param_desc.get (),
246
+ const_cast <void *>(GetBasePtr (¶m)),
247
+ ToCnnlDataType<T>(),
248
+ /* alpha1*/ -1 .f ,
249
+ /* alpha2*/ coeff,
250
+ /* beta*/ 1 .f );
257
251
}
258
252
259
253
custom_kernel::AdamKernel<T, Context>(dev_ctx,
0 commit comments