Skip to content

Commit 1f1a25d

Browse files
authored
[inference]Fix FP16 precision BLHA accumulation overflow. (#71919)
* fix blha fp16 * update
1 parent e5b4340 commit 1f1a25d

File tree

3 files changed

+95
-10
lines changed

3 files changed

+95
-10
lines changed

paddle/common/flags.cc

+14
Original file line numberDiff line numberDiff line change
@@ -1846,6 +1846,20 @@ PHI_DEFINE_EXPORTED_bool(
18461846
false,
18471847
"Enable xqa optim in block_multihead_attention kernel (GQA).");
18481848

1849+
/**
1850+
* Whether to use FP32 for accumulation of QK output in
1851+
* block_multihead_attention kernel(fp16)
1852+
* Name: blha_use_fp32_qk_sum Since Version: 3.0.0
1853+
* Value Range: bool, default=false
1854+
* Example:
1855+
* Note: If TRUE, FP32 will be used for accumulation of the QK output
1856+
* in block_multihead_attention kernel(fp16) .
1857+
*/
1858+
PHI_DEFINE_EXPORTED_bool(blha_use_fp32_qk_sum,
1859+
false,
1860+
"use FP32 for accumulation of QK output in "
1861+
"block_multihead_attention kernel(fp16).");
1862+
18491863
PHI_DEFINE_EXPORTED_bool(cuda_core_int8_gemm,
18501864
false,
18511865
"Enable speed up int8 gemm calculations when m<=4");

paddle/phi/kernels/fusion/gpu/block_attn.h

+34-10
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include "paddle/phi/kernels/fusion/gpu/mmha_util.cu.h"
2323

2424
COMMON_DECLARE_bool(use_xqa_optim);
25+
COMMON_DECLARE_bool(blha_use_fp32_qk_sum);
2526

2627
#ifdef PADDLE_WITH_HIP
2728
#define GPU(str) hip##str
@@ -98,6 +99,7 @@ struct Block_AttN_params {
9899
};
99100

100101
template <typename T,
102+
typename SUM_T,
101103
int Dh,
102104
int Dh_MAX,
103105
int THREADS_PER_KEY,
@@ -146,6 +148,7 @@ __global__ __launch_bounds__(THREADS_PER_BLOCK) void block_attention_kernel(
146148

147149
__shared__ float red_smem[WARPS_PER_BLOCK * 2];
148150
using Qk_vec = typename Qk_vec_<T, Dh_MAX>::Type;
151+
using Qk_sum_type = typename Qk_vec_<SUM_T, Dh_MAX>::Type;
149152
using Qk_vec_RoPE = typename Qk_vec_RoPE_<T, float, Dh_MAX>::Type;
150153
using QK_Packed_Int8_t = typename Packed_Int8_<Qk_vec, CACHE_TYPE>::Type;
151154

@@ -322,7 +325,7 @@ __global__ __launch_bounds__(THREADS_PER_BLOCK) void block_attention_kernel(
322325
}
323326
}
324327

325-
qk = dot<Qk_vec, Qk_vec>(q, k);
328+
qk = dot<Qk_sum_type, Qk_vec>(q, k);
326329

327330
if (QK_VECS_PER_WARP <= WARP_SIZE) {
328331
#pragma unroll
@@ -1222,6 +1225,7 @@ inline size_t gqa_smem_size_in_bytes(const Block_AttN_params<T> &params,
12221225

12231226
#ifdef PADDLE_WITH_HIP
12241227
#define BLHAG_LAUNCH_KERNEL(T, \
1228+
SUM_T, \
12251229
Dh, \
12261230
Dh_MAX, \
12271231
THDS_PER_KEY, \
@@ -1235,6 +1239,7 @@ inline size_t gqa_smem_size_in_bytes(const Block_AttN_params<T> &params,
12351239
size_t smem_sz = \
12361240
smem_size_in_bytes<T>(params, Dh, THDS_PER_VALUE, THDS_PER_BLOCK); \
12371241
constexpr auto kernel_fn = block_attention_kernel<T, \
1242+
SUM_T, \
12381243
Dh, \
12391244
Dh_MAX, \
12401245
THDS_PER_KEY, \
@@ -1295,6 +1300,7 @@ inline size_t gqa_smem_size_in_bytes(const Block_AttN_params<T> &params,
12951300
params, load_func, store_func);
12961301
#else
12971302
#define BLHAG_LAUNCH_KERNEL(T, \
1303+
SUM_T, \
12981304
Dh, \
12991305
Dh_MAX, \
13001306
THDS_PER_KEY, \
@@ -1308,6 +1314,7 @@ inline size_t gqa_smem_size_in_bytes(const Block_AttN_params<T> &params,
13081314
size_t smem_sz = \
13091315
smem_size_in_bytes<T>(params, Dh, THDS_PER_VALUE, THDS_PER_BLOCK); \
13101316
constexpr auto kernel_fn = block_attention_kernel<T, \
1317+
SUM_T, \
13111318
Dh, \
13121319
Dh_MAX, \
13131320
THDS_PER_KEY, \
@@ -1367,6 +1374,7 @@ inline size_t gqa_smem_size_in_bytes(const Block_AttN_params<T> &params,
13671374
#endif
13681375

13691376
template <typename T,
1377+
typename SUM_T,
13701378
int Dh,
13711379
int Dh_MAX,
13721380
int BlockSize,
@@ -1382,6 +1390,7 @@ void dispatch_blha_impl_kernel(const Block_AttN_params<T> &params,
13821390
StoreFunc store_func) {
13831391
VLOG(1) << "group wise";
13841392
BLHAG_LAUNCH_KERNEL(T,
1393+
SUM_T,
13851394
Dh,
13861395
Dh_MAX,
13871396
THREADS_PER_KEY,
@@ -1409,15 +1418,30 @@ void dispatch_blha_gqa_kernel(const Block_AttN_params<T> &params,
14091418
LoadFunc load_func,
14101419
StoreFunc store_func) {
14111420
if (params.gqa_num_per_partitions == 1 || !FLAGS_use_xqa_optim) {
1412-
dispatch_blha_impl_kernel<T,
1413-
Dh,
1414-
Dh_MAX,
1415-
BlockSize,
1416-
THREADS_PER_VALUE,
1417-
THREADS_PER_KEY,
1418-
THREADS_PER_BLOCK,
1419-
CACHE_TYPE>(
1420-
params, stream, load_func, store_func);
1421+
auto dispatch_blha_kernel = [&](auto kernel_type, auto qk_sum_type) {
1422+
using Kernel_T = decltype(kernel_type);
1423+
using SUM_T = decltype(qk_sum_type);
1424+
dispatch_blha_impl_kernel<Kernel_T,
1425+
SUM_T,
1426+
Dh,
1427+
Dh_MAX,
1428+
BlockSize,
1429+
THREADS_PER_VALUE,
1430+
THREADS_PER_KEY,
1431+
THREADS_PER_BLOCK,
1432+
CACHE_TYPE>(
1433+
params, stream, load_func, store_func);
1434+
};
1435+
if (FLAGS_blha_use_fp32_qk_sum) {
1436+
if constexpr (std::is_same_v<T, float16>) {
1437+
dispatch_blha_kernel(float16{}, float{});
1438+
} else {
1439+
dispatch_blha_kernel(T{}, T{});
1440+
}
1441+
} else {
1442+
dispatch_blha_kernel(T{}, T{});
1443+
}
1444+
14211445
} else if (params.gqa_num_per_partitions == 2) {
14221446
constexpr int THDS_PER_BLOCK = 1024;
14231447
BLHA_LAUNCH_GQA_KERNEL(T,

paddle/phi/kernels/fusion/gpu/mmha_util.cu.h

+47
Original file line numberDiff line numberDiff line change
@@ -2005,6 +2005,53 @@ inline __device__ float4 mul(float4 a, float b) {
20052005
return res;
20062006
}
20072007

2008+
template <>
2009+
inline __device__ float mul(uint32_t a, uint32_t b) {
2010+
float c;
2011+
float2 fa = half2_to_float2(mul<uint32_t, uint32_t, uint32_t>(a, b));
2012+
c = fa.x + fa.y;
2013+
return c;
2014+
}
2015+
2016+
template <>
2017+
inline __device__ float2 mul(uint32_t a, uint32_t b) {
2018+
float2 c;
2019+
c = half2_to_float2(mul<uint32_t, uint32_t, uint32_t>(a, b));
2020+
return c;
2021+
}
2022+
2023+
template <>
2024+
inline __device__ float4 mul(uint2 a, uint2 b) {
2025+
float4 c;
2026+
uint32_t ua = mul<uint32_t, uint32_t, uint32_t>(a.x, b.x);
2027+
uint32_t ub = mul<uint32_t, uint32_t, uint32_t>(a.y, b.y);
2028+
float2 fa = half2_to_float2(ua);
2029+
float2 fb = half2_to_float2(ub);
2030+
c.x = fa.x;
2031+
c.y = fa.y;
2032+
c.z = fb.x;
2033+
c.w = fb.y;
2034+
return c;
2035+
}
2036+
2037+
template <>
2038+
inline __device__ float4 mul(uint4 a, uint4 b) {
2039+
float4 c;
2040+
uint32_t ua = mul<uint32_t, uint32_t, uint32_t>(a.x, b.x);
2041+
uint32_t ub = mul<uint32_t, uint32_t, uint32_t>(a.y, b.y);
2042+
uint32_t uc = mul<uint32_t, uint32_t, uint32_t>(a.z, b.z);
2043+
uint32_t ud = mul<uint32_t, uint32_t, uint32_t>(a.w, b.w);
2044+
float2 fa = half2_to_float2(ua);
2045+
float2 fb = half2_to_float2(ub);
2046+
float2 fc = half2_to_float2(uc);
2047+
float2 fd = half2_to_float2(ud);
2048+
c.x = fa.x + fa.y;
2049+
c.y = fb.x + fb.y;
2050+
c.z = fc.x + fc.y;
2051+
c.w = fd.x + fd.y;
2052+
return c;
2053+
}
2054+
20082055
#ifdef ENABLE_BF16
20092056
template <>
20102057
inline __device__ __nv_bfloat16 mul(__nv_bfloat16 a, __nv_bfloat16 b) {

0 commit comments

Comments
 (0)