Skip to content

Commit 2bd200f

Browse files
Merge commit 'e87fdd9ac41f136619a76ee22cc3d4e926e319de'
2 parents e52429a + e87fdd9 commit 2bd200f

File tree

20 files changed

+366
-204
lines changed

20 files changed

+366
-204
lines changed

lib/Analysis/AxisInfo.cpp

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1109,10 +1109,7 @@ void AxisInfoAnalysis::visitForOpInductionVar(
11091109
AxisInfo::DimVectorT knownContiguity(1, 1);
11101110
AxisInfo::DimVectorT knownDivisibility(1, 1);
11111111
AxisInfo::DimVectorT knownConstancy(1, 1);
1112-
auto lbDivisibility = lb.getDivisibility();
1113-
auto stepDivisibility = step.getDivisibility();
1114-
if (!lbDivisibility.empty() && !stepDivisibility.empty())
1115-
knownDivisibility[0] = gcd(lbDivisibility[0], stepDivisibility[0]);
1112+
knownDivisibility[0] = gcd(lb.getDivisibility(0), step.getDivisibility(0));
11161113
auto inductionVar =
11171114
AxisInfo(knownContiguity, knownDivisibility, knownConstancy);
11181115
(void)argLattices[0]->join(inductionVar);

lib/Analysis/Utility.cpp

Lines changed: 1 addition & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -1167,63 +1167,10 @@ SetVector<Operation *> multiRootGetSlice(Operation *op,
11671167
return multiRootTopologicalSort(slice);
11681168
}
11691169

1170-
namespace {
1171-
// Copied from TestDeadCodeAnalysis.cpp, because some dead code analysis
1172-
// interacts with constant propagation, but SparseConstantPropagation
1173-
// doesn't seem to be sufficient.
1174-
class ConstantAnalysis : public DataFlowAnalysis {
1175-
public:
1176-
using DataFlowAnalysis::DataFlowAnalysis;
1177-
1178-
LogicalResult initialize(Operation *top) override {
1179-
WalkResult result = top->walk([&](Operation *op) {
1180-
ProgramPoint programPoint(op);
1181-
if (failed(visit(&programPoint)))
1182-
return WalkResult::interrupt();
1183-
return WalkResult::advance();
1184-
});
1185-
return success(!result.wasInterrupted());
1186-
}
1187-
1188-
LogicalResult visit(ProgramPoint *point) override {
1189-
Operation *op = point->getOperation();
1190-
Attribute value;
1191-
if (matchPattern(op, m_Constant(&value))) {
1192-
auto *constant = getOrCreate<dataflow::Lattice<dataflow::ConstantValue>>(
1193-
op->getResult(0));
1194-
propagateIfChanged(constant, constant->join(dataflow::ConstantValue(
1195-
value, op->getDialect())));
1196-
return success();
1197-
}
1198-
// Dead code analysis requires every operands has initialized ConstantValue
1199-
// state before it is visited.
1200-
// https://github.com/llvm/llvm-project/blob/2ec1aba2b69faa1de5f71832a48e25aa3b5d5314/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp#L322
1201-
// That's why we need to set all operands to unknown constants.
1202-
setAllToUnknownConstants(op->getResults());
1203-
for (Region &region : op->getRegions()) {
1204-
for (Block &block : region.getBlocks())
1205-
setAllToUnknownConstants(block.getArguments());
1206-
}
1207-
return success();
1208-
}
1209-
1210-
private:
1211-
/// Set all given values as not constants.
1212-
void setAllToUnknownConstants(ValueRange values) {
1213-
dataflow::ConstantValue unknownConstant(nullptr, nullptr);
1214-
for (Value value : values) {
1215-
auto *constant =
1216-
getOrCreate<dataflow::Lattice<dataflow::ConstantValue>>(value);
1217-
propagateIfChanged(constant, constant->join(unknownConstant));
1218-
}
1219-
}
1220-
};
1221-
} // namespace
1222-
12231170
std::unique_ptr<DataFlowSolver> createDataFlowSolver() {
12241171
auto solver = std::make_unique<DataFlowSolver>();
12251172
solver->load<dataflow::DeadCodeAnalysis>();
1226-
solver->load<ConstantAnalysis>();
1173+
solver->load<dataflow::SparseConstantPropagation>();
12271174
return solver;
12281175
}
12291176

lib/Dialect/TritonGPU/IR/Dialect.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2754,7 +2754,8 @@ struct TritonGPUInferLayoutInterface
27542754
auto mmaRetEncoding = mlir::dyn_cast<NvidiaMmaEncodingAttr>(retEncoding);
27552755
if (mmaRetEncoding && mmaRetEncoding.isHopper()) {
27562756
auto dotOpEnc = mlir::dyn_cast<DotOperandEncodingAttr>(operandEncoding);
2757-
if (!mlir::isa<NVMMASharedEncodingAttr>(operandEncoding) &&
2757+
if (!mlir::isa<NVMMASharedEncodingAttr, SharedLinearEncodingAttr>(
2758+
operandEncoding) &&
27582759
!(opIdx == 0 && dotOpEnc && dotOpEnc.getOpIdx() == 0 &&
27592760
mlir::isa<NvidiaMmaEncodingAttr>(dotOpEnc.getParent()))) {
27602761
return emitOptionalError(

lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ LogicalResult WarpGroupDotOp::inferReturnTypes(
5050

5151
// verify encodings
5252
auto aEnc = cast<TensorOrMemDesc>(operands[0].getType()).getEncoding();
53-
auto bEnc = cast<TensorOrMemDesc>(operands[1].getType()).getEncoding();
53+
auto bEnc = cast<MemDescType>(operands[1].getType()).getEncoding();
5454
auto retEnc = accTy.getEncoding();
5555
if (aEnc) {
5656
assert(bEnc);
@@ -70,10 +70,11 @@ LogicalResult WarpGroupDotOp::verify() {
7070
if (!nvmmaEnc || !nvmmaEnc.isHopper())
7171
return emitOpError("WGMMA result layout must be Hopper NVMMA");
7272

73-
if (!isa<NVMMASharedEncodingAttr, DotOperandEncodingAttr>(
74-
getA().getType().getEncoding()))
73+
if (!isa<NVMMASharedEncodingAttr, DotOperandEncodingAttr,
74+
SharedLinearEncodingAttr>(getA().getType().getEncoding()))
7575
return emitOpError("WGMMA A operand must have NVMMA shared or dot layout");
76-
if (!isa<NVMMASharedEncodingAttr>(getB().getType().getEncoding()))
76+
if (!isa<NVMMASharedEncodingAttr, SharedLinearEncodingAttr>(
77+
getB().getType().getEncoding()))
7778
return emitOpError("WGMMA B operand must have NVMMA shared layout");
7879

7980
auto numWarps = gpu::lookupNumWarps(getOperation());

python/src/gluon_ir.cc

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -669,6 +669,18 @@ void init_gluon_ir(py::module &&m) {
669669
pred, two_ctas, mbarriers,
670670
mbarrier_preds);
671671
})
672+
.def("create_tcgen05_mma_scaled",
673+
[](GluonOpBuilder &self, Value a, Value b, Value acc, Value aScale,
674+
Value bScale, tt::ScaleDotElemType aType,
675+
tt::ScaleDotElemType bType, Value useAcc, Value pred,
676+
std::vector<Value> &mbarriers,
677+
std::vector<Value> &mbarrier_preds) {
678+
Value accDep;
679+
auto tokType = self.getBuilder().getType<ttg::AsyncTokenType>();
680+
self.create<ttng::TCGen5MMAScaledOp>(
681+
tokType, a, b, acc, accDep, aScale, bScale, aType, bType,
682+
useAcc, pred, mbarriers, mbarrier_preds);
683+
})
672684
.def("create_tcgen05_commit",
673685
[](GluonOpBuilder &self, Value &barrier) {
674686
self.create<ttng::TCGen5CommitOp>(barrier);

python/src/ir.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -455,6 +455,8 @@ void init_triton_ir(py::module &&m) {
455455
auto loc = UnknownLoc::get(ty.getContext());
456456
self.addArgument(ty, loc);
457457
})
458+
.def("add_argument_at", [](Block &self, Type ty,
459+
Location loc) { self.addArgument(ty, loc); })
458460
.def("get_num_arguments", &Block::getNumArguments)
459461
.def("get_argument", &Block::getArgument)
460462
.def("dump", &Block::dump)

python/test/gluon/test_core.py

Lines changed: 81 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import torch
2+
import math
23
import pytest
34
import re
45
from itertools import product
@@ -126,9 +127,9 @@ def test_async_copy_mbarrier(device):
126127

127128

128129
@gluon.jit
129-
def warpgroup_mma_kernel(a, b, out, M: ttgl.constexpr, N: ttgl.constexpr, K: ttgl.constexpr,
130-
block_layout: ttgl.constexpr, mma_layout: ttgl.constexpr, shared_layout_a: ttgl.constexpr,
131-
shared_layout_b: ttgl.constexpr, acc_dtype: ttgl.constexpr, ASYNC: ttgl.constexpr):
130+
def mma_kernel(a, b, out, M: ttgl.constexpr, N: ttgl.constexpr, K: ttgl.constexpr, block_layout: ttgl.constexpr,
131+
mma_layout: ttgl.constexpr, shared_layout_a: ttgl.constexpr, shared_layout_b: ttgl.constexpr,
132+
acc_dtype: ttgl.constexpr, ASYNC: ttgl.constexpr, USE_TCGEN05: ttgl.constexpr):
132133
a_offs_m = ttgl.arange(0, M, layout=ttgl.SliceLayout(1, block_layout))[:, None]
133134
a_offs_k = ttgl.arange(0, K, layout=ttgl.SliceLayout(0, block_layout))[None, :]
134135
b_offs_k = ttgl.arange(0, K, layout=ttgl.SliceLayout(1, block_layout))[:, None]
@@ -143,14 +144,37 @@ def warpgroup_mma_kernel(a, b, out, M: ttgl.constexpr, N: ttgl.constexpr, K: ttg
143144

144145
smem_a = ttgl.allocate_shared_memory(operand_dtype, [M, K], shared_layout_a, a_tile)
145146
smem_b = ttgl.allocate_shared_memory(operand_dtype, [K, N], shared_layout_b, b_tile)
146-
147147
fence_async_shared()
148148

149-
acc = ttgl.zeros([M, N], dtype=acc_dtype, layout=mma_layout)
150-
acc = hopper.warpgroup_mma(smem_a, smem_b, acc, is_async=ASYNC)
149+
if USE_TCGEN05:
150+
tmem_layout: ttgl.constexpr = TensorMemoryLayout((M, N), col_stride=32 // acc_dtype.primitive_bitwidth)
151+
152+
num_warps: ttgl.constexpr = ttgl.num_warps()
153+
tmem_reg_layout: ttgl.constexpr = get_tmem_32x32b_reg_layout(
154+
M=M,
155+
N=N,
156+
shape=[M, N],
157+
num_warps=num_warps,
158+
)
159+
160+
mma_barrier = ttgl.allocate_shared_memory(ttgl.int64, [1], mbarrier.MBarrierLayout())
161+
mbarrier.init(mma_barrier, count=1)
162+
163+
acc_zero = ttgl.zeros([M, N], dtype=acc_dtype, layout=tmem_reg_layout)
164+
acc_tmem = allocate_tensor_memory(acc_dtype, [M, N], tmem_layout, acc_zero)
151165

152-
if ASYNC:
153-
acc = hopper.warpgroup_mma_wait(num_outstanding=0, deps=[acc])
166+
tcgen05_mma(smem_a, smem_b, acc_tmem, use_acc=False)
167+
tcgen05_commit(mma_barrier)
168+
mbarrier.wait(mma_barrier, phase=0)
169+
mbarrier.invalidate(mma_barrier)
170+
acc = acc_tmem.load(tmem_reg_layout)
171+
acc = ttgl.convert_layout(acc, layout=mma_layout)
172+
else:
173+
acc = ttgl.zeros([M, N], dtype=acc_dtype, layout=mma_layout)
174+
acc = hopper.warpgroup_mma(smem_a, smem_b, acc, is_async=ASYNC)
175+
176+
if ASYNC:
177+
acc = hopper.warpgroup_mma_wait(num_outstanding=0, deps=[acc])
154178

155179
ttgl.store(out + out_offs_m * N + out_offs_n, acc)
156180

@@ -168,7 +192,7 @@ def test_warpgroup_mma(ASYNC):
168192
a = torch.randn((M, K), device="cuda", dtype=torch.float16)
169193
b = torch.randn((K, N), device="cuda", dtype=torch.float16)
170194
out = torch.zeros((M, N), device="cuda", dtype=torch.float16)
171-
warpgroup_mma_kernel[(1, )](
195+
mma_kernel[(1, )](
172196
a,
173197
b,
174198
out,
@@ -181,6 +205,7 @@ def test_warpgroup_mma(ASYNC):
181205
shared_layout_b,
182206
ttgl.float16,
183207
ASYNC,
208+
False,
184209
num_warps=warps[0] * warps[1],
185210
)
186211

@@ -189,19 +214,24 @@ def test_warpgroup_mma(ASYNC):
189214
torch.testing.assert_close(out, ref, atol=1e-3, rtol=1e-1)
190215

191216

192-
@pytest.mark.xfail(not is_hopper(), reason="Requires Hopper", run=False)
217+
@pytest.mark.xfail(not (is_hopper() or is_blackwell()), reason="Requires Hopper or Blackwell", run=False)
193218
@pytest.mark.parametrize("bitwidth, transpose_a, transpose_b, acc_dtype",
194219
[(bitwidth, transpose_a, transpose_b, acc_dtype)
195220
for bitwidth in [8, 16, 32]
196221
for (transpose_a, transpose_b) in product([False, True], repeat=2)
197222
for acc_dtype in [torch.float16, torch.float32]
198223
if bitwidth == 16 or (acc_dtype == torch.float32 and not transpose_a and transpose_b)])
199224
@pytest.mark.parametrize("warps", ([8, 1], [4, 2], [4, 1]))
200-
# Swizzling 0 does not map to a valid memory descriptor lol
201-
@pytest.mark.parametrize("swizzling_a, swizzling_b", product([32, 64, 128], repeat=2))
225+
@pytest.mark.parametrize("swizzling_a, swizzling_b", product([0, 32, 64, 128], repeat=2))
202226
@pytest.mark.parametrize("shape_m, shape_n, shape_k", [(1, 1, 1), (2, 4, 1), (2, 2, 4)])
203-
def test_warpgroup_mma_shared_inputs(bitwidth, transpose_a, transpose_b, acc_dtype, warps, swizzling_a, swizzling_b,
204-
shape_m, shape_n, shape_k):
227+
def test_mma_shared_inputs(bitwidth, transpose_a, transpose_b, acc_dtype, warps, swizzling_a, swizzling_b, shape_m,
228+
shape_n, shape_k, fresh_knobs):
229+
230+
# FIXME: Workaround for a bug in PTXAS when the shared layout is transposed and the swizzling is 0
231+
if bitwidth == 16 and ((transpose_a and swizzling_a == 0 and shape_m > 1) or
232+
(not transpose_b and swizzling_b == 0 and shape_n > 1)):
233+
fresh_knobs.nvidia.disable_ptxas_opt = True
234+
use_tcgen05 = is_blackwell()
205235

206236
torch_dtype_map = {
207237
8: torch.float8_e4m3fn,
@@ -214,8 +244,7 @@ def test_warpgroup_mma_shared_inputs(bitwidth, transpose_a, transpose_b, acc_dty
214244
}
215245

216246
# We'll choose a larger instr shape along N, but sure
217-
instr_shape_k_map = {8: 32, 16: 16, 32: 8}
218-
instr_shape = [16, 32, instr_shape_k_map[bitwidth]]
247+
instr_shape = [16, 32, 256 // bitwidth]
219248
M = instr_shape[0] * warps[0]
220249
N = instr_shape[1] * warps[1]
221250
K = instr_shape[2]
@@ -239,7 +268,27 @@ def min_shape(swizzling, dim0, dim1, trans):
239268
K *= shape_k
240269
instr_shape[1] *= shape_n
241270

242-
shared_mem_accum = M * K * bitwidth // 8 + K * N * bitwidth // 8
271+
if use_tcgen05:
272+
M = 128
273+
274+
def get_shared_swizzling_zero(M, K, transpose):
275+
# K-contig
276+
if transpose:
277+
K, M = M, K
278+
bases = []
279+
for i in range(int(math.log2(128 // bitwidth))):
280+
bases.append([0, 1 << i])
281+
for i in range(int(math.log2(M))):
282+
bases.append([1 << i, 0])
283+
for i in range(int(math.log2(K // (128 // bitwidth)))):
284+
offset = int(math.log2(128 // bitwidth)) + i
285+
bases.append([0, 1 << offset])
286+
if transpose:
287+
for i in range(len(bases)):
288+
bases[i] = [bases[i][1], bases[i][0]]
289+
return ttgl.SharedLinearLayout(bases)
290+
291+
shared_mem_accum = (M + N) * K * bitwidth // 8
243292
if triton.runtime.driver.active.utils.get_device_properties(
244293
triton.runtime.driver.active.get_current_device())["max_shared_mem"] < shared_mem_accum:
245294
pytest.skip("Skipped due to insufficient shared memory on this GPU.")
@@ -248,11 +297,17 @@ def min_shape(swizzling, dim0, dim1, trans):
248297
gl_acc_dtype = acc_dtype_map[acc_dtype]
249298
out_dtype = torch.float32
250299

251-
block_layout = ttgl.BlockedLayout([1, 1], [1, THREADS_PER_WARP], warps_per_cta=warps, order=[1, 0])
252-
shared_layout_a = ttgl.NVMMASharedLayout(swizzle_byte_width=swizzling_a, element_bitwidth=bitwidth, rank=2,
253-
transposed=transpose_a)
254-
shared_layout_b = ttgl.NVMMASharedLayout(swizzle_byte_width=swizzling_b, element_bitwidth=bitwidth, rank=2,
255-
transposed=transpose_b)
300+
block_layout = ttgl.BlockedLayout([1, 8], [1, THREADS_PER_WARP], warps_per_cta=warps, order=[1, 0])
301+
if swizzling_a == 0:
302+
shared_layout_a = get_shared_swizzling_zero(M, K, transpose_a)
303+
else:
304+
shared_layout_a = ttgl.NVMMASharedLayout(swizzle_byte_width=swizzling_a, element_bitwidth=bitwidth, rank=2,
305+
transposed=transpose_a)
306+
if swizzling_b == 0:
307+
shared_layout_b = get_shared_swizzling_zero(K, N, transpose_b)
308+
else:
309+
shared_layout_b = ttgl.NVMMASharedLayout(swizzle_byte_width=swizzling_b, element_bitwidth=bitwidth, rank=2,
310+
transposed=transpose_b)
256311
mma_layout = ttgl.NVMMADistributedLayout(version=[3, 0], warps_per_cta=warps, instr_shape=instr_shape)
257312

258313
torch.manual_seed(0)
@@ -271,7 +326,7 @@ def cast(x, dtype):
271326
b = cast(torch.randn((K, N), device="cuda", dtype=torch.float32), torch_dtype)
272327
out = torch.zeros((M, N), device="cuda", dtype=out_dtype)
273328

274-
warpgroup_mma_kernel[(1, )](
329+
mma_kernel[(1, )](
275330
a,
276331
b,
277332
out,
@@ -284,6 +339,7 @@ def cast(x, dtype):
284339
shared_layout_b,
285340
gl_acc_dtype,
286341
False,
342+
use_tcgen05,
287343
num_warps=warps[0] * warps[1],
288344
)
289345

@@ -298,9 +354,9 @@ def cast(x, dtype):
298354
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = allow_fp16_red
299355

300356
if bitwidth == 8:
301-
atol, rtol = 0.5, 0.5
357+
atol, rtol = 5e-2, 5e-1
302358
elif bitwidth == 16:
303-
atol, rtol = 3e-2, 1e-1
359+
atol, rtol = 5e-2, 5e-1
304360
else:
305361
atol, rtol = 5e-4, 5e-3
306362
torch.testing.assert_close(out, ref, atol=atol, rtol=rtol)

0 commit comments

Comments
 (0)