@@ -75,7 +75,7 @@ make_tiled_copy_C_warpcontiguousN(Copy_Atom<Args...> const& copy_atom,
75
75
76
76
template <int THREADS_PER_ROW, typename Engine0, typename Layout0,
77
77
typename Engine1, typename Layout1, typename Engine2, typename Layout2>
78
- inline __device__ void dot_do_o (Tensor<Engine0, Layout0> const &do_, Tensor<Engine0, Layout0> const &o,
78
+ __forceinline__ __device__ void dot_do_o (Tensor<Engine0, Layout0> const &do_, Tensor<Engine0, Layout0> const &o,
79
79
Tensor<Engine1, Layout1> &dP_sum, Tensor<Engine2, Layout2> &sdPsum,
80
80
const int gdP_col_stride, const float scale) {
81
81
static_assert (Layout0::rank == 3 , " Only support 3D Tensor" );
@@ -425,7 +425,7 @@ inline __device__ void convert_dKV(const Params ¶ms) {
425
425
// //////////////////////////////////////////////////////////////////////////////////////////////////
426
426
427
427
template <typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_even_MN, bool Is_even_K, bool Is_first, bool Is_last, bool Is_attn_mask, bool Seq_parallel=false , typename Params>
428
- inline __device__ void compute_dq_dk_dv_1colblock (const Params ¶ms, const int bidb, const int bidh, const int n_block) {
428
+ __forceinline__ __device__ void compute_dq_dk_dv_1colblock (const Params ¶ms, const int bidb, const int bidh, const int n_block) {
429
429
430
430
const bool Is_sparse_attn_mask = params.flashmask_downstart_ptr != nullptr ;
431
431
int flashmask_startrow = 0 ;
@@ -488,9 +488,32 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
488
488
const bool flashmask_has_end = params.flashmask_downend_ptr != nullptr ;
489
489
int flashmask_upendrow = params.seqlen_q ;
490
490
491
+ #define SPARSE_MASKED_DOWN \
492
+ (((m_block * kBlockM ) >= flashmask_downstartmax) && (!flashmask_has_end || (m_block + 1 ) * kBlockM < flashmask_downendmin))
493
+
494
+ #define SPARSE_MASKED_UP \
495
+ (!Is_causal && (m_block + 1 ) * kBlockM < flashmask_upendmin && (!flashmask_has_end || m_block * kBlockM >= flashmask_upstartmax))
496
+
497
+ #define SPARSE_MASKED \
498
+ (SPARSE_MASKED_DOWN || SPARSE_MASKED_UP)
499
+
491
500
const bool enable_mask_bypass = params.enable_mask_bypass ;
492
501
493
- if (Is_sparse_attn_mask && enable_mask_bypass) {
502
+ int flashmask_downstartmax = std::numeric_limits<int >::max ();
503
+ int flashmask_downendmin = 0 ;
504
+ int flashmask_upendmin = 0 ;
505
+ int flashmask_upstartmax = std::numeric_limits<int >::max ();
506
+
507
+ if (params.flashmask_downstart_nblockmax != nullptr )
508
+ flashmask_downstartmax = gSparseMaskDownMax [n_block];
509
+ if (params.flashmask_downend_nblockmin != nullptr )
510
+ flashmask_downendmin = gSparseMaskDownEndMin [n_block];
511
+ if (params.flashmask_upend_nblockmin != nullptr )
512
+ flashmask_upendmin = gSparseMaskUpMin [n_block];
513
+ if (params.flashmask_upstart_nblockmax != nullptr )
514
+ flashmask_upstartmax = gSparseMaskUpStartMax [n_block];
515
+
516
+ if (Is_sparse_attn_mask && enable_mask_bypass && !flashmask_has_end) {
494
517
m_block_max = min (m_block_max,
495
518
cute::ceil_div (gSparseMaskDownMax [n_block], kBlockM ));
496
519
/*
@@ -744,7 +767,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
744
767
745
768
int m_block = m_block_max - 1 ;
746
769
int m_block_min = !Is_causal ? 0 : (n_block * kBlockN ) / kBlockM ;
747
- if (Is_sparse_attn_mask && enable_mask_bypass){
770
+ if (Is_sparse_attn_mask && enable_mask_bypass && !flashmask_has_end ){
748
771
if (!Is_causal) {
749
772
m_block_min = max (m_block_min, gSparseMaskUpMin [n_block] / kBlockM );
750
773
}
@@ -922,8 +945,11 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
922
945
// cute::copy(smem_tiled_copy_KV, tSsK(_, _, k), tSrK_copy_view(_, _, k));
923
946
// }
924
947
// if (cute::thread0()) { print(tSrK); }
925
- flash::gemm (acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma_sdp,
926
- smem_tiled_copy_QdO, smem_tiled_copy_KV, smem_thr_copy_QdO, smem_thr_copy_KV);
948
+
949
+ if (!SPARSE_MASKED) {
950
+ flash::gemm (acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma_sdp,
951
+ smem_tiled_copy_QdO, smem_tiled_copy_KV, smem_thr_copy_QdO, smem_thr_copy_KV);
952
+ }
927
953
928
954
// Reshape acc_s from (MMA=4, MMA_N, MMA_N) to (col=(2, MMA_N), row=(2, MMA_N))
929
955
Tensor scores = make_tensor (acc_s.data (), flash::convert_layout_acc_rowcol (acc_s.layout ()));
@@ -1005,7 +1031,9 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
1005
1031
}
1006
1032
// if (cute::thread(32, 0)) { print(scores); }
1007
1033
// Compute the exponential value.
1008
- flash::scale_apply_exp2</* scale_max=*/ false >(scores, lse, params.scale_softmax_log2 );
1034
+ if (!SPARSE_MASKED) {
1035
+ flash::scale_apply_exp2</* scale_max=*/ false >(scores, lse, params.scale_softmax_log2 );
1036
+ }
1009
1037
if (Is_dropout) {
1010
1038
uint32_t warp_id = tidx / 32 ;
1011
1039
uint32_t block_row_idx = m_block * (kBlockM / 16 ) + warp_id % AtomLayoutMS;
@@ -1048,21 +1076,23 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
1048
1076
1049
1077
// if (cute::thread0()) { print(dP_sum); }
1050
1078
1051
- flash::gemm</* A_in_regs=*/ false , /* B_in_regs=*/ Kernel_traits::Is_V_in_regs>(
1052
- acc_dp, tdPrdO, tdPrV, tdPsdO, tdPsV, tiled_mma_sdp,
1053
- smem_tiled_copy_QdO, smem_tiled_copy_KV, smem_thr_copy_QdO, smem_thr_copy_KV
1054
- );
1055
-
1056
- // Reshape acc_dp from (MMA=4, MMA_N, MMA_N) to (col=(2, MMA_N), row=(2, MMA_N))
1057
1079
Tensor dS = make_tensor (acc_dp.data (), scores.layout ());
1058
- auto pointwise_mult = [](float p, float dp, float d) {
1059
- return p * (!Is_dropout || p >= 0 ? dp - d : d);
1060
- };
1061
- #pragma unroll
1062
- for (int mi = 0 ; mi < size<0 >(dS); ++mi) {
1080
+ if (!SPARSE_MASKED) {
1081
+ flash::gemm</* A_in_regs=*/ false , /* B_in_regs=*/ Kernel_traits::Is_V_in_regs>(
1082
+ acc_dp, tdPrdO, tdPrV, tdPsdO, tdPsV, tiled_mma_sdp,
1083
+ smem_tiled_copy_QdO, smem_tiled_copy_KV, smem_thr_copy_QdO, smem_thr_copy_KV
1084
+ );
1085
+
1086
+ // Reshape acc_dp from (MMA=4, MMA_N, MMA_N) to (col=(2, MMA_N), row=(2, MMA_N))
1087
+ auto pointwise_mult = [](float p, float dp, float d) {
1088
+ return p * (!Is_dropout || p >= 0 ? dp - d : d);
1089
+ };
1063
1090
#pragma unroll
1064
- for (int ni = 0 ; ni < size<1 >(dS); ++ni) {
1065
- dS (mi, ni) = pointwise_mult (scores (mi, ni), dS (mi, ni), dP_sum (mi));
1091
+ for (int mi = 0 ; mi < size<0 >(dS); ++mi) {
1092
+ #pragma unroll
1093
+ for (int ni = 0 ; ni < size<1 >(dS); ++ni) {
1094
+ dS (mi, ni) = pointwise_mult (scores (mi, ni), dS (mi, ni), dP_sum (mi));
1095
+ }
1066
1096
}
1067
1097
}
1068
1098
// if (cute::thread0()) { print(dS); }
@@ -1104,8 +1134,10 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
1104
1134
// flash::gemm_A_in_regs(acc_dv, tdVrPt, tdVrdO, tdVsdOt, tiled_mma_dkv, smem_thr_copy_QdOt);
1105
1135
// Tensor tdKrdSt = make_tensor(tdSrdS.data(), tdVrPt.layout());
1106
1136
// flash::gemm_A_in_regs(acc_dk, tdKrdSt, tdKrQt, tdKsQt, tiled_mma_dkv, smem_thr_copy_QdOt);
1107
- flash::gemm (acc_dv, tdVrPt, tdVrdO, tdVsPt, tdVsdOt, tiled_mma_dkv,
1108
- smem_tiled_copy_PdSt, smem_tiled_copy_QdOt, smem_thr_copy_PdSt, smem_thr_copy_QdOt);
1137
+ if (!SPARSE_MASKED) {
1138
+ flash::gemm (acc_dv, tdVrPt, tdVrdO, tdVsPt, tdVsdOt, tiled_mma_dkv,
1139
+ smem_tiled_copy_PdSt, smem_tiled_copy_QdOt, smem_thr_copy_PdSt, smem_thr_copy_QdOt);
1140
+ }
1109
1141
// if (cute::thread0() && n_block == 0 && m_block == 0) { print(tdVrPt); }
1110
1142
// if (cute::thread0()) { print(acc_dv); }
1111
1143
@@ -1124,8 +1156,10 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
1124
1156
}
1125
1157
}
1126
1158
1127
- flash::gemm (acc_dq, tdQrdS, tdQrKt, tdQsdS, tdQsKt, tiled_mma_dq,
1128
- smem_tiled_copy_dS, smem_tiled_copy_Kt, smem_thr_copy_dS, smem_thr_copy_Kt);
1159
+ if (!SPARSE_MASKED) {
1160
+ flash::gemm (acc_dq, tdQrdS, tdQrKt, tdQsdS, tdQsKt, tiled_mma_dq,
1161
+ smem_tiled_copy_dS, smem_tiled_copy_Kt, smem_thr_copy_dS, smem_thr_copy_Kt);
1162
+ }
1129
1163
// if (cute::thread0()) { print(acc_dq); }
1130
1164
1131
1165
if (m_block > m_block_min) {
@@ -1163,8 +1197,10 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
1163
1197
cute::copy (smem_tiled_copy_dQ, taccdQrdQ, taccdQsdQ);
1164
1198
}
1165
1199
1166
- flash::gemm (acc_dk, tdKrdSt, tdKrQt, tdKsdSt, tdKsQt, tiled_mma_dkv,
1167
- smem_tiled_copy_PdSt, smem_tiled_copy_QdOt, smem_thr_copy_PdSt, smem_thr_copy_QdOt);
1200
+ if (!SPARSE_MASKED) {
1201
+ flash::gemm (acc_dk, tdKrdSt, tdKrQt, tdKsdSt, tdKsQt, tiled_mma_dkv,
1202
+ smem_tiled_copy_PdSt, smem_tiled_copy_QdOt, smem_thr_copy_PdSt, smem_thr_copy_QdOt);
1203
+ }
1168
1204
// if (cute::thread0()) { print(acc_dk); }
1169
1205
if (Double_buffer) { // Double buffer for sQ
1170
1206
tdKsQt.data () = tdKsQt.data () + (m_block % 2 == 0 ? size (sQ ) : -size (sQ ));
0 commit comments