Skip to content

Commit c826dc8

Browse files
authored
fix performance (PaddlePaddle#72095)
1 parent dbc4098 commit c826dc8

File tree

1 file changed

+1
-2
lines changed

1 file changed

+1
-2
lines changed

paddle/phi/kernels/gpu/batch_norm_kernel.cu

+1-2
Original file line numberDiff line numberDiff line change
@@ -646,8 +646,7 @@ void BatchNormKernel(const Context &ctx,
646646
#elif CUDNN_VERSION_MIN(7, 0, 1)
647647
// CUDNN_BATCHNORM_SPATIAL_PERSISTENT will cause precisio issue in NCHW
648648
// format.
649-
if (data_layout == DataLayout::kNHWC &&
650-
FLAGS_cudnn_batchnorm_spatial_persistent) {
649+
if (dtype == CUDNN_DATA_HALF && FLAGS_cudnn_batchnorm_spatial_persistent) {
651650
mode_ = CUDNN_BATCHNORM_SPATIAL_PERSISTENT;
652651
} else if (H == 1 && W == 1) {
653652
mode_ = CUDNN_BATCHNORM_PER_ACTIVATION;

0 commit comments

Comments
 (0)