Skip to content

[inference]Fix FP16 precision BLHA accumulation overflow. #71919

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions paddle/common/flags.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1879,6 +1879,20 @@ PHI_DEFINE_EXPORTED_bool(
false,
"Enable xqa optim in block_multihead_attention kernel (GQA).");

/**
* Whether to use FP32 for accumulation of QK output in
* block_multihead_attention kernel(fp16)
* Name: blha_use_fp32_qk_sum Since Version: 3.0.0
* Value Range: bool, default=false
* Example:
* Note: If TRUE, FP32 will be used for accumulation of the QK output
* in block_multihead_attention kernel(fp16) .
*/
PHI_DEFINE_EXPORTED_bool(blha_use_fp32_qk_sum,
false,
"use FP32 for accumulation of QK output in "
"block_multihead_attention kernel(fp16).");

PHI_DEFINE_EXPORTED_bool(cuda_core_int8_gemm,
false,
"Enable speed up int8 gemm calculations when m<=4");
Expand Down
44 changes: 34 additions & 10 deletions paddle/phi/kernels/fusion/gpu/block_attn.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include "paddle/phi/kernels/fusion/gpu/mmha_util.cu.h"

COMMON_DECLARE_bool(use_xqa_optim);
COMMON_DECLARE_bool(blha_use_fp32_qk_sum);

#ifdef PADDLE_WITH_HIP
#define GPU(str) hip##str
Expand Down Expand Up @@ -98,6 +99,7 @@ struct Block_AttN_params {
};

template <typename T,
typename SUM_T,
int Dh,
int Dh_MAX,
int THREADS_PER_KEY,
Expand Down Expand Up @@ -146,6 +148,7 @@ __global__ __launch_bounds__(THREADS_PER_BLOCK) void block_attention_kernel(

__shared__ float red_smem[WARPS_PER_BLOCK * 2];
using Qk_vec = typename Qk_vec_<T, Dh_MAX>::Type;
using Qk_sum_type = typename Qk_vec_<SUM_T, Dh_MAX>::Type;
using Qk_vec_RoPE = typename Qk_vec_RoPE_<T, float, Dh_MAX>::Type;
using QK_Packed_Int8_t = typename Packed_Int8_<Qk_vec, CACHE_TYPE>::Type;

Expand Down Expand Up @@ -322,7 +325,7 @@ __global__ __launch_bounds__(THREADS_PER_BLOCK) void block_attention_kernel(
}
}

qk = dot<Qk_vec, Qk_vec>(q, k);
qk = dot<Qk_sum_type, Qk_vec>(q, k);

