22
22
#include " paddle/phi/kernels/fusion/gpu/mmha_util.cu.h"
23
23
24
24
COMMON_DECLARE_bool (use_xqa_optim);
25
+ COMMON_DECLARE_bool (use_fp32_qk_sum);
25
26
26
27
#ifdef PADDLE_WITH_HIP
27
28
#define GPU (str ) hip##str
@@ -98,6 +99,7 @@ struct Block_AttN_params {
98
99
};
99
100
100
101
template <typename T,
102
+ typename SUM_T,
101
103
int Dh,
102
104
int Dh_MAX,
103
105
int THREADS_PER_KEY,
@@ -146,6 +148,7 @@ __global__ __launch_bounds__(THREADS_PER_BLOCK) void block_attention_kernel(
146
148
147
149
__shared__ float red_smem[WARPS_PER_BLOCK * 2 ];
148
150
using Qk_vec = typename Qk_vec_<T, Dh_MAX>::Type;
151
+ using Qk_sum_type = typename Qk_vec_<SUM_T, Dh_MAX>::Type;
149
152
using Qk_vec_RoPE = typename Qk_vec_RoPE_<T, float , Dh_MAX>::Type;
150
153
using QK_Packed_Int8_t = typename Packed_Int8_<Qk_vec, CACHE_TYPE>::Type;
151
154
@@ -322,7 +325,7 @@ __global__ __launch_bounds__(THREADS_PER_BLOCK) void block_attention_kernel(
322
325
}
323
326
}
324
327
325
- qk = dot<Qk_vec , Qk_vec>(q, k);
328
+ qk = dot<Qk_sum_type , Qk_vec>(q, k);
326
329
327
330
if (QK_VECS_PER_WARP <= WARP_SIZE) {
328
331
#pragma unroll
@@ -1222,6 +1225,7 @@ inline size_t gqa_smem_size_in_bytes(const Block_AttN_params<T> ¶ms,
1222
1225
1223
1226
#ifdef PADDLE_WITH_HIP
1224
1227
#define BLHAG_LAUNCH_KERNEL (T, \
1228
+ SUM_T, \
1225
1229
Dh, \
1226
1230
Dh_MAX, \
1227
1231
THDS_PER_KEY, \
@@ -1235,6 +1239,7 @@ inline size_t gqa_smem_size_in_bytes(const Block_AttN_params<T> ¶ms,
1235
1239
size_t smem_sz = \
1236
1240
smem_size_in_bytes<T>(params, Dh, THDS_PER_VALUE, THDS_PER_BLOCK); \
1237
1241
constexpr auto kernel_fn = block_attention_kernel<T, \
1242
+ SUM_T, \
1238
1243
Dh, \
1239
1244
Dh_MAX, \
1240
1245
THDS_PER_KEY, \
@@ -1295,6 +1300,7 @@ inline size_t gqa_smem_size_in_bytes(const Block_AttN_params<T> ¶ms,
1295
1300
params, load_func, store_func);
1296
1301
#else
1297
1302
#define BLHAG_LAUNCH_KERNEL (T, \
1303
+ SUM_T, \
1298
1304
Dh, \
1299
1305
Dh_MAX, \
1300
1306
THDS_PER_KEY, \
@@ -1308,6 +1314,7 @@ inline size_t gqa_smem_size_in_bytes(const Block_AttN_params<T> ¶ms,
1308
1314
size_t smem_sz = \
1309
1315
smem_size_in_bytes<T>(params, Dh, THDS_PER_VALUE, THDS_PER_BLOCK); \
1310
1316
constexpr auto kernel_fn = block_attention_kernel<T, \
1317
+ SUM_T, \
1311
1318
Dh, \
1312
1319
Dh_MAX, \
1313
1320
THDS_PER_KEY, \
@@ -1367,6 +1374,7 @@ inline size_t gqa_smem_size_in_bytes(const Block_AttN_params<T> ¶ms,
1367
1374
#endif
1368
1375
1369
1376
template <typename T,
1377
+ typename SUM_T,
1370
1378
int Dh,
1371
1379
int Dh_MAX,
1372
1380
int BlockSize,
@@ -1382,6 +1390,7 @@ void dispatch_blha_impl_kernel(const Block_AttN_params<T> ¶ms,
1382
1390
StoreFunc store_func) {
1383
1391
VLOG (1 ) << " group wise" ;
1384
1392
BLHAG_LAUNCH_KERNEL (T,
1393
+ SUM_T,
1385
1394
Dh,
1386
1395
Dh_MAX,
1387
1396
THREADS_PER_KEY,
@@ -1409,15 +1418,21 @@ void dispatch_blha_gqa_kernel(const Block_AttN_params<T> ¶ms,
1409
1418
LoadFunc load_func,
1410
1419
StoreFunc store_func) {
1411
1420
if (params.gqa_num_per_partitions == 1 || !FLAGS_use_xqa_optim) {
1412
- dispatch_blha_impl_kernel<T,
1413
- Dh,
1414
- Dh_MAX,
1415
- BlockSize,
1416
- THREADS_PER_VALUE,
1417
- THREADS_PER_KEY,
1418
- THREADS_PER_BLOCK,
1419
- CACHE_TYPE>(
1420
- params, stream, load_func, store_func);
1421
+ auto dispatch_blha_kernel = [&](auto qk_sum_type) {
1422
+ using SUM_T = decltype (qk_sum_type);
1423
+ dispatch_blha_impl_kernel<T,
1424
+ SUM_T,
1425
+ Dh,
1426
+ Dh_MAX,
1427
+ BlockSize,
1428
+ THREADS_PER_VALUE,
1429
+ THREADS_PER_KEY,
1430
+ THREADS_PER_BLOCK,
1431
+ CACHE_TYPE>(
1432
+ params, stream, load_func, store_func);
1433
+ };
1434
+ FLAGS_use_fp32_qk_sum ? dispatch_blha_kernel (float ())
1435
+ : dispatch_blha_kernel (T ());
1421
1436
} else if (params.gqa_num_per_partitions == 2 ) {
1422
1437
constexpr int THDS_PER_BLOCK = 1024 ;
1423
1438
BLHA_LAUNCH_GQA_KERNEL (T,
0 commit comments