diff --git a/paddle/common/flags.cc b/paddle/common/flags.cc index 6ba106e8164df3..d17ea34847dce8 100644 --- a/paddle/common/flags.cc +++ b/paddle/common/flags.cc @@ -1879,6 +1879,20 @@ PHI_DEFINE_EXPORTED_bool( false, "Enable xqa optim in block_multihead_attention kernel (GQA)."); +/** + * Whether to use FP32 for accumulation of QK output in + * block_multihead_attention kernel(fp16) + * Name: blha_use_fp32_qk_sum Since Version: 3.0.0 + * Value Range: bool, default=false + * Example: + * Note: If TRUE, FP32 will be used for accumulation of the QK output + * in block_multihead_attention kernel(fp16) . + */ +PHI_DEFINE_EXPORTED_bool(blha_use_fp32_qk_sum, + false, + "use FP32 for accumulation of QK output in " + "block_multihead_attention kernel(fp16)."); + PHI_DEFINE_EXPORTED_bool(cuda_core_int8_gemm, false, "Enable speed up int8 gemm calculations when m<=4"); diff --git a/paddle/phi/kernels/fusion/gpu/block_attn.h b/paddle/phi/kernels/fusion/gpu/block_attn.h index a3307eaaee0b0e..7a1a2b26269e71 100644 --- a/paddle/phi/kernels/fusion/gpu/block_attn.h +++ b/paddle/phi/kernels/fusion/gpu/block_attn.h @@ -22,6 +22,7 @@ #include "paddle/phi/kernels/fusion/gpu/mmha_util.cu.h" COMMON_DECLARE_bool(use_xqa_optim); +COMMON_DECLARE_bool(blha_use_fp32_qk_sum); #ifdef PADDLE_WITH_HIP #define GPU(str) hip##str @@ -98,6 +99,7 @@ struct Block_AttN_params { }; template ::Type; + using Qk_sum_type = typename Qk_vec_::Type; using Qk_vec_RoPE = typename Qk_vec_RoPE_::Type; using QK_Packed_Int8_t = typename Packed_Int8_::Type; @@ -322,7 +325,7 @@ __global__ __launch_bounds__(THREADS_PER_BLOCK) void block_attention_kernel( } } - qk = dot(q, k); + qk = dot(q, k); if (QK_VECS_PER_WARP <= WARP_SIZE) { #pragma unroll @@ -1222,6 +1225,7 @@ inline size_t gqa_smem_size_in_bytes(const Block_AttN_params ¶ms, #ifdef PADDLE_WITH_HIP #define BLHAG_LAUNCH_KERNEL(T, \ + SUM_T, \ Dh, \ Dh_MAX, \ THDS_PER_KEY, \ @@ -1235,6 +1239,7 @@ inline size_t gqa_smem_size_in_bytes(const Block_AttN_params ¶ms, size_t smem_sz = \ smem_size_in_bytes(params, Dh, THDS_PER_VALUE, THDS_PER_BLOCK); \ constexpr auto kernel_fn = block_attention_kernel ¶ms, params, load_func, store_func); #else #define BLHAG_LAUNCH_KERNEL(T, \ + SUM_T, \ Dh, \ Dh_MAX, \ THDS_PER_KEY, \ @@ -1308,6 +1314,7 @@ inline size_t gqa_smem_size_in_bytes(const Block_AttN_params ¶ms, size_t smem_sz = \ smem_size_in_bytes(params, Dh, THDS_PER_VALUE, THDS_PER_BLOCK); \ constexpr auto kernel_fn = block_attention_kernel ¶ms, #endif template ¶ms, StoreFunc store_func) { VLOG(1) << "group wise"; BLHAG_LAUNCH_KERNEL(T, + SUM_T, Dh, Dh_MAX, THREADS_PER_KEY, @@ -1409,15 +1418,30 @@ void dispatch_blha_gqa_kernel(const Block_AttN_params ¶ms, LoadFunc load_func, StoreFunc store_func) { if (params.gqa_num_per_partitions == 1 || !FLAGS_use_xqa_optim) { - dispatch_blha_impl_kernel( - params, stream, load_func, store_func); + auto dispatch_blha_kernel = [&](auto kernel_type, auto qk_sum_type) { + using Kernel_T = decltype(kernel_type); + using SUM_T = decltype(qk_sum_type); + dispatch_blha_impl_kernel( + params, stream, load_func, store_func); + }; + if (FLAGS_blha_use_fp32_qk_sum) { + if constexpr (std::is_same_v) { + dispatch_blha_kernel(float16{}, float{}); + } else { + dispatch_blha_kernel(T{}, T{}); + } + } else { + dispatch_blha_kernel(T{}, T{}); + } + } else if (params.gqa_num_per_partitions == 2) { constexpr int THDS_PER_BLOCK = 1024; BLHA_LAUNCH_GQA_KERNEL(T, diff --git a/paddle/phi/kernels/fusion/gpu/mmha_util.cu.h b/paddle/phi/kernels/fusion/gpu/mmha_util.cu.h index fa8c60f809883b..94aa383356be6e 100644 --- a/paddle/phi/kernels/fusion/gpu/mmha_util.cu.h +++ b/paddle/phi/kernels/fusion/gpu/mmha_util.cu.h @@ -2005,6 +2005,53 @@ inline __device__ float4 mul(float4 a, float b) { return res; } +template <> +inline __device__ float mul(uint32_t a, uint32_t b) { + float c; + float2 fa = half2_to_float2(mul(a, b)); + c = fa.x + fa.y; + return c; +} + +template <> +inline __device__ float2 mul(uint32_t a, uint32_t b) { + float2 c; + c = half2_to_float2(mul(a, b)); + return c; +} + +template <> +inline __device__ float4 mul(uint2 a, uint2 b) { + float4 c; + uint32_t ua = mul(a.x, b.x); + uint32_t ub = mul(a.y, b.y); + float2 fa = half2_to_float2(ua); + float2 fb = half2_to_float2(ub); + c.x = fa.x; + c.y = fa.y; + c.z = fb.x; + c.w = fb.y; + return c; +} + +template <> +inline __device__ float4 mul(uint4 a, uint4 b) { + float4 c; + uint32_t ua = mul(a.x, b.x); + uint32_t ub = mul(a.y, b.y); + uint32_t uc = mul(a.z, b.z); + uint32_t ud = mul(a.w, b.w); + float2 fa = half2_to_float2(ua); + float2 fb = half2_to_float2(ub); + float2 fc = half2_to_float2(uc); + float2 fd = half2_to_float2(ud); + c.x = fa.x + fa.y; + c.y = fb.x + fb.y; + c.z = fc.x + fc.y; + c.w = fd.x + fd.y; + return c; +} + #ifdef ENABLE_BF16 template <> inline __device__ __nv_bfloat16 mul(__nv_bfloat16 a, __nv_bfloat16 b) {