Skip to content

Commit 598f152

Browse files
committed
optimize skip block calculate in bwd
1 parent c9515ed commit 598f152

File tree

2 files changed

+62
-30
lines changed

2 files changed

+62
-30
lines changed

csrc/flash_attn/src/flash_bwd_kernel.h

Lines changed: 62 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ make_tiled_copy_C_warpcontiguousN(Copy_Atom<Args...> const& copy_atom,
7575

7676
template <int THREADS_PER_ROW, typename Engine0, typename Layout0,
7777
typename Engine1, typename Layout1, typename Engine2, typename Layout2>
78-
inline __device__ void dot_do_o(Tensor<Engine0, Layout0> const &do_, Tensor<Engine0, Layout0> const &o,
78+
__forceinline__ __device__ void dot_do_o(Tensor<Engine0, Layout0> const &do_, Tensor<Engine0, Layout0> const &o,
7979
Tensor<Engine1, Layout1> &dP_sum, Tensor<Engine2, Layout2> &sdPsum,
8080
const int gdP_col_stride, const float scale) {
8181
static_assert(Layout0::rank == 3, "Only support 3D Tensor");
@@ -425,7 +425,7 @@ inline __device__ void convert_dKV(const Params &params) {
425425
////////////////////////////////////////////////////////////////////////////////////////////////////
426426

427427
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_even_MN, bool Is_even_K, bool Is_first, bool Is_last, bool Is_attn_mask, bool Seq_parallel=false, typename Params>
428-
inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const int bidb, const int bidh, const int n_block) {
428+
__forceinline__ __device__ void compute_dq_dk_dv_1colblock(const Params &params, const int bidb, const int bidh, const int n_block) {
429429

430430
const bool Is_sparse_attn_mask = params.flashmask_downstart_ptr != nullptr;
431431
int flashmask_startrow = 0;
@@ -488,9 +488,32 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
488488
const bool flashmask_has_end = params.flashmask_downend_ptr != nullptr;
489489
int flashmask_upendrow = params.seqlen_q;
490490

491+
#define SPARSE_MASKED_DOWN \
492+
(((m_block * kBlockM) >= flashmask_downstartmax) && (!flashmask_has_end || (m_block + 1) * kBlockM < flashmask_downendmin))
493+
494+
#define SPARSE_MASKED_UP \
495+
(!Is_causal && (m_block + 1) * kBlockM < flashmask_upendmin && (!flashmask_has_end || m_block * kBlockM >= flashmask_upstartmax))
496+
497+
#define SPARSE_MASKED \
498+
(SPARSE_MASKED_DOWN || SPARSE_MASKED_UP)
499+
491500
const bool enable_mask_bypass = params.enable_mask_bypass;
492501

493-
if (Is_sparse_attn_mask && enable_mask_bypass) {
502+
int flashmask_downstartmax = std::numeric_limits<int>::max();
503+
int flashmask_downendmin = 0;
504+
int flashmask_upendmin = 0;
505+
int flashmask_upstartmax = std::numeric_limits<int>::max();
506+
507+
if(params.flashmask_downstart_nblockmax != nullptr)
508+
flashmask_downstartmax = gSparseMaskDownMax[n_block];
509+
if(params.flashmask_downend_nblockmin != nullptr)
510+
flashmask_downendmin = gSparseMaskDownEndMin[n_block];
511+
if(params.flashmask_upend_nblockmin != nullptr)
512+
flashmask_upendmin = gSparseMaskUpMin[n_block];
513+
if(params.flashmask_upstart_nblockmax != nullptr)
514+
flashmask_upstartmax = gSparseMaskUpStartMax[n_block];
515+
516+
if (Is_sparse_attn_mask && enable_mask_bypass && !flashmask_has_end) {
494517
m_block_max = min(m_block_max,
495518
cute::ceil_div(gSparseMaskDownMax[n_block], kBlockM));
496519
/*
@@ -744,7 +767,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
744767

745768
int m_block = m_block_max - 1;
746769
int m_block_min = !Is_causal ? 0 : (n_block * kBlockN) / kBlockM;
747-
if(Is_sparse_attn_mask && enable_mask_bypass){
770+
if(Is_sparse_attn_mask && enable_mask_bypass && !flashmask_has_end){
748771
if (!Is_causal) {
749772
m_block_min = max(m_block_min, gSparseMaskUpMin[n_block] / kBlockM);
750773
}
@@ -922,8 +945,11 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
922945
// cute::copy(smem_tiled_copy_KV, tSsK(_, _, k), tSrK_copy_view(_, _, k));
923946
// }
924947
// if (cute::thread0()) { print(tSrK); }
925-
flash::gemm(acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma_sdp,
926-
smem_tiled_copy_QdO, smem_tiled_copy_KV, smem_thr_copy_QdO, smem_thr_copy_KV);
948+
949+
if (!SPARSE_MASKED) {
950+
flash::gemm(acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma_sdp,
951+
smem_tiled_copy_QdO, smem_tiled_copy_KV, smem_thr_copy_QdO, smem_thr_copy_KV);
952+
}
927953

928954
// Reshape acc_s from (MMA=4, MMA_N, MMA_N) to (col=(2, MMA_N), row=(2, MMA_N))
929955
Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));
@@ -1005,7 +1031,9 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
10051031
}
10061032
// if (cute::thread(32, 0)) { print(scores); }
10071033
// Compute the exponential value.
1008-
flash::scale_apply_exp2</*scale_max=*/false>(scores, lse, params.scale_softmax_log2);
1034+
if (!SPARSE_MASKED) {
1035+
flash::scale_apply_exp2</*scale_max=*/false>(scores, lse, params.scale_softmax_log2);
1036+
}
10091037
if (Is_dropout) {
10101038
uint32_t warp_id = tidx / 32;
10111039
uint32_t block_row_idx = m_block * (kBlockM / 16) + warp_id % AtomLayoutMS;
@@ -1048,21 +1076,23 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
10481076

10491077
// if (cute::thread0()) { print(dP_sum); }
10501078

1051-
flash::gemm</*A_in_regs=*/false, /*B_in_regs=*/Kernel_traits::Is_V_in_regs>(
1052-
acc_dp, tdPrdO, tdPrV, tdPsdO, tdPsV, tiled_mma_sdp,
1053-
smem_tiled_copy_QdO, smem_tiled_copy_KV, smem_thr_copy_QdO, smem_thr_copy_KV
1054-
);
1055-
1056-
// Reshape acc_dp from (MMA=4, MMA_N, MMA_N) to (col=(2, MMA_N), row=(2, MMA_N))
10571079
Tensor dS = make_tensor(acc_dp.data(), scores.layout());
1058-
auto pointwise_mult = [](float p, float dp, float d) {
1059-
return p * (!Is_dropout || p >= 0 ? dp - d : d);
1060-
};
1061-
#pragma unroll
1062-
for (int mi = 0; mi < size<0>(dS); ++mi) {
1080+
if (!SPARSE_MASKED) {
1081+
flash::gemm</*A_in_regs=*/false, /*B_in_regs=*/Kernel_traits::Is_V_in_regs>(
1082+
acc_dp, tdPrdO, tdPrV, tdPsdO, tdPsV, tiled_mma_sdp,
1083+
smem_tiled_copy_QdO, smem_tiled_copy_KV, smem_thr_copy_QdO, smem_thr_copy_KV
1084+
);
1085+
1086+
// Reshape acc_dp from (MMA=4, MMA_N, MMA_N) to (col=(2, MMA_N), row=(2, MMA_N))
1087+
auto pointwise_mult = [](float p, float dp, float d) {
1088+
return p * (!Is_dropout || p >= 0 ? dp - d : d);
1089+
};
10631090
#pragma unroll
1064-
for (int ni = 0; ni < size<1>(dS); ++ni) {
1065-
dS(mi, ni) = pointwise_mult(scores(mi, ni), dS(mi, ni), dP_sum(mi));
1091+
for (int mi = 0; mi < size<0>(dS); ++mi) {
1092+
#pragma unroll
1093+
for (int ni = 0; ni < size<1>(dS); ++ni) {
1094+
dS(mi, ni) = pointwise_mult(scores(mi, ni), dS(mi, ni), dP_sum(mi));
1095+
}
10661096
}
10671097
}
10681098
// if (cute::thread0()) { print(dS); }
@@ -1104,8 +1134,10 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
11041134
// flash::gemm_A_in_regs(acc_dv, tdVrPt, tdVrdO, tdVsdOt, tiled_mma_dkv, smem_thr_copy_QdOt);
11051135
// Tensor tdKrdSt = make_tensor(tdSrdS.data(), tdVrPt.layout());
11061136
// flash::gemm_A_in_regs(acc_dk, tdKrdSt, tdKrQt, tdKsQt, tiled_mma_dkv, smem_thr_copy_QdOt);
1107-
flash::gemm(acc_dv, tdVrPt, tdVrdO, tdVsPt, tdVsdOt, tiled_mma_dkv,
1108-
smem_tiled_copy_PdSt, smem_tiled_copy_QdOt, smem_thr_copy_PdSt, smem_thr_copy_QdOt);
1137+
if (!SPARSE_MASKED) {
1138+
flash::gemm(acc_dv, tdVrPt, tdVrdO, tdVsPt, tdVsdOt, tiled_mma_dkv,
1139+
smem_tiled_copy_PdSt, smem_tiled_copy_QdOt, smem_thr_copy_PdSt, smem_thr_copy_QdOt);
1140+
}
11091141
// if (cute::thread0() && n_block == 0 && m_block == 0) { print(tdVrPt); }
11101142
// if (cute::thread0()) { print(acc_dv); }
11111143

@@ -1124,8 +1156,10 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
11241156
}
11251157
}
11261158

1127-
flash::gemm(acc_dq, tdQrdS, tdQrKt, tdQsdS, tdQsKt, tiled_mma_dq,
1128-
smem_tiled_copy_dS, smem_tiled_copy_Kt, smem_thr_copy_dS, smem_thr_copy_Kt);
1159+
if (!SPARSE_MASKED) {
1160+
flash::gemm(acc_dq, tdQrdS, tdQrKt, tdQsdS, tdQsKt, tiled_mma_dq,
1161+
smem_tiled_copy_dS, smem_tiled_copy_Kt, smem_thr_copy_dS, smem_thr_copy_Kt);
1162+
}
11291163
// if (cute::thread0()) { print(acc_dq); }
11301164

11311165
if (m_block > m_block_min) {
@@ -1163,8 +1197,10 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
11631197
cute::copy(smem_tiled_copy_dQ, taccdQrdQ, taccdQsdQ);
11641198
}
11651199

1166-
flash::gemm(acc_dk, tdKrdSt, tdKrQt, tdKsdSt, tdKsQt, tiled_mma_dkv,
1167-
smem_tiled_copy_PdSt, smem_tiled_copy_QdOt, smem_thr_copy_PdSt, smem_thr_copy_QdOt);
1200+
if (!SPARSE_MASKED) {
1201+
flash::gemm(acc_dk, tdKrdSt, tdKrQt, tdKsdSt, tdKsQt, tiled_mma_dkv,
1202+
smem_tiled_copy_PdSt, smem_tiled_copy_QdOt, smem_thr_copy_PdSt, smem_thr_copy_QdOt);
1203+
}
11681204
// if (cute::thread0()) { print(acc_dk); }
11691205
if (Double_buffer) { // Double buffer for sQ
11701206
tdKsQt.data() = tdKsQt.data() + (m_block % 2 == 0 ? size(sQ) : -size(sQ));

csrc/flash_attn/src/flash_bwd_launch_template.h

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -64,10 +64,6 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params &params, cudaStream_t stream,
6464
const bool is_attn_mask = params.attn_mask_ptr != nullptr;
6565
const bool is_deterministic = params.num_splits == 1;
6666
// printf("smem_size_dq_dk_dv = %d\n", smem_size_dq_dk_dv);
67-
if (params.flashmask_downend_ptr != nullptr) {
68-
// bypass is not supported for flashmask_downend
69-
params.enable_mask_bypass = false;
70-
}
7167
prepare_sparsemask<Kernel_traits>(params, stream);
7268
BOOL_SWITCH(params.is_causal, IsCausalConst, [&] {
7369
BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] {

0 commit comments

Comments
 (0)