@@ -57,7 +57,7 @@ module attributes {ttig.min_sg_size = 16 : i32, ttig.support_bf16_conversion, tt
57
57
%65 = tt.splat %64 : i32 -> tensor <1 x64 xi32 , #ttg.dot_op <{opIdx = 0 , parent = #mma , kWidth = 1 }>>
58
58
%66 = arith.cmpi slt , %38 , %65 : tensor <1 x64 xi32 , #ttg.dot_op <{opIdx = 0 , parent = #mma , kWidth = 1 }>>
59
59
%67 = tt.broadcast %66 : tensor <1 x64 xi1 , #ttg.dot_op <{opIdx = 0 , parent = #mma , kWidth = 1 }>> -> tensor <128 x64 xi1 , #ttg.dot_op <{opIdx = 0 , parent = #mma , kWidth = 1 }>>
60
- // CHECK-COUNT-16 : triton_gen.2Dblockload {{.*}} {elem_size_in_bits = 16, tile_width = 16, tile_height = 8 , v_blocks = 2
60
+ // CHECK-COUNT-8 : triton_gen.2Dblockload {{.*}} {elem_size_in_bits = 16, tile_width = 16, tile_height = 32 , v_blocks = 1
61
61
%68 = tt.load %60 , %67 , %cst_3 {ttig.block_io = " row_major" } : tensor <128 x64 x!tt.ptr <f16 >, #ttg.dot_op <{opIdx = 0 , parent = #mma , kWidth = 1 }>>
62
62
%74 = tt.addptr %60 , %cst_0 : tensor <128 x64 x!tt.ptr <f16 >, #ttg.dot_op <{opIdx = 0 , parent = #mma , kWidth = 1 }>>, tensor <128 x64 xi32 , #ttg.dot_op <{opIdx = 0 , parent = #mma , kWidth = 1 }>>
63
63
%76 = arith.addi %58 , %c1_i32 : i32
@@ -69,72 +69,6 @@ module attributes {ttig.min_sg_size = 16 : i32, ttig.support_bf16_conversion, tt
69
69
70
70
// -----
71
71
72
- #mma = #ttig.dpas <{repeatCount = 8 , systolicDepth = 8 , executionSize = 16 , opsPerChan = 2 , threadsPerWarp = 16 , warpsPerCTA = [2 , 4 ], repCluster = [4 , 2 ], A = [32 , 16 ], B = [16 , 32 ], C = [32 , 32 ]}>
73
- module attributes {ttig.min_sg_size = 16 : i32 , ttig.support_bf16_conversion , ttig.support_dpas , ttig.support_sg_2d_block , ttig.target_arch = " spir64" , " ttg.num-ctas" = 1 : i32 , " ttg.num-warps" = 8 : i32 , ttg.shared = 33280 : i32 , ttg.target = " xpu" , " ttg.threads-per-warp" = 16 : i32 } {
74
- tt.func public @matmul_tensor_pointer_kernel (%arg0: !tt.ptr <f16 > {tt.divisibility = 16 : i32 }, %arg1: !tt.ptr <f16 > {tt.divisibility = 16 : i32 }, %arg2: !tt.ptr <f32 > {tt.divisibility = 16 : i32 }, %arg3: i32 {tt.divisibility = 16 : i32 }, %arg4: i32 {tt.divisibility = 16 : i32 }, %arg5: i32 {tt.divisibility = 16 : i32 }, %arg6: i32 {tt.divisibility = 16 : i32 }, %arg7: i32 {tt.divisibility = 16 : i32 }, %arg8: i32 {tt.divisibility = 16 : i32 }, %arg9: !llvm.ptr <3 >) attributes {noinline = false } {
75
- %c63_i32 = arith.constant 63 : i32
76
- %c255_i32 = arith.constant 255 : i32
77
- %c127_i32 = arith.constant 127 : i32
78
- %c1_i32 = arith.constant 1 : i32
79
- %c0_i32 = arith.constant 0 : i32
80
- %c64_i32 = arith.constant 64 : i32
81
- %c8_i32 = arith.constant 8 : i32
82
- %c128_i32 = arith.constant 128 : i32
83
- %c256_i32 = arith.constant 256 : i32
84
- %cst_1 = arith.constant dense <0 > : tensor <256 xi32 , #ttg.slice <{dim = 0 , parent = #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 2 }>}>>
85
- %cst_4 = arith.constant dense <0.000000e+00 > : tensor <64 x256 xf16 , #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 2 }>>
86
- %0 = tt.get_program_id x : i32
87
- %1 = arith.addi %arg3 , %c127_i32 : i32
88
- %2 = arith.divsi %1 , %c128_i32 : i32
89
- %3 = arith.addi %arg4 , %c255_i32 : i32
90
- %4 = arith.divsi %3 , %c256_i32 : i32
91
- %5 = arith.muli %4 , %c8_i32 : i32
92
- %6 = arith.divsi %0 , %5 : i32
93
- %7 = arith.muli %6 , %c8_i32 : i32
94
- %8 = arith.subi %2 , %7 : i32
95
- %9 = arith.minsi %8 , %c8_i32 : i32
96
- %12 = arith.remsi %0 , %5 : i32
97
- %13 = arith.divsi %12 , %9 : i32
98
- %15 = arith.muli %13 , %c256_i32 : i32
99
- %22 = tt.make_range {end = 256 : i32 , start = 0 : i32 } : tensor <256 xi32 , #ttg.slice <{dim = 0 , parent = #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 2 }>}>>
100
- %24 = tt.splat %15 : i32 -> tensor <256 xi32 , #ttg.slice <{dim = 0 , parent = #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 2 }>}>>
101
- %26 = arith.addi %24 , %22 : tensor <256 xi32 , #ttg.slice <{dim = 0 , parent = #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 2 }>}>>%31 = tt.splat %arg4 : i32 -> tensor <256 xi32 , #ttg.slice <{dim = 0 , parent = #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 2 }>}>>
102
- %44 = tt.make_range {end = 64 : i32 , start = 0 : i32 } : tensor <64 xi32 , #ttg.slice <{dim = 1 , parent = #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 2 }>}>>
103
- %45 = tt.expand_dims %44 {axis = 1 : i32 } : tensor <64 xi32 , #ttg.slice <{dim = 1 , parent = #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 2 }>}>> -> tensor <64 x1 xi32 , #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 2 }>>
104
- %cst_2 = arith.constant dense <512 > : tensor <64 x1 xi32 , #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 2 }>>
105
- %47 = arith.muli %45 , %cst_2 : tensor <64 x1 xi32 , #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 2 }>>
106
- %48 = tt.expand_dims %26 {axis = 0 : i32 } : tensor <256 xi32 , #ttg.slice <{dim = 0 , parent = #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 2 }>}>> -> tensor <1 x256 xi32 , #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 2 }>>
107
- %49 = tt.broadcast %47 : tensor <64 x1 xi32 , #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 2 }>> -> tensor <64 x256 xi32 , #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 2 }>>
108
- %50 = tt.broadcast %48 : tensor <1 x256 xi32 , #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 2 }>> -> tensor <64 x256 xi32 , #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 2 }>>
109
- %51 = arith.addi %49 , %50 : tensor <64 x256 xi32 , #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 2 }>>
110
- %52 = tt.splat %arg1 : !tt.ptr <f16 > -> tensor <64 x256 x!tt.ptr <f16 >, #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 2 }>>
111
- %53 = tt.addptr %52 , %51 : tensor <64 x256 x!tt.ptr <f16 >, #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 2 }>>, tensor <64 x256 xi32 , #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 2 }>>
112
- %54 = arith.addi %arg5 , %c63_i32 : i32
113
- %55 = arith.divsi %54 , %c64_i32 : i32
114
- %56 = arith.muli %arg7 , %c64_i32 : i32
115
- %57 = tt.splat %56 : i32 -> tensor <64 x256 xi32 , #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 2 }>>
116
- cf.br ^bb1 (%c0_i32 , %53 : i32 , tensor <64 x256 x!tt.ptr <f16 >, #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 2 }>>)
117
- ^bb1 (%58: i32 , %61: tensor <64 x256 x!tt.ptr <f16 >, #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 2 }>>): // 2 preds: ^bb0, ^bb2
118
- %62 = arith.cmpi slt , %58 , %55 : i32
119
- cf.cond_br %62 , ^bb2 , ^bb3
120
- ^bb2 : // pred: ^bb1
121
- %63 = arith.muli %58 , %c64_i32 : i32
122
- %64 = arith.subi %arg5 , %63 : i32
123
- %69 = tt.splat %64 : i32 -> tensor <64 x1 xi32 , #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 2 }>>
124
- %70 = arith.cmpi slt , %45 , %69 : tensor <64 x1 xi32 , #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 2 }>>
125
- %71 = tt.broadcast %70 : tensor <64 x1 xi1 , #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 2 }>> -> tensor <64 x256 xi1 , #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 2 }>>
126
- // CHECK-COUNT-8: triton_gen.2Dblockload {{.*}} {elem_size_in_bits = 16, tile_width = 16, tile_height = 32, v_blocks = 1
127
- %72 = tt.load %61 , %71 , %cst_4 {ttig.block_io = " row_major" } : tensor <64 x256 x!tt.ptr <f16 >, #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 2 }>>
128
- %75 = tt.addptr %61 , %57 : tensor <64 x256 x!tt.ptr <f16 >, #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 2 }>>, tensor <64 x256 xi32 , #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 2 }>>
129
- %76 = arith.addi %58 , %c1_i32 : i32
130
- cf.br ^bb1 (%76 , %75 : i32 , tensor <64 x256 x!tt.ptr <f16 >, #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 2 }>>)
131
- ^bb3 : // pred: ^bb1
132
- tt.return
133
- }
134
- }
135
-
136
- // -----
137
-
138
72
#mma = #ttig.dpas <{repeatCount = 8 , systolicDepth = 8 , executionSize = 16 , opsPerChan = 2 , threadsPerWarp = 16 , warpsPerCTA = [8 , 1 ], repCluster = [2 , 2 ]}>
139
73
#mma_1 = #ttig.dpas <{repeatCount = 8 , systolicDepth = 8 , executionSize = 16 , opsPerChan = 2 , threadsPerWarp = 16 , warpsPerCTA = [4 , 2 ], repCluster = [1 , 1 ]}>
140
74
#mma_2 = #ttig.dpas <{repeatCount = 8 , systolicDepth = 8 , executionSize = 16 , opsPerChan = 2 , threadsPerWarp = 16 , warpsPerCTA = [8 , 1 ], repCluster = [4 , 2 ]}>
@@ -259,6 +193,7 @@ module attributes {ttig.support_sg_2d_block, "ttg.num-warps" = 8 : i32, "ttg.thr
259
193
260
194
261
195
// CHECK: %[[BLOCK_SHAPE_Y:.*]] = llvm.mlir.constant(16 : i32) : i32
196
+ // CHECK-COUNT-2: llvm.mlir.constant(0 : i32) : i32
262
197
// CHECK: %[[CST0_0:.*]] = llvm.mlir.constant(0 : i32) : i32
263
198
// CHECK: %[[CST0_1:.*]] = llvm.mlir.constant(0 : i32) : i32
264
199
// CHECK: %[[TOP_LEFT_MASK_0:.*]] = llvm.zext %[[TOP_LEFT_MASK_BOOL_0]] : i1 to i8
@@ -267,22 +202,25 @@ module attributes {ttig.support_sg_2d_block, "ttg.num-warps" = 8 : i32, "ttg.thr
267
202
// CHECK: %[[BASE_Y_0:.*]] = llvm.select %[[PRED_BOOL]], %[[CST0_0]], %[[BLOCK_SHAPE_Y]] : i1, i32
268
203
// CHECK: %[[LOAD_0:.*]] = triton_gen.2Dblockload {{.*}}, %[[BASE_Y_0]] {elem_size_in_bits = 16, tile_width = 16, tile_height = 16, v_blocks = 2
269
204
205
+ // CHECK-COUNT-3: llvm.mlir.constant(0 : i32) : i32
270
206
// CHECK: %[[CST0_0:.*]] = llvm.mlir.constant(0 : i32) : i32
271
207
// CHECK: %[[CST0_1:.*]] = llvm.mlir.constant(0 : i32) : i32
272
- // CHECK: %[[TOP_LEFT_MASK_1:.*]] = llvm.zext %[[TOP_LEFT_MASK_BOOL_64 ]] : i1 to i8
208
+ // CHECK: %[[TOP_LEFT_MASK_1:.*]] = llvm.zext %[[TOP_LEFT_MASK_BOOL_32 ]] : i1 to i8
273
209
// CHECK: %[[PRED:.*]] = llvm.call spir_funccc @_Z17sub_group_shufflecj(%[[TOP_LEFT_MASK_1]], %[[CST0_1]])
274
210
// CHECK: %[[PRED_BOOL:.*]] = llvm.trunc %[[PRED]] : i8 to i1
275
211
// CHECK: %[[BASE_Y_1:.*]] = llvm.select %[[PRED_BOOL]], %[[CST0_0]], %[[BLOCK_SHAPE_Y]] : i1, i32
276
212
// CHECK: %[[LOAD_1:.*]] = triton_gen.2Dblockload {{.*}}, %[[BASE_Y_1]] {elem_size_in_bits = 16, tile_width = 16, tile_height = 16, v_blocks = 2
277
213
214
+ // CHECK-COUNT-3: llvm.mlir.constant(0 : i32) : i32
278
215
// CHECK: %[[CST0_0:.*]] = llvm.mlir.constant(0 : i32) : i32
279
216
// CHECK: %[[CST0_1:.*]] = llvm.mlir.constant(0 : i32) : i32
280
- // CHECK: %[[TOP_LEFT_MASK_2:.*]] = llvm.zext %[[TOP_LEFT_MASK_BOOL_32 ]] : i1 to i8
217
+ // CHECK: %[[TOP_LEFT_MASK_2:.*]] = llvm.zext %[[TOP_LEFT_MASK_BOOL_64 ]] : i1 to i8
281
218
// CHECK: %[[PRED:.*]] = llvm.call spir_funccc @_Z17sub_group_shufflecj(%[[TOP_LEFT_MASK_2]], %[[CST0_1]])
282
219
// CHECK: %[[PRED_BOOL:.*]] = llvm.trunc %[[PRED]] : i8 to i1
283
220
// CHECK: %[[BASE_Y_2:.*]] = llvm.select %[[PRED_BOOL]], %[[CST0_0]], %[[BLOCK_SHAPE_Y]] : i1, i32
284
221
// CHECK: %[[LOAD_2:.*]] = triton_gen.2Dblockload {{.*}}, %[[BASE_Y_2]] {elem_size_in_bits = 16, tile_width = 16, tile_height = 16, v_blocks = 2
285
222
223
+ // CHECK-COUNT-3: llvm.mlir.constant(0 : i32) : i32
286
224
// CHECK: %[[CST0_0:.*]] = llvm.mlir.constant(0 : i32) : i32
287
225
// CHECK: %[[CST0_1:.*]] = llvm.mlir.constant(0 : i32) : i32
288
226
// CHECK: %[[TOP_LEFT_MASK_3:.*]] = llvm.zext %[[TOP_LEFT_MASK_BOOL_96]] : i1 to i8
0 commit comments