From 0c2607e669280374780c4cfdb742d9a8618ef60f Mon Sep 17 00:00:00 2001 From: niuliling123 <51102941+niuliling123@users.noreply.github.com> Date: Fri, 9 Dec 2022 14:11:49 +0800 Subject: [PATCH 1/3] cherry-pick #48563 and resolve conflict --- .../phi/kernels/gpu/batch_norm_grad_kernel.cu | 24 +++++++++++++++---- 1 file changed, 20 insertions(+), 4 deletions(-) diff --git a/paddle/phi/kernels/gpu/batch_norm_grad_kernel.cu b/paddle/phi/kernels/gpu/batch_norm_grad_kernel.cu index 5acccdfcea389..48b0cf50aac1c 100644 --- a/paddle/phi/kernels/gpu/batch_norm_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/batch_norm_grad_kernel.cu @@ -933,6 +933,21 @@ void BatchNormGradRawKernel(const Context &ctx, flag_ptr); } // 2. reduce_sum(x, dy, mean) => dscale, dbias + BatchNormParamType *dscale = nullptr; + BatchNormParamType *dbias = nullptr; + bool with_scale = false; + if (d_scale && d_bias) { + dscale = ctx.template Alloc>(d_scale); + dbias = ctx.template Alloc>(d_bias); + } else { + DenseTensor dscale_mem = + phi::Empty, Context>(ctx, {C}); + DenseTensor dbias_mem = + phi::Empty, Context>(ctx, {C}); + dscale = dscale_mem.data>(); + dbias = dbias_mem.data>(); + } + BNBackward2DChannelLastStage2 <<>>( transformed_d_y.template data(), @@ -944,8 +959,8 @@ void BatchNormGradRawKernel(const Context &ctx, H * W * D, epsilon, block_data_ptr, - ctx.template Alloc>(d_scale), - ctx.template Alloc>(d_bias), + dscale, + dbias, flag_ptr); // 3. elementwise_mul(scale, mean, inv_var, dy, dscale, dbias) => dx @@ -954,8 +969,8 @@ void BatchNormGradRawKernel(const Context &ctx, transformed_d_y.template data(), transformed_x.template data(), scale.template data>(), - d_scale->data>(), - d_bias->data>(), + dscale, + dbias, mean_ptr, variance_ptr, C, @@ -1165,6 +1180,7 @@ void BatchNormGradRawKernel(const Context &ctx, paddle::platform::dynload::cudnnDestroyTensorDescriptor( bn_param_desc_)); #endif + } else { const auto *running_mean = mean.get_ptr(); const auto *running_var = variance.get_ptr(); From 34f3cbcafb503ddaeb76107f6c08d984fb5c6104 Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Wed, 8 Feb 2023 15:10:28 +0800 Subject: [PATCH 2/3] Fix bn performance degradation (#50287) * fix bn performance degradation --- paddle/phi/kernels/gpu/batch_norm_grad_kernel.cu | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/paddle/phi/kernels/gpu/batch_norm_grad_kernel.cu b/paddle/phi/kernels/gpu/batch_norm_grad_kernel.cu index 48b0cf50aac1c..658b42e616521 100644 --- a/paddle/phi/kernels/gpu/batch_norm_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/batch_norm_grad_kernel.cu @@ -855,7 +855,8 @@ void BatchNormGradRawKernel(const Context &ctx, } // CUDNN only support small batch size bool use_native_nhwc = - d_x ? (x_dims.size() == 4 && compute_format == DataLayout::kNHWC) + d_x ? (x_dims.size() == 4 && compute_format == DataLayout::kNHWC && + H * W >= CUDNN_SPATIAL_THRESHOLD_EVAL) : false; const bool use_native_kernel = ((x_dims.size() == 2 && N >= CUDNN_PER_ACTIVATION_THRESHOLD) || From 83f1ba7a5ca5337be10c67902abad400cd6eb928 Mon Sep 17 00:00:00 2001 From: umiswing Date: Fri, 10 Feb 2023 19:17:01 +0800 Subject: [PATCH 3/3] remove if constexpr(), which is not supported on gcc54 (#50395) --- paddle/phi/kernels/sparse/gpu/conv_kernel.cu | 66 ++++--------------- .../kernels/sparse/gpu/gather_gemm_scatter.h | 61 +++++++++++++++++ 2 files changed, 73 insertions(+), 54 deletions(-) diff --git a/paddle/phi/kernels/sparse/gpu/conv_kernel.cu b/paddle/phi/kernels/sparse/gpu/conv_kernel.cu index e6f3ca3364918..f575b903895f9 100644 --- a/paddle/phi/kernels/sparse/gpu/conv_kernel.cu +++ b/paddle/phi/kernels/sparse/gpu/conv_kernel.cu @@ -150,60 +150,18 @@ void Conv3dCooGPUKernel(const GPUContext& dev_ctx, const IntT* gather_indices = rulebook_ptr + h_offsets_ptr[i]; const IntT* scatter_indices = rulebook_ptr + rulebook_len + h_offsets_ptr[i]; - - if constexpr (std::is_same::value && - std::is_same::value) { - fp16_gather_gemm_scatter gather_gemm_scatter = - getBestFp16Kernel(M, N, K); - gather_gemm_scatter( - dev_ctx, - reinterpret_cast( - x.non_zero_elements().data()), - reinterpret_cast(tmp_kernel_ptr), - reinterpret_cast(out_values_ptr), - reinterpret_cast(out_values_ptr), - M, - N, - K, - static_cast(gather_indices), - static_cast(scatter_indices), - static_cast(1), - static_cast(1)); - } - if constexpr (std::is_same::value && - std::is_same::value) { - fp32_gather_gemm_scatter gather_gemm_scatter = - getBestFp32Kernel(M, N, K, dev_ctx.GetComputeCapability()); - gather_gemm_scatter(dev_ctx, - x.non_zero_elements().data(), - tmp_kernel_ptr, - out_values_ptr, - out_values_ptr, - M, - N, - K, - gather_indices, - scatter_indices, - static_cast(1), - static_cast(1)); - } - if constexpr (std::is_same::value && - std::is_same::value) { - fp64_gather_gemm_scatter gather_gemm_scatter = - getBestFp64Kernel(M, N, K); - gather_gemm_scatter(dev_ctx, - x.non_zero_elements().data(), - tmp_kernel_ptr, - out_values_ptr, - out_values_ptr, - M, - N, - K, - gather_indices, - scatter_indices, - static_cast(1), - static_cast(1)); - } + dispatchKernel(dev_ctx, + x.non_zero_elements().data(), + tmp_kernel_ptr, + out_values_ptr, + out_values_ptr, + M, + N, + K, + gather_indices, + scatter_indices, + cutlass, + x.dtype()); } } else { #endif diff --git a/paddle/phi/kernels/sparse/gpu/gather_gemm_scatter.h b/paddle/phi/kernels/sparse/gpu/gather_gemm_scatter.h index b596ff545383f..dab35ed47737a 100644 --- a/paddle/phi/kernels/sparse/gpu/gather_gemm_scatter.h +++ b/paddle/phi/kernels/sparse/gpu/gather_gemm_scatter.h @@ -23,6 +23,7 @@ #include "cutlass/util/device_memory.h" #include "examples/common/helper.h" #include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/common/data_type.h" namespace phi { namespace sparse { typedef void (*fp16_gather_gemm_scatter)(const GPUContext& dev_ctx, @@ -115,6 +116,66 @@ void launchKernel(const GPUContext& dev_ctx, CUTLASS_CHECK(status); gemm_op(dev_ctx.stream()); } +static void dispatchKernel(const GPUContext& dev_ctx, + const void* const a, + const void* const b, + const void* const c, + void* const d, + const int m, + const int n, + const int k, + const void* a_indices, + const void* c_d_indices, + const bool cutlass, + const phi::DataType type) { + if (!cutlass) return; + + if (type == phi::DataType::FLOAT16) { + fp16_gather_gemm_scatter gather_gemm_scatter = getBestFp16Kernel(m, n, k); + gather_gemm_scatter(dev_ctx, + static_cast(a), + static_cast(b), + static_cast(c), + static_cast(d), + m, + n, + k, + static_cast(a_indices), + static_cast(c_d_indices), + static_cast(1), + static_cast(1)); + } else if (type == phi::DataType::FLOAT32) { + fp32_gather_gemm_scatter gather_gemm_scatter = + getBestFp32Kernel(m, n, k, dev_ctx.GetComputeCapability()); + gather_gemm_scatter(dev_ctx, + static_cast(a), + static_cast(b), + static_cast(c), + static_cast(d), + m, + n, + k, + static_cast(a_indices), + static_cast(c_d_indices), + static_cast(1), + static_cast(1)); + } else if (type == phi::DataType::FLOAT64) { + fp64_gather_gemm_scatter gather_gemm_scatter = getBestFp64Kernel(m, n, k); + gather_gemm_scatter(dev_ctx, + static_cast(a), + static_cast(b), + static_cast(c), + static_cast(d), + m, + n, + k, + static_cast(a_indices), + static_cast(c_d_indices), + static_cast(1), + static_cast(1)); + } +} + struct cutlass_tensorop_h1688gemm_128x64_32x2_nn_align8 { using Gemm = cutlass::gemm::device::GemmUniversal< cutlass::half_t,