Skip to content
Merged
10 changes: 10 additions & 0 deletions include/triton/Dialect/Gluon/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -35,4 +35,14 @@ def GluonInline: Pass<"gluon-inline"> {
let dependentDialects = [];
}

def GluonSimplifyControlFlow: Pass<"gluon-slimplify-control-flow"> {
let summary = "simplications for control flow ops";

let description = [{
The `gluon-inline` pass applies a reduced set of simplification
and canonicalization patterns to the module.
}];
let dependentDialects = [];
}

#endif
5 changes: 5 additions & 0 deletions include/triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,11 @@ LinearLayout chooseScaledMfmaScaleLayout(MLIRContext *ctx, int dotOperandIdx,
ArrayRef<unsigned> tilesPerWarp,
ArrayRef<unsigned> warpsPerCTA);

LinearLayout chooseScaledWmmaScaleLayout(
MLIRContext *ctx, int dotOperandIdx,
const std::vector<std::vector<int32_t>> &dotOperandWarpBasis,
ArrayRef<int64_t> dotOperandShape);

LinearLayout getSM120DotScaledScaleLayout(MLIRContext *ctx, int dotOperandIdx,
ArrayRef<int64_t> dotOperandShape,
ArrayRef<unsigned> tilesPerWarp,
Expand Down
3 changes: 1 addition & 2 deletions include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td
Original file line number Diff line number Diff line change
Expand Up @@ -1307,8 +1307,7 @@ Row |
let hasCustomAssemblyFormat = 1;

let extraClassDeclaration = extraDistributedDeclaration # [{
SmallVector<int64_t> getRepForOperand(ArrayRef<int64_t> operandShape,
Type elemType, int opIdx) const;
SmallVector<int64_t> getRepForOperand(ArrayRef<int64_t> operandShape, int kDim, int opIdx) const;
SmallVector<unsigned> getRepOrderForOperand(int opIdx) const;

static SmallVector<unsigned, 3> getDefaultInstrShape() {
Expand Down
19 changes: 19 additions & 0 deletions include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -574,4 +574,23 @@ def TTG_WarpReturnOp : TTG_Op<"warp_return", [
let assemblyFormat = "attr-dict";
}

def TTG_LocalBarrierOp : TTG_Op<"local_barrier"> {
let summary = "Synchronizes execution and shared memory reads/writes for all threads in a CTA.";
let description = [{
The `local_barrier` op synchronizes the execution and all operations
between shared memory and registers for all threads in a CTA.
It is used to coordinate communication between the threads of the CTA.

This operation waits until all threads in the CTA have reached a `local_barrier`
and operations between shared memory and registers made by these threads prior
to the op are visible to all threads in the CTA.

Data hazards between threads accessing the same memory can be avoided by synchronizing the
CTA in-between these accesses with a `local_barrier`.

A `local_barrier` operation does not provide syncronization guarantees on global memory.
}];
let assemblyFormat = "attr-dict";
}

#endif // TRITONGPU_OPS
6 changes: 3 additions & 3 deletions lib/Analysis/Membar.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -159,20 +159,20 @@ void MembarOrFenceAnalysis::visitTerminator(

void MembarAnalysis::insertBarrier(Operation *op, OpBuilder *builder) {
OpBuilder::InsertionGuard g(*builder);
auto barrierOp = builder->create<gpu::BarrierOp>(op->getLoc());
auto barrierOp = builder->create<triton::gpu::LocalBarrierOp>(op->getLoc());
}

void MembarAnalysis::update(Operation *op, BlockInfo *blockInfo,
FuncBlockInfoMapT *funcBlockInfoMap,
OpBuilder *builder) {
if (isa<gpu::BarrierOp>(op)) {
if (isa<gpu::BarrierOp, triton::gpu::LocalBarrierOp>(op)) {
// If the current op is a barrier, we sync previous reads and writes
blockInfo->sync();
return;
}

if (isa<triton::gpu::AsyncWaitOp, triton::nvidia_gpu::TMAStoreWaitOp>(op) &&
!isa<gpu::BarrierOp>(op->getNextNode())) {
!isa<gpu::BarrierOp, triton::gpu::LocalBarrierOp>(op->getNextNode())) {
// If the current op is an async wait and the next op is not a barrier we
// insert a barrier op and sync
builder->setInsertionPointAfter(op);
Expand Down
21 changes: 21 additions & 0 deletions lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include "mlir/Conversion/LLVMCommon/Pattern.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/IR/PatternMatch.h"
#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h"
#include "triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h"
Expand Down Expand Up @@ -232,6 +233,25 @@ struct LocalStoreOpConversion
const TargetInfoBase &targetInfo;
};

class LocalBarrierOpConversion
: public ConvertOpToLLVMPattern<triton::gpu::LocalBarrierOp> {
public:
LocalBarrierOpConversion(const LLVMTypeConverter &converter,
PatternBenefit benefit)
: ConvertOpToLLVMPattern<triton::gpu::LocalBarrierOp>(converter,
benefit) {}
using OpAdaptor = typename triton::gpu::LocalBarrierOp::Adaptor;

LogicalResult
matchAndRewrite(triton::gpu::LocalBarrierOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

rewriter.replaceOpWithNewOp<mlir::gpu::BarrierOp>(op);

return success();
}
};

} // namespace

void mlir::triton::populateMemoryOpToLLVMPatterns(
Expand All @@ -243,4 +263,5 @@ void mlir::triton::populateMemoryOpToLLVMPatterns(
patterns.add<LocalDeallocOpConversion>(typeConverter, benefit);
patterns.add<LocalLoadOpConversion>(typeConverter, targetInfo, benefit);
patterns.add<LocalStoreOpConversion>(typeConverter, targetInfo, benefit);
patterns.add<LocalBarrierOpConversion>(typeConverter, benefit);
}
1 change: 1 addition & 0 deletions lib/Dialect/Gluon/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ add_triton_library(GluonTransforms
Canonicalize.cpp
Inline.cpp
ResolveAutoEncodings.cpp
SimplifyControlFlow.cpp

DEPENDS
GluonTransformsIncGen
Expand Down
2 changes: 1 addition & 1 deletion lib/Dialect/Gluon/Transforms/Inline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ struct Inline : public gluon::impl::GluonInlineBase<Inline> {
void Inline::runOnOperation() {
mlir::PassManager pm(&getContext());
pm.addPass(createInlinerPass(/*opPipelines=*/{}, [](OpPassManager &pm) {
pm.addPass(gluon::createGluonCanonicalize());
pm.addPass(gluon::createGluonSimplifyControlFlow());
}));
if (failed(pm.run(getOperation())))
return signalPassFailure();
Expand Down
49 changes: 49 additions & 0 deletions lib/Dialect/Gluon/Transforms/SimplifyControlFlow.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
#include "mlir/IR/OperationSupport.h"
#include "triton/Dialect/Gluon/Transforms/Passes.h"

#include "triton/Dialect/TritonGPU/Transforms/Utility.h"

#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

using namespace mlir;
using namespace triton;

namespace mlir::triton::gluon {
#define GEN_PASS_DEF_GLUONSIMPLIFYCONTROLFLOW
#include "triton/Dialect/Gluon/Transforms/Passes.h.inc"
} // namespace mlir::triton::gluon

namespace {
struct SimplifyControlFlow
: public gluon::impl::GluonSimplifyControlFlowBase<SimplifyControlFlow> {
void runOnOperation() override;
};
} // namespace

void SimplifyControlFlow::runOnOperation() {
MLIRContext *ctx = &getContext();
RewritePatternSet patterns(&getContext());

// Populate `scf` and `cf` canonicalizers.
ctx->getLoadedDialect<scf::SCFDialect>()->getCanonicalizationPatterns(
patterns);
ctx->getLoadedDialect<cf::ControlFlowDialect>()->getCanonicalizationPatterns(
patterns);
for (mlir::RegisteredOperationName op : ctx->getRegisteredOperationsByDialect(
scf::SCFDialect::getDialectNamespace()))
op.getCanonicalizationPatterns(patterns, ctx);
for (mlir::RegisteredOperationName op : ctx->getRegisteredOperationsByDialect(
cf::ControlFlowDialect::getDialectNamespace()))
op.getCanonicalizationPatterns(patterns, ctx);
populateForOpDeadArgumentElimination(patterns);

GreedyRewriteConfig config;
// This is intended to run before AutoLayouts are resolved, in which case
// CSEing constants can lead to additional layout conflicts.
config.enableConstantCSE(false);
(void)applyPatternsGreedily(getOperation(), std::move(patterns), config);
}
9 changes: 4 additions & 5 deletions lib/Dialect/TritonGPU/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2300,12 +2300,11 @@ AMDWmmaEncodingAttr::getRepOrderForOperand(int opIdx) const {
}

SmallVector<int64_t>
AMDWmmaEncodingAttr::getRepForOperand(ArrayRef<int64_t> operandShape,
Type elemType, int opIdx) const {
AMDWmmaEncodingAttr::getRepForOperand(ArrayRef<int64_t> operandShape, int kDim,
int opIdx) const {
auto mnkDim = getInstrShape();
auto operandTileShape = opIdx == 0
? SmallVector<int64_t>{mnkDim[0], mnkDim[2]}
: SmallVector<int64_t>{mnkDim[2], mnkDim[1]};
SmallVector<int64_t, 2> operandTileShape{opIdx == 0 ? mnkDim[0] : kDim,
opIdx == 0 ? kDim : mnkDim[1]};

assert(operandTileShape.size() == 2);
auto warpsPerCTA = getWarpsPerCTA();
Expand Down
60 changes: 60 additions & 0 deletions lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1448,6 +1448,66 @@ LinearLayout chooseDsReadB64TrLayout(Attribute enc, ArrayRef<int64_t> shape,
return chooseDotDsReadB64TrLayout(dot, shape, elemBitWidth);
}

LinearLayout chooseScaledWmmaScaleLayout(
MLIRContext *ctx, int dotOperandIdx,
const std::vector<std::vector<int32_t>> &dotOperandWarpBasis,
ArrayRef<int64_t> dotOperandShape) {
using basisT = std::vector<std::vector<int32_t>>;
unsigned rank = dotOperandShape.size();
auto order = mlir::triton::gpu::getMatrixOrder(rank, /*rowMajor=*/true);
auto standardOutDims = standardOutDimNames(ctx, rank);
StringAttr kRegister = StringAttr::get(ctx, "register");
StringAttr kLane = StringAttr::get(ctx, "lane");
StringAttr kWarp = StringAttr::get(ctx, "warp");
StringAttr kBlock = StringAttr::get(ctx, "block");
unsigned int scaleKWidth = dotOperandShape[1];
// Init register layout. Will be adjusted later
auto regs =
mlir::triton::identityStandardND(kRegister, {1, scaleKWidth}, order);
LinearLayout lanes = LinearLayout::empty();
// In scaled dot, the shapes of operands(without batch dimension) are,
// respectively:
// - A: [M, K]
// - B: [K, N]
// - aScale: [M, K / 32 or 16]
// - bScale: [N, K / 32 or 16]
//
// To correctly feed A/B and its scale into instruction, we need to
// distribute aScale/bScale among warps in the same way as A/B. But bScale
// is not transposed like B. So we need to transpose the warp layout of
// bScale.
//
// The tricky part is, our desired outputs are [dim0, dim1], but
// at this position, the layouts are transposed to [dim1, dim0]. So
// instead of reverse bScale's layout, we need to reverse aScale's. There
// will be a transpose in the end to correct everything.
basisT warps = dotOperandWarpBasis;
if (dotOperandIdx == 0) {
for (auto &basis : warps) {
std::reverse(basis.begin(), basis.end());
}
}

lanes = LinearLayout({{kLane, {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {0, 0}}},
{kWarp, warps},
{kBlock, {}}},
{standardOutDims[order[0]], standardOutDims[order[1]]});
LinearLayout newLL = regs * lanes;

// Adjust register-level layout to fill the shape, at this level, both
// aScale and bScale should align with A operand.
SmallVector<int, 2> repOrder = {1, 0};
for (auto d : repOrder) {
auto outDim = standardOutDims[d];
auto dimSize = newLL.getOutDimSize(outDim);
newLL *= LinearLayout::identity1D(dotOperandShape[d] / dimSize, kRegister,
outDim);
}
newLL = newLL.transposeOuts(standardOutDims);

return newLL;
}

// Warp-level block scaling (sm_120, m16n8k32)
// Reference: NVIDIA PTX ISA "Warp-level block scaling"
// https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-block-scaling
Expand Down
25 changes: 25 additions & 0 deletions python/test/gluon/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1119,3 +1119,28 @@ def kernel(a_ptr, b_ptr, c_ptr, out_ptr):
out = torch.empty((B, B), dtype=torch.float32, device=device)
kernel[(1, )](a, b, c, out)
torch.testing.assert_close(out, torch.addmm(c, a, b), atol=1e-2, rtol=1e-2)


@gluon.jit
def kernel_auto_layout_constant(threads_per_warp: ttgl.constexpr):
BLOCK: ttgl.constexpr = 16
SIZE: ttgl.constexpr = 10

mask = ttgl.full(
(BLOCK, BLOCK),
True,
ttgl.int1,
ttgl.BlockedLayout(
size_per_thread=[1, 1],
threads_per_warp=[1, threads_per_warp],
warps_per_cta=[1, 4],
order=[1, 0],
),
)

mask &= (ttgl.arange(0, BLOCK, ttgl.AutoLayout()) < SIZE).expand_dims(0)
mask &= (ttgl.arange(0, BLOCK, ttgl.AutoLayout()) < SIZE).expand_dims(1)


def test_auto_layout_constant():
kernel_auto_layout_constant.warmup(THREADS_PER_WARP, grid=(1, ))
53 changes: 53 additions & 0 deletions python/test/gluon/test_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
HIP_TARGET_RDNA4 = GPUTarget("hip", "gfx1200", 32)
HIP_TARGET_CDNA3 = GPUTarget("hip", "gfx942", 64)
HIP_TARGET_CDNA4 = GPUTarget("hip", "gfx950", 64)
HIP_TARGET_GFX1250 = GPUTarget("hip", "gfx1250", 32)

ALL_TARGETS = [AMPERE_TARGET, HOPPER_TARGET, BLACKWELL_TARGET, HIP_TARGET_RDNA4]

Expand Down Expand Up @@ -2358,6 +2359,58 @@ def kernel():
""")


@pytest.mark.parametrize("target", [HIP_TARGET_GFX1250])
def test_amd_wmma_scaled(target):

@gluon.jit
def kernel():
wmma_layout: ttgl.constexpr = ttgl.amd.AMDWMMALayout(version=3, transposed=True, warps_per_cta=[2, 2],
instr_shape=[16, 16, 128])
wmma_layout_packed: ttgl.constexpr = ttgl.amd.AMDWMMALayout(version=3, transposed=True, warps_per_cta=[2, 2],
instr_shape=[16, 16, 64])
a_scale_linear_layout: ttgl.constexpr = ttgl.DistributedLinearLayout(
reg_bases=[[0, 1], [0, 2]], lane_bases=[[1, 0], [2, 0], [4, 0], [8, 0], [0, 0]],
warp_bases=[[0, 0], [16, 0]], block_bases=[], shape=[32, 4])
b_scale_linear_layout: ttgl.constexpr = ttgl.DistributedLinearLayout(
reg_bases=[[0, 1], [0, 2]], lane_bases=[[1, 0], [2, 0], [4, 0], [8, 0], [0, 0]],
warp_bases=[[16, 0], [0, 0]], block_bases=[], shape=[32, 4])

a = ttgl.full([32, 64], 0x11, ttgl.uint8,
ttgl.DotOperandLayout(operand_index=0, parent=wmma_layout_packed, k_width=16))
b = ttgl.full([64, 32], 0x22, ttgl.uint8,
ttgl.DotOperandLayout(operand_index=1, parent=wmma_layout_packed, k_width=16))
a_scale = ttgl.full([32, 4], 0x02, ttgl.uint8, a_scale_linear_layout)
b_scale = ttgl.full([32, 4], 0x01, ttgl.uint8, b_scale_linear_layout)
acc = ttgl.full([32, 32], 0, ttgl.float32, wmma_layout)
ttgl.amd.gfx1250.wmma_scaled(a, a_scale, 'e2m1', b, b_scale, 'e2m1', acc)

module = run_parser(kernel, *make_args(num_warps=4), target=target)
expecttest.assert_expected_inline(
anonymize_ir(module.str_nodebug()), """\
#linear = #ttg.linear<{register = [[0, 1], [0, 2]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [0, 0]], warp = [[0, 0], [16, 0]], block = []}>
#linear1 = #ttg.linear<{register = [[0, 1], [0, 2]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [0, 0]], warp = [[16, 0], [0, 0]], block = []}>
#mma = #ttg.amd_wmma<{version = 3, isTranspose = true, warpsPerCTA = [2, 2], instrShape = [16, 16, 64]}>
#mma1 = #ttg.amd_wmma<{version = 3, isTranspose = true, warpsPerCTA = [2, 2], instrShape = [16, 16, 128]}>
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "...", "ttg.threads-per-warp" = 32 : i32} {
tt.func public @kernel() attributes {noinline = false} {
%c17_i8 = arith.constant 17 : i8
%cst = arith.constant dense<17> : tensor<32x64xi8, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>>
%c34_i8 = arith.constant 34 : i8
%cst_0 = arith.constant dense<34> : tensor<64x32xi8, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 16}>>
%c2_i8 = arith.constant 2 : i8
%cst_1 = arith.constant dense<2> : tensor<32x4xi8, #linear>
%c1_i8 = arith.constant 1 : i8
%cst_2 = arith.constant dense<1> : tensor<32x4xi8, #linear1>
%cst_3 = arith.constant 0.000000e+00 : f32
%cst_4 = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma1>
%cst_5 = arith.constant 0.000000e+00 : f32
%0 = tt.dot_scaled %cst scale %cst_1, %cst_0 scale %cst_2, %cst_4 lhs = e2m1 rhs = e2m1 {fastMath = false} : tensor<32x64xi8, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>>, tensor<32x4xi8, #linear> * tensor<64x32xi8, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 16}>>, tensor<32x4xi8, #linear1> -> tensor<32x32xf32, #mma1>
tt.return
}
}
""")


@gluon.jit
def padded_shared_layout_kernel():
shape: ttgl.constexpr = [64, 64]
Expand Down
5 changes: 5 additions & 0 deletions python/triton/_internal_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,11 @@ def is_hip_gfx12():
return target is not None and target.backend == 'hip' and 'gfx12' in target.arch


def is_hip_gfx1250():
target = get_current_target()
return target is not None and target.backend == 'hip' and 'gfx1250' in target.arch


def is_hip_cdna():
return is_hip_cdna2() or is_hip_cdna3() or is_hip_cdna4()

Expand Down
Loading
Loading