@@ -1436,10 +1436,18 @@ inline __device__ void compute_dq_dk_dv_1colblock_flashmask(const Params ¶ms
1436
1436
int m_block = m_block_max - 1 ;
1437
1437
int m_block_min = !Is_causal ? 0 : (n_block * kBlockN ) / kBlockM ;
1438
1438
1439
+ int m_skip_lt_block = 0 ;
1440
+ int m_skip_ut_block = 0 ;
1439
1441
if (true /* Is_flashmask*/ && enable_mask_bypass) {
1440
1442
if (!Is_causal && (!flashmask_ut_has_start || (flashmask_ut_has_start && flashmask_utstart_max <= 0 ))) {
1441
1443
m_block_min = max (m_block_min, flashmask_utend_min / kBlockM );
1442
1444
}
1445
+ if (!Is_causal && flashmask_ut_has_start && flashmask_utstart_max > 0 ) {
1446
+ m_skip_ut_block = flashmask_utend_min / kBlockM - flashmask_utstart_max / kBlockM - 1 ;
1447
+ }
1448
+ if (flashmask_lt_has_end && flashmask_ltend_min < binfo.actual_seqlen_q ) {
1449
+ m_skip_lt_block = flashmask_ltend_min / kBlockM - flashmask_ltstart_max / kBlockM - 1 ;
1450
+ }
1443
1451
}
1444
1452
1445
1453
// We might need to exit early and write 0 to dK and dV.
@@ -1580,7 +1588,13 @@ inline __device__ void compute_dq_dk_dv_1colblock_flashmask(const Params ¶ms
1580
1588
clear (acc_dv);
1581
1589
clear (acc_dk);
1582
1590
1583
- for (; m_block >= m_block_min; --m_block) {
1591
+ int double_buffer_cnt = m_block;
1592
+ for (; m_block >= m_block_min; --m_block, --double_buffer_cnt) {
1593
+ int next_move_blocks = 1 ;
1594
+ if (m_skip_ut_block >= 0 && !Is_causal && m_block == flashmask_utend_min / kBlockM )
1595
+ next_move_blocks += m_skip_ut_block;
1596
+ if (m_skip_lt_block >= 0 && m_block == flashmask_ltend_min / kBlockM )
1597
+ next_move_blocks += m_skip_lt_block;
1584
1598
Tensor acc_s = partition_fragment_C (tiled_mma_sdp, Shape<Int<kBlockM >, Int<kBlockN >>{}); // (MMA=4, MMA_N, MMA_N)
1585
1599
clear (acc_s);
1586
1600
cute::cp_async_wait<0 >();
@@ -1786,11 +1800,11 @@ inline __device__ void compute_dq_dk_dv_1colblock_flashmask(const Params ¶ms
1786
1800
1787
1801
if (Double_buffer && m_block > m_block_min) {
1788
1802
// Double buffer for sQ
1789
- const int sQ_offset = m_block % 2 == 0 ? size (sQ ) : -size (sQ );
1803
+ const int sQ_offset = double_buffer_cnt % 2 == 0 ? size (sQ ) : -size (sQ );
1790
1804
tQsQ.data () = tQsQ.data () + sQ_offset ;
1791
1805
tSsQ.data () = tSsQ.data () + sQ_offset ;
1792
1806
// Advance gQ
1793
- tQgQ.data () = tQgQ.data () + (-int (kBlockM * params.q_row_stride ));
1807
+ tQgQ.data () = tQgQ.data () + (-int (next_move_blocks * kBlockM * params.q_row_stride ));
1794
1808
flash::copy</* Is_even_MN=*/ true , Is_even_K>(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ);
1795
1809
flash::cp_async_fence ();
1796
1810
}
@@ -1820,9 +1834,9 @@ inline __device__ void compute_dq_dk_dv_1colblock_flashmask(const Params ¶ms
1820
1834
1821
1835
if (m_block > m_block_min) {
1822
1836
// Advance gdO
1823
- tdOgdO.data () = tdOgdO.data () + (-int (kBlockM * params.do_row_stride ));
1837
+ tdOgdO.data () = tdOgdO.data () + (-int (next_move_blocks * kBlockM * params.do_row_stride ));
1824
1838
if (Is_first) {
1825
- tdOgO.data () = tdOgO.data () + (-int (kBlockM * params.o_row_stride ));
1839
+ tdOgO.data () = tdOgO.data () + (-int (next_move_blocks * kBlockM * params.o_row_stride ));
1826
1840
flash::copy</* Is_even_MN=*/ true , Is_even_K>(gmem_tiled_copy_dO, tdOgdO, tdOrdO, tQcQ, tQpQ);
1827
1841
flash::copy</* Is_even_MN=*/ true , Is_even_K>(gmem_tiled_copy_dO, tdOgO, tdOrO, tQcQ, tQpQ);
1828
1842
} else {
@@ -1838,10 +1852,10 @@ inline __device__ void compute_dq_dk_dv_1colblock_flashmask(const Params ¶ms
1838
1852
// if (cute::thread0()) { print(acc_dq); }
1839
1853
1840
1854
if (m_block > m_block_min) {
1841
- gLSE .data () = gLSE .data () + (-int (kBlockM ));
1855
+ gLSE .data () = gLSE .data () + (-int (next_move_blocks * kBlockM ));
1842
1856
#pragma unroll
1843
1857
for (int mi = 0 ; mi < size (lse); ++mi) { lse (mi) = gLSE (get<0 >(taccScS_row (mi))); }
1844
- gdPsum.data () = gdPsum.data () + (-int (kBlockM ));
1858
+ gdPsum.data () = gdPsum.data () + (-int (next_move_blocks * kBlockM ));
1845
1859
// if (!Is_first && tidx < kBlockM / 2) {
1846
1860
// sdPsum(tidx) = recast<float2>(gdPsum)(tidx);
1847
1861
// if (!Is_first && tidx < kBlockM) {
@@ -1862,6 +1876,8 @@ inline __device__ void compute_dq_dk_dv_1colblock_flashmask(const Params ¶ms
1862
1876
CUTE_STATIC_ASSERT_V (size (acc_dq) == size (tdQgdQaccum));
1863
1877
#pragma unroll
1864
1878
for (int i = 0 ; i < size (acc_dq); ++i) { atomicAdd (&tdQgdQaccum (i), acc_dq (i)); }
1879
+ // (GuoxiaWang): skip tdQgdQaccum after atomicAdd
1880
+ tdQgdQaccum.data () = tdQgdQaccum.data () + (-int ((next_move_blocks - 1 ) * kBlockM * params.d_rounded ));
1865
1881
}
1866
1882
} else {
1867
1883
#pragma unroll
@@ -1878,12 +1894,12 @@ inline __device__ void compute_dq_dk_dv_1colblock_flashmask(const Params ¶ms
1878
1894
}
1879
1895
// if (cute::thread0()) { print(acc_dk); }
1880
1896
if (Double_buffer) { // Double buffer for sQ
1881
- tdKsQt.data () = tdKsQt.data () + (m_block % 2 == 0 ? size (sQ ) : -size (sQ ));
1897
+ tdKsQt.data () = tdKsQt.data () + (double_buffer_cnt % 2 == 0 ? size (sQ ) : -size (sQ ));
1882
1898
}
1883
1899
if (!Double_buffer && m_block > m_block_min) {
1884
1900
__syncthreads ();
1885
1901
// Advance gQ
1886
- tQgQ.data () = tQgQ.data () + (-int (kBlockM * params.q_row_stride ));
1902
+ tQgQ.data () = tQgQ.data () + (-int (next_move_blocks * kBlockM * params.q_row_stride ));
1887
1903
flash::copy</* Is_even_MN=*/ true , Is_even_K>(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ);
1888
1904
flash::cp_async_fence ();
1889
1905
}
@@ -1898,7 +1914,7 @@ inline __device__ void compute_dq_dk_dv_1colblock_flashmask(const Params ¶ms
1898
1914
__syncthreads ();
1899
1915
Tensor tdQrdQ = make_tensor<Element>(shape (tdQgdQ));
1900
1916
cute::copy (gmem_tiled_copy_dQ, tdQsdQ, tdQrdQ);
1901
- tdQgdQ.data () = tdQgdQ.data () + (-int (kBlockM * params.dq_row_stride ));
1917
+ tdQgdQ.data () = tdQgdQ.data () + (-int (next_move_blocks * kBlockM * params.dq_row_stride ));
1902
1918
Tensor cdQ = make_identity_tensor (Shape<Int<kBlockM >, Int<kHeadDim >>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k)
1903
1919
Tensor tdQcdQ = gmem_thr_copy_dQ.partition_D (cdQ);
1904
1920
#pragma unroll
@@ -1909,6 +1925,12 @@ inline __device__ void compute_dq_dk_dv_1colblock_flashmask(const Params ¶ms
1909
1925
}
1910
1926
}
1911
1927
1928
+ if (m_skip_lt_block >= 0 && m_block == flashmask_ltend_min / kBlockM )
1929
+ m_block -= m_skip_lt_block;
1930
+
1931
+ if (m_skip_ut_block >= 0 && !Is_causal && m_block == flashmask_utend_min / kBlockM ) {
1932
+ m_block -= m_skip_ut_block;
1933
+ }
1912
1934
}
1913
1935
1914
1936
// Epilogue
0 commit comments