if (QK_VECS_PER_WARP <= WARP_SIZE) {
#pragma unroll
Expand Down Expand Up @@ -1222,6 +1225,7 @@ inline size_t gqa_smem_size_in_bytes(const Block_AttN_params<T> &params,

#ifdef PADDLE_WITH_HIP
#define BLHAG_LAUNCH_KERNEL(T, \
SUM_T, \
Dh, \
Dh_MAX, \
THDS_PER_KEY, \
Expand All @@ -1235,6 +1239,7 @@ inline size_t gqa_smem_size_in_bytes(const Block_AttN_params<T> &params,
size_t smem_sz = \
smem_size_in_bytes<T>(params, Dh, THDS_PER_VALUE, THDS_PER_BLOCK); \
constexpr auto kernel_fn = block_attention_kernel<T, \
SUM_T, \
Dh, \
Dh_MAX, \
THDS_PER_KEY, \
Expand Down Expand Up @@ -1295,6 +1300,7 @@ inline size_t gqa_smem_size_in_bytes(const Block_AttN_params<T> &params,
params, load_func, store_func);
#else
#define BLHAG_LAUNCH_KERNEL(T, \
SUM_T, \
Dh, \
Dh_MAX, \
THDS_PER_KEY, \
Expand All @@ -1308,6 +1314,7 @@ inline size_t gqa_smem_size_in_bytes(const Block_AttN_params<T> &params,
size_t smem_sz = \
smem_size_in_bytes<T>(params, Dh, THDS_PER_VALUE, THDS_PER_BLOCK); \
constexpr auto kernel_fn = block_attention_kernel<T, \
SUM_T, \
Dh, \
Dh_MAX, \
THDS_PER_KEY, \
Expand Down Expand Up @@ -1367,6 +1374,7 @@ inline size_t gqa_smem_size_in_bytes(const Block_AttN_params<T> &params,
#endif

template <typename T,
typename SUM_T,
int Dh,
int Dh_MAX,
int BlockSize,
Expand All @@ -1382,6 +1390,7 @@ void dispatch_blha_impl_kernel(const Block_AttN_params<T> &params,
StoreFunc store_func) {
VLOG(1) << "group wise";
BLHAG_LAUNCH_KERNEL(T,
SUM_T,
Dh,
Dh_MAX,
THREADS_PER_KEY,
Expand Down Expand Up @@ -1409,15 +1418,30 @@ void dispatch_blha_gqa_kernel(const Block_AttN_params<T> &params,
LoadFunc load_func,
StoreFunc store_func) {
if (params.gqa_num_per_partitions == 1 || !FLAGS_use_xqa_optim) {
dispatch_blha_impl_kernel<T,
Dh,
Dh_MAX,
BlockSize,
THREADS_PER_VALUE,
THREADS_PER_KEY,
THREADS_PER_BLOCK,
CACHE_TYPE>(
params, stream, load_func, store_func);
auto dispatch_blha_kernel = [&](auto kernel_type, auto qk_sum_type) {
using Kernel_T = decltype(kernel_type);
using SUM_T = decltype(qk_sum_type);
dispatch_blha_impl_kernel<Kernel_T,
SUM_T,
Dh,
Dh_MAX,
BlockSize,
THREADS_PER_VALUE,
THREADS_PER_KEY,
THREADS_PER_BLOCK,
CACHE_TYPE>(
params, stream, load_func, store_func);
};
if (FLAGS_blha_use_fp32_qk_sum) {
if constexpr (std::is_same_v<T, float16>) {
dispatch_blha_kernel(float16{}, float{});
} else {
dispatch_blha_kernel(T{}, T{});
}
} else {
dispatch_blha_kernel(T{}, T{});
}

} else if (params.gqa_num_per_partitions == 2) {
constexpr int THDS_PER_BLOCK = 1024;
BLHA_LAUNCH_GQA_KERNEL(T,
Expand Down
47 changes: 47 additions & 0 deletions paddle/phi/kernels/fusion/gpu/mmha_util.cu.h
Original file line number Diff line number Diff line change
Expand Up @@ -2005,6 +2005,53 @@ inline __device__ float4 mul(float4 a, float b) {
return res;
}

template <>
inline __device__ float mul(uint32_t a, uint32_t b) {
float c;
float2 fa = half2_to_float2(mul<uint32_t, uint32_t, uint32_t>(a, b));
c = fa.x + fa.y;
return c;
}

template <>
inline __device__ float2 mul(uint32_t a, uint32_t b) {
float2 c;
c = half2_to_float2(mul<uint32_t, uint32_t, uint32_t>(a, b));
return c;
}

template <>
inline __device__ float4 mul(uint2 a, uint2 b) {
float4 c;
uint32_t ua = mul<uint32_t, uint32_t, uint32_t>(a.x, b.x);
uint32_t ub = mul<uint32_t, uint32_t, uint32_t>(a.y, b.y);
float2 fa = half2_to_float2(ua);
float2 fb = half2_to_float2(ub);
c.x = fa.x;
c.y = fa.y;
c.z = fb.x;
c.w = fb.y;
return c;
}

template <>
inline __device__ float4 mul(uint4 a, uint4 b) {
float4 c;
uint32_t ua = mul<uint32_t, uint32_t, uint32_t>(a.x, b.x);
uint32_t ub = mul<uint32_t, uint32_t, uint32_t>(a.y, b.y);
uint32_t uc = mul<uint32_t, uint32_t, uint32_t>(a.z, b.z);
uint32_t ud = mul<uint32_t, uint32_t, uint32_t>(a.w, b.w);
float2 fa = half2_to_float2(ua);
float2 fb = half2_to_float2(ub);
float2 fc = half2_to_float2(uc);
float2 fd = half2_to_float2(ud);
c.x = fa.x + fa.y;
c.y = fb.x + fb.y;
c.z = fc.x + fc.y;
c.w = fd.x + fd.y;
return c;
}

#ifdef ENABLE_BF16
template <>
inline __device__ __nv_bfloat16 mul(__nv_bfloat16 a, __nv_bfloat16 b) {
Expand Down