1
1
import torch
2
+ import math
2
3
import pytest
3
4
import re
4
5
from itertools import product
@@ -126,9 +127,9 @@ def test_async_copy_mbarrier(device):
126
127
127
128
128
129
@gluon .jit
129
- def warpgroup_mma_kernel (a , b , out , M : ttgl .constexpr , N : ttgl .constexpr , K : ttgl .constexpr ,
130
- block_layout : ttgl .constexpr , mma_layout : ttgl .constexpr , shared_layout_a : ttgl .constexpr ,
131
- shared_layout_b : ttgl .constexpr , acc_dtype : ttgl .constexpr , ASYNC : ttgl .constexpr ):
130
+ def mma_kernel (a , b , out , M : ttgl .constexpr , N : ttgl .constexpr , K : ttgl . constexpr , block_layout : ttgl .constexpr ,
131
+ mma_layout : ttgl .constexpr , shared_layout_a : ttgl .constexpr , shared_layout_b : ttgl .constexpr ,
132
+ acc_dtype : ttgl .constexpr , ASYNC : ttgl .constexpr , USE_TCGEN05 : ttgl .constexpr ):
132
133
a_offs_m = ttgl .arange (0 , M , layout = ttgl .SliceLayout (1 , block_layout ))[:, None ]
133
134
a_offs_k = ttgl .arange (0 , K , layout = ttgl .SliceLayout (0 , block_layout ))[None , :]
134
135
b_offs_k = ttgl .arange (0 , K , layout = ttgl .SliceLayout (1 , block_layout ))[:, None ]
@@ -143,14 +144,37 @@ def warpgroup_mma_kernel(a, b, out, M: ttgl.constexpr, N: ttgl.constexpr, K: ttg
143
144
144
145
smem_a = ttgl .allocate_shared_memory (operand_dtype , [M , K ], shared_layout_a , a_tile )
145
146
smem_b = ttgl .allocate_shared_memory (operand_dtype , [K , N ], shared_layout_b , b_tile )
146
-
147
147
fence_async_shared ()
148
148
149
- acc = ttgl .zeros ([M , N ], dtype = acc_dtype , layout = mma_layout )
150
- acc = hopper .warpgroup_mma (smem_a , smem_b , acc , is_async = ASYNC )
149
+ if USE_TCGEN05 :
150
+ tmem_layout : ttgl .constexpr = TensorMemoryLayout ((M , N ), col_stride = 32 // acc_dtype .primitive_bitwidth )
151
+
152
+ num_warps : ttgl .constexpr = ttgl .num_warps ()
153
+ tmem_reg_layout : ttgl .constexpr = get_tmem_32x32b_reg_layout (
154
+ M = M ,
155
+ N = N ,
156
+ shape = [M , N ],
157
+ num_warps = num_warps ,
158
+ )
159
+
160
+ mma_barrier = ttgl .allocate_shared_memory (ttgl .int64 , [1 ], mbarrier .MBarrierLayout ())
161
+ mbarrier .init (mma_barrier , count = 1 )
162
+
163
+ acc_zero = ttgl .zeros ([M , N ], dtype = acc_dtype , layout = tmem_reg_layout )
164
+ acc_tmem = allocate_tensor_memory (acc_dtype , [M , N ], tmem_layout , acc_zero )
151
165
152
- if ASYNC :
153
- acc = hopper .warpgroup_mma_wait (num_outstanding = 0 , deps = [acc ])
166
+ tcgen05_mma (smem_a , smem_b , acc_tmem , use_acc = False )
167
+ tcgen05_commit (mma_barrier )
168
+ mbarrier .wait (mma_barrier , phase = 0 )
169
+ mbarrier .invalidate (mma_barrier )
170
+ acc = acc_tmem .load (tmem_reg_layout )
171
+ acc = ttgl .convert_layout (acc , layout = mma_layout )
172
+ else :
173
+ acc = ttgl .zeros ([M , N ], dtype = acc_dtype , layout = mma_layout )
174
+ acc = hopper .warpgroup_mma (smem_a , smem_b , acc , is_async = ASYNC )
175
+
176
+ if ASYNC :
177
+ acc = hopper .warpgroup_mma_wait (num_outstanding = 0 , deps = [acc ])
154
178
155
179
ttgl .store (out + out_offs_m * N + out_offs_n , acc )
156
180
@@ -168,7 +192,7 @@ def test_warpgroup_mma(ASYNC):
168
192
a = torch .randn ((M , K ), device = "cuda" , dtype = torch .float16 )
169
193
b = torch .randn ((K , N ), device = "cuda" , dtype = torch .float16 )
170
194
out = torch .zeros ((M , N ), device = "cuda" , dtype = torch .float16 )
171
- warpgroup_mma_kernel [(1 , )](
195
+ mma_kernel [(1 , )](
172
196
a ,
173
197
b ,
174
198
out ,
@@ -181,6 +205,7 @@ def test_warpgroup_mma(ASYNC):
181
205
shared_layout_b ,
182
206
ttgl .float16 ,
183
207
ASYNC ,
208
+ False ,
184
209
num_warps = warps [0 ] * warps [1 ],
185
210
)
186
211
@@ -189,19 +214,24 @@ def test_warpgroup_mma(ASYNC):
189
214
torch .testing .assert_close (out , ref , atol = 1e-3 , rtol = 1e-1 )
190
215
191
216
192
- @pytest .mark .xfail (not is_hopper (), reason = "Requires Hopper" , run = False )
217
+ @pytest .mark .xfail (not ( is_hopper () or is_blackwell ()) , reason = "Requires Hopper or Blackwell " , run = False )
193
218
@pytest .mark .parametrize ("bitwidth, transpose_a, transpose_b, acc_dtype" ,
194
219
[(bitwidth , transpose_a , transpose_b , acc_dtype )
195
220
for bitwidth in [8 , 16 , 32 ]
196
221
for (transpose_a , transpose_b ) in product ([False , True ], repeat = 2 )
197
222
for acc_dtype in [torch .float16 , torch .float32 ]
198
223
if bitwidth == 16 or (acc_dtype == torch .float32 and not transpose_a and transpose_b )])
199
224
@pytest .mark .parametrize ("warps" , ([8 , 1 ], [4 , 2 ], [4 , 1 ]))
200
- # Swizzling 0 does not map to a valid memory descriptor lol
201
- @pytest .mark .parametrize ("swizzling_a, swizzling_b" , product ([32 , 64 , 128 ], repeat = 2 ))
225
+ @pytest .mark .parametrize ("swizzling_a, swizzling_b" , product ([0 , 32 , 64 , 128 ], repeat = 2 ))
202
226
@pytest .mark .parametrize ("shape_m, shape_n, shape_k" , [(1 , 1 , 1 ), (2 , 4 , 1 ), (2 , 2 , 4 )])
203
- def test_warpgroup_mma_shared_inputs (bitwidth , transpose_a , transpose_b , acc_dtype , warps , swizzling_a , swizzling_b ,
204
- shape_m , shape_n , shape_k ):
227
+ def test_mma_shared_inputs (bitwidth , transpose_a , transpose_b , acc_dtype , warps , swizzling_a , swizzling_b , shape_m ,
228
+ shape_n , shape_k , fresh_knobs ):
229
+
230
+ # FIXME: Workaround for a bug in PTXAS when the shared layout is transposed and the swizzling is 0
231
+ if bitwidth == 16 and ((transpose_a and swizzling_a == 0 and shape_m > 1 ) or
232
+ (not transpose_b and swizzling_b == 0 and shape_n > 1 )):
233
+ fresh_knobs .nvidia .disable_ptxas_opt = True
234
+ use_tcgen05 = is_blackwell ()
205
235
206
236
torch_dtype_map = {
207
237
8 : torch .float8_e4m3fn ,
@@ -214,8 +244,7 @@ def test_warpgroup_mma_shared_inputs(bitwidth, transpose_a, transpose_b, acc_dty
214
244
}
215
245
216
246
# We'll choose a larger instr shape along N, but sure
217
- instr_shape_k_map = {8 : 32 , 16 : 16 , 32 : 8 }
218
- instr_shape = [16 , 32 , instr_shape_k_map [bitwidth ]]
247
+ instr_shape = [16 , 32 , 256 // bitwidth ]
219
248
M = instr_shape [0 ] * warps [0 ]
220
249
N = instr_shape [1 ] * warps [1 ]
221
250
K = instr_shape [2 ]
@@ -239,7 +268,27 @@ def min_shape(swizzling, dim0, dim1, trans):
239
268
K *= shape_k
240
269
instr_shape [1 ] *= shape_n
241
270
242
- shared_mem_accum = M * K * bitwidth // 8 + K * N * bitwidth // 8
271
+ if use_tcgen05 :
272
+ M = 128
273
+
274
+ def get_shared_swizzling_zero (M , K , transpose ):
275
+ # K-contig
276
+ if transpose :
277
+ K , M = M , K
278
+ bases = []
279
+ for i in range (int (math .log2 (128 // bitwidth ))):
280
+ bases .append ([0 , 1 << i ])
281
+ for i in range (int (math .log2 (M ))):
282
+ bases .append ([1 << i , 0 ])
283
+ for i in range (int (math .log2 (K // (128 // bitwidth )))):
284
+ offset = int (math .log2 (128 // bitwidth )) + i
285
+ bases .append ([0 , 1 << offset ])
286
+ if transpose :
287
+ for i in range (len (bases )):
288
+ bases [i ] = [bases [i ][1 ], bases [i ][0 ]]
289
+ return ttgl .SharedLinearLayout (bases )
290
+
291
+ shared_mem_accum = (M + N ) * K * bitwidth // 8
243
292
if triton .runtime .driver .active .utils .get_device_properties (
244
293
triton .runtime .driver .active .get_current_device ())["max_shared_mem" ] < shared_mem_accum :
245
294
pytest .skip ("Skipped due to insufficient shared memory on this GPU." )
@@ -248,11 +297,17 @@ def min_shape(swizzling, dim0, dim1, trans):
248
297
gl_acc_dtype = acc_dtype_map [acc_dtype ]
249
298
out_dtype = torch .float32
250
299
251
- block_layout = ttgl .BlockedLayout ([1 , 1 ], [1 , THREADS_PER_WARP ], warps_per_cta = warps , order = [1 , 0 ])
252
- shared_layout_a = ttgl .NVMMASharedLayout (swizzle_byte_width = swizzling_a , element_bitwidth = bitwidth , rank = 2 ,
253
- transposed = transpose_a )
254
- shared_layout_b = ttgl .NVMMASharedLayout (swizzle_byte_width = swizzling_b , element_bitwidth = bitwidth , rank = 2 ,
255
- transposed = transpose_b )
300
+ block_layout = ttgl .BlockedLayout ([1 , 8 ], [1 , THREADS_PER_WARP ], warps_per_cta = warps , order = [1 , 0 ])
301
+ if swizzling_a == 0 :
302
+ shared_layout_a = get_shared_swizzling_zero (M , K , transpose_a )
303
+ else :
304
+ shared_layout_a = ttgl .NVMMASharedLayout (swizzle_byte_width = swizzling_a , element_bitwidth = bitwidth , rank = 2 ,
305
+ transposed = transpose_a )
306
+ if swizzling_b == 0 :
307
+ shared_layout_b = get_shared_swizzling_zero (K , N , transpose_b )
308
+ else :
309
+ shared_layout_b = ttgl .NVMMASharedLayout (swizzle_byte_width = swizzling_b , element_bitwidth = bitwidth , rank = 2 ,
310
+ transposed = transpose_b )
256
311
mma_layout = ttgl .NVMMADistributedLayout (version = [3 , 0 ], warps_per_cta = warps , instr_shape = instr_shape )
257
312
258
313
torch .manual_seed (0 )
@@ -271,7 +326,7 @@ def cast(x, dtype):
271
326
b = cast (torch .randn ((K , N ), device = "cuda" , dtype = torch .float32 ), torch_dtype )
272
327
out = torch .zeros ((M , N ), device = "cuda" , dtype = out_dtype )
273
328
274
- warpgroup_mma_kernel [(1 , )](
329
+ mma_kernel [(1 , )](
275
330
a ,
276
331
b ,
277
332
out ,
@@ -284,6 +339,7 @@ def cast(x, dtype):
284
339
shared_layout_b ,
285
340
gl_acc_dtype ,
286
341
False ,
342
+ use_tcgen05 ,
287
343
num_warps = warps [0 ] * warps [1 ],
288
344
)
289
345
@@ -298,9 +354,9 @@ def cast(x, dtype):
298
354
torch .backends .cuda .matmul .allow_fp16_reduced_precision_reduction = allow_fp16_red
299
355
300
356
if bitwidth == 8 :
301
- atol , rtol = 0.5 , 0.5
357
+ atol , rtol = 5e-2 , 5e-1
302
358
elif bitwidth == 16 :
303
- atol , rtol = 3e -2 , 1e -1
359
+ atol , rtol = 5e -2 , 5e -1
304
360
else :
305
361
atol , rtol = 5e-4 , 5e-3
306
362
torch .testing .assert_close (out , ref , atol = atol , rtol = rtol )
0 commit comments