Skip to content

Commit bf03ff5

Browse files
committed
format code
1 parent 032fa6b commit bf03ff5

File tree

1 file changed

+6
-7
lines changed

1 file changed

+6
-7
lines changed

paddle/phi/kernels/gpu/batch_norm_grad_kernel.cu

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -988,9 +988,9 @@ PD_REGISTER_KERNEL(batch_norm_grad,
988988
double,
989989
phi::dtype::float16) {
990990
if (kernel_key.dtype() == phi::DataType::FLOAT16) {
991-
kernel->OutputAt(0).SetDataType(phi::DataType::FLOAT32); //x_grad
992-
kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT32); //scale_grad
993-
kernel->OutputAt(2).SetDataType(phi::DataType::FLOAT32); //bias_grad
991+
kernel->OutputAt(0).SetDataType(phi::DataType::FLOAT32); // x_grad
992+
kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT32); // scale_grad
993+
kernel->OutputAt(2).SetDataType(phi::DataType::FLOAT32); // bias_grad
994994
}
995995
}
996996

@@ -1002,9 +1002,9 @@ PD_REGISTER_KERNEL(batch_norm_grad_raw,
10021002
double,
10031003
phi::dtype::float16) {
10041004
if (kernel_key.dtype() == phi::DataType::FLOAT16) {
1005-
kernel->OutputAt(0).SetDataType(phi::DataType::FLOAT32); //x_grad
1006-
kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT32); //scale_grad
1007-
kernel->OutputAt(2).SetDataType(phi::DataType::FLOAT32); //bias_grad
1005+
kernel->OutputAt(0).SetDataType(phi::DataType::FLOAT32); // x_grad
1006+
kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT32); // scale_grad
1007+
kernel->OutputAt(2).SetDataType(phi::DataType::FLOAT32); // bias_grad
10081008
}
10091009
}
10101010

@@ -1017,7 +1017,6 @@ PD_REGISTER_KERNEL(batch_norm_grad_grad,
10171017
phi::BatchNormDoubleGradKernel,
10181018
float,
10191019
double) {}
1020-
10211020
#else
10221021
PD_REGISTER_KERNEL(batch_norm_grad_grad,
10231022
GPU,

0 commit comments

Comments
 (0)