Skip to content

Commit 032fa6b

Browse files
committed
fix bug of batch_norm_grad kernel with fp16
1 parent b621a4f commit 032fa6b

File tree

1 file changed

+6
-8
lines changed

1 file changed

+6
-8
lines changed

paddle/phi/kernels/gpu/batch_norm_grad_kernel.cu

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -988,10 +988,9 @@ PD_REGISTER_KERNEL(batch_norm_grad,
988988
double,
989989
phi::dtype::float16) {
990990
if (kernel_key.dtype() == phi::DataType::FLOAT16) {
991-
kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT32);
992-
kernel->OutputAt(2).SetDataType(phi::DataType::FLOAT32);
993-
kernel->OutputAt(3).SetDataType(phi::DataType::FLOAT32);
994-
kernel->OutputAt(4).SetDataType(phi::DataType::FLOAT32);
991+
kernel->OutputAt(0).SetDataType(phi::DataType::FLOAT32); //x_grad
992+
kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT32); //scale_grad
993+
kernel->OutputAt(2).SetDataType(phi::DataType::FLOAT32); //bias_grad
995994
}
996995
}
997996

@@ -1003,10 +1002,9 @@ PD_REGISTER_KERNEL(batch_norm_grad_raw,
10031002
double,
10041003
phi::dtype::float16) {
10051004
if (kernel_key.dtype() == phi::DataType::FLOAT16) {
1006-
kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT32);
1007-
kernel->OutputAt(2).SetDataType(phi::DataType::FLOAT32);
1008-
kernel->OutputAt(3).SetDataType(phi::DataType::FLOAT32);
1009-
kernel->OutputAt(4).SetDataType(phi::DataType::FLOAT32);
1005+
kernel->OutputAt(0).SetDataType(phi::DataType::FLOAT32); //x_grad
1006+
kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT32); //scale_grad
1007+
kernel->OutputAt(2).SetDataType(phi::DataType::FLOAT32); //bias_grad
10101008
}
10111009
}
10121010

0 commit comments

Comments
 (0)