Skip to content

Commit dbd7923

Browse files
committed
Prepare for Cutlass 3.2
1 parent c5e87b1 commit dbd7923

File tree

2 files changed

+40
-21
lines changed

2 files changed

+40
-21
lines changed

csrc/flash_attn/src/kernel_traits.h

Lines changed: 29 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -91,17 +91,20 @@ struct Flash_fwd_kernel_traits : public Base {
9191
SmemLayoutAtomQ{},
9292
Shape<Int<kBlockN>, Int<kHeadDim>>{}));
9393

94+
// This has to be kBlockN and not 8, otherwise we get wrong results for d=128
95+
using SmemLayoutAtomVtransposedNoSwizzle = Layout<Shape<Int<kBlockKSmem>, Int<kBlockN>>,
96+
Stride<_1, Int<kBlockKSmem>>>;
9497
using SmemLayoutAtomVtransposed = decltype(
95-
composition(Swizzle<kSwizzle, 3, 3>{},
96-
// This has to be kBlockN and not 8, otherwise we get wrong results for d=128
97-
Layout<Shape<Int<kBlockKSmem>, Int<kBlockN>>,
98-
Stride<_1, Int<kBlockKSmem>>>{}));
98+
composition(Swizzle<kSwizzle, 3, 3>{}, SmemLayoutAtomVtransposedNoSwizzle{}));
9999
using SmemLayoutVtransposed = decltype(tile_to_shape(
100100
SmemLayoutAtomVtransposed{},
101101
Shape<Int<kHeadDim>, Int<kBlockN>>{}));
102102
// Maybe the VtransposeNoSwizzle just needs to have the right shape
103103
// And the strides don't matter?
104-
using SmemLayoutVtransposedNoSwizzle = decltype(SmemLayoutVtransposed{}.layout_fn());
104+
using SmemLayoutVtransposedNoSwizzle = decltype(tile_to_shape(
105+
SmemLayoutAtomVtransposedNoSwizzle{},
106+
Shape<Int<kHeadDim>, Int<kBlockN>>{}));
107+
// using SmemLayoutVtransposedNoSwizzle = decltype(SmemLayoutVtransposed{}.layout_fn());
105108

106109
using SmemLayoutAtomO = decltype(
107110
composition(Swizzle<kSwizzle, 3, 3>{},
@@ -223,16 +226,19 @@ struct Flash_bwd_kernel_traits : public Base {
223226
SmemLayoutAtomKV{},
224227
make_shape(Int<kBlockN>{}, Int<kHeadDim>{})));
225228

229+
using SmemLayoutAtomKtransposedNoSwizzle = Layout<Shape<Int<kBlockKSmem>, Int<kBlockN>>,
230+
Stride<_1, Int<kBlockKSmem>>>;
226231
using SmemLayoutAtomKtransposed = decltype(
227-
composition(Swizzle<kSwizzle, 3, 3>{},
228-
Layout<Shape<Int<kBlockKSmem>, Int<kBlockN>>,
229-
Stride<_1, Int<kBlockKSmem>>>{}));
232+
composition(Swizzle<kSwizzle, 3, 3>{}, SmemLayoutAtomKtransposedNoSwizzle{}));
230233
using SmemLayoutKtransposed = decltype(tile_to_shape(
231234
SmemLayoutAtomKtransposed{},
232235
make_shape(Int<kHeadDim>{}, Int<kBlockN>{})));
233236
// Maybe the KtransposeNoSwizzle just needs to have the right shape
234237
// And the strides don't matter?
235-
using SmemLayoutKtransposedNoSwizzle = decltype(SmemLayoutKtransposed{}.layout_fn());
238+
using SmemLayoutKtransposedNoSwizzle = decltype(tile_to_shape(
239+
SmemLayoutAtomKtransposedNoSwizzle{},
240+
make_shape(Int<kHeadDim>{}, Int<kBlockN>{})));
241+
// using SmemLayoutKtransposedNoSwizzle = decltype(SmemLayoutKtransposed{}.layout_fn());
236242

