Skip to content

opt bwk skip load #53

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Sep 23, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 32 additions & 10 deletions csrc/flash_attn/src/flash_bwd_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -1436,10 +1436,18 @@ inline __device__ void compute_dq_dk_dv_1colblock_flashmask(const Params &params
int m_block = m_block_max - 1;
int m_block_min = !Is_causal ? 0 : (n_block * kBlockN) / kBlockM;

int m_skip_lt_block = 0;
int m_skip_ut_block = 0;
if (true/*Is_flashmask*/ && enable_mask_bypass) {
if (!Is_causal && (!flashmask_ut_has_start || (flashmask_ut_has_start && flashmask_utstart_max <= 0))) {
m_block_min = max(m_block_min, flashmask_utend_min / kBlockM);
}
if (!Is_causal && flashmask_ut_has_start && flashmask_utstart_max > 0) {
m_skip_ut_block = flashmask_utend_min / kBlockM - flashmask_utstart_max / kBlockM - 1;
}
if (flashmask_lt_has_end && flashmask_ltend_min < binfo.actual_seqlen_q ) {
m_skip_lt_block = flashmask_ltend_min / kBlockM - flashmask_ltstart_max / kBlockM - 1;
}
}

// We might need to exit early and write 0 to dK and dV.
Expand Down Expand Up @@ -1580,7 +1588,13 @@ inline __device__ void compute_dq_dk_dv_1colblock_flashmask(const Params &params
clear(acc_dv);
clear(acc_dk);

for (; m_block >= m_block_min; --m_block) {
int double_buffer_cnt = m_block;
for (; m_block >= m_block_min; --m_block, --double_buffer_cnt) {
int next_move_blocks = 1;
if (m_skip_ut_block >= 0 && !Is_causal && m_block == flashmask_utend_min / kBlockM)
next_move_blocks += m_skip_ut_block;
if (m_skip_lt_block >= 0 && m_block == flashmask_ltend_min / kBlockM)
next_move_blocks += m_skip_lt_block;
Tensor acc_s = partition_fragment_C(tiled_mma_sdp, Shape<Int<kBlockM>, Int<kBlockN>>{}); // (MMA=4, MMA_N, MMA_N)
clear(acc_s);
cute::cp_async_wait<0>();
Expand Down Expand Up @@ -1786,11 +1800,11 @@ inline __device__ void compute_dq_dk_dv_1colblock_flashmask(const Params &params

if (Double_buffer && m_block > m_block_min) {
// Double buffer for sQ
const int sQ_offset = m_block % 2 == 0 ? size(sQ) : -size(sQ);
const int sQ_offset = double_buffer_cnt % 2 == 0 ? size(sQ) : -size(sQ);
tQsQ.data() = tQsQ.data() + sQ_offset;
tSsQ.data() = tSsQ.data() + sQ_offset;
// Advance gQ
tQgQ.data() = tQgQ.data() + (-int(kBlockM * params.q_row_stride));
tQgQ.data() = tQgQ.data() + (-int(next_move_blocks * kBlockM * params.q_row_stride));
flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ);
flash::cp_async_fence();
}
Expand Down Expand Up @@ -1820,9 +1834,9 @@ inline __device__ void compute_dq_dk_dv_1colblock_flashmask(const Params &params

if (m_block > m_block_min) {
// Advance gdO
tdOgdO.data() = tdOgdO.data() + (-int(kBlockM * params.do_row_stride));
tdOgdO.data() = tdOgdO.data() + (-int(next_move_blocks * kBlockM * params.do_row_stride));
if (Is_first) {
tdOgO.data() = tdOgO.data() + (-int(kBlockM * params.o_row_stride));
tdOgO.data() = tdOgO.data() + (-int(next_move_blocks * kBlockM * params.o_row_stride));
flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_dO, tdOgdO, tdOrdO, tQcQ, tQpQ);
flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_dO, tdOgO, tdOrO, tQcQ, tQpQ);
} else {
Expand All @@ -1838,10 +1852,10 @@ inline __device__ void compute_dq_dk_dv_1colblock_flashmask(const Params &params
// if (cute::thread0()) { print(acc_dq); }

if (m_block > m_block_min) {
gLSE.data() = gLSE.data() + (-int(kBlockM));
gLSE.data() = gLSE.data() + (-int(next_move_blocks * kBlockM));
#pragma unroll
for (int mi = 0; mi < size(lse); ++mi) { lse(mi) = gLSE(get<0>(taccScS_row(mi))); }
gdPsum.data() = gdPsum.data() + (-int(kBlockM));
gdPsum.data() = gdPsum.data() + (-int(next_move_blocks * kBlockM));
// if (!Is_first && tidx < kBlockM / 2) {
// sdPsum(tidx) = recast<float2>(gdPsum)(tidx);
// if (!Is_first && tidx < kBlockM) {
Expand All @@ -1862,6 +1876,8 @@ inline __device__ void compute_dq_dk_dv_1colblock_flashmask(const Params &params
CUTE_STATIC_ASSERT_V(size(acc_dq) == size(tdQgdQaccum));
#pragma unroll
for (int i = 0; i < size(acc_dq); ++i) { atomicAdd(&tdQgdQaccum(i), acc_dq(i)); }
// (GuoxiaWang): skip tdQgdQaccum after atomicAdd
tdQgdQaccum.data() = tdQgdQaccum.data() + (-int((next_move_blocks - 1) * kBlockM * params.d_rounded));
}
} else {
#pragma unroll
Expand All @@ -1878,12 +1894,12 @@ inline __device__ void compute_dq_dk_dv_1colblock_flashmask(const Params &params
}
// if (cute::thread0()) { print(acc_dk); }
if (Double_buffer) { // Double buffer for sQ
tdKsQt.data() = tdKsQt.data() + (m_block % 2 == 0 ? size(sQ) : -size(sQ));
tdKsQt.data() = tdKsQt.data() + (double_buffer_cnt % 2 == 0 ? size(sQ) : -size(sQ));
}
if (!Double_buffer && m_block > m_block_min) {
__syncthreads();
// Advance gQ
tQgQ.data() = tQgQ.data() + (-int(kBlockM * params.q_row_stride));
tQgQ.data() = tQgQ.data() + (-int(next_move_blocks * kBlockM * params.q_row_stride));
flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ);
flash::cp_async_fence();
}
Expand All @@ -1898,7 +1914,7 @@ inline __device__ void compute_dq_dk_dv_1colblock_flashmask(const Params &params
__syncthreads();
Tensor tdQrdQ = make_tensor<Element>(shape(tdQgdQ));
cute::copy(gmem_tiled_copy_dQ, tdQsdQ, tdQrdQ);
tdQgdQ.data() = tdQgdQ.data() + (-int(kBlockM * params.dq_row_stride));
tdQgdQ.data() = tdQgdQ.data() + (-int(next_move_blocks * kBlockM * params.dq_row_stride));
Tensor cdQ = make_identity_tensor(Shape<Int<kBlockM>, Int<kHeadDim>>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k)
Tensor tdQcdQ = gmem_thr_copy_dQ.partition_D(cdQ);
#pragma unroll
Expand All @@ -1909,6 +1925,12 @@ inline __device__ void compute_dq_dk_dv_1colblock_flashmask(const Params &params
}
}

if (m_skip_lt_block >= 0 && m_block == flashmask_ltend_min / kBlockM)
m_block -= m_skip_lt_block;

if (m_skip_ut_block >= 0 && !Is_causal && m_block == flashmask_utend_min / kBlockM) {
m_block -= m_skip_ut_block;
}
}

// Epilogue
Expand Down