Skip to content

Commit 60ca005

Browse files
committed
fix blha fp16
1 parent 2ac03ae commit 60ca005

File tree

3 files changed

+73
-11
lines changed

3 files changed

+73
-11
lines changed

paddle/common/flags.cc

+14
Original file line numberDiff line numberDiff line change
@@ -1879,6 +1879,20 @@ PHI_DEFINE_EXPORTED_bool(
18791879
false,
18801880
"Enable xqa optim in block_multihead_attention kernel (GQA).");
18811881

1882+
/**
1883+
* Whether to use FP32 for accumulation of QK output in
1884+
* block_multihead_attention kernel
1885+
* Name: use_fp32_qk_sum Since Version: 3.0.0
1886+
* Value Range: bool, default=false
1887+
* Example:
1888+
* Note: If TRUE, FP32 will be used for accumulation of the QK output
1889+
* in block_multihead_attention kernel .
1890+
*/
1891+
PHI_DEFINE_EXPORTED_bool(use_fp32_qk_sum,
1892+
false,
1893+
"use FP32 for accumulation of QK output in "
1894+
"block_multihead_attention kernel.");
1895+
18821896
PHI_DEFINE_EXPORTED_bool(cuda_core_int8_gemm,
18831897
false,
18841898
"Enable speed up int8 gemm calculations when m<=4");

paddle/phi/kernels/fusion/gpu/block_attn.h

+25-10
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include "paddle/phi/kernels/fusion/gpu/mmha_util.cu.h"
2323

2424
COMMON_DECLARE_bool(use_xqa_optim);
25+
COMMON_DECLARE_bool(use_fp32_qk_sum);
2526

2627
#ifdef PADDLE_WITH_HIP
2728
#define GPU(str) hip##str
@@ -98,6 +99,7 @@ struct Block_AttN_params {
9899
};
99100

100101
template <typename T,
102+
typename SUM_T,
101103
int Dh,
102104
int Dh_MAX,
103105
int THREADS_PER_KEY,
@@ -146,6 +148,7 @@ __global__ __launch_bounds__(THREADS_PER_BLOCK) void block_attention_kernel(
146148

147149
__shared__ float red_smem[WARPS_PER_BLOCK * 2];
148150
using Qk_vec = typename Qk_vec_<T, Dh_MAX>::Type;
151+
using Qk_sum_type = typename Qk_vec_<SUM_T, Dh_MAX>::Type;
149152
using Qk_vec_RoPE = typename Qk_vec_RoPE_<T, float, Dh_MAX>::Type;
150153
using QK_Packed_Int8_t = typename Packed_Int8_<Qk_vec, CACHE_TYPE>::Type;
151154

@@ -322,7 +325,7 @@ __global__ __launch_bounds__(THREADS_PER_BLOCK) void block_attention_kernel(
322325
}
323326
}
324327

325-
qk = dot<Qk_vec, Qk_vec>(q, k);
328+
qk = dot<Qk_sum_type, Qk_vec>(q, k);
326329