237243
// TODO: generalize to other values of kBlockN
238244
// TODO: what should be the Swizzle here? 3 is faster than 1, and 1 is faster than 2
@@ -250,24 +256,30 @@ struct Flash_bwd_kernel_traits : public Base {
250256
using SmemLayoutPdS = decltype(tile_to_shape(
251257
SmemLayoutAtomPdS{},
252258
make_shape(Int<kBlockM>{}, Int<kBlockN>{})));
259+
using SmemLayoutAtomPdStransposedNoSwizzle = Layout<Shape<Int<kPBlockN>, Int<kBlockM>>,
260+
Stride<_1, Int<kPBlockN>>>;
253261
using SmemLayoutAtomPdStransposed = decltype(
254-
composition(Swizzle<kSwizzlePdS, 3, 3>{},
255-
Layout<Shape<Int<kPBlockN>, Int<kBlockM>>,
256-
Stride<_1, Int<kPBlockN>>>{}));
262+
composition(Swizzle<kSwizzlePdS, 3, 3>{}, SmemLayoutAtomPdStransposedNoSwizzle{}));
257263
using SmemLayoutPdStransposed = decltype(tile_to_shape(
258264
SmemLayoutAtomPdStransposed{},
259265
make_shape(Int<kBlockN>{}, Int<kBlockM>{})));
260-
using SmemLayoutPdStransposedNoSwizzle = decltype(SmemLayoutPdStransposed{}.layout_fn());
266+
using SmemLayoutPdStransposedNoSwizzle = decltype(tile_to_shape(
267+
SmemLayoutAtomPdStransposedNoSwizzle{},
268+
make_shape(Int<kBlockN>{}, Int<kBlockM>{})));
269+
// using SmemLayoutPdStransposedNoSwizzle = decltype(SmemLayoutPdStransposed{}.layout_fn());
261270
using SmemCopyAtomPdS = Copy_Atom<DefaultCopy, elem_type>;
262271

272+
using SmemLayoutAtomQdOtransposedNoSwizzle = Layout<Shape<Int<kBlockKSmem>, Int<kBlockM>>,
273+
Stride<_1, Int<kBlockKSmem>>>;
263274
using SmemLayoutAtomQdOtransposed = decltype(
264-
composition(Swizzle<kSwizzle, 3, 3>{},
265-
Layout<Shape<Int<kBlockKSmem>, Int<kBlockM>>,
266-
Stride<_1, Int<kBlockKSmem>>>{}));
275+
composition(Swizzle<kSwizzle, 3, 3>{}, SmemLayoutAtomQdOtransposedNoSwizzle{}));
267276
using SmemLayoutQdOtransposed = decltype(tile_to_shape(
268277
SmemLayoutAtomQdOtransposed{},
269278
make_shape(Int<kHeadDim>{}, Int<kBlockM>{})));
270-
using SmemLayoutQdOtransposedNoSwizzle = decltype(SmemLayoutQdOtransposed{}.layout_fn());
279+
using SmemLayoutQdOtransposedNoSwizzle = decltype(tile_to_shape(
280+
SmemLayoutAtomQdOtransposedNoSwizzle{},
281+
make_shape(Int<kHeadDim>{}, Int<kBlockM>{})));
282+
// using SmemLayoutQdOtransposedNoSwizzle = decltype(SmemLayoutQdOtransposed{}.layout_fn());
271283

272284
using SmemLayoutAtomdKV = decltype(
273285
composition(Swizzle<kSwizzle, 3, 3>{},

csrc/flash_attn/src/utils.h

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,10 @@ inline __device__ auto convert_layout_acc_rowcol(Layout acc_layout) {
228228
static_assert(decltype(size<0>(acc_layout))::value == 4);
229229
static_assert(decltype(rank(acc_layout))::value == 3);
230230
auto l = logical_divide(acc_layout, Shape<_2>{}); // ((2, 2), MMA_M, MMA_N)
231-
return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<2>(l)));
231+
// TD [2023-08-13]: Idk why but get<0, 1>(l) doesn't work for Cutlass 3.2, I'm getting
232+
// "int_tuple.hpp(74): error: conversion to inaccessible base class"
233+
// return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<2>(l)));
234+
return make_layout(make_layout(get<1>(get<0>(l)), get<1>(l)), make_layout(get<0>(get<0>(l)), get<2>(l)));
232235
};
233236

234237
////////////////////////////////////////////////////////////////////////////////////////////////////
@@ -244,9 +247,13 @@ inline __device__ auto convert_layout_rowcol_Aregs(Layout rowcol_layout) {
244247
static_assert(mma_shape_K == 8 || mma_shape_K == 16);
245248
constexpr int MMA_N_divisor = mma_shape_K == 8 ? 1 : 2;
246249
auto l = logical_divide(rowcol_layout, Shape<X, Shape<X, Int<MMA_N_divisor>>>{}); // ((2, MMA_M), (2, (2, MMA_N / 2)))
247-
return make_layout(make_layout(get<1, 0>(l), get<0, 0>(l), get<1, 1, 0>(l)),
248-
get<0, 1>(l),
249-
get<1, 1, 1>(l));
250+
// TD [2023-08-13]: Same error as above on Cutlass 3.2
251+
// return make_layout(make_layout(get<1, 0>(l), get<0, 0>(l), get<1, 1, 0>(l)),
252+
// get<0, 1>(l),
253+
// get<1, 1, 1>(l));
254+
return make_layout(make_layout(get<0>(get<1>(l)), get<0>(get<0>(l)), get<0>(get<1>(get<1>(l)))),
255+
get<1>(get<0>(l)),
256+
get<1>(get<1>(get<1>(l))));
250257
};
251258

252259
////////////////////////////////////////////////////////////////////////////////////////////////////

0 commit comments

Comments
 (0)