File tree 1 file changed +6
-7
lines changed 1 file changed +6
-7
lines changed Original file line number Diff line number Diff line change @@ -988,9 +988,9 @@ PD_REGISTER_KERNEL(batch_norm_grad,
988
988
double ,
989
989
phi::dtype::float16) {
990
990
if (kernel_key.dtype () == phi::DataType::FLOAT16) {
991
- kernel->OutputAt (0 ).SetDataType (phi::DataType::FLOAT32); // x_grad
992
- kernel->OutputAt (1 ).SetDataType (phi::DataType::FLOAT32); // scale_grad
993
- kernel->OutputAt (2 ).SetDataType (phi::DataType::FLOAT32); // bias_grad
991
+ kernel->OutputAt (0 ).SetDataType (phi::DataType::FLOAT32); // x_grad
992
+ kernel->OutputAt (1 ).SetDataType (phi::DataType::FLOAT32); // scale_grad
993
+ kernel->OutputAt (2 ).SetDataType (phi::DataType::FLOAT32); // bias_grad
994
994
}
995
995
}
996
996
@@ -1002,9 +1002,9 @@ PD_REGISTER_KERNEL(batch_norm_grad_raw,
1002
1002
double ,
1003
1003
phi::dtype::float16) {
1004
1004
if (kernel_key.dtype () == phi::DataType::FLOAT16) {
1005
- kernel->OutputAt (0 ).SetDataType (phi::DataType::FLOAT32); // x_grad
1006
- kernel->OutputAt (1 ).SetDataType (phi::DataType::FLOAT32); // scale_grad
1007
- kernel->OutputAt (2 ).SetDataType (phi::DataType::FLOAT32); // bias_grad
1005
+ kernel->OutputAt (0 ).SetDataType (phi::DataType::FLOAT32); // x_grad
1006
+ kernel->OutputAt (1 ).SetDataType (phi::DataType::FLOAT32); // scale_grad
1007
+ kernel->OutputAt (2 ).SetDataType (phi::DataType::FLOAT32); // bias_grad
1008
1008
}
1009
1009
}
1010
1010
@@ -1017,7 +1017,6 @@ PD_REGISTER_KERNEL(batch_norm_grad_grad,
1017
1017
phi::BatchNormDoubleGradKernel,
1018
1018
float ,
1019
1019
double ) {}
1020
-
1021
1020
#else
1022
1021
PD_REGISTER_KERNEL (batch_norm_grad_grad,
1023
1022
GPU,
You can’t perform that action at this time.
0 commit comments