Skip to content
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
18293ce
Short-circuit zero-dimensional splats (#8331)
saagarjha Oct 1, 2025
32e0baf
[AMD][Gluon] Refactor buffer_atomic_rmw API (#8325)
zwu-2025 Oct 1, 2025
bc4ec6a
[Build] Fix crt header download location for CUDA >= 13 (#8336)
saagarjha Oct 1, 2025
7b84e26
[BACKEND] Implement shmem matrix descriptors generically (#8321)
lezcano Oct 1, 2025
532fd37
[AMD] Enhance scaled wmma gluon runtime unit tests (#8339)
PMylon Oct 1, 2025
2d4f16d
[AMD] Use linear layout to infer and emit ds_read_tr (#8235)
nzaghen Oct 1, 2025
872e102
[Build] Fix deprecation warning from TarFile.extractall (#8337)
peterbell10 Oct 1, 2025
75cd616
[WS] assign stage-phase only to partitions that needs it (#8329)
3gx Oct 1, 2025
bc22e6e
[AMD] Add initial support for TDM on gfx1250 (#8333)
borontion Oct 1, 2025
eb7cdba
[KERNELS] Fix and enable batched matmul with split-k. (#8327)
yongjik Oct 1, 2025
210c7b5
[ConSan] ConSan env var should be cache invalidating (#8332)
pawelszczerbuk Oct 1, 2025
48ff763
[Proton] Intra kernel profiling tutorial and examples (#8334)
fywkevin Oct 2, 2025
3e464e9
[Build] Remove Python 3.9 compatibility code for `sysconfig.get_defau…
anmyachev Oct 2, 2025
7fc1d56
Do not use C++20 designed initializers in `TritonNVIDIAGPUToLLVM/Tens…
anmyachev Oct 2, 2025
d9215b9
Bump actions/setup-python from 5 to 6 (#8347)
anmyachev Oct 2, 2025
9273fb3
[AMD][NFC] Move LowerLoops into TritonAMDGPUPipeline (#8341)
knwng Oct 2, 2025
aafec41
[triton_kernels] fused matmul_ogs + comms (#8340)
wuweil-openai Oct 2, 2025
6e4647e
[BACKEND] Lower `tcgen05.cp` via the generic matrix descriptor loweri…
lezcano Oct 2, 2025
1d74879
[ConSan] Make sure kernel is recompiled when consan state changes (#8…
pawelszczerbuk Oct 2, 2025
43dbdd1
[PROTON] Add a flag to disable proton in order to use other profilers…
Jokeren Oct 3, 2025
ec800b5
[PROTON] Simplify proton runtime instrumentation using Triton knobs (…
Jokeren Oct 3, 2025
537dfc8
Fold layout conversion for TMEM Store to fix perf drop for flex attn …
pchen7e2 Oct 3, 2025
be6a688
[Tests] Remove subprocess usage from `test_triton_debuginfo_on` (#8350)
anmyachev Oct 3, 2025
7e042c6
[Build] Remove unused `find_library(TERMINFO_LIBRARY tinfo)` (#8362)
anmyachev Oct 3, 2025
c5d1e01
[TESTS] Remove fresh_knobs from matmul.py::test_op (#8364)
lezcano Oct 3, 2025
73b5dc1
[AMD][GLUON] Add layout in make tensor descriptor (#8355)
borontion Oct 3, 2025
4c388af
[PROTON] Fix TestScopeIdAllocation.cpp random build failure (#8363)
lijinpei Oct 3, 2025
5201154
[AMD] Use lowerLdSt for local_load to ds_read_tr path (#8344)
nzaghen Oct 3, 2025
d5156d7
[KERNELS] Change routing code to avoid storage(). (#8357)
yongjik Oct 3, 2025
88b8a5c
[AMD] disable pointer-canonicalization for large-tensor (#8359)
yangshuxin Oct 3, 2025
1888f81
[KERNELS] remove unwanted device_print =_= (#8367)
yongjik Oct 3, 2025
59aeb6b
[Gluon] Require warp_specialize default_args and worker_args be tuple…
peterbell10 Oct 3, 2025
5d84a91
[mxfp] fix x_scale OOB (#8369)
jongsoo-openai Oct 3, 2025
3910f27
[mxfp] handle values close to max correctly w/o overflow (#8356)
jongsoo-openai Oct 4, 2025
0f91265
Get MLIRContext from `newOp`, not the deleted `load` (#8373)
alexbaden Oct 5, 2025
60605d8
[mxfp] remove col-major assert for mx weight (#8249)
jongsoo-openai Oct 5, 2025
483f9ea
[AMD] Disable flaky atomic cas test on CDNA2 (#8376)
antiagainst Oct 6, 2025
6edcd49
[AMD] Limit vec size for ds_read_tr + padded layouts by min interval …
AlexAUT Oct 6, 2025
d5f3f23
[AMD] Refactor FP conversion mode setting (#8351)
ravil-mobile Oct 6, 2025
0173f75
[AMD] Add Tests for MXFP GEMM Gluon Kernel for GFX1250 (#8371)
knwng Oct 6, 2025
8868aca
Merge commit '210c7b5bb29c01781c3e3053fe6bf28eb178347f'
whitneywhtsang Oct 7, 2025
5c020ef
[WIN] Fix error C7555: use of designated initializers requires at lea…
whitneywhtsang Oct 7, 2025
633d32d
Merge commit '43dbdd1685625ce71daea1caf8a4d90fdea6457f'
whitneywhtsang Oct 7, 2025
609e327
Merge commit '5d84a9122b519251d1453fc7e7f31e2e304dc1d6'
whitneywhtsang Oct 7, 2025
e52429a
Merge commit '0173f7524d8cfc9a5b4b52dec0010eaedef14526'
whitneywhtsang Oct 7, 2025
f21b341
Revert "[mxfp] remove col-major assert for mx weight (#8249)"
whitneywhtsang Oct 8, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions include/triton/Tools/LayoutUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,41 @@ std::pair<int, ColumnAction>
largestVectorisation(MLIRContext *ctx, const LinearLayout &cvt, int bitwidth,
std::optional<int> maybeMaxVecElems = std::nullopt);

// Close cousin of doing zerosLike(tile) * divideLeft(cvt, tile)
// This one is a tad more general in the sense that it allows to divide
// cvt:
// - register=1 -> (0, 1)
// register=2 -> (8, 0)
// register=4 -> (0, 8)
// register=8 -> (0, 16)
// register=16 -> (0, 32)
// register=32 -> (0, 64)
// register=64 -> (16, 0)
// - lane=1 -> (0, 2)
// lane=2 -> (0, 4)
// lane=4 -> (1, 0)
// lane=8 -> (2, 0)
// lane=16 -> (4, 0)
// - warp=1 -> (32, 0)
// warp=2 -> (64, 0)
// - block is a size 1 dimension
// where out dims are: [row (size 128), col (size 128)]
// tile:
// - register=1 -> (0, 1)
// register=2 -> (8, 0)
// - lane=1 -> (0, 2)
// lane=2 -> (0, 4)
// lane=4 -> (1, 0)
// lane=8 -> (2, 0)
// lane=16 -> (4, 0)
// - warp=1 -> (32, 0)
// warp=2 -> (64, 0)
// where out dims are: [row (size 128), col (size 8)]
// which would not be possible to lower via the divideLeft approach as we
// cannot divide by the tile given the `register=64 -> (16, 0)` basis.
std::optional<LinearLayout> getReps(const LinearLayout &cvt,
const LinearLayout &tile);

} // namespace mlir::triton

#endif // TRITON_TOOLS_LAYOUTUTILS_H
1 change: 1 addition & 0 deletions include/triton/Tools/Sys/GetEnv.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ inline const std::set<std::string> CACHE_INVALIDATING_ENV_VARS = {
"ALLOW_LHS_TMEM_LAYOUT_CONVERSION",
"TRITON_F32_DEFAULT",
"TRITON_PREFER_TMEM_16x256_LAYOUT",
"TRITON_ENABLE_EXPERIMENTAL_CONSAN",
"TRITON_INTEL_AGGRESSIVE_DPAS_REUSE",
"TRITON_INTEL_ENABLE_BLOCK_IO_ALL_LAYOUTS",
"TRITON_INTEL_ENABLE_DPAS_FOR_WARP_SIZE_32",
Expand Down
260 changes: 112 additions & 148 deletions lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -477,23 +477,87 @@ AMDMfmaEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
return combineCtaCgaWithShape(tileLayout, getCTALayout(), shape);
}

LinearLayout chooseLLDsReadB64TrLayout(Attribute enc, ArrayRef<int64_t> shape,
int32_t elemBitWidth) {
using BaseTy = std::vector<std::vector<int32_t>>;
// This function will derive the layout for the ds_read_b64_tr instruction
// based on the input layout (LL/DotLayout/...)
// The ds_read_b64_tr works on 64 bits per lane and in groups of 16 lanes.

// Using M-continuous 16-bit input tensor A as an example. Each lane will
// load 4 consecutive elements (64-bit in total) along M. There are 4
// consecutive lanes in total along M. Then the loaded elements are exchanged
// withthin the MxK=16x4 "base unit".
// K0 K1 K2 K3
// +---+---+---+---+
// M0 | | | | | M0, K[0-3]: T0
// M1 | T | T | T | T | M1, K[0-3]: T1
// M2 | 0 | 4 | 8 |12 | M2, K[0-3]: T2
// M3 | | | | | M3, K[0-3]: T3
// +---+---+---+---+
// M4 | | | | | M4, K[0-3]: T4
// M5 | T | T | T | T | M5, K[0-3]: T5
// M6 | 1 | 5 | 9 |13 | M6, K[0-3]: T6
// M7 | | | | | M7, K[0-3]: T7
// +---+---+---+---+ ==>
// M8 | | | | | M8, K[0-3]: T8
// M9 | T | T | T | T | M9, K[0-3]: T9
// M10 | 2 | 6 |10 |14 | M10, K[0-3]: T10
// M11 | | | | | M11, K[0-3]: T11
// +---+---+---+---+
// M12 | | | | | M12, K[0-3]: T12
// M13 | T | T | T | T | M13, K[0-3]: T13
// M14 | 3 | 7 |11 |15 | M14, K[0-3]: T14
// M15 | | | | | M15, K[0-3]: T15
// +---+---+---+---+

// Given the layout represented by `enc` and shape, we can derive the layout
// that ds_read_b64_tr need to have in order to perform a vectorized load of
// the elements. This can be done by rearranging the inner 4x16 element base
// unit in the LL by rearranging the first numReg register bases and the
// first numLane lane bases.
auto rotatePrefixes = [](BaseTy &regBase, std::size_t numReg,
BaseTy &laneBase, std::size_t numLane) {
// Concatenate prefixes of the two vectors. Lane first and then regs.
// C D E F | A B
// Then copy over numReg to the regBase and numLane to laneBase
// C D | E F A B
BaseTy baseUnit(laneBase.begin(), laneBase.begin() + numLane);
llvm::append_range(
baseUnit, llvm::make_range(regBase.begin(), regBase.begin() + numReg));

std::copy(baseUnit.begin(), baseUnit.begin() + numReg, regBase.begin());
std::copy(baseUnit.begin() + numReg, baseUnit.end(), laneBase.begin());
};

auto ctx = enc.getContext();
assert(elemBitWidth == 8 || elemBitWidth == 16);
// Get how many reg bases the ds_read_tr tile spans
unsigned numRegBases = llvm::Log2_32(64 / elemBitWidth);
// 4 lane bases describe 16 lanes.
unsigned numLaneBases = 4;

auto ldsTransLayout = triton::gpu::toLinearLayout(shape, enc);
auto bases = ldsTransLayout.getBases();
auto kRegister = S("register");
auto kLane = S("lane");
rotatePrefixes(bases[kRegister], numRegBases, bases[kLane], numLaneBases);

return LinearLayout(bases, ldsTransLayout.getOutDims(), false);
}

LinearLayout chooseDotDsReadB64TrLayout(DotOperandEncodingAttr dotMfmaLayout,
ArrayRef<int64_t> shape,
int32_t elemBitWidth) {
auto mfmaLayout = llvm::cast<AMDMfmaEncodingAttr>(dotMfmaLayout.getParent());
auto mDim = mfmaLayout.getInstrShape()[0];
assert(mDim == 16 || mDim == 32);

bool isFP4 = false;
if (elemBitWidth == 4) {
// When doing ds_read_tr4 we actually write the LL as if it were on i8
// elements this is becasue LL needs to be described for the i8 tensor
// elements.
elemBitWidth = 8;
isFP4 = true;
}

assert(elemBitWidth == 16 || elemBitWidth == 8);
assert(elemBitWidth == 4);
// When doing ds_read_tr4 we actually write the LL as if it were on i8
// elements this is becasue LL needs to be described for the i8 tensor
// elements.
elemBitWidth = 8;

auto rank = shape.size();
bool hasBatchDim = rank == 3;
Expand All @@ -520,143 +584,39 @@ LinearLayout chooseDotDsReadB64TrLayout(DotOperandEncodingAttr dotMfmaLayout,

std::vector<std::vector<int32_t>> registerBase;
std::vector<std::vector<int32_t>> laneBase;
auto populateFP4LL = [&registerBase, &laneBase](int kSize, int mDim) {
const bool isMfma32 = (mDim == 32);
// ds_read_b64_tr4 operates on FP4 values swapping the packing of them. Look
// at i8 values for the ownership of register/lane since it's the data type
// of the tensor. Register dimension: what i8 in the tile are held by thread
// 0? Lane dimension: what i8 in the tile are held in register 0 of each
// thread?
registerBase.push_back({1, 0});
registerBase.push_back({2, 0});
registerBase.push_back({4, 0});
registerBase.push_back({0, 16});

// If more than one tile needs to be loaded, populate registerBase
// dimension for the other tiles
const int kTileSize = isMfma32 ? 64 : 128;
for (int reg = kTileSize; reg < kSize; reg *= 2) {
registerBase.push_back({0, reg});
}

// When mDim == 16 we have 16x128 mfma, otherwise it's 16x64
// The LL for the two is different
laneBase.push_back({0, 1});
laneBase.push_back({0, 2});
laneBase.push_back({0, 4});
laneBase.push_back({0, 8});
if (mDim == 16) {
laneBase.push_back({0, 32});
laneBase.push_back({0, 64});
} else {
assert(mDim == 32);
laneBase.push_back({8, 0});
laneBase.push_back({0, 32});
}
};
auto populateLL = [&registerBase, &laneBase](int elemBitWidth, int kSize,
int kWidthDot, int mDim) {
// Number of bits loaded by an LDS read. ds_read_tr primarily supports
// 64-bit loads for most element sizes (16b, 8b, 4b).
const int32_t ldsReadWidth = 64;
int32_t kWidthTransRead = ldsReadWidth / elemBitWidth;
const int elemByteWidth = elemBitWidth / 8;
const bool isMfma32 = (mDim == 32);

// For ds_read_b64_tr_* instructions, each thread accesses 64 bits (8 bytes)
// of data. The smallest unit for transposition is a
// [non-K, K] = {16, kWidthTransRead} sub-tile of elements,
// where each thread reads kWidthTransRead elements along the non-K
// dimension. Due to the transposition mechanism, each thread ends up with
// kWidthTransRead elements along the K dimension.
//
// The MFMA selection logic prioritizes double-rate MFMA instructions
// whenever possible:
//
// - For MFMA operations where M = N = 16, when blockK > k, mfma16x16x2*k
// is selected; otherwise (blockK ≤ k), mfma16x16xk remains the choice.
//
// - For MFMA operations where M = N = 32, when blockK > k, mfma32x32x2*k is
// selected; otherwise (blockK ≤ k), mfma32x32xk is used.
//
// NOTE: For fp8 and fp4, "double-rate" results in 4*k since scaled MFMA
// instructions are used.
//
// In "double-rate" MFMA instructions, each thread holds 2*kWidthTransRead
// elements along the K dimension:
// - The first kWidthTransRead elements belong to the first sub-tile.
// - The next kWidthTransRead elements belong to the second sub-tile.
//
// These elements are then grouped into larger tiles, each consisting of
// 8 {16, kWidthTransRead} sub-tiles. These tiles correspond to the data
// for one MFMA instruction. The shape of these tiles depends on the MFMA
// instruction used.
//
// For single-rate MFMA instructions, each thread holds kWidthTransRead
// elements along the K dimension. This means that the larger tile
// (corresponding to one MFMA instruction) consists of 4 {16,
// kWidthTransRead} sub-tiles.

// Populate register base for first subtile
for (int i = 1; i < kWidthTransRead; i *= 2) {
registerBase.push_back({i, 0});
}

const int threadsPerSubtileNonK = 16 / kWidthTransRead;
const int threadsPerSubtileK = kWidthTransRead;

// Populate lane base for first subtile
for (int i = 1; i < threadsPerSubtileNonK; i *= 2) {
laneBase.push_back({i * kWidthTransRead, 0});
}
for (int i = 1; i < threadsPerSubtileK; i *= 2) {
laneBase.push_back({0, i});
}

// Function to extend register base for multiple tiles K dim.
auto extendRegisterBaseForKDim = [&](int kTileSize,
int numSubtilesPerTile) {
const int regsPerTile = kWidthTransRead * numSubtilesPerTile;
int totalRegs = (kSize / kTileSize) * regsPerTile;

for (int reg = regsPerTile; reg < totalRegs; reg *= 2) {
registerBase.push_back({0, (reg / regsPerTile) * kTileSize});
}
};

// kDoubleTileSize is the k dimension of a tile when double rated
// mfma instructions are used.
const int kDoubleTileSize =
isMfma32 ? 32 / elemByteWidth : 64 / elemByteWidth;
// kTileSize is the actually k dimention of a tile, which is
// determined by kWidthDot.
const int kTileSize = kWidthDot * 64 / mDim;
// We use kDoubleTileSize as a reference to check whether the given
// kWidthDot leads to double or single sub-tiles in each tile.
const int numSubtilesPerTile = (kTileSize == kDoubleTileSize) ? 2 : 1;

// Extend register base for large K sizes.
if (numSubtilesPerTile == 2)
registerBase.push_back({0, threadsPerSubtileK}); // Second subtile

extendRegisterBaseForKDim(kTileSize, numSubtilesPerTile);

// Extend lane base based on MFMA size.
std::vector<std::vector<int32_t>> laneBaseExt;

if (isMfma32) {
laneBaseExt = {{16, 0}, {0, numSubtilesPerTile * threadsPerSubtileK}};
} else {
laneBaseExt = {{0, numSubtilesPerTile * threadsPerSubtileK},
{0, 2 * numSubtilesPerTile * threadsPerSubtileK}};
}
laneBase.insert(laneBase.end(), laneBaseExt.begin(), laneBaseExt.end());
};

if (isFP4)
populateFP4LL(kSize, mDim);
else
populateLL(elemBitWidth, kSize, kWidthDot, mDim);
const bool isMfma32 = (mDim == 32);
// ds_read_b64_tr4 operates on FP4 values swapping the packing of them. Look
// at i8 values for the ownership of register/lane since it's the data type
// of the tensor. Register dimension: what i8 in the tile are held by thread
// 0? Lane dimension: what i8 in the tile are held in register 0 of each
// thread?
registerBase.push_back({1, 0});
registerBase.push_back({2, 0});
registerBase.push_back({4, 0});
registerBase.push_back({0, 16});

// If more than one tile needs to be loaded, populate registerBase
// dimension for the other tiles
const int kTileSize = isMfma32 ? 64 : 128;
for (int reg = kTileSize; reg < kSize; reg *= 2) {
registerBase.push_back({0, reg});
}

// When mDim == 16 we have 16x128 mfma, otherwise it's 16x64
// The LL for the two is different
laneBase.push_back({0, 1});
laneBase.push_back({0, 2});
laneBase.push_back({0, 4});
laneBase.push_back({0, 8});
if (mDim == 16) {
laneBase.push_back({0, 32});
laneBase.push_back({0, 64});
} else {
assert(mDim == 32);
laneBase.push_back({8, 0});
laneBase.push_back({0, 32});
}

// Base vectors above are defined in a fixed order [non-k-dim, k-dim].
// To assign them to actual matrix dimensions we associate with register
Expand Down Expand Up @@ -1444,8 +1404,12 @@ LinearLayout chooseShemLayoutForRegToRegConversion(

LinearLayout chooseDsReadB64TrLayout(Attribute enc, ArrayRef<int64_t> shape,
int32_t elemBitWidth) {
auto dot = cast<DotOperandEncodingAttr>(enc);
return chooseDotDsReadB64TrLayout(dot, shape, elemBitWidth);
if (elemBitWidth == 4) {
auto dot = cast<DotOperandEncodingAttr>(enc);
return chooseDotDsReadB64TrLayout(dot, shape, elemBitWidth);
} else {
return chooseLLDsReadB64TrLayout(enc, shape, elemBitWidth);
}
}

LinearLayout chooseScaledWmmaScaleLayout(
Expand Down
Loading
Loading