Skip to content

Commit 9b4a0a5

Browse files
authored
opt bwk skip load (#53)
* opt bwk skip load * add left tri
1 parent f2962ac commit 9b4a0a5

File tree

1 file changed

+32
-10
lines changed

1 file changed

+32
-10
lines changed

csrc/flash_attn/src/flash_bwd_kernel.h

Lines changed: 32 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1436,10 +1436,18 @@ inline __device__ void compute_dq_dk_dv_1colblock_flashmask(const Params &params
14361436
int m_block = m_block_max - 1;
14371437
int m_block_min = !Is_causal ? 0 : (n_block * kBlockN) / kBlockM;
14381438

1439+
int m_skip_lt_block = 0;
1440+
int m_skip_ut_block = 0;
14391441
if (true/*Is_flashmask*/ && enable_mask_bypass) {
14401442
if (!Is_causal && (!flashmask_ut_has_start || (flashmask_ut_has_start && flashmask_utstart_max <= 0))) {
14411443
m_block_min = max(m_block_min, flashmask_utend_min / kBlockM);
14421444
}
1445+
if (!Is_causal && flashmask_ut_has_start && flashmask_utstart_max > 0) {
1446+
m_skip_ut_block = flashmask_utend_min / kBlockM - flashmask_utstart_max / kBlockM - 1;
1447+
}
1448+
if (flashmask_lt_has_end && flashmask_ltend_min < binfo.actual_seqlen_q ) {
1449+
m_skip_lt_block = flashmask_ltend_min / kBlockM - flashmask_ltstart_max / kBlockM - 1;
1450+
}
14431451
}
14441452

14451453
// We might need to exit early and write 0 to dK and dV.
@@ -1580,7 +1588,13 @@ inline __device__ void compute_dq_dk_dv_1colblock_flashmask(const Params &params
15801588
clear(acc_dv);
15811589
clear(acc_dk);
15821590

1583-
for (; m_block >= m_block_min; --m_block) {
1591+
int double_buffer_cnt = m_block;
1592+
for (; m_block >= m_block_min; --m_block, --double_buffer_cnt) {
1593+
int next_move_blocks = 1;
1594+
if (m_skip_ut_block >= 0 && !Is_causal && m_block == flashmask_utend_min / kBlockM)
1595+
next_move_blocks += m_skip_ut_block;
1596+
if (m_skip_lt_block >= 0 && m_block == flashmask_ltend_min / kBlockM)
1597+
next_move_blocks += m_skip_lt_block;
15841598
Tensor acc_s = partition_fragment_C(tiled_mma_sdp, Shape<Int<kBlockM>, Int<kBlockN>>{}); // (MMA=4, MMA_N, MMA_N)
15851599
clear(acc_s);
15861600
cute::cp_async_wait<0>();
@@ -1786,11 +1800,11 @@ inline __device__ void compute_dq_dk_dv_1colblock_flashmask(const Params &params
17861800

17871801
if (Double_buffer && m_block > m_block_min) {
17881802
// Double buffer for sQ
1789-
const int sQ_offset = m_block % 2 == 0 ? size(sQ) : -size(sQ);
1803+
const int sQ_offset = double_buffer_cnt % 2 == 0 ? size(sQ) : -size(sQ);
17901804
tQsQ.data() = tQsQ.data() + sQ_offset;
17911805
tSsQ.data() = tSsQ.data() + sQ_offset;
17921806
// Advance gQ
1793-
tQgQ.data() = tQgQ.data() + (-int(kBlockM * params.q_row_stride));
1807+
tQgQ.data() = tQgQ.data() + (-int(next_move_blocks * kBlockM * params.q_row_stride));
17941808
flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ);
17951809
flash::cp_async_fence();
17961810
}
@@ -1820,9 +1834,9 @@ inline __device__ void compute_dq_dk_dv_1colblock_flashmask(const Params &params
18201834

18211835
if (m_block > m_block_min) {
18221836
// Advance gdO
1823-
tdOgdO.data() = tdOgdO.data() + (-int(kBlockM * params.do_row_stride));
1837+
tdOgdO.data() = tdOgdO.data() + (-int(next_move_blocks * kBlockM * params.do_row_stride));
18241838
if (Is_first) {
1825-
tdOgO.data() = tdOgO.data() + (-int(kBlockM * params.o_row_stride));
1839+
tdOgO.data() = tdOgO.data() + (-int(next_move_blocks * kBlockM * params.o_row_stride));
18261840
flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_dO, tdOgdO, tdOrdO, tQcQ, tQpQ);
18271841
flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_dO, tdOgO, tdOrO, tQcQ, tQpQ);
18281842
} else {
@@ -1838,10 +1852,10 @@ inline __device__ void compute_dq_dk_dv_1colblock_flashmask(const Params &params
18381852
// if (cute::thread0()) { print(acc_dq); }
18391853

18401854
if (m_block > m_block_min) {
1841-
gLSE.data() = gLSE.data() + (-int(kBlockM));
1855+
gLSE.data() = gLSE.data() + (-int(next_move_blocks * kBlockM));
18421856
#pragma unroll
18431857
for (int mi = 0; mi < size(lse); ++mi) { lse(mi) = gLSE(get<0>(taccScS_row(mi))); }
1844-
gdPsum.data() = gdPsum.data() + (-int(kBlockM));
1858+
gdPsum.data() = gdPsum.data() + (-int(next_move_blocks * kBlockM));
18451859
// if (!Is_first && tidx < kBlockM / 2) {
18461860
// sdPsum(tidx) = recast<float2>(gdPsum)(tidx);
18471861
// if (!Is_first && tidx < kBlockM) {
@@ -1862,6 +1876,8 @@ inline __device__ void compute_dq_dk_dv_1colblock_flashmask(const Params &params
18621876
CUTE_STATIC_ASSERT_V(size(acc_dq) == size(tdQgdQaccum));
18631877
#pragma unroll
18641878
for (int i = 0; i < size(acc_dq); ++i) { atomicAdd(&tdQgdQaccum(i), acc_dq(i)); }
1879+
// (GuoxiaWang): skip tdQgdQaccum after atomicAdd
1880+
tdQgdQaccum.data() = tdQgdQaccum.data() + (-int((next_move_blocks - 1) * kBlockM * params.d_rounded));
18651881
}
18661882
} else {
18671883
#pragma unroll
@@ -1878,12 +1894,12 @@ inline __device__ void compute_dq_dk_dv_1colblock_flashmask(const Params &params
18781894
}
18791895
// if (cute::thread0()) { print(acc_dk); }
18801896
if (Double_buffer) { // Double buffer for sQ
1881-
tdKsQt.data() = tdKsQt.data() + (m_block % 2 == 0 ? size(sQ) : -size(sQ));
1897+
tdKsQt.data() = tdKsQt.data() + (double_buffer_cnt % 2 == 0 ? size(sQ) : -size(sQ));
18821898
}
18831899
if (!Double_buffer && m_block > m_block_min) {
18841900
__syncthreads();
18851901
// Advance gQ
1886-
tQgQ.data() = tQgQ.data() + (-int(kBlockM * params.q_row_stride));
1902+
tQgQ.data() = tQgQ.data() + (-int(next_move_blocks * kBlockM * params.q_row_stride));
18871903
flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ);
18881904
flash::cp_async_fence();
18891905
}
@@ -1898,7 +1914,7 @@ inline __device__ void compute_dq_dk_dv_1colblock_flashmask(const Params &params
18981914
__syncthreads();
18991915
Tensor tdQrdQ = make_tensor<Element>(shape(tdQgdQ));
19001916
cute::copy(gmem_tiled_copy_dQ, tdQsdQ, tdQrdQ);
1901-
tdQgdQ.data() = tdQgdQ.data() + (-int(kBlockM * params.dq_row_stride));
1917+
tdQgdQ.data() = tdQgdQ.data() + (-int(next_move_blocks * kBlockM * params.dq_row_stride));
19021918
Tensor cdQ = make_identity_tensor(Shape<Int<kBlockM>, Int<kHeadDim>>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k)
19031919
Tensor tdQcdQ = gmem_thr_copy_dQ.partition_D(cdQ);
19041920
#pragma unroll
@@ -1909,6 +1925,12 @@ inline __device__ void compute_dq_dk_dv_1colblock_flashmask(const Params &params
19091925
}
19101926
}
19111927

1928+
if (m_skip_lt_block >= 0 && m_block == flashmask_ltend_min / kBlockM)
1929+
m_block -= m_skip_lt_block;
1930+
1931+
if (m_skip_ut_block >= 0 && !Is_causal && m_block == flashmask_utend_min / kBlockM) {
1932+
m_block -= m_skip_ut_block;
1933+
}
19121934
}
19131935

19141936
// Epilogue

0 commit comments

Comments
 (0)