@@ -91,17 +91,20 @@ struct Flash_fwd_kernel_traits : public Base {
91
91
SmemLayoutAtomQ{},
92
92
Shape<Int<kBlockN >, Int<kHeadDim >>{}));
93
93
94
+ // This has to be kBlockN and not 8, otherwise we get wrong results for d=128
95
+ using SmemLayoutAtomVtransposedNoSwizzle = Layout<Shape<Int<kBlockKSmem >, Int<kBlockN >>,
96
+ Stride<_1, Int<kBlockKSmem >>>;
94
97
using SmemLayoutAtomVtransposed = decltype(
95
- composition (Swizzle<kSwizzle , 3 , 3 >{},
96
- // This has to be kBlockN and not 8, otherwise we get wrong results for d=128
97
- Layout<Shape<Int<kBlockKSmem >, Int<kBlockN >>,
98
- Stride<_1, Int<kBlockKSmem >>>{}));
98
+ composition (Swizzle<kSwizzle , 3 , 3 >{}, SmemLayoutAtomVtransposedNoSwizzle{}));
99
99
using SmemLayoutVtransposed = decltype(tile_to_shape(
100
100
SmemLayoutAtomVtransposed{},
101
101
Shape<Int<kHeadDim >, Int<kBlockN >>{}));
102
102
// Maybe the VtransposeNoSwizzle just needs to have the right shape
103
103
// And the strides don't matter?
104
- using SmemLayoutVtransposedNoSwizzle = decltype(SmemLayoutVtransposed{}.layout_fn());
104
+ using SmemLayoutVtransposedNoSwizzle = decltype(tile_to_shape(
105
+ SmemLayoutAtomVtransposedNoSwizzle{},
106
+ Shape<Int<kHeadDim >, Int<kBlockN >>{}));
107
+ // using SmemLayoutVtransposedNoSwizzle = decltype(SmemLayoutVtransposed{}.layout_fn());
105
108
106
109
using SmemLayoutAtomO = decltype(
107
110
composition (Swizzle<kSwizzle , 3 , 3 >{},
@@ -223,16 +226,19 @@ struct Flash_bwd_kernel_traits : public Base {
223
226
SmemLayoutAtomKV{},
224
227
make_shape (Int<kBlockN >{}, Int<kHeadDim >{})));
225
228
229
+ using SmemLayoutAtomKtransposedNoSwizzle = Layout<Shape<Int<kBlockKSmem >, Int<kBlockN >>,
230
+ Stride<_1, Int<kBlockKSmem >>>;
226
231
using SmemLayoutAtomKtransposed = decltype(
227
- composition (Swizzle<kSwizzle , 3 , 3 >{},
228
- Layout<Shape<Int<kBlockKSmem >, Int<kBlockN >>,
229
- Stride<_1, Int<kBlockKSmem >>>{}));
232
+ composition (Swizzle<kSwizzle , 3 , 3 >{}, SmemLayoutAtomKtransposedNoSwizzle{}));
230
233
using SmemLayoutKtransposed = decltype(tile_to_shape(
231
234
SmemLayoutAtomKtransposed{},
232
235
make_shape (Int<kHeadDim >{}, Int<kBlockN >{})));
233
236
// Maybe the KtransposeNoSwizzle just needs to have the right shape
234
237
// And the strides don't matter?
235
- using SmemLayoutKtransposedNoSwizzle = decltype(SmemLayoutKtransposed{}.layout_fn());
238
+ using SmemLayoutKtransposedNoSwizzle = decltype(tile_to_shape(
239
+ SmemLayoutAtomKtransposedNoSwizzle{},
240
+ make_shape (Int<kHeadDim >{}, Int<kBlockN >{})));
241
+ // using SmemLayoutKtransposedNoSwizzle = decltype(SmemLayoutKtransposed{}.layout_fn());
236
242
237
243
// TODO: generalize to other values of kBlockN
238
244
// TODO: what should be the Swizzle here? 3 is faster than 1, and 1 is faster than 2
@@ -250,24 +256,30 @@ struct Flash_bwd_kernel_traits : public Base {
250
256
using SmemLayoutPdS = decltype(tile_to_shape(
251
257
SmemLayoutAtomPdS{},
252
258
make_shape (Int<kBlockM >{}, Int<kBlockN >{})));
259
+ using SmemLayoutAtomPdStransposedNoSwizzle = Layout<Shape<Int<kPBlockN >, Int<kBlockM >>,
260
+ Stride<_1, Int<kPBlockN >>>;
253
261
using SmemLayoutAtomPdStransposed = decltype(
254
- composition (Swizzle<kSwizzlePdS , 3 , 3 >{},
255
- Layout<Shape<Int<kPBlockN >, Int<kBlockM >>,
256
- Stride<_1, Int<kPBlockN >>>{}));
262
+ composition (Swizzle<kSwizzlePdS , 3 , 3 >{}, SmemLayoutAtomPdStransposedNoSwizzle{}));
257
263
using SmemLayoutPdStransposed = decltype(tile_to_shape(
258
264
SmemLayoutAtomPdStransposed{},
259
265
make_shape (Int<kBlockN >{}, Int<kBlockM >{})));
260
- using SmemLayoutPdStransposedNoSwizzle = decltype(SmemLayoutPdStransposed{}.layout_fn());
266
+ using SmemLayoutPdStransposedNoSwizzle = decltype(tile_to_shape(
267
+ SmemLayoutAtomPdStransposedNoSwizzle{},
268
+ make_shape (Int<kBlockN >{}, Int<kBlockM >{})));
269
+ // using SmemLayoutPdStransposedNoSwizzle = decltype(SmemLayoutPdStransposed{}.layout_fn());
261
270
using SmemCopyAtomPdS = Copy_Atom<DefaultCopy, elem_type>;
262
271
272
+ using SmemLayoutAtomQdOtransposedNoSwizzle = Layout<Shape<Int<kBlockKSmem >, Int<kBlockM >>,
273
+ Stride<_1, Int<kBlockKSmem >>>;
263
274
using SmemLayoutAtomQdOtransposed = decltype(
264
- composition (Swizzle<kSwizzle , 3 , 3 >{},
265
- Layout<Shape<Int<kBlockKSmem >, Int<kBlockM >>,
266
- Stride<_1, Int<kBlockKSmem >>>{}));
275
+ composition (Swizzle<kSwizzle , 3 , 3 >{}, SmemLayoutAtomQdOtransposedNoSwizzle{}));
267
276
using SmemLayoutQdOtransposed = decltype(tile_to_shape(
268
277
SmemLayoutAtomQdOtransposed{},
269
278
make_shape (Int<kHeadDim >{}, Int<kBlockM >{})));
270
- using SmemLayoutQdOtransposedNoSwizzle = decltype(SmemLayoutQdOtransposed{}.layout_fn());
279
+ using SmemLayoutQdOtransposedNoSwizzle = decltype(tile_to_shape(
280
+ SmemLayoutAtomQdOtransposedNoSwizzle{},
281
+ make_shape (Int<kHeadDim >{}, Int<kBlockM >{})));
282
+ // using SmemLayoutQdOtransposedNoSwizzle = decltype(SmemLayoutQdOtransposed{}.layout_fn());
271
283
272
284
using SmemLayoutAtomdKV = decltype(
273
285
composition (Swizzle<kSwizzle , 3 , 3 >{},
0 commit comments