327330
if (QK_VECS_PER_WARP <= WARP_SIZE) {
328331
#pragma unroll
@@ -1222,6 +1225,7 @@ inline size_t gqa_smem_size_in_bytes(const Block_AttN_params<T> &params,
12221225

12231226
#ifdef PADDLE_WITH_HIP
12241227
#define BLHAG_LAUNCH_KERNEL(T, \
1228+
SUM_T, \
12251229
Dh, \
12261230
Dh_MAX, \
12271231
THDS_PER_KEY, \
@@ -1235,6 +1239,7 @@ inline size_t gqa_smem_size_in_bytes(const Block_AttN_params<T> &params,
12351239
size_t smem_sz = \
12361240
smem_size_in_bytes<T>(params, Dh, THDS_PER_VALUE, THDS_PER_BLOCK); \
12371241
constexpr auto kernel_fn = block_attention_kernel<T, \
1242+
SUM_T, \
12381243
Dh, \
12391244
Dh_MAX, \
12401245
THDS_PER_KEY, \
@@ -1295,6 +1300,7 @@ inline size_t gqa_smem_size_in_bytes(const Block_AttN_params<T> &params,
12951300
params, load_func, store_func);
12961301
#else
12971302
#define BLHAG_LAUNCH_KERNEL(T, \
1303+
SUM_T, \
12981304
Dh, \
12991305
Dh_MAX, \
13001306
THDS_PER_KEY, \
@@ -1308,6 +1314,7 @@ inline size_t gqa_smem_size_in_bytes(const Block_AttN_params<T> &params,
13081314
size_t smem_sz = \
13091315
smem_size_in_bytes<T>(params, Dh, THDS_PER_VALUE, THDS_PER_BLOCK); \
13101316
constexpr auto kernel_fn = block_attention_kernel<T, \
1317+
SUM_T, \
13111318
Dh, \
13121319
Dh_MAX, \
13131320
THDS_PER_KEY, \
@@ -1367,6 +1374,7 @@ inline size_t gqa_smem_size_in_bytes(const Block_AttN_params<T> &params,
13671374
#endif
13681375

13691376
template <typename T,
1377+
typename SUM_T,
13701378
int Dh,
13711379
int Dh_MAX,
13721380
int BlockSize,
@@ -1382,6 +1390,7 @@ void dispatch_blha_impl_kernel(const Block_AttN_params<T> &params,
13821390
StoreFunc store_func) {
13831391
VLOG(1) << "group wise";
13841392
BLHAG_LAUNCH_KERNEL(T,
1393+
SUM_T,
13851394
Dh,
13861395
Dh_MAX,
13871396
THREADS_PER_KEY,
@@ -1409,15 +1418,21 @@ void dispatch_blha_gqa_kernel(const Block_AttN_params<T> &params,
14091418
LoadFunc load_func,
14101419
StoreFunc store_func) {
14111420
if (params.gqa_num_per_partitions == 1 || !FLAGS_use_xqa_optim) {
1412-
dispatch_blha_impl_kernel<T,
1413-
Dh,
1414-
Dh_MAX,
1415-
BlockSize,
1416-
THREADS_PER_VALUE,
1417-
THREADS_PER_KEY,
1418-
THREADS_PER_BLOCK,
1419-
CACHE_TYPE>(
1420-
params, stream, load_func, store_func);
1421+
auto dispatch_blha_kernel = [&](auto qk_sum_type) {
1422+
using SUM_T = decltype(qk_sum_type);
1423+
dispatch_blha_impl_kernel<T,
1424+
SUM_T,
1425+
Dh,
1426+
Dh_MAX,
1427+
BlockSize,
1428+
THREADS_PER_VALUE,
1429+
THREADS_PER_KEY,
1430+
THREADS_PER_BLOCK,
1431+
CACHE_TYPE>(
1432+
params, stream, load_func, store_func);
1433+
};
1434+
FLAGS_use_fp32_qk_sum ? dispatch_blha_kernel(float())
1435+
: dispatch_blha_kernel(T());
14211436
} else if (params.gqa_num_per_partitions == 2) {
14221437
constexpr int THDS_PER_BLOCK = 1024;
14231438
BLHA_LAUNCH_GQA_KERNEL(T,

paddle/phi/kernels/fusion/gpu/mmha_util.cu.h

+34-1
Original file line numberDiff line numberDiff line change
@@ -1884,7 +1884,8 @@ inline __device__ uint16_t mul(uint16_t a, uint16_t b) {
18841884
}
18851885

18861886
template <>
1887-
inline __device__ uint32_t mul(uint32_t a, uint32_t b) {
1887+
inline __device__ uint32_t mul<uint32_t, uint32_t, uint32_t>(uint32_t a,
1888+
uint32_t b) {
18881889
uint32_t c;
18891890
#ifdef PADDLE_WITH_HIP
18901891
asm volatile("v_pk_mul_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b));
@@ -2005,6 +2006,38 @@ inline __device__ float4 mul(float4 a, float b) {
20052006
return res;
20062007
}
20072008

2009+
template <>
2010+
inline __device__ float4 mul<float4, uint2, uint2>(uint2 a, uint2 b) {
2011+
float4 c;
2012+
uint32_t ua = mul<uint32_t, uint32_t, uint32_t>(a.x, b.x);
2013+
uint32_t ub = mul<uint32_t, uint32_t, uint32_t>(a.x, b.x);
2014+
float2 fa = half2_to_float2(ua);
2015+
float2 fb = half2_to_float2(ub);
2016+
c.x = fa.x;
2017+
c.y = fa.y;
2018+
c.z = fb.x;
2019+
c.w = fb.y;
2020+
return c;
2021+
}
2022+
2023+
template <>
2024+
inline __device__ float4 mul<float4, uint4, uint4>(uint4 a, uint4 b) {
2025+
float4 c;
2026+
uint32_t ua = mul<uint32_t, uint32_t, uint32_t>(a.x, b.x);
2027+
uint32_t ub = mul<uint32_t, uint32_t, uint32_t>(a.y, b.y);
2028+
uint32_t uc = mul<uint32_t, uint32_t, uint32_t>(a.z, b.z);
2029+
uint32_t ud = mul<uint32_t, uint32_t, uint32_t>(a.w, b.w);
2030+
float2 fa = half2_to_float2(ua);
2031+
float2 fb = half2_to_float2(ub);
2032+
float2 fc = half2_to_float2(uc);
2033+
float2 fd = half2_to_float2(ud);
2034+
c.x = fa.x + fa.y;
2035+
c.y = fb.x + fb.y;
2036+
c.z = fc.x + fc.y;
2037+
c.w = fd.x + fd.y;
2038+
return c;
2039+
}
2040+
20082041
#ifdef ENABLE_BF16
20092042
template <>
20102043
inline __device__ __nv_bfloat16 mul(__nv_bfloat16 a, __nv_bfloat16 b) {

0 commit comments

Comments
 (0)