File tree 1 file changed +6
-8
lines changed 1 file changed +6
-8
lines changed Original file line number Diff line number Diff line change @@ -988,10 +988,9 @@ PD_REGISTER_KERNEL(batch_norm_grad,
988
988
double ,
989
989
phi::dtype::float16) {
990
990
if (kernel_key.dtype () == phi::DataType::FLOAT16) {
991
- kernel->OutputAt (1 ).SetDataType (phi::DataType::FLOAT32);
992
- kernel->OutputAt (2 ).SetDataType (phi::DataType::FLOAT32);
993
- kernel->OutputAt (3 ).SetDataType (phi::DataType::FLOAT32);
994
- kernel->OutputAt (4 ).SetDataType (phi::DataType::FLOAT32);
991
+ kernel->OutputAt (0 ).SetDataType (phi::DataType::FLOAT32); // x_grad
992
+ kernel->OutputAt (1 ).SetDataType (phi::DataType::FLOAT32); // scale_grad
993
+ kernel->OutputAt (2 ).SetDataType (phi::DataType::FLOAT32); // bias_grad
995
994
}
996
995
}
997
996
@@ -1003,10 +1002,9 @@ PD_REGISTER_KERNEL(batch_norm_grad_raw,
1003
1002
double ,
1004
1003
phi::dtype::float16) {
1005
1004
if (kernel_key.dtype () == phi::DataType::FLOAT16) {
1006
- kernel->OutputAt (1 ).SetDataType (phi::DataType::FLOAT32);
1007
- kernel->OutputAt (2 ).SetDataType (phi::DataType::FLOAT32);
1008
- kernel->OutputAt (3 ).SetDataType (phi::DataType::FLOAT32);
1009
- kernel->OutputAt (4 ).SetDataType (phi::DataType::FLOAT32);
1005
+ kernel->OutputAt (0 ).SetDataType (phi::DataType::FLOAT32); // x_grad
1006
+ kernel->OutputAt (1 ).SetDataType (phi::DataType::FLOAT32); // scale_grad
1007
+ kernel->OutputAt (2 ).SetDataType (phi::DataType::FLOAT32); // bias_grad
1010
1008
}
1011
1009
}
1012
1010
You can’t perform that action at this time.
0 commit comments