diff --git a/.github/workflows/integration-tests-amd.yml b/.github/workflows/integration-tests-amd.yml index 76b4964a59..bc97c536b8 100644 --- a/.github/workflows/integration-tests-amd.yml +++ b/.github/workflows/integration-tests-amd.yml @@ -13,6 +13,7 @@ jobs: integration-tests-amd: runs-on: ${{ matrix.runner }} timeout-minutes: 45 + continue-on-error: ${{ matrix.runner[1] == 'gfx90a' }} strategy: matrix: runner: ${{ fromJson(inputs.matrix) }} diff --git a/CMakeLists.txt b/CMakeLists.txt index 564feeee6a..17e76db64f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -89,10 +89,6 @@ if(NOT CMAKE_BUILD_TYPE) set(CMAKE_BUILD_TYPE "Release") endif() -if(NOT WIN32) - find_library(TERMINFO_LIBRARY tinfo) -endif() - if(TRITON_BUILD_UT) # This is an aggregate target for all unit tests. add_custom_target(TritonUnitTests) diff --git a/include/triton/Conversion/TritonGPUToLLVM/Utility.h b/include/triton/Conversion/TritonGPUToLLVM/Utility.h index 25469c2d33..08bb3ff093 100644 --- a/include/triton/Conversion/TritonGPUToLLVM/Utility.h +++ b/include/triton/Conversion/TritonGPUToLLVM/Utility.h @@ -528,32 +528,6 @@ Value emitPadding(Location loc, RewriterBase &rewriter, triton::gpu::PaddedSharedEncodingAttr layout, unsigned bitwidth, Value smemOffset, bool offsetInBytes); -// Emits IR to load data from shared memory into registers, or to store data -// from registers into shared memory. -// -// You supply perVectorCallback, which is called once per group of register -// elements to transfer. You can use this callback to emit IR to load or store -// data from or to shared memory. -// -// elemLlvmTy should be dstTy's element type converted to an LLVM-dialect type. -// -// If maxVecElems is provided, we won't vectorize more than this many elements. -// -// Returns true on success. -[[nodiscard]] bool emitTransferBetweenRegistersAndShared( - RankedTensorType registerTy, triton::gpu::MemDescType sharedTy, - Type elemLlvmTy, std::optional maxVecElems, - const SharedMemoryObject &smemObj, Location loc, RewriterBase &rewriter, - const TargetInfoBase &target, - std::function perVectorCallback); - -[[nodiscard]] bool emitTransferBetweenRegistersAndShared( - LinearLayout ®Layout, triton::gpu::MemDescType sharedTy, Type elemLlvmTy, - std::optional maxVecElems, const SharedMemoryObject &smemObj, - Location loc, RewriterBase &rewriter, const TargetInfoBase &target, - Value laneId, Value warpId, - std::function perVectorCallback); - // Close cousin of lowerLdStMatrix in MemoryOpToLLVM.cpp // We might want to merge them at some point, but having to support // ldmatrix.trans makes the code in lowerLdStMatrix a bit specific diff --git a/include/triton/Dialect/TritonGPU/Transforms/Schedule.h b/include/triton/Dialect/TritonGPU/Transforms/Schedule.h index feab0160bc..258762bdde 100644 --- a/include/triton/Dialect/TritonGPU/Transforms/Schedule.h +++ b/include/triton/Dialect/TritonGPU/Transforms/Schedule.h @@ -54,6 +54,7 @@ class CoarseSchedule { iterator end() { return orderClusters.end(); } const_iterator end() const { return orderClusters.end(); } size_t size() const { return orderClusters.size(); } + void clear() { orderClusters.clear(); } iterator newAtBack() { orderClusters.push_back(orderClusters.size()); return std::prev(orderClusters.end()); @@ -157,7 +158,10 @@ class CoarseSchedule { // Set based on CoarseSchedule. void serialize(scf::ForOp &forOp) const; // Create a CoarseSchedule based on forOp's . - LogicalResult deSerialize(scf::ForOp &forOp); + // If normalizeClusterId is true, clusters [minClusterId, maxClusterId] will + // be remapped to [0, maxClusterId - minClusterId]. + // If false, it won't remap and clusters [0, maxClusterId] will be created. + LogicalResult deSerialize(scf::ForOp &forOp, bool normalizeClusterId = true); static ClusterHash hashCluster(Cluster cluster) { return reinterpret_cast(&*cluster); diff --git a/include/triton/Tools/LayoutUtils.h b/include/triton/Tools/LayoutUtils.h index 1d1ca52864..15fc661f90 100644 --- a/include/triton/Tools/LayoutUtils.h +++ b/include/triton/Tools/LayoutUtils.h @@ -147,6 +147,41 @@ std::pair largestVectorisation(MLIRContext *ctx, const LinearLayout &cvt, int bitwidth, std::optional maybeMaxVecElems = std::nullopt); +// Close cousin of doing zerosLike(tile) * divideLeft(cvt, tile) +// This one is a tad more general in the sense that it allows to divide +// cvt: +// - register=1 -> (0, 1) +// register=2 -> (8, 0) +// register=4 -> (0, 8) +// register=8 -> (0, 16) +// register=16 -> (0, 32) +// register=32 -> (0, 64) +// register=64 -> (16, 0) +// - lane=1 -> (0, 2) +// lane=2 -> (0, 4) +// lane=4 -> (1, 0) +// lane=8 -> (2, 0) +// lane=16 -> (4, 0) +// - warp=1 -> (32, 0) +// warp=2 -> (64, 0) +// - block is a size 1 dimension +// where out dims are: [row (size 128), col (size 128)] +// tile: +// - register=1 -> (0, 1) +// register=2 -> (8, 0) +// - lane=1 -> (0, 2) +// lane=2 -> (0, 4) +// lane=4 -> (1, 0) +// lane=8 -> (2, 0) +// lane=16 -> (4, 0) +// - warp=1 -> (32, 0) +// warp=2 -> (64, 0) +// where out dims are: [row (size 128), col (size 8)] +// which would not be possible to lower via the divideLeft approach as we +// cannot divide by the tile given the `register=64 -> (16, 0)` basis. +std::optional getReps(const LinearLayout &cvt, + const LinearLayout &tile); + } // namespace mlir::triton #endif // TRITON_TOOLS_LAYOUTUTILS_H diff --git a/include/triton/Tools/Sys/GetEnv.hpp b/include/triton/Tools/Sys/GetEnv.hpp index 449c9334ad..12e9312c65 100644 --- a/include/triton/Tools/Sys/GetEnv.hpp +++ b/include/triton/Tools/Sys/GetEnv.hpp @@ -43,6 +43,7 @@ inline const std::set CACHE_INVALIDATING_ENV_VARS = { "ALLOW_LHS_TMEM_LAYOUT_CONVERSION", "TRITON_F32_DEFAULT", "TRITON_PREFER_TMEM_16x256_LAYOUT", + "TRITON_ENABLE_EXPERIMENTAL_CONSAN", "TRITON_INTEL_AGGRESSIVE_DPAS_REUSE", "TRITON_INTEL_ENABLE_BLOCK_IO_ALL_LAYOUTS", "TRITON_INTEL_ENABLE_DPAS_FOR_WARP_SIZE_32", diff --git a/lib/Conversion/TritonGPUToLLVM/Utility.cpp b/lib/Conversion/TritonGPUToLLVM/Utility.cpp index b5f003ad49..4196b85901 100644 --- a/lib/Conversion/TritonGPUToLLVM/Utility.cpp +++ b/lib/Conversion/TritonGPUToLLVM/Utility.cpp @@ -706,110 +706,6 @@ lowerLocalLdSt(Location loc, MLIRContext *ctx, maybeMaxVecElems, localLoadOp); } -bool emitTransferBetweenRegistersAndShared( - LinearLayout ®Layout, triton::gpu::MemDescType sharedTy, Type elemLlvmTy, - std::optional maxVecElems, const SharedMemoryObject &smemObj, - Location loc, RewriterBase &rewriter, const TargetInfoBase &target, - Value laneId, Value warpId, - std::function perVectorCallback) { - MLIRContext *ctx = rewriter.getContext(); - auto b = TritonLLVMOpBuilder(loc, rewriter); - - StringAttr kBlock = str_attr("block"); - StringAttr kRegister = str_attr("register"); - StringAttr kLane = str_attr("lane"); - StringAttr kWarp = str_attr("warp"); - StringAttr kOffset = str_attr("offset"); - - auto shape = sharedTy.getShape(); - auto paddedEnc = - dyn_cast(sharedTy.getEncoding()); - LinearLayout regToSharedLayout = LinearLayout::empty(); - if (paddedEnc) { - const auto &sharedLL = paddedEnc.getLinearComponent(); - regToSharedLayout = regLayout.invertAndCompose(sharedLL); - } else { - auto sharedLL = triton::gpu::toLinearLayout(sharedTy); - regToSharedLayout = regLayout.invertAndCompose(sharedLL); - } - - // TODO(jlebar): We don't currently support loading from shared memory in a - // different CTA. We'd need to emit `mapa.shared::cluster` instructions. - if (regToSharedLayout.hasInDim(kBlock) && - regToSharedLayout.hasOutDim(kBlock) && - !regToSharedLayout.isTrivialOver({kBlock})) { - return false; - } - - // Determine how many consecutive registers map to consecutive shmem elements - // in out-dimension offsetN. This is our load instruction's vector width. - // - // It's OK if the vector width we choose here is wider than the hardware - // supports; LLVM will legalize it. - int vecElems = - std::min({regToSharedLayout.getNumConsecutiveInOut(), - maxVecElems.value_or(std::numeric_limits::max())}); - if (paddedEnc) { - vecElems = std::min(vecElems, int(paddedEnc.getMinInterval())); - } - - auto withCTAOffset = triton::gpu::getNumCTAs(sharedTy.getEncoding()) > 1; - Value blockId = - withCTAOffset ? target.getClusterCTAId(rewriter, loc) : b.i32_val(0); - - int numElems = regToSharedLayout.getInDimSize(kRegister); - auto vecTy = vec_ty(elemLlvmTy, vecElems); - SmallVector regIds; - for (int i = 0; i < numElems / vecElems; i++) { - regIds.push_back(i * vecElems); - } - - auto smemBase = smemObj.getBase(); - - auto indicesVec = applyLinearLayoutVec(loc, rewriter, regToSharedLayout, - {{kRegister, b.i32_val(0)}, - {kLane, laneId}, - {kWarp, warpId}, - {kBlock, blockId}}, - regIds); - - // Compute affine offset given by memdesc_subslice - auto offset = smemObj.getShmemOffset(loc, rewriter, sharedTy); - SmallVector vecAddrVec; - for (auto &indices : indicesVec) { - Value smemOffset = indices[0].second; - smemOffset = b.xor_(smemOffset, offset); - if (paddedEnc) { - // Apply the offset needed for padding. - auto bitwidth = elemLlvmTy.getIntOrFloatBitWidth(); - Value padOffset = emitPadding(loc, rewriter, paddedEnc, bitwidth, - smemOffset, /*offsetInBytes=*/false); - smemOffset = b.add(smemOffset, padOffset); - } - auto vecAddr = b.gep(smemBase.getType(), elemLlvmTy, smemBase, smemOffset, - LLVM::GEPNoWrapFlags::inbounds); - vecAddrVec.push_back(vecAddr); - } - - for (Value &vecAddr : vecAddrVec) { - perVectorCallback(vecTy, vecAddr); - } - return true; -} - -bool emitTransferBetweenRegistersAndShared( - RankedTensorType registerTy, triton::gpu::MemDescType sharedTy, - Type elemLlvmTy, std::optional maxVecElems, - const SharedMemoryObject &smemObj, Location loc, RewriterBase &rewriter, - const TargetInfoBase &target, - std::function perVectorCallback) { - auto regLayout = triton::gpu::toLinearLayout(registerTy); - auto [laneId, warpId] = getLaneAndWarpId(rewriter, loc); - return emitTransferBetweenRegistersAndShared( - regLayout, sharedTy, elemLlvmTy, maxVecElems, smemObj, loc, rewriter, - target, laneId, warpId, perVectorCallback); -} - SmallVector unpackLLElements(Location loc, Value llvmStruct, RewriterBase &rewriter) { assert(bool(llvmStruct) && "can not unpack null values"); diff --git a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp index 1e83c0d307..ee1cfcdbe4 100644 --- a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp +++ b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp @@ -477,6 +477,75 @@ AMDMfmaEncodingAttr::toLinearLayout(ArrayRef shape) const { return combineCtaCgaWithShape(tileLayout, getCTALayout(), shape); } +LinearLayout chooseLLDsReadB64TrLayout(Attribute enc, ArrayRef shape, + int32_t elemBitWidth) { + using BaseTy = std::vector>; + // This function will derive the layout for the ds_read_b64_tr instruction + // based on the input layout (LL/DotLayout/...) + // The ds_read_b64_tr works on 64 bits per lane and in groups of 16 lanes. + + // Using M-continuous 16-bit input tensor A as an example. Each lane will + // load 4 consecutive elements (64-bit in total) along M. There are 4 + // consecutive lanes in total along M. Then the loaded elements are exchanged + // withthin the MxK=16x4 "base unit". + // K0 K1 K2 K3 + // +---+---+---+---+ + // M0 | | | | | M0, K[0-3]: T0 + // M1 | T | T | T | T | M1, K[0-3]: T1 + // M2 | 0 | 4 | 8 |12 | M2, K[0-3]: T2 + // M3 | | | | | M3, K[0-3]: T3 + // +---+---+---+---+ + // M4 | | | | | M4, K[0-3]: T4 + // M5 | T | T | T | T | M5, K[0-3]: T5 + // M6 | 1 | 5 | 9 |13 | M6, K[0-3]: T6 + // M7 | | | | | M7, K[0-3]: T7 + // +---+---+---+---+ ==> + // M8 | | | | | M8, K[0-3]: T8 + // M9 | T | T | T | T | M9, K[0-3]: T9 + // M10 | 2 | 6 |10 |14 | M10, K[0-3]: T10 + // M11 | | | | | M11, K[0-3]: T11 + // +---+---+---+---+ + // M12 | | | | | M12, K[0-3]: T12 + // M13 | T | T | T | T | M13, K[0-3]: T13 + // M14 | 3 | 7 |11 |15 | M14, K[0-3]: T14 + // M15 | | | | | M15, K[0-3]: T15 + // +---+---+---+---+ + + // Given the layout represented by `enc` and shape, we can derive the layout + // that ds_read_b64_tr need to have in order to perform a vectorized load of + // the elements. This can be done by rearranging the inner 4x16 element base + // unit in the LL by rearranging the first numReg register bases and the + // first numLane lane bases. + auto rotatePrefixes = [](BaseTy ®Base, std::size_t numReg, + BaseTy &laneBase, std::size_t numLane) { + // Concatenate prefixes of the two vectors. Lane first and then regs. + // C D E F | A B + // Then copy over numReg to the regBase and numLane to laneBase + // C D | E F A B + BaseTy baseUnit(laneBase.begin(), laneBase.begin() + numLane); + llvm::append_range( + baseUnit, llvm::make_range(regBase.begin(), regBase.begin() + numReg)); + + std::copy(baseUnit.begin(), baseUnit.begin() + numReg, regBase.begin()); + std::copy(baseUnit.begin() + numReg, baseUnit.end(), laneBase.begin()); + }; + + auto ctx = enc.getContext(); + assert(elemBitWidth == 8 || elemBitWidth == 16); + // Get how many reg bases the ds_read_tr tile spans + unsigned numRegBases = llvm::Log2_32(64 / elemBitWidth); + // 4 lane bases describe 16 lanes. + unsigned numLaneBases = 4; + + auto ldsTransLayout = triton::gpu::toLinearLayout(shape, enc); + auto bases = ldsTransLayout.getBases(); + auto kRegister = S("register"); + auto kLane = S("lane"); + rotatePrefixes(bases[kRegister], numRegBases, bases[kLane], numLaneBases); + + return LinearLayout(bases, ldsTransLayout.getOutDims(), false); +} + LinearLayout chooseDotDsReadB64TrLayout(DotOperandEncodingAttr dotMfmaLayout, ArrayRef shape, int32_t elemBitWidth) { @@ -484,16 +553,11 @@ LinearLayout chooseDotDsReadB64TrLayout(DotOperandEncodingAttr dotMfmaLayout, auto mDim = mfmaLayout.getInstrShape()[0]; assert(mDim == 16 || mDim == 32); - bool isFP4 = false; - if (elemBitWidth == 4) { - // When doing ds_read_tr4 we actually write the LL as if it were on i8 - // elements this is becasue LL needs to be described for the i8 tensor - // elements. - elemBitWidth = 8; - isFP4 = true; - } - - assert(elemBitWidth == 16 || elemBitWidth == 8); + assert(elemBitWidth == 4); + // When doing ds_read_tr4 we actually write the LL as if it were on i8 + // elements this is becasue LL needs to be described for the i8 tensor + // elements. + elemBitWidth = 8; auto rank = shape.size(); bool hasBatchDim = rank == 3; @@ -520,143 +584,39 @@ LinearLayout chooseDotDsReadB64TrLayout(DotOperandEncodingAttr dotMfmaLayout, std::vector> registerBase; std::vector> laneBase; - auto populateFP4LL = [®isterBase, &laneBase](int kSize, int mDim) { - const bool isMfma32 = (mDim == 32); - // ds_read_b64_tr4 operates on FP4 values swapping the packing of them. Look - // at i8 values for the ownership of register/lane since it's the data type - // of the tensor. Register dimension: what i8 in the tile are held by thread - // 0? Lane dimension: what i8 in the tile are held in register 0 of each - // thread? - registerBase.push_back({1, 0}); - registerBase.push_back({2, 0}); - registerBase.push_back({4, 0}); - registerBase.push_back({0, 16}); - - // If more than one tile needs to be loaded, populate registerBase - // dimension for the other tiles - const int kTileSize = isMfma32 ? 64 : 128; - for (int reg = kTileSize; reg < kSize; reg *= 2) { - registerBase.push_back({0, reg}); - } - - // When mDim == 16 we have 16x128 mfma, otherwise it's 16x64 - // The LL for the two is different - laneBase.push_back({0, 1}); - laneBase.push_back({0, 2}); - laneBase.push_back({0, 4}); - laneBase.push_back({0, 8}); - if (mDim == 16) { - laneBase.push_back({0, 32}); - laneBase.push_back({0, 64}); - } else { - assert(mDim == 32); - laneBase.push_back({8, 0}); - laneBase.push_back({0, 32}); - } - }; - auto populateLL = [®isterBase, &laneBase](int elemBitWidth, int kSize, - int kWidthDot, int mDim) { - // Number of bits loaded by an LDS read. ds_read_tr primarily supports - // 64-bit loads for most element sizes (16b, 8b, 4b). - const int32_t ldsReadWidth = 64; - int32_t kWidthTransRead = ldsReadWidth / elemBitWidth; - const int elemByteWidth = elemBitWidth / 8; - const bool isMfma32 = (mDim == 32); - - // For ds_read_b64_tr_* instructions, each thread accesses 64 bits (8 bytes) - // of data. The smallest unit for transposition is a - // [non-K, K] = {16, kWidthTransRead} sub-tile of elements, - // where each thread reads kWidthTransRead elements along the non-K - // dimension. Due to the transposition mechanism, each thread ends up with - // kWidthTransRead elements along the K dimension. - // - // The MFMA selection logic prioritizes double-rate MFMA instructions - // whenever possible: - // - // - For MFMA operations where M = N = 16, when blockK > k, mfma16x16x2*k - // is selected; otherwise (blockK ≤ k), mfma16x16xk remains the choice. - // - // - For MFMA operations where M = N = 32, when blockK > k, mfma32x32x2*k is - // selected; otherwise (blockK ≤ k), mfma32x32xk is used. - // - // NOTE: For fp8 and fp4, "double-rate" results in 4*k since scaled MFMA - // instructions are used. - // - // In "double-rate" MFMA instructions, each thread holds 2*kWidthTransRead - // elements along the K dimension: - // - The first kWidthTransRead elements belong to the first sub-tile. - // - The next kWidthTransRead elements belong to the second sub-tile. - // - // These elements are then grouped into larger tiles, each consisting of - // 8 {16, kWidthTransRead} sub-tiles. These tiles correspond to the data - // for one MFMA instruction. The shape of these tiles depends on the MFMA - // instruction used. - // - // For single-rate MFMA instructions, each thread holds kWidthTransRead - // elements along the K dimension. This means that the larger tile - // (corresponding to one MFMA instruction) consists of 4 {16, - // kWidthTransRead} sub-tiles. - - // Populate register base for first subtile - for (int i = 1; i < kWidthTransRead; i *= 2) { - registerBase.push_back({i, 0}); - } - - const int threadsPerSubtileNonK = 16 / kWidthTransRead; - const int threadsPerSubtileK = kWidthTransRead; - - // Populate lane base for first subtile - for (int i = 1; i < threadsPerSubtileNonK; i *= 2) { - laneBase.push_back({i * kWidthTransRead, 0}); - } - for (int i = 1; i < threadsPerSubtileK; i *= 2) { - laneBase.push_back({0, i}); - } - - // Function to extend register base for multiple tiles K dim. - auto extendRegisterBaseForKDim = [&](int kTileSize, - int numSubtilesPerTile) { - const int regsPerTile = kWidthTransRead * numSubtilesPerTile; - int totalRegs = (kSize / kTileSize) * regsPerTile; - - for (int reg = regsPerTile; reg < totalRegs; reg *= 2) { - registerBase.push_back({0, (reg / regsPerTile) * kTileSize}); - } - }; - - // kDoubleTileSize is the k dimension of a tile when double rated - // mfma instructions are used. - const int kDoubleTileSize = - isMfma32 ? 32 / elemByteWidth : 64 / elemByteWidth; - // kTileSize is the actually k dimention of a tile, which is - // determined by kWidthDot. - const int kTileSize = kWidthDot * 64 / mDim; - // We use kDoubleTileSize as a reference to check whether the given - // kWidthDot leads to double or single sub-tiles in each tile. - const int numSubtilesPerTile = (kTileSize == kDoubleTileSize) ? 2 : 1; - - // Extend register base for large K sizes. - if (numSubtilesPerTile == 2) - registerBase.push_back({0, threadsPerSubtileK}); // Second subtile - - extendRegisterBaseForKDim(kTileSize, numSubtilesPerTile); - - // Extend lane base based on MFMA size. - std::vector> laneBaseExt; - - if (isMfma32) { - laneBaseExt = {{16, 0}, {0, numSubtilesPerTile * threadsPerSubtileK}}; - } else { - laneBaseExt = {{0, numSubtilesPerTile * threadsPerSubtileK}, - {0, 2 * numSubtilesPerTile * threadsPerSubtileK}}; - } - laneBase.insert(laneBase.end(), laneBaseExt.begin(), laneBaseExt.end()); - }; - if (isFP4) - populateFP4LL(kSize, mDim); - else - populateLL(elemBitWidth, kSize, kWidthDot, mDim); + const bool isMfma32 = (mDim == 32); + // ds_read_b64_tr4 operates on FP4 values swapping the packing of them. Look + // at i8 values for the ownership of register/lane since it's the data type + // of the tensor. Register dimension: what i8 in the tile are held by thread + // 0? Lane dimension: what i8 in the tile are held in register 0 of each + // thread? + registerBase.push_back({1, 0}); + registerBase.push_back({2, 0}); + registerBase.push_back({4, 0}); + registerBase.push_back({0, 16}); + + // If more than one tile needs to be loaded, populate registerBase + // dimension for the other tiles + const int kTileSize = isMfma32 ? 64 : 128; + for (int reg = kTileSize; reg < kSize; reg *= 2) { + registerBase.push_back({0, reg}); + } + + // When mDim == 16 we have 16x128 mfma, otherwise it's 16x64 + // The LL for the two is different + laneBase.push_back({0, 1}); + laneBase.push_back({0, 2}); + laneBase.push_back({0, 4}); + laneBase.push_back({0, 8}); + if (mDim == 16) { + laneBase.push_back({0, 32}); + laneBase.push_back({0, 64}); + } else { + assert(mDim == 32); + laneBase.push_back({8, 0}); + laneBase.push_back({0, 32}); + } // Base vectors above are defined in a fixed order [non-k-dim, k-dim]. // To assign them to actual matrix dimensions we associate with register @@ -1444,8 +1404,12 @@ LinearLayout chooseShemLayoutForRegToRegConversion( LinearLayout chooseDsReadB64TrLayout(Attribute enc, ArrayRef shape, int32_t elemBitWidth) { - auto dot = cast(enc); - return chooseDotDsReadB64TrLayout(dot, shape, elemBitWidth); + if (elemBitWidth == 4) { + auto dot = cast(enc); + return chooseDotDsReadB64TrLayout(dot, shape, elemBitWidth); + } else { + return chooseLLDsReadB64TrLayout(enc, shape, elemBitWidth); + } } LinearLayout chooseScaledWmmaScaleLayout( diff --git a/lib/Dialect/TritonGPU/IR/Ops.cpp b/lib/Dialect/TritonGPU/IR/Ops.cpp index 4d237caf1e..74dfa46aa0 100644 --- a/lib/Dialect/TritonGPU/IR/Ops.cpp +++ b/lib/Dialect/TritonGPU/IR/Ops.cpp @@ -73,6 +73,31 @@ bool isConvertTrivial(ConvertLayoutOp op) { // Canonicalizer //===----------------------------------------------------------------------===// +// tmem_store(cvt) -> tmem_store +struct CanonicalizeConvertFromTMEMStore + : public mlir::OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult + matchAndRewrite(nvidia_gpu::TMEMStoreOp op, + PatternRewriter &rewriter) const override { + auto convert = op.getSrc().getDefiningOp(); + if (!convert) + return failure(); + + // bail for incompatible layouts + auto cvtSrcType = convert.getSrc().getType(); + if (!nvidia_gpu::isDistributedLayoutTMemCompatible( + op.getOperation(), cvtSrcType, op.getDst().getType())) { + return failure(); + } + + rewriter.modifyOpInPlace( + op, [&]() { op.getSrcMutable().assign(convert.getSrc()); }); + return mlir::success(); + } +}; + // reshape(cvt) -> reshape struct CanonicalizeConvertFromReshape : public mlir::OpRewritePattern { @@ -373,6 +398,7 @@ void ConvertLayoutOp::getCanonicalizationPatterns(RewritePatternSet &patterns, patterns.add(context); patterns.add(context); patterns.add(context); + patterns.add(context); } LogicalResult Fp4ToFpOp::verify() { diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/Schedule.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeliner/Schedule.cpp index 58a49a5b3a..7b9b0ca2fd 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeliner/Schedule.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeliner/Schedule.cpp @@ -263,7 +263,8 @@ void tt::CoarseSchedule::serialize(scf::ForOp &forOp) const { } // Create a CoarseSchedule based on forOp's . -LogicalResult tt::CoarseSchedule::deSerialize(scf::ForOp &forOp) { +LogicalResult tt::CoarseSchedule::deSerialize(scf::ForOp &forOp, + bool normalizeClusterId) { auto [minClusterId, maxClusterId] = getMinMaxCluster(forOp); std::optional maxStage = tryGetMaxStage(forOp); if (!maxStage) { @@ -272,9 +273,16 @@ LogicalResult tt::CoarseSchedule::deSerialize(scf::ForOp &forOp) { numStages = *maxStage + 1; DenseMap clustersMap; - for (int i = minClusterId; i < maxClusterId + 1; i++) { - clustersMap.insert({i, clusters.newAtBack()}); + if (normalizeClusterId) { + for (int i = minClusterId; i < maxClusterId + 1; i++) { + clustersMap.insert({i, clusters.newAtBack()}); + } + } else { + for (int i = 0; i < maxClusterId + 1; i++) { + clustersMap.insert({i, clusters.newAtBack()}); + } } + for (Operation &op : forOp.getBody()->without_terminator()) { if (!op.hasAttr(mlir::triton::kLoopStageAttrName)) continue; diff --git a/lib/Tools/LayoutUtils.cpp b/lib/Tools/LayoutUtils.cpp index 76067b467e..02cd30781a 100644 --- a/lib/Tools/LayoutUtils.cpp +++ b/lib/Tools/LayoutUtils.cpp @@ -477,4 +477,89 @@ largestVectorisation(MLIRContext *ctx, const LinearLayout &cvt, int bitwidth, llvm_unreachable("Vectorization < 1 is not valid"); } +std::optional getReps(const LinearLayout &cvt, + const LinearLayout &tile) { + + // Ensure tile out-dims are subset of cvt out-dims. + for (auto od : tile.getOutDimNames()) + assert(cvt.hasOutDim(od) && "tile out-dims must be contained in cvt"); + + // Precompute tile out-dim bit-widths. + llvm::SmallDenseMap outBLog2; + for (StringAttr od : cvt.getOutDimNames()) + outBLog2[od] = tile.hasOutDim(od) ? tile.getOutDimSizeLog2(od) : 0; + + // Build a per-out-dimension mask by OR-ing all tile bases that touch it. + llvm::SmallDenseMap tileMaskPerOutDim; + for (StringAttr od : cvt.getOutDimNames()) + tileMaskPerOutDim[od] = 0; + for (auto &[inDim, inBases] : tile.getBases()) { + (void)inDim; + for (auto &basis : inBases) { + int idx = 0; + for (StringAttr od : tile.getOutDimNames()) { + tileMaskPerOutDim[od] |= basis[idx++]; + } + } + } + + // Build reps with the same in/out dims as cvt, but zeroing out the leading + // inB bases (per in-dim) and keeping the remainder bases unchanged from cvt. + LinearLayout::BasesT repsBases; + for (StringAttr id : cvt.getInDimNames()) { + int inA = cvt.getInDimSizeLog2(id); + int inB = tile.hasInDim(id) ? tile.getInDimSizeLog2(id) : 0; + if (inB > inA) { + return std::nullopt; + } + + std::vector> basesForDim; + basesForDim.reserve(inA); + + // 1) Validate the starting bases match exactly. + for (int i = 0; i < inB; ++i) { + for (StringAttr od : cvt.getOutDimNames()) { + int a = cvt.getBasis(id, i, od); + int b = tile.getBasis(id, i, od); + if (a != b) { + return std::nullopt; + } + } + } + + // 2) Validate no overlap: the remaining cvt bases must have zeros in all + // tile-bit positions (computed as OR of all tile bases) for each + // out-dim. + for (int i = inB; i < inA; ++i) { + for (StringAttr od : cvt.getOutDimNames()) { + int32_t mask = tileMaskPerOutDim.lookup(od); + if (mask == 0) + continue; + int v = cvt.getBasis(id, i, od); + if ((v & mask) != 0) { + return std::nullopt; + } + } + } + + // 3) Emit reps bases: first inB as all-zeros; remainder copied from cvt. + for (int i = 0; i < inB; ++i) { + std::vector zero(cvt.getNumOutDims(), 0); + basesForDim.push_back(std::move(zero)); + } + for (int i = inB; i < inA; ++i) { + std::vector keep; + keep.reserve(cvt.getNumOutDims()); + for (StringAttr od : cvt.getOutDimNames()) + keep.push_back(cvt.getBasis(id, i, od)); + basesForDim.push_back(std::move(keep)); + } + + repsBases[id] = std::move(basesForDim); + } + + return LinearLayout(std::move(repsBases), cvt.getOutDims(), + /*requireSurjective=*/false); +} + } // namespace mlir::triton diff --git a/python/src/gluon_ir.cc b/python/src/gluon_ir.cc index b313f9cceb..a2418d8e6b 100644 --- a/python/src/gluon_ir.cc +++ b/python/src/gluon_ir.cc @@ -762,7 +762,25 @@ void init_gluon_ir(py::module &&m) { tt::CacheModifier cacheModifier) { self.create( dest, ptr, offsets, mask, other, stride, cacheModifier); - }); + }) + .def("create_make_tensor_descriptor", + [](TritonOpBuilder &self, Type resultTy, Value &base, + std::vector &shape, std::vector &strides, + tt::PaddingOption paddingOption) -> Value { + return self.create(resultTy, base, shape, + strides, paddingOption); + }) + .def("create_async_tdm_copy_global_to_local", + [](GluonOpBuilder &self, Value descPtr, std::vector &indices, + Value result) { + Value pred = self.create(1, 1); + self.create(descPtr, indices, + result, pred); + }) + .def("create_async_tdm_wait", [](GluonOpBuilder &self, int num) { + ValueRange tokens; + self.create(tokens, num); + }); py::class_(m, "WarpSpecializeOp", py::module_local()) diff --git a/python/test/gluon/test_consan.py b/python/test/gluon/test_consan.py index 046c0465b7..d0ea6c3a23 100644 --- a/python/test/gluon/test_consan.py +++ b/python/test/gluon/test_consan.py @@ -69,19 +69,77 @@ def run_in_process(client_fn, args=(), kwargs={}): XBLOCK = ttgl.constexpr(128) +@gluon.jit +def failing_kernel(input): + smem_layout: ttgl.constexpr = ttgl.NVMMASharedLayout(swizzle_byte_width=128, element_bitwidth=16, rank=2) + smem = ttgl.allocate_shared_memory(ttgl.float16, [XBLOCK, XBLOCK], smem_layout) + blocked_layout: ttgl.constexpr = ttgl.BlockedLayout(size_per_thread=[1, XBLOCK], threads_per_warp=[32, 1], + warps_per_cta=[4, 1], order=[0, 1]) + offs_m = ttgl.arange(0, XBLOCK, layout=ttgl.SliceLayout(dim=1, parent=blocked_layout))[:, None] + offs_n = ttgl.arange(0, XBLOCK, layout=ttgl.SliceLayout(dim=0, parent=blocked_layout))[None, :] + offs = offs_m * XBLOCK + offs_n + ampere.async_copy.async_copy_global_to_shared(smem, input + offs) + ampere.async_copy.commit_group() + + ampere.async_copy.async_copy_global_to_shared(smem, input + offs) + ampere.async_copy.commit_group() + ampere.async_copy.wait_group(0) + + +def run_failing_kernel(device, enable_consan, mode): + # ConSan requires a global memory allocation + def alloc_fn(size: int, alignment: int, stream: Optional[int]): + return torch.empty(size, device="cuda", dtype=torch.int8) + + triton.set_allocator(alloc_fn) + + if enable_consan: + if mode == "env": + os.environ["TRITON_INSTRUMENTATION_MODE"] = "consan" + knobs.refresh_knobs() + elif mode == "knob": + knobs.compilation.instrumentation_mode = "consan" + + input = torch.randn((XBLOCK, XBLOCK), device=device, dtype=torch.float16) + failing_kernel[(1, )](input) + + +@pytest.mark.xfail(not is_cuda() or torch.cuda.get_device_capability()[0] < 9, reason="Requires hopper", run=False) +def test_cache_miss_knob(device, monkeypatch): + # First run without consan + run_in_process(run_failing_kernel, (device, False, "knob")) + + # Then run with consan and assert that if fails + monkeypatch.setenv("CUDA_LAUNCH_BLOCKING", "1") + result = run_in_process(run_failing_kernel, (device, True, "knob")) + assert "device-side assert" in str(result.exc) + + +@pytest.mark.xfail(not is_cuda() or torch.cuda.get_device_capability()[0] < 9, reason="Requires hopper", run=False) +def test_cache_miss_env(device, monkeypatch): + # First run without consan + run_in_process(run_failing_kernel, (device, False, "env")) + + # Then run with consan and assert that if fails + monkeypatch.setenv("CUDA_LAUNCH_BLOCKING", "1") + result = run_in_process(run_failing_kernel, (device, True, "env")) + assert "device-side assert" in str(result.exc) + + @pytest.mark.xfail(not is_cuda() or torch.cuda.get_device_capability()[0] < 9, reason="Requires hopper or newer", run=False) @pytest.mark.parametrize("FAILURE", [True, False]) -def test_async_tma_kernel(FAILURE, device, run_wrapper): +def test_async_tma_kernel(FAILURE, device, run_wrapper, monkeypatch): if run_wrapper: - result = run_in_process(test_async_tma_kernel, (FAILURE, device, False)) + result = run_in_process(test_async_tma_kernel, (FAILURE, device, False, monkeypatch)) if FAILURE: assert "device-side assert" in str(result.exc) assert "Buffer being accessed has outstanding writes" in result.driver_stderr_output return - knobs.compilation.enable_experimental_consan = True - os.environ["CUDA_LAUNCH_BLOCKING"] = "1" + monkeypatch.setenv("TRITON_INSTRUMENTATION_MODE", "consan") + monkeypatch.setenv("CUDA_LAUNCH_BLOCKING", "1") + knobs.refresh_knobs() # ConSan requires a global memory allocation def alloc_fn(size: int, alignment: int, stream: Optional[int]): @@ -119,9 +177,9 @@ def kernel(input_desc, out, FAILURE: ttgl.constexpr): @pytest.mark.xfail(not is_cuda() or torch.cuda.get_device_capability()[0] < 9, reason="Requires hopper or newer", run=False) @pytest.mark.parametrize("FAILURE", [True, False]) -def test_tma_interleave_kernel(FAILURE, device, run_wrapper): +def test_tma_interleave_kernel(FAILURE, device, run_wrapper, monkeypatch): if run_wrapper: - result = run_in_process(test_tma_interleave_kernel, (FAILURE, device, False)) + result = run_in_process(test_tma_interleave_kernel, (FAILURE, device, False, monkeypatch)) if FAILURE: assert "device-side assert" in str(result.exc) assert "Buffer being accessed has outstanding writes" in result.driver_stderr_output @@ -130,8 +188,9 @@ def test_tma_interleave_kernel(FAILURE, device, run_wrapper): assert result.driver_stderr_output == "" return - knobs.compilation.enable_experimental_consan = True - os.environ["CUDA_LAUNCH_BLOCKING"] = "1" + monkeypatch.setenv("TRITON_INSTRUMENTATION_MODE", "consan") + monkeypatch.setenv("CUDA_LAUNCH_BLOCKING", "1") + knobs.refresh_knobs() # ConSan requires a global memory allocation def alloc_fn(size: int, alignment: int, stream: Optional[int]): @@ -178,9 +237,9 @@ def kernel(input_desc, out, FAILURE: ttgl.constexpr): @pytest.mark.xfail(not is_cuda() or torch.cuda.get_device_capability()[0] < 9, reason="Requires ampere or newer", run=False) @pytest.mark.parametrize("FAILURE", [True, False]) -def test_async_copy(FAILURE, device, run_wrapper): +def test_async_copy(FAILURE, device, run_wrapper, monkeypatch): if run_wrapper: - result = run_in_process(test_async_copy, (FAILURE, device, False)) + result = run_in_process(test_async_copy, (FAILURE, device, False, monkeypatch)) if FAILURE: assert "device-side assert" in str(result.exc) assert "Accessing buffer with pending access. Pending access type: async_copy_global_to_shared" in result.driver_stderr_output @@ -189,8 +248,9 @@ def test_async_copy(FAILURE, device, run_wrapper): assert result.driver_stderr_output == "" return - knobs.compilation.enable_experimental_consan = True - os.environ["CUDA_LAUNCH_BLOCKING"] = "1" + monkeypatch.setenv("TRITON_INSTRUMENTATION_MODE", "consan") + monkeypatch.setenv("CUDA_LAUNCH_BLOCKING", "1") + knobs.refresh_knobs() # ConSan requires a global memory allocation def alloc_fn(size: int, alignment: int, stream: Optional[int]): @@ -226,9 +286,9 @@ def kernel(input, FAILURE: ttgl.constexpr): run=False) @pytest.mark.parametrize("FAILURE", [True, False]) @pytest.mark.parametrize("MEM_ACCESS_KIND", ["tma_cp", "local_store", "tmem_load", "tmem_store"]) -def test_tcgen5_mma(FAILURE, MEM_ACCESS_KIND, device, run_wrapper): +def test_tcgen5_mma(FAILURE, MEM_ACCESS_KIND, device, run_wrapper, monkeypatch): if run_wrapper: - result = run_in_process(test_tcgen5_mma, (FAILURE, MEM_ACCESS_KIND, device, False)) + result = run_in_process(test_tcgen5_mma, (FAILURE, MEM_ACCESS_KIND, device, False, monkeypatch)) if FAILURE: assert "device-side assert" in str(result.exc) if MEM_ACCESS_KIND == "tma_cp": @@ -242,8 +302,9 @@ def test_tcgen5_mma(FAILURE, MEM_ACCESS_KIND, device, run_wrapper): assert result.driver_stderr_output == "" return - knobs.compilation.enable_experimental_consan = True - os.environ["CUDA_LAUNCH_BLOCKING"] = "1" + monkeypatch.setenv("TRITON_INSTRUMENTATION_MODE", "consan") + monkeypatch.setenv("CUDA_LAUNCH_BLOCKING", "1") + knobs.refresh_knobs() # ConSan requires a global memory allocation def alloc_fn(size: int, alignment: int, stream: Optional[int]): @@ -294,9 +355,9 @@ def kernel(input_desc, FAILURE: ttgl.constexpr, MEM_ACCESS_KIND: ttgl.constexpr) @pytest.mark.xfail(not is_cuda() or torch.cuda.get_device_capability()[0] != 9, reason="Requires hopper", run=False) @pytest.mark.parametrize("FAILURE", [True, False]) -def test_warpgroup_mma(FAILURE, device, run_wrapper): +def test_warpgroup_mma(FAILURE, device, run_wrapper, monkeypatch): if run_wrapper: - result = run_in_process(test_warpgroup_mma, (FAILURE, device, False)) + result = run_in_process(test_warpgroup_mma, (FAILURE, device, False, monkeypatch)) if FAILURE: assert "device-side assert" in str(result.exc) assert "Accessing buffer with pending access. Pending access type: warpgroup_mma operand read" in result.driver_stderr_output @@ -305,8 +366,9 @@ def test_warpgroup_mma(FAILURE, device, run_wrapper): assert result.driver_stderr_output == "" return - knobs.compilation.enable_experimental_consan = True - os.environ["CUDA_LAUNCH_BLOCKING"] = "1" + monkeypatch.setenv("TRITON_INSTRUMENTATION_MODE", "consan") + monkeypatch.setenv("CUDA_LAUNCH_BLOCKING", "1") + knobs.refresh_knobs() # ConSan requires a global memory allocation def alloc_fn(size: int, alignment: int, stream: Optional[int]): @@ -338,9 +400,9 @@ def kernel(input, FAILURE: ttgl.constexpr): @pytest.mark.xfail(not is_cuda() or torch.cuda.get_device_capability()[0] != 9, reason="Requires hopper", run=False) @pytest.mark.parametrize("FAILURE", [True, False]) -def test_warpgroup_mma2(FAILURE, device, run_wrapper): +def test_warpgroup_mma2(FAILURE, device, run_wrapper, monkeypatch): if run_wrapper: - result = run_in_process(test_warpgroup_mma2, (FAILURE, device, False)) + result = run_in_process(test_warpgroup_mma2, (FAILURE, device, False, monkeypatch)) if FAILURE: assert "device-side assert" in str(result.exc) assert "Accessing buffer with pending access. Pending access type: warpgroup_mma operand read" in result.driver_stderr_output @@ -349,8 +411,9 @@ def test_warpgroup_mma2(FAILURE, device, run_wrapper): assert result.driver_stderr_output == "" return - knobs.compilation.enable_experimental_consan = True - os.environ["CUDA_LAUNCH_BLOCKING"] = "1" + monkeypatch.setenv("TRITON_INSTRUMENTATION_MODE", "consan") + monkeypatch.setenv("CUDA_LAUNCH_BLOCKING", "1") + knobs.refresh_knobs() # ConSan requires a global memory allocation def alloc_fn(size: int, alignment: int, stream: Optional[int]): @@ -386,11 +449,11 @@ def kernel(input, FAILURE: ttgl.constexpr): run=False) @pytest.mark.parametrize("BUF_IDX", [0, 1]) @pytest.mark.parametrize("BAR_IDX", [0, 1, 2, 3]) -def test_tcgen5_mma_multibar(BUF_IDX, BAR_IDX, device, run_wrapper): +def test_tcgen5_mma_multibar(BUF_IDX, BAR_IDX, device, run_wrapper, monkeypatch): if BAR_IDX == 0: pytest.skip("Skipping due to wait on false-predicated barrier - not supported yet") if run_wrapper: - result = run_in_process(test_tcgen5_mma_multibar, (BUF_IDX, BAR_IDX, device, False)) + result = run_in_process(test_tcgen5_mma_multibar, (BUF_IDX, BAR_IDX, device, False, monkeypatch)) if BAR_IDX // 2 < BUF_IDX: assert "device-side assert" in str(result.exc) assert "Buffer being accessed has outstanding writes" in result.driver_stderr_output @@ -398,8 +461,9 @@ def test_tcgen5_mma_multibar(BUF_IDX, BAR_IDX, device, run_wrapper): assert result.exc is None assert result.driver_stderr_output == "" return - knobs.compilation.enable_experimental_consan = True - os.environ["CUDA_LAUNCH_BLOCKING"] = "1" + monkeypatch.setenv("TRITON_INSTRUMENTATION_MODE", "consan") + monkeypatch.setenv("CUDA_LAUNCH_BLOCKING", "1") + knobs.refresh_knobs() # ConSan requires a global memory allocation def alloc_fn(size: int, alignment: int, stream: Optional[int]): @@ -446,9 +510,9 @@ def inc_mod(x, mod): @pytest.mark.xfail(not is_cuda() or torch.cuda.get_device_capability()[0] < 10, reason="Requires blackwell or newer", run=False) @pytest.mark.parametrize("FAILURE", [True, False]) -def test_multibuffered_loop(FAILURE, device, run_wrapper): +def test_multibuffered_loop(FAILURE, device, run_wrapper, monkeypatch): if run_wrapper: - result = run_in_process(test_multibuffered_loop, (FAILURE, device, False)) + result = run_in_process(test_multibuffered_loop, (FAILURE, device, False, monkeypatch)) if FAILURE: assert "device-side assert" in str(result.exc) assert "Buffer being accessed has outstanding reads" in result.driver_stderr_output @@ -457,8 +521,9 @@ def test_multibuffered_loop(FAILURE, device, run_wrapper): assert result.driver_stderr_output == "" return - knobs.compilation.enable_experimental_consan = True - os.environ["CUDA_LAUNCH_BLOCKING"] = "1" + monkeypatch.setenv("TRITON_INSTRUMENTATION_MODE", "consan") + monkeypatch.setenv("CUDA_LAUNCH_BLOCKING", "1") + knobs.refresh_knobs() # ConSan requires a global memory allocation def alloc_fn(size: int, alignment: int, stream: Optional[int]): @@ -560,9 +625,9 @@ def kernel(input_desc, FAILURE: ttgl.constexpr): @pytest.mark.xfail(not is_cuda() or torch.cuda.get_device_capability()[0] != 9, reason="Requires hopper", run=False) @pytest.mark.parametrize("FAILURE", [True, False]) -def test_multibuffered_wgmma_loop(FAILURE, device, run_wrapper): +def test_multibuffered_wgmma_loop(FAILURE, device, run_wrapper, monkeypatch): if run_wrapper: - result = run_in_process(test_multibuffered_wgmma_loop, (FAILURE, device, False)) + result = run_in_process(test_multibuffered_wgmma_loop, (FAILURE, device, False, monkeypatch)) if FAILURE: assert "device-side assert" in str(result.exc) assert "Accessing buffer with pending access. Pending access type: warpgroup_mma operand read" in result.driver_stderr_output @@ -571,8 +636,9 @@ def test_multibuffered_wgmma_loop(FAILURE, device, run_wrapper): assert result.driver_stderr_output == "" return - knobs.compilation.enable_experimental_consan = True - os.environ["CUDA_LAUNCH_BLOCKING"] = "1" + monkeypatch.setenv("TRITON_INSTRUMENTATION_MODE", "consan") + monkeypatch.setenv("CUDA_LAUNCH_BLOCKING", "1") + knobs.refresh_knobs() # ConSan requires a global memory allocation def alloc_fn(size: int, alignment: int, stream: Optional[int]): @@ -641,9 +707,9 @@ def kernel(input_desc, FAILURE: ttgl.constexpr): @pytest.mark.xfail(not is_cuda() or torch.cuda.get_device_capability()[0] < 9, reason="Requires hopper", run=False) @pytest.mark.parametrize("FAILURE", [True, False]) -def test_ws_store_wait_load(FAILURE, device, run_wrapper): +def test_ws_store_wait_load(FAILURE, device, run_wrapper, monkeypatch): if run_wrapper: - result = run_in_process(test_ws_store_wait_load, (FAILURE, device, False)) + result = run_in_process(test_ws_store_wait_load, (FAILURE, device, False, monkeypatch)) if FAILURE: assert "device-side assert" in str(result.exc) assert "Buffer being accessed has outstanding writes" in result.driver_stderr_output @@ -651,8 +717,9 @@ def test_ws_store_wait_load(FAILURE, device, run_wrapper): assert result.exc is None assert result.driver_stderr_output == "" return - knobs.compilation.enable_experimental_consan = True - os.environ["CUDA_LAUNCH_BLOCKING"] = "1" + monkeypatch.setenv("TRITON_INSTRUMENTATION_MODE", "consan") + monkeypatch.setenv("CUDA_LAUNCH_BLOCKING", "1") + knobs.refresh_knobs() # ConSan requires a global memory allocation def alloc_fn(size: int, alignment: int, stream: Optional[int]): @@ -694,9 +761,9 @@ def ws_kernel(output, FAILURE: ttgl.constexpr): @pytest.mark.xfail(not is_cuda() or torch.cuda.get_device_capability()[0] < 9, reason="Requires hopper", run=False) @pytest.mark.parametrize("FAILURE", [True, False]) -def test_ws_load_wait_store(FAILURE, device, run_wrapper): +def test_ws_load_wait_store(FAILURE, device, run_wrapper, monkeypatch): if run_wrapper: - result = run_in_process(test_ws_load_wait_store, (FAILURE, device, False)) + result = run_in_process(test_ws_load_wait_store, (FAILURE, device, False, monkeypatch)) if FAILURE: assert "device-side assert" in str(result.exc) assert "Buffer being accessed has outstanding reads" in result.driver_stderr_output @@ -704,8 +771,9 @@ def test_ws_load_wait_store(FAILURE, device, run_wrapper): assert result.exc is None assert result.driver_stderr_output == "" return - knobs.compilation.enable_experimental_consan = True - os.environ["CUDA_LAUNCH_BLOCKING"] = "1" + monkeypatch.setenv("TRITON_INSTRUMENTATION_MODE", "consan") + monkeypatch.setenv("CUDA_LAUNCH_BLOCKING", "1") + knobs.refresh_knobs() # ConSan requires a global memory allocation def alloc_fn(size: int, alignment: int, stream: Optional[int]): @@ -747,9 +815,9 @@ def ws_kernel(output, FAILURE: ttgl.constexpr): @pytest.mark.xfail(not is_cuda() or torch.cuda.get_device_capability()[0] < 9, reason="Requires hopper", run=False) @pytest.mark.parametrize("MISSING_BAR", ["none", "1", "2"]) -def test_ws_two_loads_two_bars(MISSING_BAR, device, run_wrapper): +def test_ws_two_loads_two_bars(MISSING_BAR, device, run_wrapper, monkeypatch): if run_wrapper: - result = run_in_process(test_ws_two_loads_two_bars, (MISSING_BAR, device, False)) + result = run_in_process(test_ws_two_loads_two_bars, (MISSING_BAR, device, False, monkeypatch)) if MISSING_BAR != "none": assert "device-side assert" in str(result.exc) assert "Buffer being accessed has outstanding reads" in result.driver_stderr_output @@ -757,8 +825,9 @@ def test_ws_two_loads_two_bars(MISSING_BAR, device, run_wrapper): assert result.exc is None assert result.driver_stderr_output == "" return - knobs.compilation.enable_experimental_consan = True - os.environ["CUDA_LAUNCH_BLOCKING"] = "1" + monkeypatch.setenv("TRITON_INSTRUMENTATION_MODE", "consan") + monkeypatch.setenv("CUDA_LAUNCH_BLOCKING", "1") + knobs.refresh_knobs() # ConSan requires a global memory allocation def alloc_fn(size: int, alignment: int, stream: Optional[int]): @@ -809,9 +878,9 @@ def kernel(output, MISSING_BAR: ttgl.constexpr): @pytest.mark.xfail(not is_cuda() or torch.cuda.get_device_capability()[0] < 9, reason="Requires hopper", run=False) @pytest.mark.parametrize("FAILURE", [True, False]) -def test_ws_two_loads_one_bar(FAILURE, device, run_wrapper): +def test_ws_two_loads_one_bar(FAILURE, device, run_wrapper, monkeypatch): if run_wrapper: - result = run_in_process(test_ws_two_loads_one_bar, (FAILURE, device, False)) + result = run_in_process(test_ws_two_loads_one_bar, (FAILURE, device, False, monkeypatch)) if FAILURE: assert "device-side assert" in str(result.exc) assert "Buffer being accessed has outstanding reads" in result.driver_stderr_output @@ -819,8 +888,9 @@ def test_ws_two_loads_one_bar(FAILURE, device, run_wrapper): assert result.exc is None assert result.driver_stderr_output == "" return - knobs.compilation.enable_experimental_consan = True - os.environ["CUDA_LAUNCH_BLOCKING"] = "1" + monkeypatch.setenv("TRITON_INSTRUMENTATION_MODE", "consan") + monkeypatch.setenv("CUDA_LAUNCH_BLOCKING", "1") + knobs.refresh_knobs() # ConSan requires a global memory allocation def alloc_fn(size: int, alignment: int, stream: Optional[int]): @@ -868,9 +938,9 @@ def kernel(output, FAILURE: ttgl.constexpr): @pytest.mark.xfail(not is_cuda() or torch.cuda.get_device_capability()[0] < 9, reason="Requires hopper", run=False) @pytest.mark.parametrize("MISSING_BAR", ["none", "0", "1", "2", "3"]) -def test_ws_two_loads_two_bars_loop(MISSING_BAR, device, run_wrapper): +def test_ws_two_loads_two_bars_loop(MISSING_BAR, device, run_wrapper, monkeypatch): if run_wrapper: - result = run_in_process(test_ws_two_loads_two_bars_loop, (MISSING_BAR, device, False)) + result = run_in_process(test_ws_two_loads_two_bars_loop, (MISSING_BAR, device, False, monkeypatch)) if MISSING_BAR != "none": assert "device-side assert" in str(result.exc) if MISSING_BAR in ["0", "1"]: @@ -881,8 +951,9 @@ def test_ws_two_loads_two_bars_loop(MISSING_BAR, device, run_wrapper): assert result.exc is None assert result.driver_stderr_output == "" return - knobs.compilation.enable_experimental_consan = True - os.environ["CUDA_LAUNCH_BLOCKING"] = "1" + monkeypatch.setenv("TRITON_INSTRUMENTATION_MODE", "consan") + monkeypatch.setenv("CUDA_LAUNCH_BLOCKING", "1") + knobs.refresh_knobs() # ConSan requires a global memory allocation def alloc_fn(size: int, alignment: int, stream: Optional[int]): @@ -951,9 +1022,9 @@ def kernel(output, MISSING_BAR: ttgl.constexpr): @pytest.mark.xfail(not is_cuda() or torch.cuda.get_device_capability()[0] < 9, reason="Requires hopper", run=False) @pytest.mark.parametrize("FAILURE", [True, False]) -def test_ws_load_ordering(FAILURE, device, run_wrapper): +def test_ws_load_ordering(FAILURE, device, run_wrapper, monkeypatch): if run_wrapper: - result = run_in_process(test_ws_load_ordering, (FAILURE, device, False)) + result = run_in_process(test_ws_load_ordering, (FAILURE, device, False, monkeypatch)) if FAILURE: assert "device-side assert" in str(result.exc) assert "Buffer being accessed has outstanding writes" in result.driver_stderr_output @@ -961,8 +1032,9 @@ def test_ws_load_ordering(FAILURE, device, run_wrapper): assert result.exc is None assert result.driver_stderr_output == "" return - knobs.compilation.enable_experimental_consan = True - os.environ["CUDA_LAUNCH_BLOCKING"] = "1" + monkeypatch.setenv("TRITON_INSTRUMENTATION_MODE", "consan") + monkeypatch.setenv("CUDA_LAUNCH_BLOCKING", "1") + knobs.refresh_knobs() # ConSan requires a global memory allocation def alloc_fn(size: int, alignment: int, stream: Optional[int]): @@ -1015,9 +1087,9 @@ def kernel(output, FAILURE: ttgl.constexpr): @pytest.mark.xfail(not is_cuda() or torch.cuda.get_device_capability()[0] < 9, reason="Requires hopper", run=False) @pytest.mark.parametrize("MISSING_BAR", ["none", "T2", "T3"]) -def test_ws_two_producers_two_consumers(MISSING_BAR, device, run_wrapper): +def test_ws_two_producers_two_consumers(MISSING_BAR, device, run_wrapper, monkeypatch): if run_wrapper: - result = run_in_process(test_ws_two_producers_two_consumers, (MISSING_BAR, device, False)) + result = run_in_process(test_ws_two_producers_two_consumers, (MISSING_BAR, device, False, monkeypatch)) if MISSING_BAR != "none": assert "device-side assert" in str(result.exc) assert "Buffer being accessed has outstanding" in result.driver_stderr_output @@ -1025,8 +1097,9 @@ def test_ws_two_producers_two_consumers(MISSING_BAR, device, run_wrapper): assert result.exc is None assert result.driver_stderr_output == "" return - knobs.compilation.enable_experimental_consan = True - os.environ["CUDA_LAUNCH_BLOCKING"] = "1" + monkeypatch.setenv("TRITON_INSTRUMENTATION_MODE", "consan") + monkeypatch.setenv("CUDA_LAUNCH_BLOCKING", "1") + knobs.refresh_knobs() # ConSan requires a global memory allocation def alloc_fn(size: int, alignment: int, stream: Optional[int]): @@ -1102,9 +1175,9 @@ def kernel(output, MISSING_BAR: ttgl.constexpr): @pytest.mark.xfail(not is_cuda() or torch.cuda.get_device_capability()[0] < 9, reason="Requires hopper", run=False) @pytest.mark.parametrize("MISSING_BAR", ["none", "1", "2"]) -def test_ws_different_warp_sizes(MISSING_BAR, device, run_wrapper): +def test_ws_different_warp_sizes(MISSING_BAR, device, run_wrapper, monkeypatch): if run_wrapper: - result = run_in_process(test_ws_different_warp_sizes, (MISSING_BAR, device, False)) + result = run_in_process(test_ws_different_warp_sizes, (MISSING_BAR, device, False, monkeypatch)) if MISSING_BAR != "none": assert "device-side assert" in str(result.exc) assert "Buffer being accessed has outstanding reads" in result.driver_stderr_output @@ -1112,8 +1185,9 @@ def test_ws_different_warp_sizes(MISSING_BAR, device, run_wrapper): assert result.exc is None assert result.driver_stderr_output == "" return - knobs.compilation.enable_experimental_consan = True - os.environ["CUDA_LAUNCH_BLOCKING"] = "1" + monkeypatch.setenv("TRITON_INSTRUMENTATION_MODE", "consan") + monkeypatch.setenv("CUDA_LAUNCH_BLOCKING", "1") + knobs.refresh_knobs() # ConSan requires a global memory allocation def alloc_fn(size: int, alignment: int, stream: Optional[int]): @@ -1171,9 +1245,9 @@ def kernel(output, MISSING_BAR: ttgl.constexpr): @pytest.mark.xfail(not is_cuda() or torch.cuda.get_device_capability()[0] < 9, reason="Requires hopper or newer", run=False) @pytest.mark.parametrize("FAILURE", [True, False]) -def test_ws_async_copy_commits(FAILURE, device, run_wrapper): +def test_ws_async_copy_commits(FAILURE, device, run_wrapper, monkeypatch): if run_wrapper: - result = run_in_process(test_ws_async_copy_commits, (FAILURE, device, False)) + result = run_in_process(test_ws_async_copy_commits, (FAILURE, device, False, monkeypatch)) if FAILURE: assert "device-side assert" in str(result.exc) assert "Buffer being accessed has outstanding writes" in result.driver_stderr_output @@ -1182,8 +1256,9 @@ def test_ws_async_copy_commits(FAILURE, device, run_wrapper): assert result.driver_stderr_output == "" return - knobs.compilation.enable_experimental_consan = True - os.environ["CUDA_LAUNCH_BLOCKING"] = "1" + monkeypatch.setenv("TRITON_INSTRUMENTATION_MODE", "consan") + monkeypatch.setenv("CUDA_LAUNCH_BLOCKING", "1") + knobs.refresh_knobs() def alloc_fn(size: int, alignment: int, stream: Optional[int]): return torch.empty(size, device="cuda", dtype=torch.int8) @@ -1233,9 +1308,9 @@ def kernel(input, FAILURE: ttgl.constexpr): @pytest.mark.xfail(not is_cuda() or torch.cuda.get_device_capability()[0] < 9, reason="Requires hopper or newer", run=False) @pytest.mark.parametrize("FAILURE", [True, False]) -def test_ws_async_copy_wait_visibility(FAILURE, device, run_wrapper): +def test_ws_async_copy_wait_visibility(FAILURE, device, run_wrapper, monkeypatch): if run_wrapper: - result = run_in_process(test_ws_async_copy_wait_visibility, (FAILURE, device, False)) + result = run_in_process(test_ws_async_copy_wait_visibility, (FAILURE, device, False, monkeypatch)) if FAILURE: assert "device-side assert" in str(result.exc) assert (("Buffer being accessed has outstanding writes" in result.driver_stderr_output) @@ -1246,8 +1321,9 @@ def test_ws_async_copy_wait_visibility(FAILURE, device, run_wrapper): assert result.driver_stderr_output == "" return - knobs.compilation.enable_experimental_consan = True - os.environ["CUDA_LAUNCH_BLOCKING"] = "1" + monkeypatch.setenv("TRITON_INSTRUMENTATION_MODE", "consan") + monkeypatch.setenv("CUDA_LAUNCH_BLOCKING", "1") + knobs.refresh_knobs() def alloc_fn(size: int, alignment: int, stream: Optional[int]): return torch.empty(size, device="cuda", dtype=torch.int8) @@ -1287,9 +1363,9 @@ def kernel(input, FAILURE: ttgl.constexpr): @pytest.mark.xfail(not is_cuda() or torch.cuda.get_device_capability()[0] != 9, reason="Requires hopper", run=False) @pytest.mark.parametrize("FAILURE", [True, False]) -def test_ws_wgmma_wait_visibility(FAILURE, device, run_wrapper): +def test_ws_wgmma_wait_visibility(FAILURE, device, run_wrapper, monkeypatch): if run_wrapper: - result = run_in_process(test_ws_wgmma_wait_visibility, (FAILURE, device, False)) + result = run_in_process(test_ws_wgmma_wait_visibility, (FAILURE, device, False, monkeypatch)) if FAILURE: assert "device-side assert" in str(result.exc) assert "Accessing buffer with pending access. Pending access type: warpgroup_mma operand read" in result.driver_stderr_output @@ -1298,8 +1374,9 @@ def test_ws_wgmma_wait_visibility(FAILURE, device, run_wrapper): assert result.driver_stderr_output == "" return - knobs.compilation.enable_experimental_consan = True - os.environ["CUDA_LAUNCH_BLOCKING"] = "1" + monkeypatch.setenv("TRITON_INSTRUMENTATION_MODE", "consan") + monkeypatch.setenv("CUDA_LAUNCH_BLOCKING", "1") + knobs.refresh_knobs() def alloc_fn(size: int, alignment: int, stream: Optional[int]): return torch.empty(size, device="cuda", dtype=torch.int8) @@ -1340,14 +1417,15 @@ def kernel(FAILURE: ttgl.constexpr): @pytest.mark.xfail(not is_cuda() or torch.cuda.get_device_capability()[0] < 9, reason="Requires hopper", run=False) -def test_deadlock_two_partitions(device, run_wrapper): +def test_deadlock_two_partitions(device, run_wrapper, monkeypatch): if run_wrapper: - result = run_in_process(test_deadlock_two_partitions, (device, False)) + result = run_in_process(test_deadlock_two_partitions, (device, False, monkeypatch)) assert "device-side assert" in str(result.exc) assert "Deadlock detected" in result.driver_stderr_output return - knobs.compilation.enable_experimental_consan = True - os.environ["CUDA_LAUNCH_BLOCKING"] = "1" + monkeypatch.setenv("TRITON_INSTRUMENTATION_MODE", "consan") + monkeypatch.setenv("CUDA_LAUNCH_BLOCKING", "1") + knobs.refresh_knobs() # ConSan requires a global memory allocation def alloc_fn(size: int, alignment: int, stream: Optional[int]): @@ -1374,14 +1452,15 @@ def kernel(): @pytest.mark.xfail(not is_cuda() or torch.cuda.get_device_capability()[0] < 9, reason="Requires hopper", run=False) -def test_deadlock_overarrival(device, run_wrapper): +def test_deadlock_overarrival(device, run_wrapper, monkeypatch): if run_wrapper: - result = run_in_process(test_deadlock_overarrival, (device, False)) + result = run_in_process(test_deadlock_overarrival, (device, False, monkeypatch)) assert "device-side assert" in str(result.exc) assert "Deadlock detected" in result.driver_stderr_output return - knobs.compilation.enable_experimental_consan = True - os.environ["CUDA_LAUNCH_BLOCKING"] = "1" + monkeypatch.setenv("TRITON_INSTRUMENTATION_MODE", "consan") + monkeypatch.setenv("CUDA_LAUNCH_BLOCKING", "1") + knobs.refresh_knobs() # ConSan requires a global memory allocation def alloc_fn(size: int, alignment: int, stream: Optional[int]): @@ -1403,14 +1482,15 @@ def kernel(): @pytest.mark.xfail(not is_cuda() or torch.cuda.get_device_capability()[0] < 9, reason="Requires hopper", run=False) -def test_deadlock_underarrival(device, run_wrapper): +def test_deadlock_underarrival(device, run_wrapper, monkeypatch): if run_wrapper: - result = run_in_process(test_deadlock_underarrival, (device, False)) + result = run_in_process(test_deadlock_underarrival, (device, False, monkeypatch)) assert "device-side assert" in str(result.exc) assert "Deadlock detected" in result.driver_stderr_output return - knobs.compilation.enable_experimental_consan = True - os.environ["CUDA_LAUNCH_BLOCKING"] = "1" + monkeypatch.setenv("TRITON_INSTRUMENTATION_MODE", "consan") + monkeypatch.setenv("CUDA_LAUNCH_BLOCKING", "1") + knobs.refresh_knobs() # ConSan requires a global memory allocation def alloc_fn(size: int, alignment: int, stream: Optional[int]): @@ -1439,14 +1519,15 @@ def kernel(): @pytest.mark.xfail(not is_cuda() or torch.cuda.get_device_capability()[0] < 9, reason="Requires hopper", run=False) -def test_deadlock_different_phases(device, run_wrapper): +def test_deadlock_different_phases(device, run_wrapper, monkeypatch): if run_wrapper: - result = run_in_process(test_deadlock_different_phases, (device, False)) + result = run_in_process(test_deadlock_different_phases, (device, False, monkeypatch)) assert result.exc is None assert result.driver_stderr_output == "" return - knobs.compilation.enable_experimental_consan = True - os.environ["CUDA_LAUNCH_BLOCKING"] = "1" + monkeypatch.setenv("TRITON_INSTRUMENTATION_MODE", "consan") + monkeypatch.setenv("CUDA_LAUNCH_BLOCKING", "1") + knobs.refresh_knobs() # ConSan requires a global memory allocation def alloc_fn(size: int, alignment: int, stream: Optional[int]): @@ -1474,14 +1555,15 @@ def kernel(): @pytest.mark.xfail(not is_cuda() or torch.cuda.get_device_capability()[0] < 9, reason="Requires hopper", run=False) -def test_deadlock_exempt_when_tma_signals(device, run_wrapper): +def test_deadlock_exempt_when_tma_signals(device, run_wrapper, monkeypatch): if run_wrapper: - result = run_in_process(test_deadlock_exempt_when_tma_signals, (device, False)) + result = run_in_process(test_deadlock_exempt_when_tma_signals, (device, False, monkeypatch)) assert result.exc is None assert result.driver_stderr_output == "" return - knobs.compilation.enable_experimental_consan = True - os.environ["CUDA_LAUNCH_BLOCKING"] = "1" + monkeypatch.setenv("TRITON_INSTRUMENTATION_MODE", "consan") + monkeypatch.setenv("CUDA_LAUNCH_BLOCKING", "1") + knobs.refresh_knobs() # ConSan requires a global memory allocation def alloc_fn(size: int, alignment: int, stream: Optional[int]): @@ -1517,14 +1599,15 @@ def kernel(input_desc): @pytest.mark.xfail(not is_cuda() or torch.cuda.get_device_capability()[0] < 9, reason="Requires hopper", run=False) -def test_barrier_underflow(device, run_wrapper): +def test_barrier_underflow(device, run_wrapper, monkeypatch): if run_wrapper: - result = run_in_process(test_barrier_underflow, (device, False)) + result = run_in_process(test_barrier_underflow, (device, False, monkeypatch)) assert "device-side assert" in str(result.exc) assert "Barrier arrive underflow: current count would become negative" in result.driver_stderr_output return - knobs.compilation.enable_experimental_consan = True - os.environ["CUDA_LAUNCH_BLOCKING"] = "1" + monkeypatch.setenv("TRITON_INSTRUMENTATION_MODE", "consan") + monkeypatch.setenv("CUDA_LAUNCH_BLOCKING", "1") + knobs.refresh_knobs() # ConSan requires a global memory allocation def alloc_fn(size: int, alignment: int, stream: Optional[int]): diff --git a/python/test/gluon/test_core.py b/python/test/gluon/test_core.py index bc534c3372..d60ae68a58 100644 --- a/python/test/gluon/test_core.py +++ b/python/test/gluon/test_core.py @@ -1,6 +1,7 @@ import torch import pytest import re +from itertools import product import triton import triton.language as tl @@ -125,29 +126,28 @@ def test_async_copy_mbarrier(device): @gluon.jit -def warpgroup_mma_kernel(a, b, out, M: ttgl.constexpr, N: ttgl.constexpr, K: ttgl.constexpr, ASYNC: ttgl.constexpr): - block_layout: ttgl.constexpr = ttgl.BlockedLayout([1, 1], [1, 32], [4, 1], [1, 0]) - mma_layout: ttgl.constexpr = ttgl.NVMMADistributedLayout(version=[3, 0], warps_per_cta=[4, 1], - instr_shape=[16, 32, 16]) - nvmma_layout: ttgl.constexpr = ttgl.NVMMASharedLayout(swizzle_byte_width=64, element_bitwidth=16, rank=2) - +def warpgroup_mma_kernel(a, b, out, M: ttgl.constexpr, N: ttgl.constexpr, K: ttgl.constexpr, + block_layout: ttgl.constexpr, mma_layout: ttgl.constexpr, shared_layout_a: ttgl.constexpr, + shared_layout_b: ttgl.constexpr, acc_dtype: ttgl.constexpr, ASYNC: ttgl.constexpr): a_offs_m = ttgl.arange(0, M, layout=ttgl.SliceLayout(1, block_layout))[:, None] - a_offs_n = ttgl.arange(0, K, layout=ttgl.SliceLayout(0, block_layout))[None, :] - b_offs_m = ttgl.arange(0, K, layout=ttgl.SliceLayout(1, block_layout))[:, None] + a_offs_k = ttgl.arange(0, K, layout=ttgl.SliceLayout(0, block_layout))[None, :] + b_offs_k = ttgl.arange(0, K, layout=ttgl.SliceLayout(1, block_layout))[:, None] b_offs_n = ttgl.arange(0, N, layout=ttgl.SliceLayout(0, block_layout))[None, :] out_offs_m = ttgl.arange(0, M, layout=ttgl.SliceLayout(1, mma_layout))[:, None] out_offs_n = ttgl.arange(0, N, layout=ttgl.SliceLayout(0, mma_layout))[None, :] - acc = ttgl.zeros([M, N], dtype=a.dtype.element_ty, layout=mma_layout) - A = ttgl.load(a + a_offs_m * K + a_offs_n) - B = ttgl.load(b + b_offs_m * N + b_offs_n) + operand_dtype = a.dtype.element_ty + a_tile = ttgl.load(a + a_offs_m * K + a_offs_k) + b_tile = ttgl.load(b + b_offs_k * N + b_offs_n) - a_shmem = ttgl.allocate_shared_memory(ttgl.float16, [M, K], nvmma_layout, A) - b_shmem = ttgl.allocate_shared_memory(ttgl.float16, [K, N], nvmma_layout, B) + smem_a = ttgl.allocate_shared_memory(operand_dtype, [M, K], shared_layout_a, a_tile) + smem_b = ttgl.allocate_shared_memory(operand_dtype, [K, N], shared_layout_b, b_tile) fence_async_shared() - acc = hopper.warpgroup_mma(a_shmem, b_shmem, acc, is_async=ASYNC) + + acc = ttgl.zeros([M, N], dtype=acc_dtype, layout=mma_layout) + acc = hopper.warpgroup_mma(smem_a, smem_b, acc, is_async=ASYNC) if ASYNC: acc = hopper.warpgroup_mma_wait(num_outstanding=0, deps=[acc]) @@ -160,16 +160,152 @@ def warpgroup_mma_kernel(a, b, out, M: ttgl.constexpr, N: ttgl.constexpr, K: ttg def test_warpgroup_mma(ASYNC): torch.manual_seed(0) M, N, K = 64, 32, 32 + warps = [4, 1] + block_layout = ttgl.BlockedLayout([1, 1], [1, THREADS_PER_WARP], warps_per_cta=warps, order=[1, 0]) + mma_layout = ttgl.NVMMADistributedLayout(version=[3, 0], warps_per_cta=warps, instr_shape=[16, 32, 16]) + shared_layout_a = ttgl.NVMMASharedLayout.get_default_for([M, K], ttgl.float16) + shared_layout_b = ttgl.NVMMASharedLayout.get_default_for([K, N], ttgl.float16) a = torch.randn((M, K), device="cuda", dtype=torch.float16) b = torch.randn((K, N), device="cuda", dtype=torch.float16) out = torch.zeros((M, N), device="cuda", dtype=torch.float16) - warpgroup_mma_kernel[(1, )](a, b, out, M, N, K, ASYNC) + warpgroup_mma_kernel[(1, )]( + a, + b, + out, + M, + N, + K, + block_layout, + mma_layout, + shared_layout_a, + shared_layout_b, + ttgl.float16, + ASYNC, + num_warps=warps[0] * warps[1], + ) ref = torch.matmul(a, b) torch.testing.assert_close(out, ref, atol=1e-3, rtol=1e-1) +@pytest.mark.xfail(not is_hopper(), reason="Requires Hopper", run=False) +@pytest.mark.parametrize("bitwidth, transpose_a, transpose_b, acc_dtype", + [(bitwidth, transpose_a, transpose_b, acc_dtype) + for bitwidth in [8, 16, 32] + for (transpose_a, transpose_b) in product([False, True], repeat=2) + for acc_dtype in [torch.float16, torch.float32] + if bitwidth == 16 or (acc_dtype == torch.float32 and not transpose_a and transpose_b)]) +@pytest.mark.parametrize("warps", ([8, 1], [4, 2], [4, 1])) +# Swizzling 0 does not map to a valid memory descriptor lol +@pytest.mark.parametrize("swizzling_a, swizzling_b", product([32, 64, 128], repeat=2)) +@pytest.mark.parametrize("shape_m, shape_n, shape_k", [(1, 1, 1), (2, 4, 1), (2, 2, 4)]) +def test_warpgroup_mma_shared_inputs(bitwidth, transpose_a, transpose_b, acc_dtype, warps, swizzling_a, swizzling_b, + shape_m, shape_n, shape_k): + + torch_dtype_map = { + 8: torch.float8_e4m3fn, + 16: torch.float16, + 32: torch.float32, + } + acc_dtype_map = { + torch.float16: ttgl.float16, + torch.float32: ttgl.float32, + } + + # We'll choose a larger instr shape along N, but sure + instr_shape_k_map = {8: 32, 16: 16, 32: 8} + instr_shape = [16, 32, instr_shape_k_map[bitwidth]] + M = instr_shape[0] * warps[0] + N = instr_shape[1] * warps[1] + K = instr_shape[2] + + def min_shape(swizzling, dim0, dim1, trans): + tile_cols = (8 * max(16, swizzling)) // bitwidth + outer_dim, contig_dim = (dim0, dim1) + if trans: + outer_dim, contig_dim = contig_dim, outer_dim + contig_dim = max(contig_dim, tile_cols) + outer_dim = max(outer_dim, 8) + if trans: + outer_dim, contig_dim = contig_dim, outer_dim + return outer_dim, contig_dim + + # Get the minimum shape for the given swizzling / transpose + M, K = min_shape(swizzling_a, M, K, transpose_a) + K, N = min_shape(swizzling_b, K, N, transpose_b) + M *= shape_m + N *= shape_n + K *= shape_k + instr_shape[1] *= shape_n + + shared_mem_accum = M * K * bitwidth // 8 + K * N * bitwidth // 8 + if triton.runtime.driver.active.utils.get_device_properties( + triton.runtime.driver.active.get_current_device())["max_shared_mem"] < shared_mem_accum: + pytest.skip("Skipped due to insufficient shared memory on this GPU.") + + torch_dtype = torch_dtype_map[bitwidth] + gl_acc_dtype = acc_dtype_map[acc_dtype] + out_dtype = torch.float32 + + block_layout = ttgl.BlockedLayout([1, 1], [1, THREADS_PER_WARP], warps_per_cta=warps, order=[1, 0]) + shared_layout_a = ttgl.NVMMASharedLayout(swizzle_byte_width=swizzling_a, element_bitwidth=bitwidth, rank=2, + transposed=transpose_a) + shared_layout_b = ttgl.NVMMASharedLayout(swizzle_byte_width=swizzling_b, element_bitwidth=bitwidth, rank=2, + transposed=transpose_b) + mma_layout = ttgl.NVMMADistributedLayout(version=[3, 0], warps_per_cta=warps, instr_shape=instr_shape) + + torch.manual_seed(0) + + def cast(x, dtype): + if dtype != torch.float32: + return x.to(torch_dtype) + else: + # zero-out the lower 13 bits + x = x.view(torch.int32) + x = x & ~((1 << 13) - 1) + return x.view(dtype) + + # Sample bf16 as tf32 does not use the full range + a = cast(torch.randn((M, K), device="cuda", dtype=torch.float32), torch_dtype) + b = cast(torch.randn((K, N), device="cuda", dtype=torch.float32), torch_dtype) + out = torch.zeros((M, N), device="cuda", dtype=out_dtype) + + warpgroup_mma_kernel[(1, )]( + a, + b, + out, + M, + N, + K, + block_layout, + mma_layout, + shared_layout_a, + shared_layout_b, + gl_acc_dtype, + False, + num_warps=warps[0] * warps[1], + ) + + try: + allow_tf32 = torch.backends.cuda.matmul.allow_tf32 + torch.backends.cuda.matmul.allow_tf32 = True + allow_fp16_red = torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction + torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = acc_dtype == torch.float16 + ref = torch.matmul(a.to(acc_dtype), b.to(acc_dtype)).to(out_dtype) + finally: + torch.backends.cuda.matmul.allow_tf32 = allow_tf32 + torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = allow_fp16_red + + if bitwidth == 8: + atol, rtol = 0.5, 0.5 + elif bitwidth == 16: + atol, rtol = 3e-2, 1e-1 + else: + atol, rtol = 5e-4, 5e-3 + torch.testing.assert_close(out, ref, atol=atol, rtol=rtol) + + @pytest.mark.xfail(not is_hip_cdna4(), reason="Requires CDNA4", run=False) @pytest.mark.parametrize("use_buffer_load", [True, False]) def test_amd_direct_load_to_shared(use_buffer_load): @@ -267,7 +403,7 @@ def kernel(a_ptr, b_ptr, c_ptr, # torch.testing.assert_close(ref, triton_output) -@pytest.mark.skipif(not (is_hip_cdna3() or is_hip_cdna4()), reason="Requires CDNA3 or CDNA4") +@pytest.mark.xfail(not (is_hip_cdna3() or is_hip_cdna4()), reason="Requires CDNA3 or CDNA4", run=False) @pytest.mark.parametrize("M, N, K", [(32, 32, 16), (16, 16, 32)]) @pytest.mark.parametrize("in_dtype", ['float16', 'bfloat16']) @pytest.mark.parametrize("num_warps", [4, 8]) @@ -560,8 +696,6 @@ def kernel(in_ptr, out_ptr, smem_h: ttgl.constexpr, smem_w: ttgl.constexpr, num_ tmem_alias: ttgl.constexpr = TensorMemoryLayout((num_rows, num_cols), col_stride=1) tmem = tmem._reinterpret(ttgl.int8, (num_rows, num_cols), tmem_alias) value = tmem.load(blocked) - ttgl.static_print(ttgl.to_linear_layout(blocked, (smem_h, smem_w))) - ttgl.static_print(ttgl.to_linear_layout(blocked, (num_rows, num_cols))) ttgl.store(ttgl.set_auto_layout(out_ptrs, blocked), value) torch.manual_seed(0) @@ -1042,7 +1176,7 @@ def kernel(a, BLOCK: ttgl.constexpr, SIZE_PER_THREAD: ttgl.constexpr): blocked: ttgl.constexpr = ttgl.BlockedLayout([SIZE_PER_THREAD], [64], [4], [0]) offsets = ttgl.arange(0, BLOCK, layout=blocked) val = ttgl.full([BLOCK], 1.0, ttgl.bfloat16, layout=blocked) - ttgl.amd.cdna4.buffer_atomic_rmw("fadd", a, offsets, val, mask=1, scope="cta", sem="relaxed") + ttgl.amd.cdna4.buffer_atomic_add(a, offsets, val, mask=1, scope="cta", sem="relaxed") a = torch.randn((BLOCK), dtype=elem_type, device="cuda") origin_a = a.clone() diff --git a/python/test/gluon/test_frontend.py b/python/test/gluon/test_frontend.py index 1066e475ce..71e5a1a370 100644 --- a/python/test/gluon/test_frontend.py +++ b/python/test/gluon/test_frontend.py @@ -1436,6 +1436,9 @@ def test_zeros(): # CHECK: arith.constant dense<7> : tensor<8x8xi16, [[BLOCKED2D]]> ttgl.full_like(a, 7, shape=[8, 8], dtype=ttgl.int16, layout=layout_2d) + # CHECK: arith.constant 0.000000e+00 : f32 + ttgl.zeros((), ttgl.float32, layout) + @filecheck_test @gluon.jit @@ -2504,33 +2507,33 @@ def kernel(int32_ptr, uint32_ptr, int64_ptr, fp16_ptr, fp32_ptr): offsets = ttgl.arange(0, BLOCK, layout=ttgl.AutoLayout()) val = ttgl.full([BLOCK], 1, ttgl.int32, layout=ttgl.AutoLayout()) - ttgl.amd.cdna3.buffer_atomic_rmw("smax", int32_ptr, offsets, val) - ttgl.amd.cdna3.buffer_atomic_rmw("smin", int32_ptr, offsets, val) - ttgl.amd.cdna3.buffer_atomic_rmw("and", int32_ptr, offsets, val) - ttgl.amd.cdna3.buffer_atomic_rmw("or", int32_ptr, offsets, val) + ttgl.amd.cdna3.buffer_atomic_max(int32_ptr, offsets, val) + ttgl.amd.cdna3.buffer_atomic_min(int32_ptr, offsets, val) + ttgl.amd.cdna3.buffer_atomic_and(int32_ptr, offsets, val) + ttgl.amd.cdna3.buffer_atomic_or(int32_ptr, offsets, val) #value broadcast - ttgl.amd.cdna3.buffer_atomic_rmw("xor", int32_ptr, offsets, value=1) + ttgl.amd.cdna3.buffer_atomic_xor(int32_ptr, offsets, value=1) # operands should be unsigned val = ttgl.full([BLOCK], 1, ttgl.uint32, layout=ttgl.AutoLayout()) - ttgl.amd.cdna3.buffer_atomic_rmw("umax", uint32_ptr, offsets, val) - ttgl.amd.cdna3.buffer_atomic_rmw("umin", uint32_ptr, offsets, val) - ttgl.amd.cdna3.buffer_atomic_rmw("iadd", uint32_ptr, offsets, val) + ttgl.amd.cdna3.buffer_atomic_max(uint32_ptr, offsets, val) + ttgl.amd.cdna3.buffer_atomic_min(uint32_ptr, offsets, val) + ttgl.amd.cdna3.buffer_atomic_add(uint32_ptr, offsets, val) val = val.cast(ttgl.int64) #mask broadcast - ttgl.amd.cdna3.buffer_atomic_rmw("xchg", int64_ptr, offsets, val, mask=0) + ttgl.amd.cdna3.buffer_atomic_xchg(int64_ptr, offsets, val, mask=0) mask = ttgl.full([BLOCK], True, ttgl.int32, layout=ttgl.AutoLayout()) val = ttgl.zeros([BLOCK], ttgl.float16, layout=ttgl.AutoLayout()) - ttgl.amd.cdna3.buffer_atomic_rmw("fadd", fp16_ptr, offsets, val, mask=mask) - ttgl.amd.cdna3.buffer_atomic_rmw("fadd", fp16_ptr, offsets, val, mask=mask, scope="sys") - ttgl.amd.cdna3.buffer_atomic_rmw("fadd", fp16_ptr, offsets, val, mask=mask, scope="cta", sem="relaxed") + ttgl.amd.cdna3.buffer_atomic_add(fp16_ptr, offsets, val, mask=mask) + ttgl.amd.cdna3.buffer_atomic_add(fp16_ptr, offsets, val, mask=mask, scope="sys") + ttgl.amd.cdna3.buffer_atomic_add(fp16_ptr, offsets, val, mask=mask, scope="cta", sem="relaxed") val = val.cast(ttgl.float32) - ttgl.amd.cdna3.buffer_atomic_rmw("fadd", fp32_ptr, offsets, val, mask=mask) - ttgl.amd.cdna3.buffer_atomic_rmw("fadd", fp32_ptr, offsets, val, mask=mask, scope="sys") - ttgl.amd.cdna3.buffer_atomic_rmw("fadd", fp32_ptr, offsets, val, mask=mask, scope="cta", sem="relaxed") + ttgl.amd.cdna3.buffer_atomic_add(fp32_ptr, offsets, val, mask=mask) + ttgl.amd.cdna3.buffer_atomic_add(fp32_ptr, offsets, val, mask=mask, scope="sys") + ttgl.amd.cdna3.buffer_atomic_add(fp32_ptr, offsets, val, mask=mask, scope="cta", sem="relaxed") fp16_ptr = MockTensor(ttgl.float16) fp32_ptr = MockTensor(ttgl.float32) @@ -2612,10 +2615,10 @@ def test_buffer_atomic_rmw_bf16(target): def kernel(bf16_ptr): offsets = ttgl.arange(0, 1, layout=ttgl.AutoLayout()) val = ttgl.zeros([1], ttgl.bfloat16, layout=ttgl.AutoLayout()) - ttgl.amd.cdna4.buffer_atomic_rmw("fadd", bf16_ptr, offsets, val, mask=0) + ttgl.amd.cdna4.buffer_atomic_add(bf16_ptr, offsets, val, mask=0) mask = ttgl.full([1], True, ttgl.int32, layout=ttgl.AutoLayout()) - ttgl.amd.cdna4.buffer_atomic_rmw("fadd", bf16_ptr, offsets, val, mask=mask, scope="sys") - ttgl.amd.cdna4.buffer_atomic_rmw("fadd", bf16_ptr, offsets, val, mask=mask, scope="cta", sem="relaxed") + ttgl.amd.cdna4.buffer_atomic_add(bf16_ptr, offsets, val, mask=mask, scope="sys") + ttgl.amd.cdna4.buffer_atomic_add(bf16_ptr, offsets, val, mask=mask, scope="cta", sem="relaxed") bf16_ptr = MockTensor(ttgl.bfloat16) module = run_parser(kernel, *make_args(bf16_ptr), target=target) @@ -2727,3 +2730,48 @@ def kernel(): run_parser(kernel) assert "For step must be a scalar, got" in str(e.value) + + +@gluon.jit +def amd_tdm_kernel(ptr): + SHARED_LAYOUT: ttgl.constexpr = ttgl.PaddedSharedLayout.with_identity_for([[32, 4]], [16, 64], [1, 0]) + BLOCKED_LAYOUT: ttgl.constexpr = ttgl.BlockedLayout([1, 8], [4, 8], [4, 1], [1, 0]) + + desc = ttgl.amd.gfx1250.tdm.make_tensor_descriptor(base=ptr, shape=(32, 128), strides=(128, 1), + block_shape=(16, 64), layout=SHARED_LAYOUT) + + buffer = ttgl.allocate_shared_memory(desc.dtype, shape=desc.block_shape, layout=desc.layout) + ttgl.amd.gfx1250.tdm.async_load(desc, offsets=[0, 2], dest=buffer) + + ttgl.amd.gfx1250.tdm.async_wait(0) + buffer.load(layout=BLOCKED_LAYOUT) + + +@pytest.mark.parametrize("target", [HIP_TARGET_GFX1250]) +def test_amd_tdm(target): + + ptr = MockTensor(ttgl.float16) + module = run_parser(amd_tdm_kernel, *make_args(ptr), target) + expecttest.assert_expected_inline( + anonymize_ir(module.str_nodebug()), """\ +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#shared = #ttg.padded_shared<[32:+4] {order = [1, 0], shape = [16, 64]}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "...", "ttg.threads-per-warp" = 32 : i32} { + tt.func public @amd_tdm_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %c32_i32 = arith.constant 32 : i32 + %c128_i32 = arith.constant 128 : i32 + %c128_i64 = arith.constant 128 : i64 + %c1_i64 = arith.constant 1 : i64 + %0 = tt.make_tensor_descriptor %arg0, [%c32_i32, %c128_i32], [%c128_i64, %c1_i64] : , > + %1 = ttg.local_alloc : () -> !ttg.memdesc<16x64xf16, #shared, #smem, mutable> + %c0_i32 = arith.constant 0 : i32 + %c2_i32 = arith.constant 2 : i32 + %true = arith.constant true + %2 = amdgpu.async_tdm_copy_global_to_local %0[%c0_i32, %c2_i32] into %1, %true : !tt.tensordesc> -> !ttg.memdesc<16x64xf16, #shared, #smem, mutable> + %3 = amdgpu.async_tdm_wait {num = 0 : i32} + %4 = ttg.local_load %1 : !ttg.memdesc<16x64xf16, #shared, #smem, mutable> -> tensor<16x64xf16, #blocked> + tt.return + } +} +""") diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index c9ed274081..30904cbe61 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -1659,6 +1659,8 @@ def kernel(X, SHAPE0: tl.constexpr, SHAPE1: tl.constexpr): @pytest.mark.parametrize("num_ctas", num_ctas_list) @pytest.mark.parametrize("dtype_str", ["int32", "int64"]) def test_atomic_cas(sem, num_ctas, dtype_str, device): + if is_hip_cdna2(): + pytest.skip("Disabled due to being flaky on CDNA2") # 1. make sure that atomic_cas changes the original value (Lock) @triton.jit def change_value(Lock, triton_dtype: tl.constexpr): diff --git a/python/test/unit/test_debuginfo.py b/python/test/unit/test_debuginfo.py index dc93dfb235..bcd355a55b 100644 --- a/python/test/unit/test_debuginfo.py +++ b/python/test/unit/test_debuginfo.py @@ -1,40 +1,65 @@ import os -import subprocess -all_names = ["offsets", "pid", "block_start", "mask", "x", "y", "output"] +import pytest +import torch + +import triton +import triton.language as tl + + +@triton.jit +def add_kernel( + x_ptr, + y_ptr, + output_ptr, + n_elements, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(x_ptr + offsets, mask=mask) + y = tl.load(y_ptr + offsets, mask=mask) + output = x + y + tl.store(output_ptr + offsets, output, mask=mask) def checkDbgInfo(llir, hasDbgInfo): assert hasDbgInfo == ('dbg_value' in llir) - for name in all_names: + for name in ["offsets", "pid", "block_start", "mask", "x", "y", "output"]: assert hasDbgInfo == ('!DILocalVariable(name: \"' + name + '\"' in llir) -def test_triton_debuginfo_on(): - lineInfoKey = "TRITON_DISABLE_LINE_INFO" - diLocalVarKey = "LLVM_EXTRACT_DI_LOCAL_VARIABLES" +@pytest.mark.parametrize("lineInfoKey, diLocalVarKey, hasDbgInfo", [ + (None, None, False), + # expect dbginfo based on parent proccess' TRITON_DISABLE_LINE_INFO + (None, "1", "infer"), + ("0", "1", True), + ("1", "1", False), + ("0", "0", False), + ("1", "0", False), +]) +def test_triton_debuginfo_on(lineInfoKey, diLocalVarKey, hasDbgInfo, device, monkeypatch): + lineInfoKeyName = "TRITON_DISABLE_LINE_INFO" + diLocalVarKeyName = "LLVM_EXTRACT_DI_LOCAL_VARIABLES" + if lineInfoKey is not None: + monkeypatch.setenv(lineInfoKeyName, lineInfoKey) + if diLocalVarKey is not None: + monkeypatch.setenv(diLocalVarKeyName, diLocalVarKey) isEnvSet = lambda env, str: env.get(str, None) is not None - hasOrigLineInfo = (not isEnvSet(os.environ, lineInfoKey) - or os.environ[lineInfoKey].lower() not in ["on", "true", "1"]) - envs = [ - # expect no dbginfo if unset - {lineInfoKey: None, diLocalVarKey: None, "hasDbgInfo": False}, - # expect dbginfo based on parent proccess' TRITON_DISABLE_LINE_INFO - {lineInfoKey: None, diLocalVarKey: "1", "hasDbgInfo": hasOrigLineInfo}, - {lineInfoKey: "0", diLocalVarKey: "1", "hasDbgInfo": True}, - {lineInfoKey: "1", diLocalVarKey: "1", "hasDbgInfo": False}, - {lineInfoKey: "0", diLocalVarKey: "0", "hasDbgInfo": False}, - {lineInfoKey: "1", diLocalVarKey: "0", "hasDbgInfo": False}, - ] - - _run_test = lambda test_env: subprocess.run([ - "python3", os.path.dirname(os.path.realpath(__file__)) + "/test_debuginfo_helper.py" - ], env=test_env, capture_output=True, text=True) - for env in envs: - test_env = os.environ.copy() - test_env["TRITON_ALWAYS_COMPILE"] = "1" - for entry in env: - if not isEnvSet(env, entry): continue - test_env[entry] = str(env[entry]) - checkDbgInfo(str(_run_test(test_env).stdout), hasDbgInfo=env["hasDbgInfo"]) + if hasDbgInfo == "infer": + hasDbgInfo = (not isEnvSet(os.environ, lineInfoKeyName) + or os.environ[lineInfoKeyName].lower() not in ["on", "true", "1"]) + + size = 98432 + torch.manual_seed(0) + x = torch.rand(size, device=device) + y = torch.rand(size, device=device) + output = torch.empty_like(x) + n_elements = output.numel() + grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), ) + add_kernel.device_caches.clear() + h = add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024) + checkDbgInfo(h.asm['llir'], hasDbgInfo) diff --git a/python/test/unit/test_debuginfo_helper.py b/python/test/unit/test_debuginfo_helper.py deleted file mode 100644 index b96e703236..0000000000 --- a/python/test/unit/test_debuginfo_helper.py +++ /dev/null @@ -1,37 +0,0 @@ -import torch - -import triton -import triton.language as tl - -DEVICE = triton.runtime.driver.active.get_active_torch_device() - - -@triton.jit -def add_kernel( - x_ptr, - y_ptr, - output_ptr, - n_elements, - BLOCK_SIZE: tl.constexpr, -): - pid = tl.program_id(axis=0) - block_start = pid * BLOCK_SIZE - offsets = block_start + tl.arange(0, BLOCK_SIZE) - mask = offsets < n_elements - x = tl.load(x_ptr + offsets, mask=mask) - y = tl.load(y_ptr + offsets, mask=mask) - output = x + y - tl.store(output_ptr + offsets, output, mask=mask) - - -size = 98432 -torch.manual_seed(0) -x = torch.rand(size, device=DEVICE) -y = torch.rand(size, device=DEVICE) -all_names = ["offsets", "pid", "block_start", "mask", "x", "y", "output"] -output = torch.empty_like(x) -n_elements = output.numel() -grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), ) -h = add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024) -llir = h.asm['llir'] -print(llir) diff --git a/python/triton/experimental/gluon/language/_core.py b/python/triton/experimental/gluon/language/_core.py index de0e9f896a..e2b986b643 100644 --- a/python/triton/experimental/gluon/language/_core.py +++ b/python/triton/experimental/gluon/language/_core.py @@ -509,10 +509,6 @@ def warp_specialize(default_args, default_partition, worker_args, worker_partiti """ worker_num_warps = [_unwrap_if_constexpr(w) for w in worker_num_warps] worker_num_regs = [_unwrap_if_constexpr(r) for r in worker_num_regs] - if not isinstance(default_args, tuple): - default_args = (default_args, ) - if not isinstance(worker_args, tuple): - worker_args = (worker_args, ) return _semantic.warp_specialize(default_args, default_partition, worker_args, worker_partitions, worker_num_warps, worker_num_regs, _generator) diff --git a/python/triton/experimental/gluon/language/_semantic.py b/python/triton/experimental/gluon/language/_semantic.py index 3dc8a2ff56..a50c5ad5d0 100644 --- a/python/triton/experimental/gluon/language/_semantic.py +++ b/python/triton/experimental/gluon/language/_semantic.py @@ -156,6 +156,8 @@ def reshape(self, input: TensorTy, dst_shape: List[int], can_reorder: bool): return self._wrap_tensor_infer_layout(value) def splat(self, value, shape, layout): + if len(shape) == 0: + return value ret_ty = ttgl.distributed_type(value.dtype, shape, layout) handle = self.builder.create_splat(ret_ty.to_ir(self.builder), value.handle) return ttgl.tensor(handle, ret_ty) @@ -418,6 +420,10 @@ def gather(self, src: TensorTy, index: TensorTy, axis: int) -> TensorTy: def warp_specialize(self, default_args, default_partition, worker_args, worker_partitions, worker_num_warps: Sequence[int], worker_num_regs: Sequence[int], generator): num_partitions = len(worker_partitions) + _check(isinstance(default_args, (tuple, ttgl.tuple)), + lambda: f"default_args must be a tuple of arguments, but got {type(default_args)}") + _check(isinstance(worker_args, (tuple, ttgl.tuple)), + lambda: f"worker_args must be a tuple of arguments, but got {type(worker_args)}") assert num_partitions == len( worker_num_warps ), f"warp specialize got {num_partitions} partitions but {len(worker_num_warps)} warp counts" diff --git a/python/triton/experimental/gluon/language/amd/cdna3/__init__.py b/python/triton/experimental/gluon/language/amd/cdna3/__init__.py index b514480288..7d88a62b84 100644 --- a/python/triton/experimental/gluon/language/amd/cdna3/__init__.py +++ b/python/triton/experimental/gluon/language/amd/cdna3/__init__.py @@ -9,7 +9,10 @@ if TYPE_CHECKING: from ..._semantic import GluonSemantic -__all__ = ["buffer_atomic_rmw", "buffer_load", "buffer_store", "mfma"] +__all__ = [ + "buffer_atomic_add", "buffer_atomic_and", "buffer_atomic_min", "buffer_atomic_max", "buffer_atomic_or", + "buffer_atomic_xor", "buffer_atomic_xor", "buffer_load", "buffer_store", "mfma" +] _atomic_op_str_to_op = { "smax": ir.ATOMIC_OP.MAX, "smin": ir.ATOMIC_OP.MIN, "umax": ir.ATOMIC_OP.UMAX, "umin": ir.ATOMIC_OP.UMIN, "fadd": @@ -28,33 +31,48 @@ def _verify_buffer_ops(ptr, offsets, mask=None, other=None): assert mask is not None, "when other is not None, mask should not be None" -def _verify_element_type_in_buffer_atomic(op, elem_type, arch): +def _verify_element_type_and_dispatch_op(op, elem_type, arch): supported_types = [ ttgl.float16, ttgl.float32, ttgl.bfloat16, ttgl.float64, ttgl.int32, ttgl.int64, ttgl.uint32, ttgl.uint64 ] assert elem_type in supported_types, f"{elem_type} is not supported in buffer atomic on {arch}." - op = _atomic_op_str_to_op[_unwrap_if_constexpr(op)] - if op in [ir.ATOMIC_OP.AND, ir.ATOMIC_OP.OR, ir.ATOMIC_OP.XOR, ir.ATOMIC_OP.XCHG]: - assert elem_type in [ttgl.int32, ttgl.int64], f"{op} with {elem_type} is not supported on CDNA3 and CDNA4" - - if op in [ir.ATOMIC_OP.UMAX, ir.ATOMIC_OP.UMIN, ir.ATOMIC_OP.ADD]: - assert elem_type in [ttgl.uint32, ttgl.uint64], f"{op} with {elem_type} is not supported on CDNA3 and CDNA4" - - if op in [ir.ATOMIC_OP.MAX, ir.ATOMIC_OP.MIN]: - assert elem_type in [ttgl.int32, ttgl.int64, - ttgl.float64], "smax only support int32, int64 and fp64 on CDNA3 and CDNA4" - - if op == ir.ATOMIC_OP.FADD: - if elem_type is ttgl.bfloat16: + if op in ['and', 'or', 'xor', 'xchg']: + assert elem_type in [ttgl.int32, ttgl.int64], f"{op} with {elem_type} is not supported on CDNA3 or CDNA4" + return _atomic_op_str_to_op[_unwrap_if_constexpr(op)] + + if op in ['max', 'min']: + if elem_type in [ttgl.int32, ttgl.int64, ttgl.float64]: + op = 's' + op + return _atomic_op_str_to_op[_unwrap_if_constexpr(op)] + elif elem_type in [ttgl.uint32, ttgl.uint64]: + op = 'u' + op + return _atomic_op_str_to_op[_unwrap_if_constexpr(op)] + else: + raise ValueError(f"{op} with {elem_type} is not supported on CDNA3 and CDNA4") + + if op == 'add': + if elem_type in [ttgl.uint32, ttgl.uint64]: + op = 'i' + op + return _atomic_op_str_to_op[_unwrap_if_constexpr(op)] + elif elem_type in [ttgl.float16, ttgl.float32, ttgl.float64]: + op = 'f' + op + return _atomic_op_str_to_op[_unwrap_if_constexpr(op)] + elif elem_type is ttgl.bfloat16: assert arch == "cdna4", "Buffer atomic fadd with bf16 is only supported on CDNA4 for now." + op = 'f' + op + return _atomic_op_str_to_op[_unwrap_if_constexpr(op)] else: - assert elem_type in [ttgl.float16, ttgl.float32, ttgl.float64] + raise ValueError(f"{op} with {elem_type} is not supported on CDNA3 and CDNA4") + + raise ValueError(f"Unknown {op} on CDNA3 or CDNA4") def _buffer_atomic_rmw_impl(op, ptr, offsets, value, arch, mask, sem, scope, _semantic): _verify_buffer_ops(ptr, offsets, mask) + op = _verify_element_type_and_dispatch_op(op, ptr.type.scalar.element_ty, arch) + mask = _unwrap_if_constexpr(mask) if mask is not None: mask = _semantic.to_tensor(mask) @@ -66,11 +84,8 @@ def _buffer_atomic_rmw_impl(op, ptr, offsets, value, arch, mask, sem, scope, _se value = _semantic.to_tensor(value) _, value = _semantic.broadcast_impl_value(offsets, value) - _verify_element_type_in_buffer_atomic(op, value.dtype, arch) - sem = _semantic._str_to_sem(sem) scope = _semantic._str_to_scope(scope) - op = _atomic_op_str_to_op[_unwrap_if_constexpr(op)] return _semantic.tensor( _semantic.builder.create_buffer_atomic_rmw(op, ptr.handle, offsets.handle, value.handle, sem, scope, mask), value.type) @@ -153,28 +168,71 @@ def mfma(a, b, acc, _semantic: GluonSemantic = None): return ttgl.tensor(handle, ret_type) +""" +AMD Buffer Atomic RMW operations. +The supported operatios are max, min, add, and, or, xor, xchg. +Similar to normal atomic ops: it loads data at ptr plus offsets, do `op` with `value`, and store result to `ptr` plus `offsets` with +the specified memory semantics and scope. + +Buffer atomics access global memory via a scalar base pointer and a tensor of offsets instead of a tensor of pointers. +Similar to other buffer ops, the `mask` is a boolean vector that determines if a given element should be processed with +the atomic RMW op. Elements with `mask[i] == 0` are dropped (i.e., the atomic is not executed). + +Buffer Atomic RMW ops return the pre-op value in the global memory. + +Args: + ptr (pointer to scalar): Global memory scalar base pointer to load from. + offsets (tensor): Offsets tensor for the load operation. + value (tensor): Another operand of `op`. + mask (tensor, optional): Mask tensor for predicated loads. Defaults to None. + sem (str, optional): Memory Semantic Descriptor. Default is None which means acq_rel memory semantic. + scope (str, optional): Memory Sync Scope for atomic accesses. Default is None and it will be mapped to `gpu`, which is called `agent` for AMDGPU. Please ref https://llvm.org/docs/AMDGPUUsage.html#memory-model-gfx942 for details. +""" + + @builtin -def buffer_atomic_rmw(op, ptr, offsets, value, mask=None, sem=None, scope=None, _semantic=None): - """ - AMD Buffer Atomic RMW operation. - Similar to normal atomic ops: it loads data at ptr plus offsets, do `op` with `value`, and store result to `ptr` plus `offsets` with - the specified memory semantics and scope. +def buffer_atomic_max(ptr, offsets, value, mask=None, sem=None, scope=None, _semantic=None): + return _buffer_atomic_rmw_impl('max', ptr, offsets, value, "cdna3", mask=mask, sem=sem, scope=scope, + _semantic=_semantic) - Buffer atomics access global memory via a scalar base pointer and a tensor of offsets instead of a tensor of pointers. - Similar to other buffer ops, the `mask` is a boolean vector that determines if a given element should be processed with - the atomic RMW op. Elements with `mask[i] == 0` are dropped (i.e., the atomic is not executed). - Buffer Atomic RMW ops return the pre-op value in the global memory. +@builtin +def buffer_atomic_min(ptr, offsets, value, mask=None, sem=None, scope=None, _semantic=None): - Args: - op (str) : The operator to be executed atomically. - ptr (pointer to scalar): Global memory scalar base pointer to load from. - offsets (tensor): Offsets tensor for the load operation. - value (tensor): Another operand of `op`. - mask (tensor, optional): Mask tensor for predicated loads. Defaults to None. - sem (str, optional): Memory Semantic Descriptor. Default is None which means acq_rel memory semantic. - scope (str, optional): Memory Sync Scope for atomic accesses. Default is None and it will be mapped to `gpu`, which is called `agent` for AMDGPU. Please ref https://llvm.org/docs/AMDGPUUsage.html#memory-model-gfx942 for details. - """ + return _buffer_atomic_rmw_impl('min', ptr, offsets, value, "cdna3", mask=mask, sem=sem, scope=scope, + _semantic=_semantic) + + +@builtin +def buffer_atomic_add(ptr, offsets, value, mask=None, sem=None, scope=None, _semantic=None): + + return _buffer_atomic_rmw_impl('add', ptr, offsets, value, "cdna3", mask=mask, sem=sem, scope=scope, + _semantic=_semantic) + + +@builtin +def buffer_atomic_and(ptr, offsets, value, mask=None, sem=None, scope=None, _semantic=None): + + return _buffer_atomic_rmw_impl('and', ptr, offsets, value, "cdna3", mask=mask, sem=sem, scope=scope, + _semantic=_semantic) + + +@builtin +def buffer_atomic_or(ptr, offsets, value, mask=None, sem=None, scope=None, _semantic=None): + + return _buffer_atomic_rmw_impl('or', ptr, offsets, value, "cdna3", mask=mask, sem=sem, scope=scope, + _semantic=_semantic) + + +@builtin +def buffer_atomic_xor(ptr, offsets, value, mask=None, sem=None, scope=None, _semantic=None): + + return _buffer_atomic_rmw_impl('xor', ptr, offsets, value, "cdna3", mask=mask, sem=sem, scope=scope, + _semantic=_semantic) + + +@builtin +def buffer_atomic_xchg(ptr, offsets, value, mask=None, sem=None, scope=None, _semantic=None): - return _buffer_atomic_rmw_impl(op, ptr, offsets, value, "cdna3", mask=mask, sem=sem, scope=scope, + return _buffer_atomic_rmw_impl('xchg', ptr, offsets, value, "cdna3", mask=mask, sem=sem, scope=scope, _semantic=_semantic) diff --git a/python/triton/experimental/gluon/language/amd/cdna4/__init__.py b/python/triton/experimental/gluon/language/amd/cdna4/__init__.py index edeb3506f5..b284b8a273 100644 --- a/python/triton/experimental/gluon/language/amd/cdna4/__init__.py +++ b/python/triton/experimental/gluon/language/amd/cdna4/__init__.py @@ -51,11 +51,55 @@ def mfma_scaled(a, a_scale, a_format, b, b_scale, b_format, acc, _semantic=None) return ttgl.tensor(tensor.handle, ret_ty) +""" +buffer_atomic_rmw of cnda4 shares the same signature and functionalities as cdna3.buffer_atomic_rmw. +The cdna4 version additionally supports `fadd` with `bf16`. +""" + + @builtin -def buffer_atomic_rmw(op, ptr, offsets, value, mask=None, sem=None, scope=None, _semantic=None): - """ - buffer_atomic_rmw of cnda4 shares the same signature and functionalities as cdna3.buffer_atomic_rmw. - The cdna4 version additionally supports `fadd` with `bf16`. - """ - return _buffer_atomic_rmw_impl(op, ptr, offsets, value, "cdna4", mask=mask, sem=sem, scope=scope, +def buffer_atomic_max(ptr, offsets, value, mask=None, sem=None, scope=None, _semantic=None): + return _buffer_atomic_rmw_impl('max', ptr, offsets, value, "cdna4", mask=mask, sem=sem, scope=scope, + _semantic=_semantic) + + +@builtin +def buffer_atomic_min(ptr, offsets, value, mask=None, sem=None, scope=None, _semantic=None): + + return _buffer_atomic_rmw_impl('min', ptr, offsets, value, "cdna4", mask=mask, sem=sem, scope=scope, + _semantic=_semantic) + + +@builtin +def buffer_atomic_add(ptr, offsets, value, mask=None, sem=None, scope=None, _semantic=None): + + return _buffer_atomic_rmw_impl('add', ptr, offsets, value, "cdna4", mask=mask, sem=sem, scope=scope, + _semantic=_semantic) + + +@builtin +def buffer_atomic_and(ptr, offsets, value, mask=None, sem=None, scope=None, _semantic=None): + + return _buffer_atomic_rmw_impl('and', ptr, offsets, value, "cdna4", mask=mask, sem=sem, scope=scope, + _semantic=_semantic) + + +@builtin +def buffer_atomic_or(ptr, offsets, value, mask=None, sem=None, scope=None, _semantic=None): + + return _buffer_atomic_rmw_impl('or', ptr, offsets, value, "cdna4", mask=mask, sem=sem, scope=scope, + _semantic=_semantic) + + +@builtin +def buffer_atomic_xor(ptr, offsets, value, mask=None, sem=None, scope=None, _semantic=None): + + return _buffer_atomic_rmw_impl('xor', ptr, offsets, value, "cdna4", mask=mask, sem=sem, scope=scope, + _semantic=_semantic) + + +@builtin +def buffer_atomic_xchg(ptr, offsets, value, mask=None, sem=None, scope=None, _semantic=None): + + return _buffer_atomic_rmw_impl('xchg', ptr, offsets, value, "cdna4", mask=mask, sem=sem, scope=scope, _semantic=_semantic) diff --git a/python/triton/experimental/gluon/language/amd/gfx1250/__init__.py b/python/triton/experimental/gluon/language/amd/gfx1250/__init__.py index dc604bccec..d13c5fecac 100644 --- a/python/triton/experimental/gluon/language/amd/gfx1250/__init__.py +++ b/python/triton/experimental/gluon/language/amd/gfx1250/__init__.py @@ -4,8 +4,9 @@ from triton.experimental.gluon.language._semantic import _check from ..._layouts import DotOperandLayout from .._layouts import AMDWMMALayout +from . import tdm -__all__ = ["wmma", "wmma_scaled"] +__all__ = ["tdm", "wmma", "wmma_scaled"] @builtin @@ -58,8 +59,8 @@ def wmma_scaled(a, a_scale, a_format, b, b_scale, b_format, acc, _semantic=None) "accumulator tensor's layout must be (16, 16, 128)" # TODO: Add more formats - assert a_format.value in {"e2m1"}, f"Unsupported lhs_format: {a_format.value}" - assert b_format.value in {"e2m1"}, f"Unsupported rhs_format: {b_format.value}" + assert a_format.value in {"e2m1", "e4m3", "e5m2"}, f"Unsupported lhs_format: {a_format.value}" + assert b_format.value in {"e2m1", "e4m3", "e5m2"}, f"Unsupported rhs_format: {b_format.value}" assert a_scale is not None and b_scale is not None, "Scales must not be None" diff --git a/python/triton/experimental/gluon/language/amd/gfx1250/tdm.py b/python/triton/experimental/gluon/language/amd/gfx1250/tdm.py new file mode 100644 index 0000000000..e50d1e25a3 --- /dev/null +++ b/python/triton/experimental/gluon/language/amd/gfx1250/tdm.py @@ -0,0 +1,148 @@ +from __future__ import annotations +from typing import List, Tuple, TYPE_CHECKING +from dataclasses import dataclass + +import triton.experimental.gluon.language._core as ttgl +from triton.experimental.gluon.language._layouts import PaddedSharedLayout +from triton.experimental.gluon.language._core import builtin, _unwrap_if_constexpr + +if TYPE_CHECKING: + from triton._C import ir + from triton.experimental.gluon.language._core import shared_memory_descriptor + +__all__ = ["async_load", "async_wait", "make_tensor_descriptor", "tensor_descriptor", "tensor_descriptor_type"] + + +@dataclass(eq=True) +class tensor_descriptor_type(ttgl.base_type): + """The type for a tensor descriptor.""" + + block_type: ttgl.block_type + shape_type: ttgl.tuple_type + strides_type: ttgl.tuple_type + layout: PaddedSharedLayout + + def __str__(self) -> str: + return f"tensor_descriptor<{self.block_type}, {self.layout}>" + + def _unflatten_ir(self, handles: List[ir.value], cursor: int) -> Tuple[tensor_descriptor, int]: + handle = handles[cursor] + cursor += 1 + shape, cursor = self.shape_type._unflatten_ir(handles, cursor) + strides, cursor = self.strides_type._unflatten_ir(handles, cursor) + value = tensor_descriptor(handle, shape, strides, self) + return value, cursor + + def _to_ir(self, builder: ir.builder) -> ir.type: + is_signed = self.block_type.element_ty.is_int_signed() + return builder.get_tensor_descriptor_layout_type( + self.block_type.to_ir(builder), + is_signed, + self.layout._to_ir(builder), + ) + + def _flatten_ir_types(self, builder: ir.builder, out: List[ir.type]) -> None: + out.append(self._to_ir(builder)) + self.shape_type._flatten_ir_types(builder, out) + self.strides_type._flatten_ir_types(builder, out) + + def mangle(self) -> str: + return f"TD{self.block_type.mangle()}_{self.shape_type.mangle()}_{self.strides_type.mangle()}_{self.layout.mangle()}TD" + + +@dataclass +class tensor_descriptor(ttgl.base_value): + """A descriptor representing a tensor in global memory.""" + + handle: ir.value + shape: ttgl.tuple + strides: ttgl.tuple + type: tensor_descriptor_type + + def _flatten_ir(self, handles: List[ir.value]) -> None: + handles.append(self.handle) + self.shape._flatten_ir(handles) + self.strides._flatten_ir(handles) + + @property + def block_type(self): + return self.type.block_type + + @property + def block_shape(self): + return self.type.block_type.shape + + @property + def dtype(self): + return self.type.block_type.element_ty + + @property + def layout(self): + return self.type.layout + + +@builtin +def make_tensor_descriptor(base: ttgl.tensor, shape: List[ttgl.constexpr | ttgl.tensor], + strides: List[ttgl.constexpr | ttgl.tensor], block_shape: List[ttgl.constexpr], + layout: PaddedSharedLayout, _semantic=None) -> tensor_descriptor: + """Make a tensor descriptor object. + + Args: + base (tensor): base pointer of the tensor in global memory. + shape (List[int]): shape of the tensor. + strides (List[int]): strides of the tensor. + block_shape (List[int]): block shape of the tensor. + layout (PaddedSharedLayout): the layout of the tensor in shared memory. + + Returns: + tensor_descriptor: the created tensor descriptor object + """ + ndim = len(shape) + # TODO: support 1D-5D tensor descriptors + assert ndim == 2, f"Expected 2 dimensions but got {ndim} dimensions" + assert len(strides) == ndim, f"Expected {ndim} strides but got {len(strides)}" + assert len(block_shape) == ndim, f"Expected block_shape to have {ndim} dimensions but got {len(strides)}" + assert isinstance(base.dtype, ttgl.pointer_type), "Expected base to be a pointer" + + layout = _unwrap_if_constexpr(layout) + assert isinstance(layout, PaddedSharedLayout), "Expected layout to be a PaddedSharedLayout" + + base_handle = base.handle + shape_handles = _semantic._convert_to_ir_values(shape, require_i64=False) # i32 shape + stride_handles = _semantic._convert_to_ir_values(strides, require_i64=True) # i64 stride + + shape = ttgl.tuple(shape) + strides = ttgl.tuple(strides) + block_type = ttgl.block_type(base.type.element_ty, block_shape) + type = tensor_descriptor_type(block_type, shape.type, strides.type, layout) + + padding = _semantic._str_to_padding_option("zero") + handle = _semantic.builder.create_make_tensor_descriptor(type._to_ir(_semantic.builder), base_handle, shape_handles, + stride_handles, padding) + + return tensor_descriptor(handle, shape, strides, type) + + +@builtin +def async_load(src: tensor_descriptor, offsets: List[ttgl.constexpr | ttgl.tensor], dest: shared_memory_descriptor, + _semantic=None) -> None: + """Load a block of tensor specified in tensor descriptor from global memory to shared memory asynchronously. + + Args: + src (tensor_descriptor): the source tensor descriptor. + offsets (List[int]): the offsets from the base pointer in the tensor descriptor. + dest (shared_memory_descriptor): the shared memory destination to store the loaded data. + """ + offset_handles = _semantic._convert_to_ir_values(offsets, require_i64=False) + _semantic.builder.create_async_tdm_copy_global_to_local(src.handle, offset_handles, dest.handle) + + +@builtin +def async_wait(num_outstanding=0, _semantic=None) -> None: + """Wait for the outstanding asynchronous tensor operations to complete. + + Args: + num_outstanding (int): number of outstanding async tensor operations to wait for. + """ + num_outstanding = _unwrap_if_constexpr(num_outstanding) + _semantic.builder.create_async_tdm_wait(num_outstanding) diff --git a/python/triton/knobs.py b/python/triton/knobs.py index 0ae11f51b2..83c914c291 100644 --- a/python/triton/knobs.py +++ b/python/triton/knobs.py @@ -411,7 +411,9 @@ class compilation_knobs(base_knobs): disable_line_info: env_bool = env_bool("TRITON_DISABLE_LINE_INFO") front_end_debugging: env_bool = env_bool("TRITON_FRONT_END_DEBUGGING") allow_non_constexpr_globals: env_bool = env_bool("TRITON_ALLOW_NON_CONSTEXPR_GLOBALS") - enable_experimental_consan: env_bool = env_bool("TRITON_ENABLE_EXPERIMENTAL_CONSAN") + # Instrumentation mode is checked on every run, which is expensive. + # We cache the value here to avoid the expensive check on every run. + instrumentation_mode: str = env_str("TRITON_INSTRUMENTATION_MODE", "").get() listener: Union[CompilationListener, None] = None @@ -579,6 +581,7 @@ class amd_knobs(base_knobs): class proton_knobs(base_knobs): + disable: env_bool = env_bool("TRITON_PROTON_DISABLE", False) cupti_lib_dir: env_str = env_str( "TRITON_CUPTI_LIB_PATH", str(pathlib.Path(__file__).parent.absolute() / "backends" / "nvidia" / "lib" / "cupti")) @@ -600,3 +603,4 @@ class proton_knobs(base_knobs): def refresh_knobs(): runtime.debug = env_bool("TRITON_DEBUG").get() + compilation.instrumentation_mode = env_str("TRITON_INSTRUMENTATION_MODE", "").get() diff --git a/python/triton/runtime/build.py b/python/triton/runtime/build.py index fe468ea75d..ae286bfbac 100644 --- a/python/triton/runtime/build.py +++ b/python/triton/runtime/build.py @@ -63,11 +63,7 @@ def _build(name: str, src: str, srcdir: str, library_dirs: list[str], include_di if cc is None: raise RuntimeError( "Failed to find C compiler. Please specify via CC environment variable or set triton.knobs.build.impl.") - # This function was renamed and made public in Python 3.10 - if hasattr(sysconfig, 'get_default_scheme'): - scheme = sysconfig.get_default_scheme() - else: - scheme = sysconfig._get_default_scheme() # type: ignore + scheme = sysconfig.get_default_scheme() # 'posix_local' is a custom scheme on Debian. However, starting Python 3.10, the default install # path changes to include 'local'. This change is required to use triton with system-wide python. if scheme == 'posix_local': diff --git a/python/triton/runtime/jit.py b/python/triton/runtime/jit.py index 3efa9eea73..a0b5d43c69 100644 --- a/python/triton/runtime/jit.py +++ b/python/triton/runtime/jit.py @@ -650,6 +650,7 @@ def _pack_args(self, backend, kwargs, bound_args, specialization, options): def run(self, *args, grid, warmup, **kwargs): kwargs["debug"] = kwargs.get("debug", self.debug) or knobs.runtime.debug + kwargs["instrumentation_mode"] = knobs.compilation.instrumentation_mode # parse options device = driver.active.get_current_device() diff --git a/python/triton_kernels/tests/test_matmul.py b/python/triton_kernels/tests/test_matmul.py index bc8cefdbf4..001627ba15 100644 --- a/python/triton_kernels/tests/test_matmul.py +++ b/python/triton_kernels/tests/test_matmul.py @@ -114,7 +114,7 @@ def _apply_padding_and_fill_unused_part_with_nan(t, is_padded): # --------------- -def init_precision(out_dtype, act_use_flexpoint, weight_dtype, weight_mxfp, n_expts_tot=1, expt_is_inner=False, device="cuda"): +def init_precision(out_dtype, act_use_flexpoint, weight_dtype, weight_mxfp, mode, n_expts_tot=1, expt_is_inner=False, device="cuda"): weight_use_flexpoint = weight_dtype.itemsize == 1 and not weight_mxfp # flexpoint make_tensor = lambda val0, val1: torch.tensor([val0, val1] * (n_expts_tot // 2) + @@ -133,8 +133,8 @@ def init_precision(out_dtype, act_use_flexpoint, weight_dtype, weight_mxfp, n_ex ) if weight_use_flexpoint else InFlexData(), out_data=OutFlexData( dtype=out_dtype, - expected_scale=make(4.00, 5.00, expt_is_inner), - actual_scale=make(0, 0, expt_is_inner), + expected_scale=make(4.00, 5.00, mode == "batched" or expt_is_inner), + actual_scale=make(0, 0, mode == "batched" or expt_is_inner), checksum_scale=None, ) if act_use_flexpoint else OutFlexData(), ) @@ -233,6 +233,7 @@ class Case: Case(1000, 700, 700, "ragged", "float16", "float16", 8, 2, split_k=9), Case(16, 16, 1000, "batched", "float16", "float16", 5, 1, split_k=None), Case(16, 16, 1000, "batched", "float8_e5m2", "float8_e5m2", 5, 1, split_k=None), + Case(16, 16, 2048, "batched", "float8_e5m2", "float8_e5m2", 6, 1, split_k=5), # mx types: Case(16, 256, 256, "plain", "bfloat16", "mxfloat4_e2m1", 1, 1), Case(16, 256, 256, "plain", "bfloat16", "mxfloat4_e2m1", 1, 1, hbm_swizzling=True), @@ -316,7 +317,7 @@ class Case: def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, inner_expt_opt, has_y_gammas, is_persistent, n_expts_tot, n_expts_act, mode, act_dtype_str, weight_dtype_str, block_m, hbm_swizzling, epilogue_subtile, x_transpose, w_transpose, y_transpose, - device, opt_flags_scope, fresh_knobs): + device, opt_flags_scope): # TODO: remove when Triton FP8 supports proper RTNE if is_cuda(): if "float8" in weight_dtype_str and torch.cuda.get_device_capability()[0] < 9: @@ -413,7 +414,7 @@ def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, inner_expt_o weight_dtype = dtype_str_to_torch(weight_dtype_str) act_dtype = dtype_str_to_torch(act_dtype_str) precision_opt = init_precision(act_dtype, act_is_float8, weight_dtype, weight_mxfp, - n_expts_tot, expt_is_inner, device=device) + mode, n_expts_tot, expt_is_inner, device=device) # precision_opt.x_pad_trans_requires_flexpoint = False if mode == "ragged": m, rdata, gindx, sindx = init_routing_data(m, n_expts_tot, n_expts_act, do_gather, do_scatter, @@ -668,7 +669,7 @@ def test_fused_act(m, n, k, mode, split_k, do_gather, do_scatter, fused_scatter, else: rdata = gindx = sindx = None - precision_opt = init_precision(act_dtype, str(act_dtype).startswith("torch.float8"), weight_dtype, False, n_expts_tot, device=device) + precision_opt = init_precision(act_dtype, str(act_dtype).startswith("torch.float8"), weight_dtype, False, mode, n_expts_tot, device=device) x, w, bias, _, _ = init_compute_data(m, n, k, rdata, gindx, sindx, n_expts_tot, n_expts_act, mode, act_dtype, weight_dtype, False, requires_grad=False, device=device) diff --git a/python/triton_kernels/tests/test_mxfp.py b/python/triton_kernels/tests/test_mxfp.py index 354e4b1c50..db9402eb86 100644 --- a/python/triton_kernels/tests/test_mxfp.py +++ b/python/triton_kernels/tests/test_mxfp.py @@ -45,6 +45,22 @@ def test_mxfp4_rounding_cases(dst_dtype, device): assert_equal(dequant_torch, dequant) +@pytest.mark.parametrize("src_dtype", ["float4_e2m1", "float8_e5m2", "float8_e4m3fn"]) +@pytest.mark.parametrize("dst_dtype", ["float16", "bfloat16", "float32"]) +def test_mxfp_extreme_values(src_dtype, dst_dtype, device): + if "float8" in src_dtype and (is_cuda() and torch.cuda.get_device_capability()[0] < 9): + pytest.skip("Float8 not tested on A100") + src_dtype = dtype_str_to_torch(src_dtype) + dst_dtype = dtype_str_to_torch(dst_dtype) + BIG_VALUE = 65470 if dst_dtype == torch.float16 else 3.3895e38 + x = torch.tensor([BIG_VALUE, BIG_VALUE], dtype=dst_dtype, device=device) + xq_value, xq_scale = downcast_to_mxfp(x, src_dtype, axis=-1) + xdq = upcast_from_mxfp(xq_value, xq_scale, dst_dtype, axis=-1) + xdq_ref = upcast_from_mxfp_torch(xq_value, xq_scale, dst_dtype, axis=-1) + assert_equal(xdq_ref, xdq) + assert not xdq.isinf().any() + + @pytest.mark.parametrize("src_dtype", ["float4_e2m1", "float8_e5m2", "float8_e4m3fn"]) @pytest.mark.parametrize("dst_dtype", ["float16", "bfloat16", "float32"]) def test_mxfp_quant_dequant(src_dtype, dst_dtype, device): diff --git a/python/triton_kernels/triton_kernels/matmul_ogs.py b/python/triton_kernels/triton_kernels/matmul_ogs.py index a58b7811fb..a5fdf65a1d 100644 --- a/python/triton_kernels/triton_kernels/matmul_ogs.py +++ b/python/triton_kernels/triton_kernels/matmul_ogs.py @@ -52,6 +52,13 @@ class FnName(Enum): QUANTIZE_MXFP8 = auto() +@dataclass(frozen=True) +class FusedComm: + out_handles: torch.Tensor + scatter_shard_indx: torch.Tensor | None = None + reduce_rank: int = 0 + n_reduce_shards: int = 1 + EpilogueSpecs = FnSpecs # TODO: remove this alias when callers are updated _kernels = dict() @@ -219,7 +226,8 @@ class MatmulAllocation: scratchpads: dict[str, tuple] def init_allocation(x, w, precision_config, fused_activation, - routing_data, gather_indx, scatter_indx, inner_routing_data, opt_flags): + routing_data, gather_indx, scatter_indx, inner_routing_data, + n_reduce_shards, opt_flags): # ---- output ------ N = w.shape[-1] # by default - M is number of rows in the activations @@ -233,6 +241,7 @@ def init_allocation(x, w, precision_config, fused_activation, else: Mc = scatter_indx.src_indx.shape[0] // routing_data.n_expts_act # compressed number of rows y_rows = Mc + y_rows *= n_reduce_shards if inner_routing_data is not None: batch_dim = inner_routing_data.base.n_expts_tot else: @@ -244,8 +253,9 @@ def init_allocation(x, w, precision_config, fused_activation, scratchpad = dict() if opt_flags.split_k > 1 or (scatter_indx is not None and not opt_flags.fused_scatter): scratch_out_dtype = torch.float32 if opt_flags.split_k > 1 else out_dtype - scratchpad["matmul"] = ((opt_flags.split_k, 1, M, N), scratch_out_dtype) + scratchpad["matmul"] = ((opt_flags.split_k, batch_dim, M, N), scratch_out_dtype) if "matmul" in scratchpad and precision_config.out_scale is not None: + assert batch_dim == 1, "batch_dim > 1 not supported yet" scratchpad["mx_out_scale"] = ((opt_flags.split_k, 1, M, triton.cdiv(N, MXFP_BLOCK_SIZE)), torch.uint8) return MatmulAllocation(x.device, output, scratchpad) @@ -323,11 +333,14 @@ def reduce_grouped(x: torch.Tensor, indx: torch.Tensor, out: torch.Tensor, out_m Returns - The input tensor `x` (modified in place). """ + M = x.shape[2] # Only used for per-batch flex scale. if indx is None and x.shape[0] == 1: return x.squeeze(0), None if indx is not None: num_groups = indx.shape[0] else: + # Handle batched matmul (K, B, M, N) by pretending it to be (K, 1, B*M, N). + x = x.view(x.shape[0], 1, x.shape[1] * x.shape[2], x.shape[3]) num_groups = x.shape[-2] if x_flex is None: x_flex = InFlexData() @@ -351,8 +364,10 @@ def reduce_grouped(x: torch.Tensor, indx: torch.Tensor, out: torch.Tensor, out_m x_flex.reinterpret(x), x.stride(0), x.stride(2), x.stride(3), # x_expected_scale, # scalar input scale out_flex.reinterpret(out), out.stride(1), out.stride(2), # - out_expected_scale, out_actual_scale, out_checksum_scale, indx, # - x.shape[0], x.shape[-1], # + out_expected_scale, out_actual_scale, out_checksum_scale, + out_flex is not None and out_flex.is_per_batch, + indx, + x.shape[0], M, x.shape[-1], # x_mx_scale, stride_mxb, stride_mxs, # out_mx_scale, stride_omxs, # *fused_activation.fn_args, fused_activation.reduction_n, @@ -383,6 +398,7 @@ def matmul_ogs(x, w, bias, gammas: torch.Tensor | None = None, out_alpha: float | None = None, y: torch.Tensor | None = None, + fused_comm: FusedComm | None = None, fused_activation: FusedActivation | None = None, epilogue: Epilogue | None = None, y_acc_in: torch.Tensor | None = None, @@ -392,6 +408,12 @@ def matmul_ogs(x, w, bias, Y[:, :] = 0. for e in num_experts: Y[idxs_y_m(e), :] += matmul(X[idxs_x_m(e), :], W[e, :, :]) + + matmul can be optionally fused with all gather or scatter at the end for the output. When fused_comm is specified, the m-th row of the output will be stored to (m * n_reduce_shards + reduce_rank) -th row + of each rank id in range [scatter_shard_indx[m] * n_reduce_shards, (scatter_shard_indx[m] + 1) * n_reduce_shards) if scatter_shard_indx is not None, otherwise the output will be all gathered across all reduce ranks. + When scatter_shard_indx is specified, the caller should ensure that the indices of different shards do not conflict. + + The output buffer for fused comm should be pre-allocated and passed in via fused_comm.out_handles, which contains ipc handles to the output tensors, each with shape (n_rows * n_reduce_shards, n_cols). """ is_input_batched = x.ndim == 3 if is_input_batched: @@ -399,6 +421,7 @@ def matmul_ogs(x, w, bias, assert scatter_indx is None, "scatter not supported in batched mode" assert routing_data is None, "routing not supported in batched mode" assert inner_routing_data is None, "routing not supported in batched mode" + assert fused_comm is None, "fused comm is not supported in batched mode" assert w.ndim == 3 and w.shape[0] == x.shape[0] if inner_routing_data is not None: assert routing_data is None @@ -509,7 +532,7 @@ def matmul_ogs(x, w, bias, matmul_fused_activation, reduce_fused_activation = reduce_fused_activation, matmul_fused_activation # allocate output/scratchpad memory allocation = init_allocation(x, w, precision_config, fused_activation, - routing_data, gather_indx, scatter_indx, inner_routing_data, opt_flags) + routing_data, gather_indx, scatter_indx, inner_routing_data, fused_comm.n_reduce_shards if fused_comm is not None else 1, opt_flags) memory = apply_allocation(allocation, y) # early exit if batch_size * M * N == 0: @@ -569,7 +592,7 @@ def matmul_ogs(x, w, bias, # create tma descriptor for y y_has_tma = ( opt_flags.is_persistent and (has_scatter_tma or not opt_flags.fused_scatter) - and (y_acc_in is None or y_acc_is_y) + and (y_acc_in is None or y_acc_is_y) and fused_comm is None ) block_n = opt_flags.block_n // opt_flags.epilogue_subtile // matmul_fused_activation.reduction_n y_tma_block_size = [1, block_n] if has_scatter_tma else [1, opt_flags.block_m, block_n] @@ -597,7 +620,12 @@ def matmul_ogs(x, w, bias, # is True the fast code path, stride(-2) == 1 takes precedence, e.g., vs. # w_transpose = w_storage.data.stride()[-1] != 1 w_transpose = w_storage.data.stride()[-2] == 1 - + fused_comm_kwargs = { + "pYPtrs": fused_comm.out_handles, + "ScatterShardIndx": fused_comm.scatter_shard_indx, + "reduce_rank": fused_comm.reduce_rank, + "n_reduce_shards": fused_comm.n_reduce_shards, + } if fused_comm is not None else {} (kernels._p_matmul_ogs if opt_flags.is_persistent else kernels._matmul_ogs)[(grid,)]( y_tensor_or_tma, y_storage.data, *out_matmul.stride(), *((None, out_matmul_scale, None) if out_matmul_has_mx else out_matmul_flex), @@ -629,7 +657,7 @@ def matmul_ogs(x, w, bias, precision_config.allow_tf32, precision_config.flexpoint_saturate_inf, flex.rhs_data.is_per_batch, - flex.out_data.is_per_batch, + out_matmul_flex.is_per_batch, flex.acc_data.is_per_batch, opt_flags.block_m, opt_flags.block_n, @@ -652,6 +680,7 @@ def matmul_ogs(x, w, bias, SWAP_XW=get_swap_xw(precision_config, opt_flags), IS_EPILOGUE_QUANT_MXFP8=epilogue.specs.name == FnName.QUANTIZE_MXFP8.name, NUM_SMS = grid if opt_flags.is_persistent else 0, + **fused_comm_kwargs, **opt_flags.target_kernel_kwargs) # Build grouped reduction inputs in a uniform way group_indx = None if scatter_indx is None or opt_flags.fused_scatter else scatter_indx.src_indx.view(-1, routing_data.n_expts_act) @@ -764,3 +793,53 @@ def matmul_ogs_torch(x, w, bias, msk = dst_idx != -1 out[dst_idx[msk], :] += y[lo:hi, :][msk, :].float() return out + + +def post_matmul_comm_torch(y: torch.Tensor, rank: int, n_reduce_shards: int, + world_size: int, + scatter_shard_indx: torch.Tensor | None = None, +): + """ + Reference implementation of post matmul communication. + + y: the local matmul output + rank: the global rank + n_reduce_shards: the number of reduce shards + world_size: the world size + scatter_shard_indx: the shard indices for the scatter. None if all gather. + + Output shape: + (batch_size, n_rows, n_cols) -> (batch_size, n_rows * n_reduce_shards, n_cols) if batched, otherwise + (n_rows, n_cols) -> (n_rows * n_reduce_shards, n_cols) + """ + from torch import distributed as dist + # if n_reduce_shards == 1: + # return y + + ys = [torch.empty_like(y) for _ in range(world_size)] + dist.all_gather(ys, y) + out_shape = (*y.shape[:-2], y.shape[-2] * n_reduce_shards, y.shape[-1]) + + if scatter_shard_indx is None: + # all gather + assert n_reduce_shards == world_size + return torch.cat(ys, dim=-1).reshape(out_shape) + else: + # Note: when multiple ranks scatter to the same destination, the result is undefined. + scatter_shard_indx_global = torch.empty((world_size, *scatter_shard_indx.shape), device=scatter_shard_indx.device, dtype=scatter_shard_indx.dtype) + dist.all_gather([scatter_shard_indx_global[i] for i in range(world_size)], scatter_shard_indx) + + assert len(out_shape) == 2, "batched mode not supported" + result = torch.zeros(out_shape, device=y.device, dtype=y.dtype) + reduce_shard_id = rank // n_reduce_shards + + for i in range(world_size // n_reduce_shards): + scatter_mask = scatter_shard_indx_global[i * n_reduce_shards, :] == reduce_shard_id + for j in range(n_reduce_shards): + out_slice = result.as_strided( + (result.shape[0] // n_reduce_shards, result.shape[1]), + (result.stride(0) * n_reduce_shards, result.stride(1)), + storage_offset=j * result.stride(0), + ) + out_slice[scatter_mask, :] = ys[i * n_reduce_shards + j][scatter_mask, :] + return result diff --git a/python/triton_kernels/triton_kernels/matmul_ogs_details/_common.py b/python/triton_kernels/triton_kernels/matmul_ogs_details/_common.py index 4ec5ce648f..aa91aaba68 100644 --- a/python/triton_kernels/triton_kernels/matmul_ogs_details/_common.py +++ b/python/triton_kernels/triton_kernels/matmul_ogs_details/_common.py @@ -1,5 +1,4 @@ import torch - import triton import triton.language as tl @@ -55,10 +54,27 @@ def swizzle2d(pid, grid_m, grid_n, GROUP_M: tl.constexpr): @triton.jit -def _load_tile_attrs(tile_id, num_tiles, unpadded_m, grid_n, M, K, ExptData, ExptHist, ExptOffs, ExptTileOffs, - EXPT_IS_INNER: tl.constexpr, X_IS_PADDED: tl.constexpr, W_IS_PADDED: tl.constexpr, - BLOCK_M: tl.constexpr, BLOCK_K: tl.constexpr, PACKED_BLOCK_K_W: tl.constexpr, - SPLIT_K: tl.constexpr, GROUP_M: tl.constexpr, XCD_SWIZZLE: tl.constexpr): +def _load_tile_attrs( + tile_id, + num_tiles, + unpadded_m, + grid_n, + M, + K, + ExptData, + ExptHist, + ExptOffs, + ExptTileOffs, + EXPT_IS_INNER: tl.constexpr, + X_IS_PADDED: tl.constexpr, + W_IS_PADDED: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_K: tl.constexpr, + PACKED_BLOCK_K_W: tl.constexpr, + SPLIT_K: tl.constexpr, + GROUP_M: tl.constexpr, + XCD_SWIZZLE: tl.constexpr, +): # unpack and swizzle program ids pid_emnk = tile_id if XCD_SWIZZLE != 1: @@ -234,7 +250,7 @@ def matmul_launch_metadata(grid, kernel, args): # recreate inverse GatherIndx. dst = torch.full_like(gindx, -1) idx = torch.arange(len(gindx), device=gindx.device, dtype=torch.int32) - mask = (gindx != -1) + mask = gindx != -1 dst[gindx[mask]] = idx[mask] n_read_rows = (dst.view((-1, n_expts_act)) != -1).any(dim=1).sum() else: @@ -252,3 +268,9 @@ def matmul_launch_metadata(grid, kernel, args): ret["bytes"] = int(n_x_bytes + n_y_bytes + n_w_bytes) return ret + + +@triton.jit +def threadfence_system(): + tl.inline_asm_elementwise("mov.u32 $0, 0x0; fence.sc.sys;", args=(), dtype=(tl.int32, ), is_pure=False, pack=1, + constraints="=r") diff --git a/python/triton_kernels/triton_kernels/matmul_ogs_details/_matmul_ogs.py b/python/triton_kernels/triton_kernels/matmul_ogs_details/_matmul_ogs.py index fe5fcfa008..ba634857eb 100644 --- a/python/triton_kernels/triton_kernels/matmul_ogs_details/_matmul_ogs.py +++ b/python/triton_kernels/triton_kernels/matmul_ogs_details/_matmul_ogs.py @@ -15,6 +15,7 @@ matmul_launch_metadata, swizzle2d, xcd_swizzle, + threadfence_system, ) @@ -94,6 +95,10 @@ def _matmul_ogs( UPCAST_INDICES: tl.constexpr = False, SWAP_XW: tl.constexpr = False, IS_EPILOGUE_QUANT_MXFP8: tl.constexpr = False, + pYPtrs=None, + ScatterShardIndx=None, + reduce_rank = 0, + n_reduce_shards: tl.constexpr = 1, ): tl.assume(stride_y_k >= 0) tl.assume(stride_y_z >= 0) @@ -366,6 +371,8 @@ def _matmul_ogs( if is_x_microscaled: XMxScalePtrs += (MX_SCALE_BLOCK_K * SPLIT_K) * stride_x_mx_k else: + # if w.dtype.is_fp8() and not x.dtype.is_fp8(): + # w = w.to(x.dtype) acc = tl.dot(x, w, acc, max_num_imprecise_acc=MAX_NUM_IMPRECISE_ACC, allow_tf32=ALLOW_TF32) XPtrs += (BLOCK_K * SPLIT_K) * stride_x_k WPtrs += (PACKED_BLOCK_K_W * SPLIT_K) * stride_w_k @@ -457,4 +464,26 @@ def _matmul_ogs( out = float_to_flex(out, YExpectedScale, YActualScale, YChecksumScale, mask, Y, FLEXPOINT_SATURATE_INF) if EPILOGUE_FN is not None and not IS_EPILOGUE_QUANT_MXFP8: out = EPILOGUE_FN(out, *epilogue_fn_args, target_dtype=YPtrs.dtype.element_ty) - tl.store(YPtrs, out, mask=mask) + if pYPtrs is None: + tl.store(YPtrs, out, mask=mask) + else: + tl.static_assert(Y_TMA_MODE is None, "TMA is not supported with fused comms") + if ScatterShardIndx is not None: + dst_shard_idx = tl.load(ScatterShardIndx + offs_y_m, mask=mask_m) + for i in tl.static_range(n_reduce_shards): + peer = dst_shard_idx * n_reduce_shards + (reduce_rank + i) % n_reduce_shards + peer_Y_ptr = tl.load(pYPtrs + peer).to(tl.pointer_type(Y.type.element_ty)) + tl.multiple_of(peer_Y_ptr, 16) + offs_y_mn = offs_y_m.to(index_type)[:, None] * stride_y_m * n_reduce_shards + reduce_rank * stride_y_m + offs_y_n.to(index_type)[None, :] * stride_y_n + tl.store(peer_Y_ptr[:, None] + offs_y_mn, out, mask=mask) + else: + # full all gather + for i in tl.static_range(n_reduce_shards): + peer = (reduce_rank + i) % n_reduce_shards + peer_Y_ptr = tl.load(pYPtrs + peer).to(tl.pointer_type(Y.type.element_ty)) + tl.multiple_of(peer_Y_ptr, 16) + offs_y_mn = offs_y_m.to(index_type)[:, None] * stride_y_m * n_reduce_shards + reduce_rank * stride_y_m + offs_y_n.to(index_type)[None, :] * stride_y_n + tl.store(peer_Y_ptr + offs_y_mn, out, mask=mask) + + if pYPtrs is not None: + threadfence_system() diff --git a/python/triton_kernels/triton_kernels/matmul_ogs_details/_p_matmul_ogs.py b/python/triton_kernels/triton_kernels/matmul_ogs_details/_p_matmul_ogs.py index 7b3f85c2d1..0dbbb60af6 100644 --- a/python/triton_kernels/triton_kernels/matmul_ogs_details/_p_matmul_ogs.py +++ b/python/triton_kernels/triton_kernels/matmul_ogs_details/_p_matmul_ogs.py @@ -19,6 +19,7 @@ make_matmul_repr, matmul_launch_metadata, swizzle2d, + threadfence_system, ) @@ -100,9 +101,13 @@ def _p_matmul_ogs( X_TMA_MODE: tl.constexpr, Y_TMA_MODE: tl.constexpr, TOKENS_PER_EXPT_FOR_ANNOTATION=None, - UPCAST_INDICES:tl.constexpr=False, + UPCAST_INDICES: tl.constexpr=False, SWAP_XW: tl.constexpr = False, IS_EPILOGUE_QUANT_MXFP8: tl.constexpr = False, + pYPtrs=None, + ScatterShardIndx=None, + reduce_rank=0, + n_reduce_shards: tl.constexpr = 1, ): # tl.static_assert(SWIZZLE_MX_VALUE is None, "NYI. Value swizzling") @@ -199,7 +204,7 @@ def _p_matmul_ogs( tile_id1 = tl.program_id(0) - NUM_SMS # Keep track of local max for updating flexpoint scales. - USE_LOCAL_ABSMAX: tl.constexpr = (YActualScale is not None) and (not PER_BATCH_OUT_SCALE) and (not is_out_microscaled) + USE_LOCAL_ABSMAX: tl.constexpr = (YActualScale is not None) and (not PER_BATCH_OUT_SCALE) and (not is_out_microscaled) and (pYPtrs is None) if USE_LOCAL_ABSMAX: THREADS_PER_BLOCK: tl.constexpr = tl.extra.cuda.num_threads() local_absmax = tl.full([THREADS_PER_BLOCK], 0.0, tl.uint32) @@ -306,7 +311,8 @@ def _p_matmul_ogs( mask_k_scale = tl.full([MX_SCALE_BLOCK_K], True, dtype=tl.int1) else: mask_k_scale = off_k_mx + tl.arange(0, MX_SCALE_BLOCK_K) < tl.cdiv(K, MX_PACK_DIVISOR) - x_scales = tl.load(XMxScalePtrs, mask=mask_k_scale[None, :], other=0.0) + mask_m = off_m + tl.arange(0, BLOCK_M) < eM + x_scales = tl.load(XMxScalePtrs, mask=mask_k_scale[None, :] & mask_m[:, None], other=0.0) elif x_format == "fp16" or x_format == "bf16": x_scales: tl.constexpr = None else: @@ -360,7 +366,7 @@ def _p_matmul_ogs( tl.device_assert(stride_y_k // stride_y_m == tl.cdiv(stride_y_k, stride_y_m)) split_k_row_offs = pid_k1 * (stride_y_k // stride_y_m) offs_y_m = tl.where(mask_m, offs_y_m + split_k_row_offs, offs_y_m) - elif Y_TMA_MODE is None: + elif Y_TMA_MODE is None and pYPtrs is None: tl.static_assert(HAS_SCATTER) offs_y_m, mask_m = _load_writeback_idx_and_mask(WriteBackIndx, writeback_size, start_m1 + offs_m, mask_m) MASK_ACC: tl.constexpr = USE_FLEXPOINT_SCALE @@ -512,31 +518,54 @@ def _p_matmul_ogs( out = EPILOGUE_FN(out, *epilogue_fn_args, target_dtype=YPtr.dtype.element_ty, pid=len(accs)*tile_id1 + a_i) out = out.to(YPtr.dtype.element_ty) - if USE_SCATTER_TMA: - # Convert -1 offsets to INT_MAX. We do this by clearing the leading bit. Note that - # there shouldn't be any other negative values. - offs_y_m = (offs_y_m.to(tl.uint32, bitcast=True) & 0x7FFFFFFF).to(tl.int32, bitcast=True) - Y.scatter(out, offs_y_m, out_off_n) - elif Y_TMA_MODE == "dense": - out = tl.reshape(out, [1] + out.shape) - off_kz = pid_k * batch_size + start_z1 - Y.store([off_kz, off_m1, out_off_n], out) - elif Y_TMA_MODE == "ragged": - out = tl.reshape(out, [1] + out.shape) - store_ragged(Y, start_m1, eM1, [pid_k, off_m1, out_off_n], out, ragged_dim=1) + + if pYPtrs is None: + if USE_SCATTER_TMA: + # Convert -1 offsets to INT_MAX. We do this by clearing the leading bit. Note that + # there shouldn't be any other negative values. + offs_y_m = (offs_y_m.to(tl.uint32, bitcast=True) & 0x7FFFFFFF).to(tl.int32, bitcast=True) + Y.scatter(out, offs_y_m, out_off_n) + elif Y_TMA_MODE == "dense": + out = tl.reshape(out, [1] + out.shape) + off_kz = pid_k * batch_size + start_z1 + Y.store([off_kz, off_m1, out_off_n], out) + elif Y_TMA_MODE == "ragged": + out = tl.reshape(out, [1] + out.shape) + store_ragged(Y, start_m1, eM1, [pid_k, off_m1, out_off_n], out, ragged_dim=1) + else: + tl.static_assert(Y_TMA_MODE is None) + offs_y_n = out_off_n + tl.arange(0, OUT_BLOCK_N) + mask_n = offs_y_n < yN + mask = mask_m[:, None] & mask_n[None, :] + offs_kzmn = pid_k1.to(index_type) * stride_y_k + start_z1.to(index_type) * stride_y_z + offs_y_m.to(index_type)[:, None] * stride_y_m + offs_y_n[None, :] * stride_y_n + tl.store(YPtr + offs_kzmn, out, mask=mask) else: - tl.static_assert(Y_TMA_MODE is None) + tl.static_assert(Y_TMA_MODE is None, "TMA is not supported with fused comms") offs_y_n = out_off_n + tl.arange(0, OUT_BLOCK_N) mask_n = offs_y_n < yN - - YPtrs = YPtr + pid_k1.to(index_type) * stride_y_k + start_z1.to(index_type) * stride_y_z + offs_y_m.to(index_type)[:, None] * stride_y_m + offs_y_n[None, :] * stride_y_n mask = mask_m[:, None] & mask_n[None, :] - tl.store(YPtrs, out, mask=mask) + offs_kzmn = pid_k1.to(index_type) * stride_y_k + start_z1.to(index_type) * stride_y_z + offs_y_n[None, :] * stride_y_n +offs_y_m.to(index_type)[:, None] * stride_y_m * n_reduce_shards + reduce_rank * stride_y_m + if ScatterShardIndx is not None: + dst_shard_idx = tl.load(ScatterShardIndx + offs_y_m, mask=mask_m) + for i in tl.static_range(n_reduce_shards): + peer = dst_shard_idx * n_reduce_shards + (reduce_rank + i) % n_reduce_shards + peer_Y_ptr = tl.load(pYPtrs + peer).to(tl.pointer_type(YPtr.type.element_ty)) + tl.multiple_of(peer_Y_ptr, 16) + tl.store(peer_Y_ptr[:, None] + offs_kzmn, out, mask=mask) + else: + # full all gather + for i in tl.static_range(n_reduce_shards): + peer = (reduce_rank + i) % n_reduce_shards + peer_Y_ptr = tl.load(pYPtrs + peer).to(tl.pointer_type(YPtr.type.element_ty)) + tl.multiple_of(peer_Y_ptr, 16) + tl.store(peer_Y_ptr + offs_kzmn, out, mask=mask) # Update the flexpoint scales if USE_LOCAL_ABSMAX: tl.atomic_max(YActualScale, compute_scale(local_absmax.to(tl.float32, bitcast=True), YPtr), sem="relaxed") + if pYPtrs is not None: + threadfence_system() _per_device_alloc_fns = {} def get_per_device_per_stream_alloc_fn(device): diff --git a/python/triton_kernels/triton_kernels/matmul_ogs_details/_reduce_grouped.py b/python/triton_kernels/triton_kernels/matmul_ogs_details/_reduce_grouped.py index 2d3b86646e..b6eac14b2f 100644 --- a/python/triton_kernels/triton_kernels/matmul_ogs_details/_reduce_grouped.py +++ b/python/triton_kernels/triton_kernels/matmul_ogs_details/_reduce_grouped.py @@ -9,7 +9,7 @@ def _reduce_grouped(X, stride_xb: tl.uint64, stride_xm: tl.uint64, stride_xn, # XScale, # input scalar flex scale Out, stride_om: tl.uint64, stride_on, # output tensor OutExpectedScale, OutActualScale, OutChecksumScale, # output scalar flex scales - InIndx, B, N, # + PER_BATCH_OUT_SCALE: tl.constexpr, InIndx, B, M, N, # XMxScale, stride_mxb: tl.uint64, stride_mxs: tl.uint64, # optional per-32-col output MXFP scales (uint8) OutMxScale, stride_omxs: tl.uint64, # optional per-32-col output MXFP scales (uint8) @@ -42,6 +42,12 @@ def _reduce_grouped(X, stride_xb: tl.uint64, stride_xm: tl.uint64, stride_xn, # XScalePtrs = XMxScale + tl.arange(0, BLOCK_N // 32) * stride_xn if HAS_OUT_MX_SCALE: OutScalePtrs = OutMxScale + tl.arange(0, BLOCK_N_OUT // 32) * stride_on + if PER_BATCH_OUT_SCALE: + out_batch_idx = pid_t // M + OutExpectedScale += out_batch_idx + OutActualScale += out_batch_idx + if OutChecksumScale is not None: + OutChecksumScale += out_batch_idx x_scale = load_scale(XScale) for n_curr in tl.range(0, N, BLOCK_N, num_stages=4): acc = tl.zeros([BLOCK_N_OUT], dtype=tl.float32) diff --git a/python/triton_kernels/triton_kernels/matmul_ogs_details/opt_flags.py b/python/triton_kernels/triton_kernels/matmul_ogs_details/opt_flags.py index 67fb892ec6..718181cdea 100644 --- a/python/triton_kernels/triton_kernels/matmul_ogs_details/opt_flags.py +++ b/python/triton_kernels/triton_kernels/matmul_ogs_details/opt_flags.py @@ -171,9 +171,7 @@ def make_default_opt_flags_amd( ) is_persistent = constraints.get("is_persistent", False) # split_k: - if batch_size > 1: - split_k = 1 # currently not supported - elif constraints.get("split_k", None) is not None: + if constraints.get("split_k", None) is not None: split_k = constraints["split_k"] elif is_persistent or enforce_bitwise_invariance: split_k = 1 @@ -303,9 +301,7 @@ def make_default_opt_flags_nvidia( # TODO: swizzle the HBM layout of the weights instead block_n, block_k = block_k, block_n # split_k - if batch_size > 1: - split_k = 1 # currently not supported - elif constraints.get("split_k", None) is not None: + if constraints.get("split_k", None) is not None: split_k = constraints["split_k"] elif is_persistent or enforce_bitwise_invariance or precision_config.act_scale is not None or precision_config.out_scale is not None: split_k = 1 diff --git a/python/triton_kernels/triton_kernels/numerics_details/mxfp.py b/python/triton_kernels/triton_kernels/numerics_details/mxfp.py index 8ad4ddb609..15d2251541 100644 --- a/python/triton_kernels/triton_kernels/numerics_details/mxfp.py +++ b/python/triton_kernels/triton_kernels/numerics_details/mxfp.py @@ -297,6 +297,17 @@ def upcast_from_mxfp_torch(tensor: torch.Tensor, scale: torch.Tensor, target_dty padded_tensor = padded_tensor.view(*new_shape) dq_scale_padded = dq_scale.unsqueeze(-1) # shape: [..., ceil(axis_shape/32), 1] out_padded = padded_tensor * dq_scale_padded + # Need to clamp since due to rounding, we can have overflow that was within + # the range before quantization. + # e.g., 3.3895e+38 -> log2(3.3895e+38 / max_fp8e4m3=448) ~= 119.17 -> round + # up to 120 + exp_bias=127 -> scale=247 + # 3.3895e+38 / 2**120 ~= 254.9976 -> round to 256 in fp8e4m3fn + # Dequantization: 256 * 2**120 > 3.4e38 overflowing 3.38953139e38 + finfo = torch.finfo(target_dtype) + out_padded = (padded_tensor * dq_scale_padded).clamp(finfo.min, finfo.max) + if tensor.dtype == torch.float8_e5m2: + # fp8e5m2 can have inf and we want to preserve so separately handle + out_padded = out_padded.where(~padded_tensor.isinf(), padded_tensor.to(target_dtype)) # Flatten back and remove the padded tail out_padded = out_padded.view(*fp32_tensor.shape[:-1], new_axis_shape) diff --git a/python/triton_kernels/triton_kernels/numerics_details/mxfp_details/_upcast_from_mxfp.py b/python/triton_kernels/triton_kernels/numerics_details/mxfp_details/_upcast_from_mxfp.py index 5e5f027fa9..3d284c79b0 100644 --- a/python/triton_kernels/triton_kernels/numerics_details/mxfp_details/_upcast_from_mxfp.py +++ b/python/triton_kernels/triton_kernels/numerics_details/mxfp_details/_upcast_from_mxfp.py @@ -119,6 +119,16 @@ def _upcast_from_mxfp(out_ptr, stride_o_outer, stride_o_quant: tl.constexpr, mx_ scale = scale.reshape(dst_scale.shape) out_tensor = dst_tensor * dst_scale + if dst_dtype == tl.float32: + max_fin = 3.4028234663852886e+38 + elif dst_dtype == tl.bfloat16: + max_fin = 3.3895313892515355e+38 + else: + tl.static_assert(dst_dtype == tl.float16) + max_fin = 65504 + # TODO: handle infinity same as upcast_from_mxfp_torch together with the + # above FIXME + out_tensor = tl.clamp(out_tensor, min=-max_fin, max=max_fin) # Correct any NaNs encoded via the scale. out_tensor = tl.where(scale == 0xFF, float("nan"), out_tensor) out_tensor = out_tensor.reshape([BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_DIM]) diff --git a/python/triton_kernels/triton_kernels/routing.py b/python/triton_kernels/triton_kernels/routing.py index e04a487dc1..59dd2d0317 100644 --- a/python/triton_kernels/triton_kernels/routing.py +++ b/python/triton_kernels/triton_kernels/routing.py @@ -128,16 +128,15 @@ def forward(ctx, expt_scal, expt_indx, n_expts_tot, bitmatrix): expt_offs = torch.empty(n_expts_tot, dtype=torch.int32, device=device) combined_indx = torch.empty(n_gates_pad * 2, dtype=torch.int32, device=device) gate_scal = torch.empty(n_gates_pad, dtype=dtype, device=device) - token_offs_combined = empty_aligned((block_m_num + 1, n_expts_tot + 1), torch.int32, device, MEMSET_BLOCK_A) - block_pid_map = empty_aligned((block_m_num, max_n_tiles(n_expts_tot, n_gates_pad)), torch.int32, device, - MEMSET_BLOCK_A) + token_offs_combined, _ = empty_aligned((block_m_num + 1, n_expts_tot + 1), torch.int32, device, MEMSET_BLOCK_A) + block_pid_map, block_pid_map_n_elts = empty_aligned((block_m_num, max_n_tiles(n_expts_tot, n_gates_pad)), + torch.int32, device, MEMSET_BLOCK_A) # slice padded allocations combine_indx = combined_indx[:n_gates_pad] dispatch_indx = combined_indx[n_gates_pad:] token_offs_raw, token_offs_pad = token_offs_combined[0], token_offs_combined[1:] # grid sizes - block_pid_map_n_elts = block_pid_map.untyped_storage().size() // block_pid_map.dtype.itemsize blocks1a = exact_div(block_pid_map_n_elts, MEMSET_BLOCK_A) + token_offs_combined.shape[0] blocks1b = cdiv(n_gates_pad * 2, MEMSET_BLOCK) + n_expts_tot + 1 blocks2a = n_expts_tot * token_offs_pad.shape[0] @@ -198,7 +197,7 @@ def empty_aligned(shape, dtype, device, pad_size): pad = lambda x: cdiv(x, pad_size) * pad_size ret = torch.empty((*shape[:-1], pad(shape[-1])), dtype=dtype, device=device) ret_slices = (*[slice(None)] * (len(shape) - 1), slice(0, shape[-1])) - return ret[ret_slices] + return ret[ret_slices], ret.numel() def max_n_tiles(n_expts_tot, n_gates): @@ -217,10 +216,11 @@ def compute_expt_data(expt_hist, n_expts_tot, n_gates): MEMSET_BLOCK = 512 dtype = torch.int32 device = expt_hist.device - token_offs_combined = empty_aligned((block_m_num + 1, n_expts_tot + 1), dtype, device, MEMSET_BLOCK) - block_pid_map = empty_aligned((block_m_num, max_n_tiles(n_expts_tot, n_gates)), dtype, device, MEMSET_BLOCK) + token_offs_combined, _ = empty_aligned((block_m_num + 1, n_expts_tot + 1), dtype, device, MEMSET_BLOCK) + block_pid_map, block_pid_map_size = empty_aligned((block_m_num, max_n_tiles(n_expts_tot, n_gates)), dtype, device, + MEMSET_BLOCK) token_offs_raw, token_offs_pad = token_offs_combined[0], token_offs_combined[1:] - n_memset_blocks = exact_div(block_pid_map.storage().size(), MEMSET_BLOCK) + n_memset_blocks = exact_div(block_pid_map_size, MEMSET_BLOCK) _expt_data_memset[(token_offs_combined.shape[0] + n_memset_blocks, )]( expt_hist, n_expts_tot, # diff --git a/setup.py b/setup.py index fcd858ae61..2742760041 100644 --- a/setup.py +++ b/setup.py @@ -367,8 +367,8 @@ def download_and_copy(name, src_func, dst_path, variable, version, url_func): with zipfile.ZipFile(file_bytes, "r") as file: file.extractall(path=tmp_path) else: - file = tarfile.open(fileobj=open_url(url), mode="r|*") - file.extractall(path=tmp_path) + with open_url(url) as url_file, tarfile.open(fileobj=url_file, mode="r|*") as tar_file: + tar_file.extractall(path=tmp_path, filter="data") os.makedirs(os.path.split(dst_path)[0], exist_ok=True) print(f'copy {src_path} to {dst_path} ...') if os.path.isdir(src_path): @@ -584,14 +584,15 @@ def download_and_copy_dependencies(): url_func=lambda system, arch, version: f"https://developer.download.nvidia.com/compute/cuda/redist/cuda_nvdisasm/{system}-{arch}/cuda_nvdisasm-{system}-{arch}-{version}-archive.tar.xz", ) + crt = "crt" if int(NVIDIA_TOOLCHAIN_VERSION["cudacrt"].split(".")[0]) >= 13 else "nvcc" download_and_copy( name="nvcc", - src_func=lambda system, arch, version: f"cuda_nvcc-{system}-{arch}-{version}-archive/include", + src_func=lambda system, arch, version: f"cuda_{crt}-{system}-{arch}-{version}-archive/include", dst_path="include", variable="TRITON_CUDACRT_PATH", version=NVIDIA_TOOLCHAIN_VERSION["cudacrt"], url_func=lambda system, arch, version: - f"https://developer.download.nvidia.com/compute/cuda/redist/cuda_nvcc/{system}-{arch}/cuda_nvcc-{system}-{arch}-{version}-archive.tar.xz", + f"https://developer.download.nvidia.com/compute/cuda/redist/cuda_{crt}/{system}-{arch}/cuda_{crt}-{system}-{arch}-{version}-archive.tar.xz", ) download_and_copy( name="cudart", diff --git a/test/Conversion/amd/async_ops_to_llvm.mlir b/test/Conversion/amd/async_ops_to_llvm.mlir index edb89d0819..a4a806a9b3 100644 --- a/test/Conversion/amd/async_ops_to_llvm.mlir +++ b/test/Conversion/amd/async_ops_to_llvm.mlir @@ -138,7 +138,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar tt.func public @async_commit_group(%arg0: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg1: i32 {tt.divisibility = 16 : i32}, %arg2: !ttg.memdesc<32x64xf16, #shared, #smem, mutable>) { - // CHECK-NEXT: llvm.mlir.constant(0 : i32) : i32 + // CHECK: llvm.mlir.constant(0 : i32) : i32 // CHECK-NEXT: llvm.return ttg.async_commit_group tt.return diff --git a/test/Conversion/amd/ds_transpose.mlir b/test/Conversion/amd/ds_transpose.mlir index 84eb92c386..4853527273 100644 --- a/test/Conversion/amd/ds_transpose.mlir +++ b/test/Conversion/amd/ds_transpose.mlir @@ -5,8 +5,13 @@ #mma32_scaled = #ttg.amd_mfma<{version = 4, warpsPerCTA = [2, 2], instrShape = [32, 32, 64], isTransposed = true}> #shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0, 1]}> #shared1 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}> +#padding = #ttg.padded_shared<[512:+16] {order = [0, 1], shape = [128, 64]}> +#padding_vec1 = #ttg.padded_shared<[1:+4] {order = [0, 1], shape = [128, 64]}> #smem = #ttg.shared_memory +#linear_ds_tr_tile_out = #ttg.linear<{register = [[0, 1], [0, 2], [0, 8], [0, 4]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [32, 0], [16, 0]], warp = [[0, 0], [0, 0]], block = []}> +#linear_ds_tr_tile_invalid = #ttg.linear<{register = [[0, 1], [0, 2], [0, 8], [0, 4]], lane = [[1, 0], [4, 0], [2, 0], [8, 0], [32, 0], [16, 0]], warp = [[0, 0], [0, 0]], block = []}> + module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { // CHECK-LABEL: ds_transpose_n_t_fp16_mfma_16 tt.func @ds_transpose_n_t_fp16_mfma_16(%arg0: !ttg.memdesc<128x64xf16, #shared, #smem, mutable>, %arg1: !ttg.memdesc<64x128xf16, #shared1, #smem, mutable>, %arg2: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}) { @@ -664,4 +669,46 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr tt.store %ptr2, %2 : tensor<128x512x!tt.ptr, #ttg.dot_op<{opIdx = 1, parent = #mma32, kWidth = 16}>> tt.return } + + // CHECK-LABEL: ds_transpose_ll + tt.func @ds_transpose_ll(%arg0: !ttg.memdesc<64x16xbf16, #shared, #smem>, %arg1: !tt.ptr) { + // CHECK-COUNT-4: rocdl.ds.read.tr16.b64 %{{.*}} : <3> -> vector<4xbf16> + // CHECK-NOT: rocdl.ds.read.tr16.b64 + %a1 = ttg.local_load %arg0 : !ttg.memdesc<64x16xbf16, #shared, #smem> -> tensor<64x16xbf16, #linear_ds_tr_tile_out> + + %ptr1 = tt.splat %arg1 : !tt.ptr -> tensor<64x16x!tt.ptr, #linear_ds_tr_tile_out> + tt.store %ptr1, %a1 : tensor<64x16x!tt.ptr, #linear_ds_tr_tile_out> + tt.return + } + + // CHECK-LABEL: ds_transpose_ll_invalid + tt.func @ds_transpose_ll_invalid(%arg0: !ttg.memdesc<64x16xbf16, #shared, #smem>, %arg1: !tt.ptr) { + %a1 = ttg.local_load %arg0 : !ttg.memdesc<64x16xbf16, #shared, #smem> -> tensor<64x16xbf16, #linear_ds_tr_tile_invalid> + // CHECK-NOT: rocdl.ds.read.tr16.b64 + + %ptr1 = tt.splat %arg1 : !tt.ptr -> tensor<64x16x!tt.ptr, #linear_ds_tr_tile_invalid> + tt.store %ptr1, %a1 : tensor<64x16x!tt.ptr, #linear_ds_tr_tile_invalid> + tt.return + } + + // CHECK-LABEL: ds_transpose_with_padding + tt.func @ds_transpose_with_padding(%arg0: !ttg.memdesc<128x64xf16, #padding, #smem, mutable>, %arg2: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}) { + // CHECK-COUNT-16: rocdl.ds.read.tr16.b64 %{{.*}} : <3> -> vector<4xf16> + // CHECK-NOT: rocdl.ds.read.tr16.b64 + %1 = ttg.local_load %arg0 : !ttg.memdesc<128x64xf16, #padding, #smem, mutable> -> tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma16, kWidth = 8}>> + + %ptr1 = tt.splat %arg2 : !tt.ptr -> tensor<128x64x!tt.ptr, #ttg.dot_op<{opIdx = 0, parent = #mma16, kWidth = 8}>> + tt.store %ptr1, %1 : tensor<128x64x!tt.ptr, #ttg.dot_op<{opIdx = 0, parent = #mma16, kWidth = 8}>> + tt.return + } + + // CHECK-LABEL: ds_transpose_padding_interval_too_small + tt.func @ds_transpose_padding_interval_too_small(%arg0: !ttg.memdesc<128x64xf16, #padding_vec1, #smem, mutable>, %arg2: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}) { + // CHECK-NOT: rocdl.ds.read.tr16.b64 + %1 = ttg.local_load %arg0 : !ttg.memdesc<128x64xf16, #padding_vec1, #smem, mutable> -> tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma16, kWidth = 8}>> + + %ptr1 = tt.splat %arg2 : !tt.ptr -> tensor<128x64x!tt.ptr, #ttg.dot_op<{opIdx = 0, parent = #mma16, kWidth = 8}>> + tt.store %ptr1, %1 : tensor<128x64x!tt.ptr, #ttg.dot_op<{opIdx = 0, parent = #mma16, kWidth = 8}>> + tt.return + } } diff --git a/test/Conversion/amd/minmax.mlir b/test/Conversion/amd/minmax.mlir index f0dbd10c60..691681c47f 100644 --- a/test/Conversion/amd/minmax.mlir +++ b/test/Conversion/amd/minmax.mlir @@ -12,7 +12,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr // GFX942: llvm.intr.maxnum // GFX950: llvm.func @min_max -// GFX950-NEXT: llvm.intr.minimum +// GFX950: llvm.intr.minimum // GFX950-NEXT: llvm.intr.maximum tt.func public @min_max(%arg0: f32, %arg1: f32) { %0 = arith.minimumf %arg0, %arg1 : f32 diff --git a/test/Conversion/amd/tritongpu_tdm_to_llvm.mlir b/test/Conversion/amd/tritongpu_tdm_to_llvm.mlir new file mode 100644 index 0000000000..9dc27a6e16 --- /dev/null +++ b/test/Conversion/amd/tritongpu_tdm_to_llvm.mlir @@ -0,0 +1,25 @@ +// RUN: triton-opt %s --split-input-file --allocate-shared-memory --convert-triton-amdgpu-to-llvm=arch=gfx1250 --convert-builtin-func-to-llvm | FileCheck %s --check-prefixes=GFX1250 + +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#shared = #ttg.padded_shared<[32:+4] {order = [1, 0], shape = [64, 64]}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "...", "ttg.threads-per-warp" = 32 : i32} { + // GFX1250-LABEL: tdm_kernel + tt.func public @tdm_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %c_shape = arith.constant 128 : i32 + %c_stride0 = arith.constant 128 : i64 + %c_stride1 = arith.constant 1 : i64 + %c_offset = arith.constant 0 : i32 + %c_pred = arith.constant true + %0 = tt.make_tensor_descriptor %arg0, [%c_shape, %c_shape], [%c_stride0, %c_stride1] : , > + %1 = ttg.local_alloc : () -> !ttg.memdesc<64x64xf16, #shared, #smem, mutable> + // GFX1250-COUNT-4: llvm.insertelement{{.*}} : vector<4xi32> + // GFX1250-COUNT-8: llvm.insertelement{{.*}} : vector<8xi32> + // GFX1250: llvm.amdgcn.tensor.load.to.lds.d2{{.*}} : (vector<4xi32>, vector<8xi32>, i32) -> () + %2 = amdgpu.async_tdm_copy_global_to_local %0[%c_offset, %c_offset] into %1, %c_pred : !tt.tensordesc> -> !ttg.memdesc<64x64xf16, #shared, #smem, mutable> + // GFX1250: llvm.amdgcn.s.wait.tensorcnt{{.*}} : (i16) -> () + %3 = amdgpu.async_tdm_wait {num = 0 : i32} + %4 = ttg.local_load %1 : !ttg.memdesc<64x64xf16, #shared, #smem, mutable> -> tensor<64x64xf16, #blocked> + tt.return + } +} diff --git a/test/Conversion/tritongpu_to_llvm_blackwell.mlir b/test/Conversion/tritongpu_to_llvm_blackwell.mlir index 4f524ba472..febda9c140 100644 --- a/test/Conversion/tritongpu_to_llvm_blackwell.mlir +++ b/test/Conversion/tritongpu_to_llvm_blackwell.mlir @@ -41,7 +41,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32} { #tmem = #ttng.tensor_memory_encoding module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32} { // CHECK-LABEL: @tc_gen5_mma_multi_m_n - // CHECK: %[[TMEM_BASE:.+]] = llvm.ptrtoint %{{.*}} : !llvm.ptr<3> to i32 + // CHECK: %[[TMEM_BASE:.+]] = llvm.ptrtoint %arg2{{.*}} : !llvm.ptr<3> to i32 // CHECK: @$5 tcgen05.mma.cta_group::1.kind::f16 [ $0 + 0 ], $1, $2, $3, $4;", "r,l,l,r,b,b" %[[TMEM_BASE]] // CHECK: @$5 tcgen05.mma.cta_group::1.kind::f16 [ $0 + 64 ], $1, $2, $3, $4;", "r,l,l,r,b,b" %[[TMEM_BASE]] // 1048576 = row << 16 + col = 16 << 16 + 0 @@ -74,7 +74,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32} { #tmem = #ttng.tensor_memory_encoding module attributes {"ttg.num-ctas" = 2 : i32, "ttg.num-warps" = 8 : i32} { // CHECK-LABEL: @tc_gen5_mma_multi_ctas - // CHECK: %[[TMEM_BASE:.+]] = llvm.ptrtoint %{{.*}} : !llvm.ptr<3> to i32 + // CHECK: %[[TMEM_BASE:.+]] = llvm.ptrtoint %arg2{{.*}} : !llvm.ptr<3> to i32 // CHECK: @$5 tcgen05.mma.cta_group::1.kind::f16 [ $0 + 0 ], $1, $2, $3, $4;", "r,l,l,r,b,b" %[[TMEM_BASE]] // CHECK: @$5 tcgen05.mma.cta_group::1.kind::f16 [ $0 + 32 ], $1, $2, $3, $4;", "r,l,l,r,b,b" %[[TMEM_BASE]] // 1048576 = row << 16 + col = 16 << 16 + 0 @@ -210,7 +210,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK: %[[TRUE:.+]] = llvm.mlir.constant(true) : i1 // CHECK: %[[DESC1:.+]] = llvm.mlir.constant(681579536 : i32) : i32 // CHECK: @$7 tcgen05.mma.cta_group::1.kind::mxf8f6f4.block_scale.scale_vec::1X [ $0 + 0 ], $1, $2, $3, [ $4 + 0 ], [ $5 + 0 ], $6;", "r,l,l,r,r,r,b,b" %[[TMEM_BASE]], %{{.+}}, %{{.+}}, %[[DESC1]], %{{.+}}, %{{.+}}, %[[TRUE]] - tt.func @tc_gen5_mma_block_scale(%a: !ttg.memdesc<128x64xi8, #shared, #ttg.shared_memory>, + tt.func @tc_gen5_mma_block_scale(%a: !ttg.memdesc<128x64xf8E4M3FN, #shared, #ttg.shared_memory>, %b: !ttg.memdesc<32x128xi8, #shared1, #ttg.shared_memory>, %c: !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>, %scale_a: !ttg.memdesc<128x2xi8, #tmem_scales, #ttng.tensor_memory>, @@ -220,7 +220,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { %barrier: !ttg.memdesc<1xi64, #shared2, #ttg.shared_memory, mutable>, %barrierPred: i1) { ttng.tc_gen5_mma_scaled %a, %b, %c, %scale_a, %scale_b, %useAcc, %pred lhs = e4m3 rhs = e2m1, %barrier[%barrierPred] {is_async} : - !ttg.memdesc<128x64xi8, #shared, #ttg.shared_memory>, + !ttg.memdesc<128x64xf8E4M3FN, #shared, #ttg.shared_memory>, !ttg.memdesc<32x128xi8, #shared1, #ttg.shared_memory>, !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>, !ttg.memdesc<128x2xi8, #tmem_scales, #ttng.tensor_memory>, diff --git a/test/Conversion/tritongpu_to_llvm_hopper.mlir b/test/Conversion/tritongpu_to_llvm_hopper.mlir index 04edfd28cb..b8db8d449b 100644 --- a/test/Conversion/tritongpu_to_llvm_hopper.mlir +++ b/test/Conversion/tritongpu_to_llvm_hopper.mlir @@ -141,13 +141,14 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32} { #blocked = #ttg.blocked<{sizePerThread = [1, 16], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}> #mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 64, 16]}> #shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = true, elementBitWidth = 8}> +#shared1 = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = true, elementBitWidth = 16}> #smem = #ttg.shared_memory module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-LABEL: dot_reg_operand_upcast - tt.func @dot_reg_operand_upcast(%a_desc: !ttg.memdesc<128x64xi8, #shared, #smem>, %b: !ttg.memdesc<64x64xf16, #shared, #smem>, %acc: tensor<128x64xf32, #mma>) { + tt.func @dot_reg_operand_upcast(%a_desc: !ttg.memdesc<128x64xi8, #shared, #smem>, %b: !ttg.memdesc<64x64xf16, #shared1, #smem>, %acc: tensor<128x64xf32, #mma>) { %a_dotop = ttg.local_load %a_desc : !ttg.memdesc<128x64xi8, #shared, #smem> -> tensor<128x64xi8, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> %a_casted = arith.sitofp %a_dotop : tensor<128x64xi8, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> to tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> - %res = ttng.warp_group_dot %a_casted, %b, %acc : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * !ttg.memdesc<64x64xf16, #shared, #smem> -> tensor<128x64xf32, #mma> + %res = ttng.warp_group_dot %a_casted, %b, %acc : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * !ttg.memdesc<64x64xf16, #shared1, #smem> -> tensor<128x64xf32, #mma> tt.return } } diff --git a/test/NVWS/assign_stage_phase.mlir b/test/NVWS/assign_stage_phase.mlir index cf610a71f7..02f2cd9c76 100644 --- a/test/NVWS/assign_stage_phase.mlir +++ b/test/NVWS/assign_stage_phase.mlir @@ -32,11 +32,11 @@ module attributes {"ttg.num-warps" = 4 : i32} { // CHECK: put.exit [[AREF]][[[S0b]]] nvws.aref.put.exit %1[%c0_i32], %token [#nvws.async_op] {ttg.partition = array} : <[!ttg.memdesc<3x1xi32, #shared, #smem, mutable>]>, !ttg.async.token - // CHECK-NEXT: [[S1a:%.*]] = arith.addi [[S1]], [[C1]] {ttg.partition = array} - // CHECK-NEXT: [[CMP:%.*]] = arith.cmpi eq, [[S1a]], [[C3]] {ttg.partition = array} - // CHECK-NEXT: [[S1b:%.*]] = arith.select [[CMP]], [[C0]], [[S1a]] {ttg.partition = array} - // CHECK-NEXT: [[P1a:%.*]] = arith.xori [[P1]], [[C1]] {ttg.partition = array} - // CHECK-NEXT: [[P1b:%.*]] = arith.select [[CMP]], [[P1a]], [[P1]] {ttg.partition = array} + // CHECK-NEXT: [[S1a:%.*]] = arith.addi [[S1]], [[C1]] {ttg.partition = array} + // CHECK-NEXT: [[CMP:%.*]] = arith.cmpi eq, [[S1a]], [[C3]] {ttg.partition = array} + // CHECK-NEXT: [[S1b:%.*]] = arith.select [[CMP]], [[C0]], [[S1a]] {ttg.partition = array} + // CHECK-NEXT: [[P1a:%.*]] = arith.xori [[P1]], [[C1]] {ttg.partition = array} + // CHECK-NEXT: [[P1b:%.*]] = arith.select [[CMP]], [[P1a]], [[P1]] {ttg.partition = array} // CHECK-NEXT: {{.*}}, [[TOK1:%.*]] = nvws.aref.get.enter [[AREF]][[[S1b]], [[P1b]]] {ttg.partition = array} %buffers_0, %token_1 = nvws.aref.get.enter %1[%c0_i32, %c0_i32] {ttg.partition = array} : <[!ttg.memdesc<3x1xi32, #shared, #smem, mutable>]> -> !ttg.memdesc<1xi32, #shared, #smem, mutable, 1x1>, !ttg.async.token %3 = ttg.local_load %buffers_0 {ttg.partition = array} : !ttg.memdesc<1xi32, #shared, #smem, mutable, 1x1> -> tensor<1xi32, #blocked> @@ -45,11 +45,11 @@ module attributes {"ttg.num-warps" = 4 : i32} { "op_b"(%3) {ttg.partition = array} : (tensor<1xi32, #blocked>) -> () // CHECK: op_b - // CHECK-NEXT: [[S2a:%.*]] = arith.addi [[S2]], [[C1]] {ttg.partition = array} - // CHECK-NEXT: [[CMP:%.*]] = arith.cmpi eq, [[S2a]], [[C3]] {ttg.partition = array} - // CHECK-NEXT: [[S2b:%.*]] = arith.select [[CMP]], [[C0]], [[S2a]] {ttg.partition = array} - // CHECK-NEXT: [[P2a:%.*]] = arith.xori [[P2]], [[C1]] {ttg.partition = array} - // CHECK-NEXT: [[P2b:%.*]] = arith.select [[CMP]], [[P2a]], [[P2]] {ttg.partition = array} + // CHECK-NEXT: [[S2a:%.*]] = arith.addi [[S2]], [[C1]] {ttg.partition = array} + // CHECK-NEXT: [[CMP:%.*]] = arith.cmpi eq, [[S2a]], [[C3]] {ttg.partition = array} + // CHECK-NEXT: [[S2b:%.*]] = arith.select [[CMP]], [[C0]], [[S2a]] {ttg.partition = array} + // CHECK-NEXT: [[P2a:%.*]] = arith.xori [[P2]], [[C1]] {ttg.partition = array} + // CHECK-NEXT: [[P2b:%.*]] = arith.select [[CMP]], [[P2a]], [[P2]] {ttg.partition = array} // CHECK-NEXT: {{.*}}, [[TOK2:%.*]] = nvws.aref.get.enter [[AREF]][[[S2b]], [[P2b]]] {ttg.partition = array} %buffers_2, %token_3 = nvws.aref.get.enter %1[%c0_i32, %c0_i32] {ttg.partition = array} : <[!ttg.memdesc<3x1xi32, #shared, #smem, mutable>]> -> !ttg.memdesc<1xi32, #shared, #smem, mutable, 1x1>, !ttg.async.token %4 = ttg.local_load %buffers_2 {ttg.partition = array} : !ttg.memdesc<1xi32, #shared, #smem, mutable, 1x1> -> tensor<1xi32, #blocked> @@ -310,7 +310,7 @@ module attributes {"ttg.num-warps" = 4 : i32, ttg.target = "cuda:100"} { // CHECK: nvws.aref.buffer [[AREF]][[[SPUT]]], [[TOK1]] %9 = nvws.aref.buffer %0, %arg3 {ttg.partition = array} : <[!ttg.memdesc<2x128x128xf32, #tmem, #ttng.tensor_memory, mutable>]>, !ttg.async.token -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable, 2x128x128> %10 = ttng.tc_gen5_mma %7, %8, %9[], %true, %true {ttg.partition = array} : !ttg.memdesc<128x64xf16, #shared, #smem>, !ttg.memdesc<64x128xf16, #shared, #smem>, !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable, 2x128x128> - %11 = arith.cmpi eq, %arg2, %c0_i32 : i32 + %11 = arith.cmpi eq, %arg2, %c0_i32 {ttg.partition = array} : i32 // CHECK: [[RET_IF:%.*]]:5 = scf.if %12 = scf.if %11 -> (!ttg.async.token) { nvws.aref.put.exit %0, %arg3 [#nvws.async_op] {ttg.partition = array} : <[!ttg.memdesc<2x128x128xf32, #tmem, #ttng.tensor_memory, mutable>]>, !ttg.async.token @@ -337,3 +337,307 @@ module attributes {"ttg.num-warps" = 4 : i32, ttg.target = "cuda:100"} { tt.return } } + + +// ----- +#blocked = #ttg.blocked<{sizePerThread = [1, 64], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [2, 2], order = [1, 0]}> +#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}> +#shared1 = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = true, elementBitWidth = 16}> +#smem = #ttg.shared_memory +#tmem = #ttng.tensor_memory_encoding +module attributes {"ttg.num-warps" = 4 : i32, ttg.target = "cuda:100"} { + // CHECK-LABEL: @attention_forward + tt.func public @attention_forward(%arg0: !ttg.memdesc<256x64xf16, #shared, #smem>, %arg1: !tt.tensordesc>, %arg2: !tt.tensordesc>, %arg3: f32, %arg4: i32) { + %cst = arith.constant dense<1.000000e+00> : tensor<256xf32, #ttg.slice<{dim = 1, parent = #blocked}>> + %cst_0 = arith.constant dense<0.000000e+00> : tensor<256x64xf32, #blocked> + %cst_1 = arith.constant dense<0xFF800000> : tensor<256xf32, #ttg.slice<{dim = 1, parent = #blocked}>> + %c64_i32 = arith.constant 64 : i32 + %c0_i32 = arith.constant 0 : i32 + %false = arith.constant false + %true = arith.constant true + %result = ttng.tmem_alloc : () -> !ttg.memdesc<2x256x64xf32, #tmem, #ttng.tensor_memory, mutable> + %0 = nvws.aref.create %result : <[!ttg.memdesc<2x256x64xf32, #tmem, #ttng.tensor_memory, mutable>]> + %buffers, %token = nvws.aref.put.enter %0 : <[!ttg.memdesc<2x256x64xf32, #tmem, #ttng.tensor_memory, mutable>]> -> !ttg.memdesc<256x64xf32, #tmem, #ttng.tensor_memory, mutable, 2x256x64>, !ttg.async.token + %result_2 = ttng.tmem_alloc : () -> !ttg.memdesc<1x256x64xf32, #tmem, #ttng.tensor_memory, mutable> + %1 = nvws.aref.create %result_2 : <[!ttg.memdesc<1x256x64xf32, #tmem, #ttng.tensor_memory, mutable>]> + %buffers_3, %token_4 = nvws.aref.put.enter %1 : <[!ttg.memdesc<1x256x64xf32, #tmem, #ttng.tensor_memory, mutable>]> -> !ttg.memdesc<256x64xf32, #tmem, #ttng.tensor_memory, mutable, 1x256x64>, !ttg.async.token + %2 = nvws.aref.buffer %1, %token_4 : <[!ttg.memdesc<1x256x64xf32, #tmem, #ttng.tensor_memory, mutable>]>, !ttg.async.token -> !ttg.memdesc<256x64xf32, #tmem, #ttng.tensor_memory, mutable, 1x256x64> + %3 = ttng.tmem_store %cst_0, %2[], %true : tensor<256x64xf32, #blocked> -> !ttg.memdesc<256x64xf32, #tmem, #ttng.tensor_memory, mutable, 1x256x64> + %result_5 = ttng.tmem_alloc : () -> !ttg.memdesc<1x256x64xf16, #tmem, #ttng.tensor_memory, mutable> + %4 = nvws.aref.create %result_5 : <[!ttg.memdesc<1x256x64xf16, #tmem, #ttng.tensor_memory, mutable>]> + // CHECK: [[RET:%.*]]:16 = scf.for + %5:4 = scf.for %arg5 = %c0_i32 to %arg4 step %c64_i32 iter_args(%arg6 = %cst, %arg7 = %cst_1, %arg8 = %token, %arg9 = %token_4) -> (tensor<256xf32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<256xf32, #ttg.slice<{dim = 1, parent = #blocked}>>, !ttg.async.token, !ttg.async.token) : i32 { + %7 = tt.descriptor_load %arg1[%arg5, %c0_i32] {ttg.partition = array} : !tt.tensordesc> -> tensor<64x64xf16, #blocked1> + %8 = ttg.local_alloc %7 {ttg.partition = array} : (tensor<64x64xf16, #blocked1>) -> !ttg.memdesc<64x64xf16, #shared, #smem> + %9 = ttg.memdesc_trans %8 {order = array, ttg.partition = array} : !ttg.memdesc<64x64xf16, #shared, #smem> -> !ttg.memdesc<64x64xf16, #shared1, #smem> + %10 = nvws.aref.buffer %0, %arg8 {ttg.partition = array} : <[!ttg.memdesc<2x256x64xf32, #tmem, #ttng.tensor_memory, mutable>]>, !ttg.async.token -> !ttg.memdesc<256x64xf32, #tmem, #ttng.tensor_memory, mutable, 2x256x64> + %11 = ttng.tc_gen5_mma %arg0, %9, %10[], %false, %true {ttg.partition = array} : !ttg.memdesc<256x64xf16, #shared, #smem>, !ttg.memdesc<64x64xf16, #shared1, #smem>, !ttg.memdesc<256x64xf32, #tmem, #ttng.tensor_memory, mutable, 2x256x64> + nvws.aref.put.exit %0, %arg8 [#nvws.async_op] {ttg.partition = array} : <[!ttg.memdesc<2x256x64xf32, #tmem, #ttng.tensor_memory, mutable>]>, !ttg.async.token + %buffers_10, %token_11 = nvws.aref.get.enter %0 {ttg.partition = array} : <[!ttg.memdesc<2x256x64xf32, #tmem, #ttng.tensor_memory, mutable>]> -> !ttg.memdesc<256x64xf32, #tmem, #ttng.tensor_memory, mutable, 2x256x64>, !ttg.async.token + %12 = nvws.aref.buffer %0, %token_11 {ttg.partition = array} : <[!ttg.memdesc<2x256x64xf32, #tmem, #ttng.tensor_memory, mutable>]>, !ttg.async.token -> !ttg.memdesc<256x64xf32, #tmem, #ttng.tensor_memory, mutable, 2x256x64> + %result_12, %token_13 = ttng.tmem_load %12[] {ttg.partition = array} : !ttg.memdesc<256x64xf32, #tmem, #ttng.tensor_memory, mutable, 2x256x64> -> tensor<256x64xf32, #blocked> + nvws.aref.get.exit %0, %token_11 [#nvws.async_op] {ttg.partition = array} : <[!ttg.memdesc<2x256x64xf32, #tmem, #ttng.tensor_memory, mutable>]>, !ttg.async.token + %13 = "compute_row_max"(%result_12, %arg3) {ttg.partition = array} : (tensor<256x64xf32, #blocked>, f32) -> tensor<256xf32, #ttg.slice<{dim = 1, parent = #blocked}>> + %14 = "sub_row_max"(%result_12, %13, %arg3) {ttg.partition = array} : (tensor<256x64xf32, #blocked>, tensor<256xf32, #ttg.slice<{dim = 1, parent = #blocked}>>, f32) -> tensor<256x64xf32, #blocked> + %15 = math.exp2 %14 {ttg.partition = array} : tensor<256x64xf32, #blocked> + %16 = arith.subf %arg7, %13 {ttg.partition = array} : tensor<256xf32, #ttg.slice<{dim = 1, parent = #blocked}>> + %17 = arith.subf %arg7, %13 {ttg.partition = array} : tensor<256xf32, #ttg.slice<{dim = 1, parent = #blocked}>> + %18 = math.exp2 %16 {ttg.partition = array} : tensor<256xf32, #ttg.slice<{dim = 1, parent = #blocked}>> + %19 = math.exp2 %17 {ttg.partition = array} : tensor<256xf32, #ttg.slice<{dim = 1, parent = #blocked}>> + %20 = "tt.reduce"(%15) <{axis = 1 : i32}> ({ + ^bb0(%arg10: f32, %arg11: f32): + %36 = arith.addf %arg10, %arg11 : f32 + tt.reduce.return %36 : f32 + }) {ttg.partition = array} : (tensor<256x64xf32, #blocked>) -> tensor<256xf32, #ttg.slice<{dim = 1, parent = #blocked}>> + %21 = arith.mulf %arg6, %19 {ttg.partition = array} : tensor<256xf32, #ttg.slice<{dim = 1, parent = #blocked}>> + %22 = arith.addf %21, %20 {ttg.partition = array} : tensor<256xf32, #ttg.slice<{dim = 1, parent = #blocked}>> + %23 = tt.expand_dims %18 {axis = 1 : i32, ttg.partition = array} : tensor<256xf32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<256x1xf32, #blocked> + %24 = tt.broadcast %23 {ttg.partition = array} : tensor<256x1xf32, #blocked> -> tensor<256x64xf32, #blocked> + %25 = nvws.aref.buffer %1, %arg9 {ttg.partition = array} : <[!ttg.memdesc<1x256x64xf32, #tmem, #ttng.tensor_memory, mutable>]>, !ttg.async.token -> !ttg.memdesc<256x64xf32, #tmem, #ttng.tensor_memory, mutable, 1x256x64> + %result_14, %token_15 = ttng.tmem_load %25[] {ttg.partition = array} : !ttg.memdesc<256x64xf32, #tmem, #ttng.tensor_memory, mutable, 1x256x64> -> tensor<256x64xf32, #blocked> + %26 = arith.mulf %result_14, %24 {ttg.partition = array} : tensor<256x64xf32, #blocked> + %27 = tt.descriptor_load %arg2[%arg5, %c0_i32] {ttg.partition = array} : !tt.tensordesc> -> tensor<64x64xf16, #blocked1> + %28 = ttg.local_alloc %27 {ttg.partition = array} : (tensor<64x64xf16, #blocked1>) -> !ttg.memdesc<64x64xf16, #shared, #smem> + %29 = arith.truncf %15 {ttg.partition = array} : tensor<256x64xf32, #blocked> to tensor<256x64xf16, #blocked> + %buffers_16, %token_17 = nvws.aref.put.enter %4 {ttg.partition = array} : <[!ttg.memdesc<1x256x64xf16, #tmem, #ttng.tensor_memory, mutable>]> -> !ttg.memdesc<256x64xf16, #tmem, #ttng.tensor_memory, mutable, 1x256x64>, !ttg.async.token + %30 = nvws.aref.buffer %4, %token_17 {ttg.partition = array} : <[!ttg.memdesc<1x256x64xf16, #tmem, #ttng.tensor_memory, mutable>]>, !ttg.async.token -> !ttg.memdesc<256x64xf16, #tmem, #ttng.tensor_memory, mutable, 1x256x64> + %31 = ttng.tmem_store %29, %30[%token_17], %true {ttg.partition = array} : tensor<256x64xf16, #blocked> -> !ttg.memdesc<256x64xf16, #tmem, #ttng.tensor_memory, mutable, 1x256x64> + nvws.aref.put.exit %4, %token_17 [#nvws.async_op] {ttg.partition = array} : <[!ttg.memdesc<1x256x64xf16, #tmem, #ttng.tensor_memory, mutable>]>, !ttg.async.token + %32 = ttng.tmem_store %26, %25[], %true {ttg.partition = array} : tensor<256x64xf32, #blocked> -> !ttg.memdesc<256x64xf32, #tmem, #ttng.tensor_memory, mutable, 1x256x64> + nvws.aref.put.exit %1, %arg9 [#nvws.async_op] {ttg.partition = array} : <[!ttg.memdesc<1x256x64xf32, #tmem, #ttng.tensor_memory, mutable>]>, !ttg.async.token + // CHECK: tmem_store + // CHECK: tmem_store + // CHECK: arith.addi {{.*}} {ttg.partition = array} + // CHECK: arith.cmpi {{.*}} {ttg.partition = array} + // CHECK: [[S10:%.*]] = arith.select {{.*}} {ttg.partition = array} + // CHECK: arith.xori {{.*}} {ttg.partition = array} + // CHECK: [[P11:%.*]] = arith.select {{.*}} {ttg.partition = array} + %buffers_18, %token_19 = nvws.aref.get.enter %1 {ttg.partition = array} : <[!ttg.memdesc<1x256x64xf32, #tmem, #ttng.tensor_memory, mutable>]> -> !ttg.memdesc<256x64xf32, #tmem, #ttng.tensor_memory, mutable, 1x256x64>, !ttg.async.token + %33 = nvws.aref.buffer %1, %token_19 {ttg.partition = array} : <[!ttg.memdesc<1x256x64xf32, #tmem, #ttng.tensor_memory, mutable>]>, !ttg.async.token -> !ttg.memdesc<256x64xf32, #tmem, #ttng.tensor_memory, mutable, 1x256x64> + %buffers_20, %token_21 = nvws.aref.get.enter %4 {ttg.partition = array} : <[!ttg.memdesc<1x256x64xf16, #tmem, #ttng.tensor_memory, mutable>]> -> !ttg.memdesc<256x64xf16, #tmem, #ttng.tensor_memory, mutable, 1x256x64>, !ttg.async.token + %34 = nvws.aref.buffer %4, %token_21 {ttg.partition = array} : <[!ttg.memdesc<1x256x64xf16, #tmem, #ttng.tensor_memory, mutable>]>, !ttg.async.token -> !ttg.memdesc<256x64xf16, #tmem, #ttng.tensor_memory, mutable, 1x256x64> + %35 = ttng.tc_gen5_mma %34, %28, %33[], %true, %true {ttg.partition = array} : !ttg.memdesc<256x64xf16, #tmem, #ttng.tensor_memory, mutable, 1x256x64>, !ttg.memdesc<64x64xf16, #shared, #smem>, !ttg.memdesc<256x64xf32, #tmem, #ttng.tensor_memory, mutable, 1x256x64> + nvws.aref.get.exit %4, %token_21 [#nvws.async_op] {ttg.partition = array} : <[!ttg.memdesc<1x256x64xf16, #tmem, #ttng.tensor_memory, mutable>]>, !ttg.async.token + nvws.aref.get.exit %1, %token_19 [#nvws.async_op] {ttg.partition = array} : <[!ttg.memdesc<1x256x64xf32, #tmem, #ttng.tensor_memory, mutable>]>, !ttg.async.token + // CHECK: tc_gen5_mma {{.*}} %true, %true + // CHECK-NEXT: aref.get.exit {{.*}} {ttg.partition = array} + // CHECK-NEXT: aref.get.exit {{.*}} {ttg.partition = array} + // CHECK-NEXT: arith.addi {{.*}} {ttg.partition = array} + // CHECK-NEXT: arith.cmpi {{.*}} {ttg.partition = array} + // CHECK-NEXT: [[S4:%.*]] = arith.select {{.*}} {ttg.partition = array} + // CHECK-NEXT: arith.xori {{.*}} {ttg.partition = array} + // CHECK-NEXT: [[P0:%.*]] = arith.select {{.*}} {ttg.partition = array} + // CHECK-NEXT: aref.put.enter {{.*}}[[[S4]], [[P0]]] {ttg.partition = array} + // CHECK-NEXT: arith.addi {{.*}} {ttg.partition = array} + // CHECK-NEXT: arith.cmpi {{.*}} {ttg.partition = array} + // CHECK-NEXT: [[S8:%.*]] = arith.select {{.*}} {ttg.partition = array} + // CHECK-NEXT: arith.xori {{.*}} {ttg.partition = array} + // CHECK-NEXT: [[P1:%.*]] = arith.select {{.*}} {ttg.partition = array} + // CHECK-NEXT: aref.put.enter {{.*}}[[[S8]], [[P1]]] {ttg.partition = array} + %buffers_22, %token_23 = nvws.aref.put.enter %0 {ttg.partition = array} : <[!ttg.memdesc<2x256x64xf32, #tmem, #ttng.tensor_memory, mutable>]> -> !ttg.memdesc<256x64xf32, #tmem, #ttng.tensor_memory, mutable, 2x256x64>, !ttg.async.token + %buffers_24, %token_25 = nvws.aref.put.enter %1 {ttg.partition = array} : <[!ttg.memdesc<1x256x64xf32, #tmem, #ttng.tensor_memory, mutable>]> -> !ttg.memdesc<256x64xf32, #tmem, #ttng.tensor_memory, mutable, 1x256x64>, !ttg.async.token + // CHECK: scf.yield [[X0:%.*]], [[X1:%.*]], [[X2:%.*]], [[X3:%.*]], [[S4]], [[X5:%.*]], [[X6:%.*]], [[X7:%.*]], [[S8]], [[X9:%.*]], [[S10]], [[P11]] + scf.yield %22, %13, %token_23, %token_25 : tensor<256xf32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<256xf32, #ttg.slice<{dim = 1, parent = #blocked}>>, !ttg.async.token, !ttg.async.token + } {tt.warp_specialize, ttg.partition.stages = [0 : i32, 1 : i32, 0 : i32, 1 : i32], ttg.warp_specialize.tag = 0 : i32} + // CHECK-NEXT: } {tt.warp_specialize + // CHECK-NEXT: aref.put.exit {{.*}}[[RET]]#8 + // CHECK-NEXT: aref.put.exit {{.*}}[[RET]]#4 + // CHECK-NEXT: arith.addi [[RET]]#10 + // CHECK-NEXT: arith.cmpi + // CHECK-NEXT: arith.select + // CHECK-NEXT: arith.xori [[RET]]#11 + nvws.aref.put.exit %1, %5#3 [#nvws.async_op] : <[!ttg.memdesc<1x256x64xf32, #tmem, #ttng.tensor_memory, mutable>]>, !ttg.async.token + nvws.aref.put.exit %0, %5#2 [#nvws.async_op] : <[!ttg.memdesc<2x256x64xf32, #tmem, #ttng.tensor_memory, mutable>]>, !ttg.async.token + %buffers_6, %token_7 = nvws.aref.get.enter %1 : <[!ttg.memdesc<1x256x64xf32, #tmem, #ttng.tensor_memory, mutable>]> -> !ttg.memdesc<256x64xf32, #tmem, #ttng.tensor_memory, mutable, 1x256x64>, !ttg.async.token + %6 = nvws.aref.buffer %1, %token_7 : <[!ttg.memdesc<1x256x64xf32, #tmem, #ttng.tensor_memory, mutable>]>, !ttg.async.token -> !ttg.memdesc<256x64xf32, #tmem, #ttng.tensor_memory, mutable, 1x256x64> + %result_8, %token_9 = ttng.tmem_load %6[] : !ttg.memdesc<256x64xf32, #tmem, #ttng.tensor_memory, mutable, 1x256x64> -> tensor<256x64xf32, #blocked> + nvws.aref.get.exit %1, %token_7 [#nvws.async_op] : <[!ttg.memdesc<1x256x64xf32, #tmem, #ttng.tensor_memory, mutable>]>, !ttg.async.token + "use"(%5#0, %result_8, %5#1) : (tensor<256xf32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<256x64xf32, #blocked>, tensor<256xf32, #ttg.slice<{dim = 1, parent = #blocked}>>) -> () + tt.return + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 128], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [2, 2], order = [1, 0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 64], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}> +#linear = #ttg.linear<{register = [[0, 1], [0, 2], [32, 0], [64, 0], [0, 4]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [16, 0]], warp = [[0, 0], [0, 0]], block = []}> +#linear1 = #ttg.linear<{register = [[1, 0], [2, 0], [0, 32], [0, 64], [4, 0]], lane = [[0, 1], [0, 2], [0, 4], [0, 8], [0, 16]], warp = [[0, 0], [0, 0]], block = []}> +#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}> +#shared1 = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = true, elementBitWidth = 16}> +#shared2 = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 8}> +#shared3 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [4, 3, 2, 1, 0]}> +#shared4 = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = true, elementBitWidth = 8}> +#smem = #ttg.shared_memory +#tmem = #ttng.tensor_memory_encoding +#tmem1 = #ttng.tensor_memory_encoding +#tmem_scales = #ttng.tensor_memory_scales_encoding<> +module attributes {"ttg.num-warps" = 4 : i32, ttg.target = "cuda:100"} { + // CHECK-LABEL: @matmul_tma_acc_with_conditional_user + tt.func @matmul_tma_acc_with_conditional_user(%arg0: !tt.tensordesc>, %arg1: !tt.tensordesc>) { + %c32_i32 = arith.constant 32 : i32 + %cst = arith.constant dense<1.000000e+00> : tensor<128x128xf32, #blocked> + %cst_0 = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked> + %true = arith.constant true + %c1_i32 = arith.constant 1 : i32 + %c0_i32 = arith.constant 0 : i32 + %result = ttng.tmem_alloc : () -> !ttg.memdesc<2x128x128xf32, #tmem, #ttng.tensor_memory, mutable> + %0 = nvws.aref.create %result : <[!ttg.memdesc<2x128x128xf32, #tmem, #ttng.tensor_memory, mutable>]> + %buffers, %token = nvws.aref.put.enter %0 : <[!ttg.memdesc<2x128x128xf32, #tmem, #ttng.tensor_memory, mutable>]> -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable, 2x128x128>, !ttg.async.token + %1 = nvws.aref.buffer %0, %token : <[!ttg.memdesc<2x128x128xf32, #tmem, #ttng.tensor_memory, mutable>]>, !ttg.async.token -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable, 2x128x128> + %2 = ttng.tmem_store %cst_0, %1[], %true : tensor<128x128xf32, #blocked> -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable, 2x128x128> + %3 = scf.for %arg2 = %c0_i32 to %c32_i32 step %c1_i32 iter_args(%arg3 = %token) -> (!ttg.async.token) : i32 { + %4:3 = "get_offsets"(%arg2) : (i32) -> (i32, i32, i32) + %5 = tt.descriptor_load %arg0[%4#0, %4#2] {ttg.partition = array} : !tt.tensordesc> -> tensor<128x64xf16, #blocked1> + %6 = tt.descriptor_load %arg1[%4#1, %4#2] {ttg.partition = array} : !tt.tensordesc> -> tensor<64x128xf16, #blocked1> + %7 = ttg.local_alloc %5 {ttg.partition = array} : (tensor<128x64xf16, #blocked1>) -> !ttg.memdesc<128x64xf16, #shared, #smem> + %8 = ttg.local_alloc %6 {ttg.partition = array} : (tensor<64x128xf16, #blocked1>) -> !ttg.memdesc<64x128xf16, #shared, #smem> + %9 = nvws.aref.buffer %0, %arg3 {ttg.partition = array} : <[!ttg.memdesc<2x128x128xf32, #tmem, #ttng.tensor_memory, mutable>]>, !ttg.async.token -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable, 2x128x128> + %10 = ttng.tc_gen5_mma %7, %8, %9[], %true, %true {ttg.partition = array} : !ttg.memdesc<128x64xf16, #shared, #smem>, !ttg.memdesc<64x128xf16, #shared, #smem>, !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable, 2x128x128> + // CHECK: tc_gen5_mma + // CHECK-NEXT: arith.cmpi {{.*}} {ttg.partition = array} + // CHECK-NEXT: scf.if + %11 = arith.cmpi eq, %arg2, %c0_i32 {ttg.partition = array} : i32 + %12 = scf.if %11 -> (!ttg.async.token) { + nvws.aref.put.exit %0, %arg3 [#nvws.async_op] {ttg.partition = array} : <[!ttg.memdesc<2x128x128xf32, #tmem, #ttng.tensor_memory, mutable>]>, !ttg.async.token + %buffers_1, %token_2 = nvws.aref.get.enter %0 {ttg.partition = array} : <[!ttg.memdesc<2x128x128xf32, #tmem, #ttng.tensor_memory, mutable>]> -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable, 2x128x128>, !ttg.async.token + %15 = nvws.aref.buffer %0, %token_2 {ttg.partition = array} : <[!ttg.memdesc<2x128x128xf32, #tmem, #ttng.tensor_memory, mutable>]>, !ttg.async.token -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable, 2x128x128> + %result_3, %token_4 = ttng.tmem_load %15[] : !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable, 2x128x128> -> tensor<128x128xf32, #blocked> + nvws.aref.get.exit %0, %token_2 [#nvws.async_op] {ttg.partition = array} : <[!ttg.memdesc<2x128x128xf32, #tmem, #ttng.tensor_memory, mutable>]>, !ttg.async.token + "acc_user"(%result_3) : (tensor<128x128xf32, #blocked>) -> () + %buffers_5, %token_6 = nvws.aref.put.enter %0 {ttg.partition = array} : <[!ttg.memdesc<2x128x128xf32, #tmem, #ttng.tensor_memory, mutable>]> -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable, 2x128x128>, !ttg.async.token + scf.yield %token_6 : !ttg.async.token + } else { + scf.yield %arg3 : !ttg.async.token + } + %13 = nvws.aref.buffer %0, %12 {ttg.partition = array} : <[!ttg.memdesc<2x128x128xf32, #tmem, #ttng.tensor_memory, mutable>]>, !ttg.async.token -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable, 2x128x128> + %14 = ttng.tmem_store %cst, %13[], %true {ttg.partition = array} : tensor<128x128xf32, #blocked> -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable, 2x128x128> + scf.yield %12 : !ttg.async.token + } {tt.num_stages = 2 : i32, tt.warp_specialize, ttg.partition.stages = [0 : i32, 1 : i32, 0 : i32], ttg.warp_specialize.tag = 5 : i32} + nvws.aref.put.exit %0, %3 [#nvws.async_op] : <[!ttg.memdesc<2x128x128xf32, #tmem, #ttng.tensor_memory, mutable>]>, !ttg.async.token + tt.return + } +} + +// ----- +#blocked = #ttg.blocked<{sizePerThread = [1, 128], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}> +#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 8}> +#shared1 = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = true, elementBitWidth = 8}> +#smem = #ttg.shared_memory +#tmem = #ttng.tensor_memory_encoding +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: @matmul_tma_persistent_ws_kernel + tt.func public @matmul_tma_persistent_ws_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %false = arith.constant false + %true = arith.constant true + %c1_i64 = arith.constant 1 : i64 + %c128_i32 = arith.constant 128 : i32 + %c148_i32 = arith.constant 148 : i32 + %c0_i32 = arith.constant 0 : i32 + %c1_i32 = arith.constant 1 : i32 + %c127_i32 = arith.constant 127 : i32 + %c8_i32 = arith.constant 8 : i32 + %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked> + %0 = arith.extsi %arg3 : i32 to i64 + %1 = tt.make_tensor_descriptor %arg0, [%arg6, %arg8], [%0, %c1_i64] : , > + %2 = arith.extsi %arg4 : i32 to i64 + %3 = tt.make_tensor_descriptor %arg1, [%arg7, %arg8], [%2, %c1_i64] : , > + %4 = arith.extsi %arg5 : i32 to i64 + %5 = tt.make_tensor_descriptor %arg2, [%arg6, %arg7], [%4, %c1_i64] : , > + %6 = tt.get_program_id x : i32 + %7 = arith.addi %arg6, %c127_i32 : i32 + %8 = arith.divsi %7, %c128_i32 : i32 + %9 = arith.addi %arg7, %c127_i32 : i32 + %10 = arith.divsi %9, %c128_i32 : i32 + %11 = arith.addi %arg8, %c127_i32 : i32 + %12 = arith.divsi %11, %c128_i32 : i32 + %13 = arith.muli %8, %10 : i32 + %14 = arith.muli %10, %c8_i32 : i32 + %15 = ttg.local_alloc : () -> !ttg.memdesc<1x128x128xf8E4M3FN, #shared, #smem, mutable> + %16 = nvws.aref.create %15 : <[!ttg.memdesc<1x128x128xf8E4M3FN, #shared, #smem, mutable>]> + %17 = ttg.local_alloc : () -> !ttg.memdesc<1x128x128xf8E4M3FN, #shared, #smem, mutable> + %18 = nvws.aref.create %17 : <[!ttg.memdesc<1x128x128xf8E4M3FN, #shared, #smem, mutable>]> + %result = ttng.tmem_alloc : () -> !ttg.memdesc<1x128x128xf32, #tmem, #ttng.tensor_memory, mutable> + %19 = nvws.aref.create %result : <[!ttg.memdesc<1x128x128xf32, #tmem, #ttng.tensor_memory, mutable>]> + scf.for %arg9 = %6 to %13 step %c148_i32 : i32 { + %20 = arith.divsi %arg9, %14 {ttg.partition = array} : i32 + %21 = arith.muli %20, %c8_i32 {ttg.partition = array} : i32 + %22 = arith.subi %8, %21 {ttg.partition = array} : i32 + %23 = arith.minsi %22, %c8_i32 {ttg.partition = array} : i32 + %24 = arith.remsi %arg9, %23 {ttg.partition = array} : i32 + %25 = arith.addi %21, %24 {ttg.partition = array} : i32 + %26 = arith.remsi %arg9, %14 {ttg.partition = array} : i32 + %27 = arith.divsi %26, %23 {ttg.partition = array} : i32 + %28 = arith.muli %25, %c128_i32 {ttg.partition = array} : i32 + %29 = arith.muli %27, %c128_i32 {ttg.partition = array} : i32 + // CHECK: arith.addi {{.*}} {ttg.partition = array} + // CHECK-NEXT: arith.cmpi {{.*}} {ttg.partition = array} + // CHECK-NEXT: arith.select {{.*}} {ttg.partition = array} + // CHECK-NEXT: arith.xori {{.*}} {ttg.partition = array} + // CHECK-NEXT: arith.select {{.*}} {ttg.partition = array} + // CHECK-NEXT: aref.put.enter {{.*}} {ttg.partition = array} + %buffers, %token = nvws.aref.put.enter %19 {ttg.partition = array} : <[!ttg.memdesc<1x128x128xf32, #tmem, #ttng.tensor_memory, mutable>]> -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable, 1x128x128>, !ttg.async.token + %30 = nvws.aref.buffer %19, %token {ttg.partition = array} : <[!ttg.memdesc<1x128x128xf32, #tmem, #ttng.tensor_memory, mutable>]>, !ttg.async.token -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable, 1x128x128> + %31 = ttng.tmem_store %cst, %30[], %true {ttg.partition = array} : tensor<128x128xf32, #blocked> -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable, 1x128x128> + nvws.aref.put.exit %19, %token [#nvws.async_op] {ttg.partition = array} : <[!ttg.memdesc<1x128x128xf32, #tmem, #ttng.tensor_memory, mutable>]>, !ttg.async.token + %buffers_0, %token_1 = nvws.aref.get.enter %19 {ttg.partition = array} : <[!ttg.memdesc<1x128x128xf32, #tmem, #ttng.tensor_memory, mutable>]> -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable, 1x128x128>, !ttg.async.token + // CHECK: tmem_store + // CHECK-NEXT: aref.put.exit + // CHECK-NEXT: arith.addi {{.*}} {ttg.partition = array} + // CHECK-NEXT: arith.cmpi {{.*}} {ttg.partition = array} + // CHECK-NEXT: arith.select {{.*}} {ttg.partition = array} + // CHECK-NEXT: arith.xori {{.*}} {ttg.partition = array} + // CHECK-NEXT: arith.select {{.*}} {ttg.partition = array} + // CHECK-NEXT: aref.get.enter + // CHECK-NEXT: scf.for + %32 = scf.for %arg10 = %c0_i32 to %12 step %c1_i32 iter_args(%arg11 = %false) -> (i1) : i32 { + %36 = arith.muli %arg10, %c128_i32 {loop.cluster = 2 : i32, loop.stage = 0 : i32, ttg.partition = array} : i32 + // CHECK-NEXT: arith.muli {{.*}} ttg.partition = array + // CHECK-NEXT: arith.addi {{.*}} ttg.partition = array + // CHECK-NEXT: arith.cmpi {{.*}} ttg.partition = array + // CHECK-NEXT: arith.select {{.*}} ttg.partition = array + // CHECK-NEXT: arith.xori {{.*}} ttg.partition = array + // CHECK-NEXT: arith.select {{.*}} ttg.partition = array + // CHECK-NEXT: aref.put.enter {{.*}} ttg.partition = array + %buffers_8, %token_9 = nvws.aref.put.enter %16 {loop.cluster = 2 : i32, loop.stage = 0 : i32, ttg.partition = array} : <[!ttg.memdesc<1x128x128xf8E4M3FN, #shared, #smem, mutable>]> -> !ttg.memdesc<128x128xf8E4M3FN, #shared, #smem, mutable, 1x128x128>, !ttg.async.token + nvws.descriptor_load %1[%28, %36] 16384 %buffers_8 {loop.cluster = 2 : i32, loop.stage = 0 : i32, ttg.partition = array} : !tt.tensordesc>, i32, i32, !ttg.memdesc<128x128xf8E4M3FN, #shared, #smem, mutable, 1x128x128> + nvws.aref.put.exit %16, %token_9 [#nvws.async_op] {loop.cluster = 2 : i32, loop.stage = 0 : i32, ttg.partition = array} : <[!ttg.memdesc<1x128x128xf8E4M3FN, #shared, #smem, mutable>]>, !ttg.async.token + // CHECK: aref.put.exit {{.*}} ttg.partition = array + // CHECK-NEXT: arith.addi {{.*}} {ttg.partition = array} + // CHECK-NEXT: arith.cmpi {{.*}} {ttg.partition = array} + // CHECK-NEXT: arith.select {{.*}} {ttg.partition = array} + // CHECK-NEXT: arith.xori {{.*}} {ttg.partition = array} + // CHECK-NEXT: arith.select {{.*}} {ttg.partition = array} + // CHECK-NEXT: aref.get.enter {{.*}} {ttg.partition = array} + + // CHECK-NOT: partition = array} : <[!ttg.memdesc<1x128x128xf8E4M3FN, #shared, #smem, mutable>]> -> !ttg.memdesc<128x128xf8E4M3FN, #shared, #smem, 1x128x128>, !ttg.async.token + %buffers_12, %token_13 = nvws.aref.put.enter %18 {loop.cluster = 2 : i32, loop.stage = 0 : i32, ttg.partition = array} : <[!ttg.memdesc<1x128x128xf8E4M3FN, #shared, #smem, mutable>]> -> !ttg.memdesc<128x128xf8E4M3FN, #shared, #smem, mutable, 1x128x128>, !ttg.async.token + nvws.descriptor_load %3[%29, %36] 16384 %buffers_12 {loop.cluster = 2 : i32, loop.stage = 0 : i32, ttg.partition = array} : !tt.tensordesc>, i32, i32, !ttg.memdesc<128x128xf8E4M3FN, #shared, #smem, mutable, 1x128x128> + nvws.aref.put.exit %18, %token_13 [#nvws.async_op] {loop.cluster = 2 : i32, loop.stage = 0 : i32, ttg.partition = array} : <[!ttg.memdesc<1x128x128xf8E4M3FN, #shared, #smem, mutable>]>, !ttg.async.token + %buffers_14, %token_15 = nvws.aref.get.enter %18 {ttg.partition = array} : <[!ttg.memdesc<1x128x128xf8E4M3FN, #shared, #smem, mutable>]> -> !ttg.memdesc<128x128xf8E4M3FN, #shared, #smem, 1x128x128>, !ttg.async.token + %37 = ttg.memdesc_trans %buffers_14 {loop.cluster = 0 : i32, loop.stage = 2 : i32, order = array, ttg.partition = array} : !ttg.memdesc<128x128xf8E4M3FN, #shared, #smem, 1x128x128> -> !ttg.memdesc<128x128xf8E4M3FN, #shared1, #smem, 1x128x128> + %38 = nvws.aref.buffer %19, %token_1 {loop.cluster = 0 : i32, loop.stage = 2 : i32, ttg.partition = array} : <[!ttg.memdesc<1x128x128xf32, #tmem, #ttng.tensor_memory, mutable>]>, !ttg.async.token -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable, 1x128x128> + %39 = ttng.tc_gen5_mma %buffers_10, %37, %38[], %arg11, %true {loop.cluster = 0 : i32, loop.stage = 2 : i32, tt.self_latency = 1 : i32, ttg.partition = array} : !ttg.memdesc<128x128xf8E4M3FN, #shared, #smem, 1x128x128>, !ttg.memdesc<128x128xf8E4M3FN, #shared1, #smem, 1x128x128>, !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable, 1x128x128> + nvws.aref.get.exit %18, %token_15 [#nvws.async_op] {ttg.partition = array} : <[!ttg.memdesc<1x128x128xf8E4M3FN, #shared, #smem, mutable>]>, !ttg.async.token + nvws.aref.get.exit %16, %token_11 [#nvws.async_op] {ttg.partition = array} : <[!ttg.memdesc<1x128x128xf8E4M3FN, #shared, #smem, mutable>]>, !ttg.async.token + // CHECK: scf.yield + scf.yield %true : i1 + } {tt.scheduled_max_stage = 2 : i32} + nvws.aref.get.exit %19, %token_1 [#nvws.async_op] {ttg.partition = array} : <[!ttg.memdesc<1x128x128xf32, #tmem, #ttng.tensor_memory, mutable>]>, !ttg.async.token + %buffers_2, %token_3 = nvws.aref.put.enter %19 {ttg.partition = array} : <[!ttg.memdesc<1x128x128xf32, #tmem, #ttng.tensor_memory, mutable>]> -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable, 1x128x128>, !ttg.async.token + %33 = nvws.aref.buffer %19, %token_3 {ttg.partition = array} : <[!ttg.memdesc<1x128x128xf32, #tmem, #ttng.tensor_memory, mutable>]>, !ttg.async.token -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable, 1x128x128> + %result_4, %token_5 = ttng.tmem_load %33[] {ttg.partition = array} : !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable, 1x128x128> -> tensor<128x128xf32, #blocked> + nvws.aref.put.exit %19, %token_3 [#nvws.async_op] {ttg.partition = array} : <[!ttg.memdesc<1x128x128xf32, #tmem, #ttng.tensor_memory, mutable>]>, !ttg.async.token + %buffers_6, %token_7 = nvws.aref.get.enter %19 {ttg.partition = array} : <[!ttg.memdesc<1x128x128xf32, #tmem, #ttng.tensor_memory, mutable>]> -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable, 1x128x128>, !ttg.async.token + nvws.aref.get.exit %19, %token_7 [#nvws.async_op] {ttg.partition = array} : <[!ttg.memdesc<1x128x128xf32, #tmem, #ttng.tensor_memory, mutable>]>, !ttg.async.token + %34 = tt.fp_to_fp %result_4 {ttg.partition = array}, rounding = rtne : tensor<128x128xf32, #blocked> -> tensor<128x128xf8E4M3FN, #blocked> + %35 = ttg.convert_layout %34 {ttg.partition = array} : tensor<128x128xf8E4M3FN, #blocked> -> tensor<128x128xf8E4M3FN, #blocked1> + tt.descriptor_store %5[%28, %29], %35 {ttg.partition = array} : !tt.tensordesc>, tensor<128x128xf8E4M3FN, #blocked1> + } {tt.num_stages = 3 : i32, tt.warp_specialize, ttg.partition.stages = [0 : i32, 1 : i32, 0 : i32], ttg.warp_specialize.tag = 0 : i32} + tt.return + } +} diff --git a/test/NVWS/lower_aref.mlir b/test/NVWS/lower_aref.mlir index 1da74da187..58d7c1c7c9 100644 --- a/test/NVWS/lower_aref.mlir +++ b/test/NVWS/lower_aref.mlir @@ -386,7 +386,7 @@ module attributes {"ttg.num-warps" = 4 : i32, ttg.target = "cuda:100"} { // CHECK-NEXT: tc_gen5_mma {{.*}}, {{.*}}, [[VIEW]][] %9 = nvws.aref.buffer %0, %arg3 {ttg.partition = array} : <[!ttg.memdesc<2x128x128xf32, #tmem, #ttng.tensor_memory, mutable>]>, !ttg.async.token -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable, 2x128x128> %10 = ttng.tc_gen5_mma %7, %8, %9[], %true, %true {ttg.partition = array} : !ttg.memdesc<128x64xf16, #shared, #smem>, !ttg.memdesc<64x128xf16, #shared, #smem>, !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable, 2x128x128> - %11 = arith.cmpi eq, %arg2, %c0_i32 : i32 + %11 = arith.cmpi eq, %arg2, %c0_i32 {ttg.partition = array} : i32 // CHECK: [[RET_IF:%.*]]:4 = scf.if %12 = scf.if %11 -> (!ttg.async.token) { // CHECK: tc_gen5_commit diff --git a/test/TritonGPU/amd/amd-canonicalize-pointers-dont-run-mlir-canonicalizer.mlir b/test/TritonGPU/amd/amd-canonicalize-pointers-dont-run-mlir-canonicalizer.mlir index d1f45008d5..a016d39259 100644 --- a/test/TritonGPU/amd/amd-canonicalize-pointers-dont-run-mlir-canonicalizer.mlir +++ b/test/TritonGPU/amd/amd-canonicalize-pointers-dont-run-mlir-canonicalizer.mlir @@ -1,6 +1,6 @@ // NOTE: Assertions have been autogenerated by mlir/utils/generate-test-checks.py -// RUN: triton-opt %s -allow-unregistered-dialect -split-input-file -tritonamdgpu-canonicalize-pointers -verify-diagnostics | FileCheck %s +// RUN: triton-opt %s -allow-unregistered-dialect -split-input-file -tritonamdgpu-canonicalize-pointers="enable-large-tensor-ptr-canon=true" -verify-diagnostics | FileCheck %s module attributes {"ttg.num-warps" = 4 : i32} { tt.func @ifOpTwoYields(%arg0: !tt.ptr, %arg1: tensor<1024xf32>, %arg2: i1) -> (tensor<1024xf32>, tensor<1024xf32>) { diff --git a/test/TritonGPU/amd/amd-canonicalize-pointers-no-large-tensor.mlir b/test/TritonGPU/amd/amd-canonicalize-pointers-no-large-tensor.mlir new file mode 100644 index 0000000000..6cd080d170 --- /dev/null +++ b/test/TritonGPU/amd/amd-canonicalize-pointers-no-large-tensor.mlir @@ -0,0 +1,20 @@ +// RUN: triton-opt %s -allow-unregistered-dialect -split-input-file -tritonamdgpu-canonicalize-pointers="enable-large-tensor-ptr-canon=false" -canonicalize -verify-diagnostics | FileCheck %s + +// this case is copied from amd-canonicalize-pointers-no-large-tensor.mlir. With +// enable-large-tensor-ptr-canon=false, the input is not changed at all. +module attributes {"ttg.num-warps" = 4 : i32} { + tt.func @conversion1(%arg0: !tt.ptr) -> tensor<1024xf32> { + %c1024_i32 = arith.constant 1024 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c1024_i32 : i32 + %2 = tt.splat %1 : i32 -> tensor<1024xi32> + %3 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr> + %4 = tt.addptr %3, %2 : tensor<1024x!tt.ptr>, tensor<1024xi32> + %5 = tt.load %4 : tensor<1024x!tt.ptr> + tt.return %5 : tensor<1024xf32> + } +} + +// CHECK-LABEL: tt.func @conversion1 +// CHECK: %[[ADDPTR:.*]] = tt.addptr +// CHECK: = tt.load %[[ADDPTR]] diff --git a/test/TritonGPU/amd/amd-canonicalize-pointers.mlir b/test/TritonGPU/amd/amd-canonicalize-pointers.mlir index a165b609a1..21000e2f04 100644 --- a/test/TritonGPU/amd/amd-canonicalize-pointers.mlir +++ b/test/TritonGPU/amd/amd-canonicalize-pointers.mlir @@ -1,6 +1,6 @@ // NOTE: Assertions have been autogenerated by mlir/utils/generate-test-checks.py -// RUN: triton-opt %s -allow-unregistered-dialect -split-input-file -tritonamdgpu-canonicalize-pointers -canonicalize -verify-diagnostics | FileCheck %s +// RUN: triton-opt %s -allow-unregistered-dialect -split-input-file -tritonamdgpu-canonicalize-pointers="enable-large-tensor-ptr-canon=true" -canonicalize -verify-diagnostics | FileCheck %s module attributes {"ttg.num-warps" = 4 : i32} { tt.func @conversion1(%arg0: !tt.ptr) -> tensor<1024xf32> { diff --git a/test/TritonGPU/amd/amd-conditional-barrier.mlir b/test/TritonGPU/amd/amd-conditional-barrier.mlir index bdd32a48d3..a55c004c7a 100644 --- a/test/TritonGPU/amd/amd-conditional-barrier.mlir +++ b/test/TritonGPU/amd/amd-conditional-barrier.mlir @@ -4,8 +4,8 @@ module attributes {"ttg.compute-capability" = 0 : i32, "ttg.num-ctas" = 1 : i32, tt.func @conditional_barrier() { // CHECK-LABEL: llvm.func @conditional_barrier - // CHECK: %[[CMP0:.+]] = llvm.icmp "ne" %3, %1 : i32 - // CHECK: %[[CMP1:.+]] = llvm.icmp "eq" %3, %1 : i32 + // CHECK: %[[CMP0:.+]] = llvm.icmp "ne" %[[OP0:.+]], %[[OP1:.+]] : i32 + // CHECK: %[[CMP1:.+]] = llvm.icmp "eq" %[[OP0]], %[[OP1]] : i32 // CHECK: llvm.cond_br %[[CMP0]], ^bb1, ^bb2 // CHECK: ^bb1: // CHECK: rocdl.s.barrier diff --git a/test/TritonGPU/amd/amd-pipeline-chained-dots.mlir b/test/TritonGPU/amd/amd-pipeline-chained-dots.mlir index 3eab4f1252..9a967924e4 100644 --- a/test/TritonGPU/amd/amd-pipeline-chained-dots.mlir +++ b/test/TritonGPU/amd/amd-pipeline-chained-dots.mlir @@ -1,4 +1,4 @@ -// RUN: triton-opt %s -split-input-file -tritonamdgpu-schedule-loops="num_stages=4 use_async_copy=1" -tritonamdgpu-pipeline="use_async_copy=1" -canonicalize | FileCheck %s +// RUN: triton-opt %s -split-input-file -tritonamdgpu-schedule-loops="num_stages=4" -tritonamdgpu-pipeline="use_async_copy=1" -canonicalize | FileCheck %s #blocked = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 4], order = [0, 1]}> #mma = #ttg.amd_mfma<{version = 4, warpsPerCTA = [4, 1], instrShape = [32, 32, 16], isTransposed = true}> diff --git a/test/TritonGPU/automatic-warp-specialization.mlir b/test/TritonGPU/automatic-warp-specialization.mlir index 5d43c30a68..7cf4e92424 100644 --- a/test/TritonGPU/automatic-warp-specialization.mlir +++ b/test/TritonGPU/automatic-warp-specialization.mlir @@ -194,6 +194,8 @@ tt.func public @attention_forward( %one = arith.constant dense<1.0> : tensor<256xf32, #ttg.slice<{dim = 1, parent = #blocked}>> // CHECK-LABEL: ttg.warp_specialize + // CHECK-LABEL: default + // CHECK: ttng.fence_async_shared // PIPELINE: partition0 // PIPELINE-COUNT-4: ttng.tc_gen5_mma // PIPELINE-NOT: ttng.tc_gen5_mma diff --git a/test/TritonGPU/canonicalize.mlir b/test/TritonGPU/canonicalize.mlir index 10fba51712..61aa2d326a 100644 --- a/test/TritonGPU/canonicalize.mlir +++ b/test/TritonGPU/canonicalize.mlir @@ -124,6 +124,26 @@ tt.func @test_canonicalize_convert_local_load(%arg0: !ttg.async.token) -> tensor // ----- +#blocked = #ttg.blocked<{sizePerThread = [1, 32], threadsPerWarp = [32, 1], warpsPerCTA = [4, 2], order = [0, 1]}> +#linear = #ttg.linear<{register = [[0, 1], [0, 2], [0, 4], [0, 8], [0, 16]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [0, 32]], warp = [[32, 0], [64, 0], [16, 0]], block = []}> +#tmem = #ttng.tensor_memory_encoding +// CHECK-LABEL: test_canonicalize_convert_tmem_store +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32} { + tt.func @test_canonicalize_convert_tmem_store( + %arg0: tensor<128x64xbf16, #linear>, + %arg1: !ttg.memdesc<128x64xbf16, #tmem, #ttng.tensor_memory, mutable> + ) { + %true = arith.constant true + // CHECK-NOT: ttg.convert_layout + %1 = ttg.convert_layout %arg0 : tensor<128x64xbf16, #linear> -> tensor<128x64xbf16, #blocked> + // CHECK: ttng.tmem_store %{{.*}} : tensor<128x64xbf16, #linear> -> + ttng.tmem_store %1, %arg1, %true : tensor<128x64xbf16, #blocked> -> !ttg.memdesc<128x64xbf16, #tmem, #ttng.tensor_memory, mutable> + tt.return + } +} + +// ----- + #blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> #shared = #ttg.swizzled_shared<{vec = 1, perPhase=2, maxPhase=8, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> #smem = #ttg.shared_memory diff --git a/test/TritonGPU/loop-pipeline-combine-waits.mlir b/test/TritonGPU/loop-pipeline-combine-waits.mlir index caa1a91ad8..31fbcf9e9f 100644 --- a/test/TritonGPU/loop-pipeline-combine-waits.mlir +++ b/test/TritonGPU/loop-pipeline-combine-waits.mlir @@ -1,4 +1,4 @@ -// RUN: triton-opt %s -split-input-file -tritonamdgpu-schedule-loops="num_stages=3 use_async_copy=1 use_pingpong=1" -tritonamdgpu-pipeline="use_async_copy=1" | FileCheck %s +// RUN: triton-opt %s -split-input-file -tritonamdgpu-schedule-loops="num_stages=3" -tritonamdgpu-pipeline="use_async_copy=1 use_pingpong=1" | FileCheck %s #blocked = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1]}> #mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 8]}> diff --git a/test/TritonGPU/loop-pipeline-hip.mlir b/test/TritonGPU/loop-pipeline-hip.mlir index 7f6ab22305..6952899735 100644 --- a/test/TritonGPU/loop-pipeline-hip.mlir +++ b/test/TritonGPU/loop-pipeline-hip.mlir @@ -1,5 +1,5 @@ // RUN: triton-opt %s -split-input-file -tritonamdgpu-schedule-loops=num_stages=2 -tritonamdgpu-pipeline -canonicalize | FileCheck %s --check-prefixes=COMMON,SYNC -// RUN: triton-opt %s -split-input-file -tritonamdgpu-schedule-loops="num_stages=2 use_async_copy=1" -tritonamdgpu-pipeline="use_async_copy=1" -canonicalize | FileCheck %s --check-prefixes=COMMON,ASYNC +// RUN: triton-opt %s -split-input-file -tritonamdgpu-schedule-loops="num_stages=2" -tritonamdgpu-pipeline="use_async_copy=1" -canonicalize | FileCheck %s --check-prefixes=COMMON,ASYNC #blocked = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1]}> #blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0]}> diff --git a/test/lib/Proton/CMakeLists.txt b/test/lib/Proton/CMakeLists.txt index a16e6a6ad3..94d090f478 100644 --- a/test/lib/Proton/CMakeLists.txt +++ b/test/lib/Proton/CMakeLists.txt @@ -3,4 +3,5 @@ add_mlir_library(TritonTestProton LINK_LIBS PUBLIC MLIRPass + ProtonAnalysis ) diff --git a/third_party/amd/backend/compiler.py b/third_party/amd/backend/compiler.py index 464c4f77dd..e7c5a6674d 100644 --- a/third_party/amd/backend/compiler.py +++ b/third_party/amd/backend/compiler.py @@ -224,8 +224,8 @@ def make_ttgir(mod, metadata, options): use_async_copy = knobs.amd.use_async_copy use_block_pingpong = is_pingpong_schedule_enabled(options.arch, use_async_copy) - amd.passes.ttgpuir.add_schedule_loops(pm, options.num_stages, use_async_copy, use_block_pingpong) - amd.passes.ttgpuir.add_pipeline(pm, use_async_copy) + amd.passes.ttgpuir.add_schedule_loops(pm, options.num_stages) + amd.passes.ttgpuir.add_pipeline(pm, use_async_copy, use_block_pingpong) if use_async_copy: amd.passes.ttgpuir.add_coalesce_async_copy(pm, options.arch) passes.common.add_canonicalizer(pm) diff --git a/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUAttrDefs.td b/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUAttrDefs.td index ddf46ac8cc..91431f2b5c 100644 --- a/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUAttrDefs.td +++ b/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUAttrDefs.td @@ -34,6 +34,10 @@ class TritonAMDGPU_Attr traits = [], : AttrDef { } +def SetFP8Clamping : TritonAMDGPU_Attr<"SetFP8Clamping"> { + let mnemonic = "amdgcn.set.fp8.clamping"; +} + class TritonAMDGPU_I32Enum cases> : I32EnumAttr { let genSpecializedAttr = 0; diff --git a/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td b/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td index 67041ff2db..04cc779c7e 100644 --- a/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td +++ b/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td @@ -696,4 +696,51 @@ def LocalLoadPackedTransposedOp : TT_AMDGPU_Op<"local_load_packed_tranposed", [L let hasVerifier = 1; } +//===----------------------------------------------------------------------===// +// AsyncTDMCopyGlobalToLocalOp +//===----------------------------------------------------------------------===// + +def AsyncTDMCopyGlobalToLocalOp : TT_AMDGPU_Op<"async_tdm_copy_global_to_local"> { + let summary = "Copy data based on descriptor from global memory to local memory asynchronously"; + + let description = [{ + This operation copies data from global memory to local memory + asynchronously. This is analogue to tt.load except the data are copied to + local memory pointed by `result` instead of a distributed tensor. The data + copied depends on the global memory descriptor pointed to by `desc`. Set + `pred` to false will disable the copy. + }]; + + let arguments = (ins + Arg]>:$desc, + Variadic:$indices, + Arg]>:$result, + I1:$pred + ); + + let results = (outs TTG_AsyncToken:$token); + + let assemblyFormat = [{ + $desc `[` $indices `]` `into` $result `,` $pred + attr-dict `:` qualified(type($desc)) `->` qualified(type($result)) + }]; +} + +//===----------------------------------------------------------------------===// +// AsyncTDMWait +//===----------------------------------------------------------------------===// + +def AsyncTDMWait : TT_AMDGPU_Op<"async_tdm_wait"> { + let summary = "Wait until there are less than or equal to the given number of outstanding TDM operations"; + let arguments = (ins Variadic:$asyncToken, I32Attr:$num); + let description = [{ + This operation waits until there are less than or equal to the given number + of outstanding TDM operations, including both loads and stores. This is + necessary to ensure that data is available in the LDS before it is used. + }]; + let results = (outs TTG_AsyncToken:$retToken); + + let assemblyFormat = "$asyncToken attr-dict"; +} + #endif diff --git a/third_party/amd/include/TritonAMDGPUTransforms/Passes.td b/third_party/amd/include/TritonAMDGPUTransforms/Passes.td index c6b0fcc313..bd9fc77e31 100644 --- a/third_party/amd/include/TritonAMDGPUTransforms/Passes.td +++ b/third_party/amd/include/TritonAMDGPUTransforms/Passes.td @@ -16,13 +16,7 @@ def TritonAMDGPUScheduleLoops : Pass<"tritonamdgpu-schedule-loops", "mlir::Modul let options = [ Option<"numStages", "num_stages", "int32_t", /*default*/"2", - "Number of Pipeline stages">, - Option<"useAsyncCopy", "use_async_copy", - "bool", /*default*/"false", - "Use AsyncCopyGlobalToLocal to directly load to shared memory">, - Option<"usePingpong", "use_pingpong", - "bool", /*default*/"false", - "Use schedules to enable block ping-pong">, + "Number of Pipeline stages"> ]; } @@ -38,6 +32,9 @@ def TritonAMDGPUPipeline : Pass<"tritonamdgpu-pipeline", "mlir::ModuleOp"> { Option<"useAsyncCopy", "use_async_copy", "bool", /*default*/"false", "Use AsyncCopyGlobalToLocal to directly load to shared memory">, + Option<"usePingpong", "use_pingpong", + "bool", /*default*/"false", + "Use schedules to enable block ping-pong"> ]; } @@ -126,6 +123,11 @@ def TritonAMDGPUCanonicalizePointers : Pass<"tritonamdgpu-canonicalize-pointers" let dependentDialects = []; + let options = [ + Option<"enableLargeTensorPtrCanon", "enable-large-tensor-ptr-canon", + "bool", /*default=*/"false", + "Whether to enable canonicalization for pointers pointing to large-tensors (a specialization for tensors over 2GB)"> + ]; } def TritonAMDGPUReorderInstructions: Pass<"tritonamdgpu-reorder-instructions", "mlir::ModuleOp"> { diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/AsyncUtility.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/AsyncUtility.cpp index f92f51c788..fe2142fb88 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/AsyncUtility.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/AsyncUtility.cpp @@ -3,6 +3,7 @@ #include "Dialect/TritonAMDGPU/IR/Dialect.h" #include "TargetInfo.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "llvm/ADT/TypeSwitch.h" namespace mlir::triton::AMD { namespace { @@ -50,21 +51,20 @@ bool comesFromAsyncWait(Value token) { } // namespace void annotateLocalLoadsSyncedViaAsyncWait(ModuleOp mod) { - SmallVector localLoads; - mod->walk([&](triton::gpu::LocalLoadOp localLoadOp) { - localLoads.emplace_back(localLoadOp); - }); - auto *ctx = mod->getContext(); - for (auto &loadOp : localLoads) { - auto token = loadOp.getToken(); - if (loadOp->hasAttr(syncedViaAsyncWaitAttrName)) - continue; - - bool isSyncedViaAsyncWait = token && comesFromAsyncWait(token); - loadOp->setAttr(syncedViaAsyncWaitAttrName, - BoolAttr::get(ctx, isSyncedViaAsyncWait)); - } + + mod->walk([&](Operation *op) { + TypeSwitch(op) + .Case([&](auto loadOp) { + if (loadOp->hasAttr(syncedViaAsyncWaitAttrName)) + return; + Value token = loadOp.getToken(); + bool isSyncedViaAsyncWait = token && comesFromAsyncWait(token); + loadOp->setAttr(syncedViaAsyncWaitAttrName, + BoolAttr::get(ctx, isSyncedViaAsyncWait)); + }); + }); } bool isSyncedViaAsyncWait(Operation *op) { @@ -112,8 +112,10 @@ void addAsyncCopyAliasScope(LLVM::AliasAnalysisOpInterface directToLdsOp) { directToLdsOp.setAliasScopes(b.getArrayAttr(getAsyncCopyScope(ctx))); } -void addLocalLoadNoAliasScope(triton::gpu::LocalLoadOp localLoadOp, +void addLocalLoadNoAliasScope(Operation *localLoadOp, LLVM::AliasAnalysisOpInterface llLoadOp) { + if (!localLoadOp->hasTrait()) + return; if (!isSyncedViaAsyncWait(localLoadOp)) return; diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/AsyncUtility.h b/third_party/amd/lib/TritonAMDGPUToLLVM/AsyncUtility.h index 32d0197886..e03d35bd91 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/AsyncUtility.h +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/AsyncUtility.h @@ -34,7 +34,7 @@ bool isSyncedViaAsyncWait(Operation *localLoadOp); // If localLoadOp has a token from an AsyncWait: // - Attaches "amdgpu.LocalLoad" alias scope to llLoadOp // - Attaches "amdgpu.AsyncCopies" as *non* alias scope to llLoadOp -void addLocalLoadNoAliasScope(triton::gpu::LocalLoadOp localLoadOp, +void addLocalLoadNoAliasScope(Operation *localLoadOp, LLVM::AliasAnalysisOpInterface llLoadOp); // Overload from above without checking the AsyncToken void addLocalLoadNoAliasScope(LLVM::AliasAnalysisOpInterface llLoadOp); diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/CMakeLists.txt b/third_party/amd/lib/TritonAMDGPUToLLVM/CMakeLists.txt index edee537e68..32534d049b 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/CMakeLists.txt +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/CMakeLists.txt @@ -3,6 +3,7 @@ add_triton_library(TritonAMDGPUToLLVM AtomicRMWOpsEmitter.cpp AllocateSharedMemory.cpp BufferOpsEmitter.cpp + TensorPtrOpsToLLVM.cpp ConvertLayoutOpToLLVM.cpp MemoryOpToLLVM.cpp MaskedOpsToLLVM.cpp diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/ElementwiseOpToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/ElementwiseOpToLLVM.cpp index 93fc09e371..cc03182dc1 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/ElementwiseOpToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/ElementwiseOpToLLVM.cpp @@ -1,3 +1,4 @@ +#include "Dialect/TritonAMDGPU/IR/Dialect.h" #include "TargetInfo.h" #include "Utility.h" #include "mlir/Dialect/Arith/IR/Arith.h" @@ -237,16 +238,6 @@ cvtScalePkDowncastToFp8(Location loc, ConversionPatternRewriter &rewriter, assert(v.size() == 4); auto b = TritonLLVMOpBuilder(loc, rewriter); - // This is the location of the fp16_ovfl flag in the Mode register. It's - // calculated following this formula: - // (mode register ID = 1) | (Offset << 6) | ((Width - 1) << 11) - // In this case, Offset = 23 and Width = 1. - // When the bit is 0/1, the conversion from fp32/fp16/bf16 to fp8/bf8 is in - // non-saturation/saturation mode. - Value fp16OVFLModeRegLoc = b.i32_val(1473); - LLVM::createLLVMIntrinsicCallOp(rewriter, loc, "llvm.amdgcn.s.setreg", {}, - {fp16OVFLModeRegLoc, b.i32_val(1)}); - Type v2I16Ty = vec_ty(i16_ty, 2); Value v2I16Vec = b.undef(v2I16Ty); Value scale = b.f32_val(1); @@ -1855,6 +1846,17 @@ struct FpToFpOpConversion } } + if (dstType.isFloat() && (dstType.getIntOrFloatBitWidth() == 8)) { + auto func = op->getParentOfType(); + if (func) { + using attrType = triton::amdgpu::SetFP8ClampingAttr; + auto attrName = attrType::getMnemonic(); + if (!func->hasAttrOfType(attrName)) { + func->setAttr(attrName, attrType::get(op->getContext())); + } + } + } + inVals.resize(numElements, b.undef(typeConverter->convertType(srcType))); SmallVector outVals; if (srcType != dstType) { @@ -2323,10 +2325,41 @@ struct PreciseSqrtOpConversion private: bool ftz; }; - } // namespace namespace mlir::triton::AMD { +void adjustModeRegister(ModuleOp mod, const TargetInfo &targetInfo) { + MLIRContext *ctx = mod->getContext(); + Location loc = mod->getLoc(); + mlir::OpBuilder builder(ctx); + auto auxBuilder = TritonLLVMOpBuilder(loc, builder); + + mod->walk([&](LLVM::LLVMFuncOp func) { + using attrType = triton::amdgpu::SetFP8ClampingAttr; + auto attrName = attrType::getMnemonic(); + if (!func->hasAttrOfType(attrName)) + return; + else + func->removeAttr(attrName); + + if (func.getBody().empty()) + return; + auto &body = func.getBody().front(); + builder.setInsertionPoint(&body.front()); + + // This is the location of the fp16_ovfl flag in the Mode register. It's + // calculated following this formula: + // (mode register ID = 1) | (Offset << 6) | ((Width - 1) << 11) + // In this case, Offset = 23 and Width = 1. + // When the bit is 0/1, the conversion from fp32/fp16/bf16 to fp8/bf8 is + // in non-saturation/saturation mode. + Value fp16OVFLModeRegLoc = auxBuilder.i32_val(1473); + LLVM::createLLVMIntrinsicCallOp( + builder, loc, "llvm.amdgcn.s.setreg", {}, + {fp16OVFLModeRegLoc, auxBuilder.i32_val(1)}); + }); +} + void populateElementwiseOpToLLVMPatterns( LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, bool ftz, ModuleAxisInfoAnalysis &axisInfoAnalysis, ModuleAllocation &allocation, diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp index 782606c888..c1e9a9e53c 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -996,6 +996,214 @@ struct AsyncCopyGlobalToLocalOpConversion } }; +struct AsyncTDMCopyGlobalToLocalOpConversion + : public ConvertOpToLLVMPattern< + triton::amdgpu::AsyncTDMCopyGlobalToLocalOp>, + public LoadStoreConversionBase { + AsyncTDMCopyGlobalToLocalOpConversion( + LLVMTypeConverter &converter, const AMD::TargetInfo &targetInfo, + ModuleAxisInfoAnalysis &axisAnalysisPass, PatternBenefit benefit) + : ConvertOpToLLVMPattern(converter, benefit), + LoadStoreConversionBase(targetInfo, axisAnalysisPass) {} + + std::pair createTDMDescriptors( + RewriterBase &rewriter, Location loc, + const LLVMTypeConverter *typeConverter, int64_t elementSizeInBytes, + ArrayRef tensorShape, ArrayRef blockShape, + ArrayRef tensorStride, Value srcPtr, Value dstPtr, Value pred, + Value multicastMask, unsigned padIntervalInDwords, + unsigned padAmountInDwords) const { + assert(tensorShape.size() == 2 && tensorStride.size() == 2 && + blockShape.size() == 2); + auto b = TritonLLVMOpBuilder(loc, rewriter); + + Value ldsAddr = b.ptrtoint(i32_ty, dstPtr); + Value globalAddr = b.ptrtoint(i64_ty, srcPtr); + + // group0 (128 bits / 4 dwords) effective bit encoding: + // [1:0]: pred + // [63:32]: lds address + // [120:64]: global address + // [127:126]: type - currently always set to 0x2 + SmallVector group0(4, b.i32_val(0)); + group0[0] = b.zext(i32_ty, pred); + group0[1] = ldsAddr; + group0[2] = b.trunc(i32_ty, globalAddr); + group0[3] = b.trunc(i32_ty, b.lshr(globalAddr, b.i64_val(32))); + group0[3] = b.or_(group0[3], b.i32_val(0x80000000)); + + VectorType vecTy0 = vec_ty(i32_ty, 4); + Value group0Vec = b.undef(vecTy0); + for (unsigned ii = 0; ii < 4; ++ii) { + Value vecIdx = createIndexAttrConstant(rewriter, loc, + typeConverter->getIndexType(), ii); + group0Vec = b.insert_element(vecTy0, group0Vec, group0[ii], vecIdx); + } + + // group1 (256 bits / 8 dwords) effective bit encoding: + // [15:0]: multicast mask + // [17:16]: data size - log2(element size in bytes) + // [20]: enable padding + // [24:22]: pad interval - log2(pad interval in dwords) - 1 + // [31:25]: pad amount - pad amount in dwords - 1 + // [79:48]: tensor shape dim inner + // [111:80]: tensor shape dim outer + // [127:112]: block shape dim inner + // [143:128]: block shape dim outer + // [207:160]: tensor stride dim outer (we only use 32 bits) + SmallVector group1(8, b.i32_val(0)); + int32_t dataSize = log2(elementSizeInBytes); + group1[0] = multicastMask; + group1[0] = b.or_(group1[0], b.i32_val(dataSize << 16)); + if (padIntervalInDwords > 0 && padAmountInDwords > 0) { + assert(llvm::isPowerOf2_32(padIntervalInDwords)); + int32_t log2PadInterval = log2(padIntervalInDwords); + group1[0] = b.or_(group1[0], b.i32_val(1 << 20)); + group1[0] = b.or_(group1[0], b.i32_val((log2PadInterval - 1) << 22)); + group1[0] = b.or_(group1[0], b.i32_val((padAmountInDwords - 1) << 25)); + } + group1[1] = b.shl(tensorShape[1], b.i32_val(16)); + group1[2] = b.lshr(tensorShape[1], b.i32_val(16)); + group1[2] = b.or_(group1[2], b.shl(tensorShape[0], b.i32_val(16))); + group1[3] = b.lshr(tensorShape[0], b.i32_val(16)); + group1[3] = b.or_(group1[3], b.i32_val(blockShape[1] << 16)); + group1[4] = b.i32_val(blockShape[0] & 0xFFFF); + group1[5] = tensorStride[0]; + + VectorType vecTy1 = vec_ty(i32_ty, 8); + Value group1Vec = b.undef(vecTy1); + for (unsigned ii = 0; ii < 8; ++ii) { + Value vecIdx = createIndexAttrConstant(rewriter, loc, + typeConverter->getIndexType(), ii); + group1Vec = b.insert_element(vecTy1, group1Vec, group1[ii], vecIdx); + } + + return {group0Vec, group1Vec}; + } + + LogicalResult + matchAndRewrite(triton::amdgpu::AsyncTDMCopyGlobalToLocalOp op, + OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto ctx = rewriter.getContext(); + auto loc = op.getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + + auto mod = op->getParentOfType(); + auto tensorDescTy = op.getDesc().getType(); + auto smemTy = op.getResult().getType(); + + auto swizzledEnc = + llvm::dyn_cast(smemTy.getEncoding()); + if (swizzledEnc && swizzledEnc.getMaxPhase() != 1) + return rewriter.notifyMatchFailure(op, "TDM does not support swizzling"); + + auto paddedEnc = + llvm::dyn_cast(smemTy.getEncoding()); + if (!paddedEnc && !swizzledEnc) + return rewriter.notifyMatchFailure( + op, "Invalid shared memory layout for TDM."); + + Type llvmElemTy = getTypeConverter()->convertType(smemTy.getElementType()); + auto elementBitWidth = llvmElemTy.getIntOrFloatBitWidth(); + + unsigned padInterval = 0; + unsigned padAmount = 0; + if (paddedEnc) { + if (paddedEnc.getIntervals().size() != 1 || + paddedEnc.getPaddings().size() != 1) + return rewriter.notifyMatchFailure( + op, "NYI: Multiple interval-padding pairs in TDM."); + padInterval = paddedEnc.getIntervals()[0]; + padAmount = paddedEnc.getPaddings()[0]; + } + unsigned dwordSize = 32; + auto padIntervalInDwords = padInterval * elementBitWidth / dwordSize; + auto padAmountInDwords = padAmount * elementBitWidth / dwordSize; + if (padInterval > 0 && padIntervalInDwords < 2) + return rewriter.notifyMatchFailure( + op, "TDM padding interval must be at least 2 dwords"); + if (padAmount > 0 && padAmountInDwords < 1) + return rewriter.notifyMatchFailure( + op, "TDM padding amount must be at least 1 dword"); + + // [base, shape0, shape1, stride0, stride1] + SmallVector descriptorFields = + unpackLLElements(loc, adaptor.getDesc(), rewriter); + if (descriptorFields.size() != 5) + return rewriter.notifyMatchFailure(op, "NYI: TDM > 2D cases."); + + Value base = descriptorFields[0]; + SmallVector tensorShape{descriptorFields[1], descriptorFields[2]}; + SmallVector tensorStride{descriptorFields[3], descriptorFields[4]}; + + // Cast strides from i64 to i32 + tensorStride[0] = b.trunc(i32_ty, tensorStride[0]); + tensorStride[1] = b.trunc(i32_ty, tensorStride[1]); + + SmallVector offset = adaptor.getIndices(); + SmallVector blockShape = + llvm::to_vector(tensorDescTy.getBlockType().getShape()); + SmallVector blockShapePerCTA = blockShape; + + int numCTAs = TritonGPUDialect::getNumCTAs(mod); + Value multicastMask = b.i32_val(0); + if (numCTAs > 1) { + return rewriter.notifyMatchFailure(op, "NYI: Support multicast."); + } + + Type globalPtrTy = ptr_ty(ctx, 1); + Type sharedPtrTy = ptr_ty(ctx, 3); + + // For block shape [M, N], each warp will handle shape [M/numWarps, N]. + auto numWarps = triton::gpu::lookupNumWarps(op); + auto warpId = getLaneAndWarpId(rewriter, loc).second; + + int outerBlockShape = blockShapePerCTA[0]; + int outerBlockShapePerWarp = ceil(outerBlockShape, numWarps); + int outerBlockStride = blockShapePerCTA[1]; + + // Shift global pointer by offset + Value outerOffset = b.mul(b.i32_val(outerBlockShapePerWarp), warpId); + offset[0] = b.add(offset[0], outerOffset); + + Value baseOffset = b.add(b.mul(tensorStride[0], offset[0]), + b.mul(tensorStride[1], offset[1])); + base = b.gep(globalPtrTy, llvmElemTy, base, baseOffset); + + // Shift shared pointer by offset + auto dstMemObj = LLVM::getSharedMemoryObjectFromStruct( + loc, adaptor.getResult(), llvmElemTy, rewriter); + Value dstBase = dstMemObj.getBase(); + Value dstOffset = b.mul(b.i32_val(outerBlockStride), outerOffset); + if (paddedEnc) { + Value padding = emitPadding(loc, rewriter, paddedEnc, elementBitWidth, + dstOffset, false); + dstOffset = b.add(dstOffset, padding); + } + dstBase = b.gep(sharedPtrTy, llvmElemTy, dstBase, dstOffset); + + // Update tensor shape and block shape based on offset + Value zero = b.i32_val(0); + tensorShape[0] = b.smax(zero, b.sub(tensorShape[0], offset[0])); + tensorShape[1] = b.smax(zero, b.sub(tensorShape[1], offset[1])); + + blockShapePerCTA[0] = outerBlockShapePerWarp; + + auto elementSizeInBytes = elementBitWidth / 8; + auto [group0, group1] = createTDMDescriptors( + rewriter, loc, getTypeConverter(), elementSizeInBytes, tensorShape, + blockShapePerCTA, tensorStride, base, dstBase, op.getPred(), + multicastMask, padIntervalInDwords, padAmountInDwords); + LLVM::createLLVMIntrinsicCallOp(rewriter, loc, + "llvm.amdgcn.tensor.load.to.lds.d2", {}, + {group0, group1, b.i32_val(0)}); + + rewriter.eraseOp(op); + return success(); + } +}; + struct StoreOpConversion : public ConvertOpToLLVMPattern, public LoadStoreConversionBase { StoreOpConversion(LLVMTypeConverter &converter, @@ -1776,6 +1984,24 @@ struct AsyncWaitOpConversion : public ConvertOpToLLVMPattern { const AMD::TargetInfo &targetInfo; }; +struct AsyncTDMWaitConversion + : public ConvertOpToLLVMPattern { + AsyncTDMWaitConversion(LLVMTypeConverter &converter, PatternBenefit benefit) + : ConvertOpToLLVMPattern(converter, benefit) {} + + LogicalResult + matchAndRewrite(triton::amdgpu::AsyncTDMWait op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + LLVM::createLLVMIntrinsicCallOp(rewriter, loc, + "llvm.amdgcn.s.wait.tensorcnt", {}, + {b.i16_val(op.getNum())}); + rewriter.eraseOp(op); + return success(); + } +}; + struct AsyncCommitGroupOpConversion : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; @@ -1799,13 +2025,15 @@ void populateLoadStoreOpToLLVMPatterns(LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, ModuleAxisInfoAnalysis &axisInfoAnalysis, PatternBenefit benefit) { - patterns.add(typeConverter, targetInfo, - axisInfoAnalysis, benefit); + patterns + .add( + typeConverter, targetInfo, axisInfoAnalysis, benefit); patterns.add(typeConverter, targetInfo, benefit); + patterns.add(typeConverter, benefit); patterns.add(typeConverter, benefit); } } // namespace mlir::triton::AMD diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/MemoryOpToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/MemoryOpToLLVM.cpp index c67ad64173..8ca31ea2ac 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/MemoryOpToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/MemoryOpToLLVM.cpp @@ -4,11 +4,13 @@ #include "mlir/Dialect/LLVMIR/ROCDLDialect.h" #include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" #include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Attributes.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Types.h" #include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "triton/Tools/LayoutUtils.h" +#include "triton/Tools/LinearLayout.h" -using ::mlir::triton::gpu::AMDMfmaEncodingAttr; -using ::mlir::triton::gpu::DotOperandEncodingAttr; using ::mlir::triton::gpu::MemDescType; namespace { @@ -32,103 +34,73 @@ class TransLocalLoadOpConversion ConversionPatternRewriter &rewriter) const override { MemDescType srcTy = op.getSrc().getType(); RankedTensorType dstTy = op.getType(); - Attribute srcLayout = srcTy.getEncoding(); - Attribute dstLayout = dstTy.getEncoding(); - - if (isPackedLoad || canUseTransLoad(op, srcTy, dstTy)) { - return lowerSharedToDotOperandTransLL(op, adaptor, - this->getTypeConverter(), rewriter); - } - return failure(); - } - -private: - bool checkLayoutProperties(MemDescType srcTy, RankedTensorType dstTy) const { - // Verify the layout properties required for using the ds_read_tr - // instruction. This instruction is used to load non-k contiguous tensors - // from shared memory into a dot layout with an MFMA layout parent. - auto dotEnc = llvm::dyn_cast(dstTy.getEncoding()); - if (!dotEnc) { - return false; - } - - auto mfmaEnc = llvm::dyn_cast(dotEnc.getParent()); - if (!mfmaEnc) { - return false; - } - - auto tilesPerWarp = mfmaEnc.getTilesPerWarp(); - if (!mfmaEnc.hasUnitTilesPerWarp()) { - return false; + auto typeConverter = this->getTypeConverter(); + auto llvmElemTy = typeConverter->convertType(dstTy.getElementType()); + unsigned bitwidth = llvmElemTy.getIntOrFloatBitWidth(); + // 64 is the number of bytes ds_read_tr + unsigned needContigReg = 64 / bitwidth; + // 16 is the number of lanes that participate in the data shuffle + unsigned needContigLane = 16; + + if (!canUseTransLoad(op, srcTy, dstTy, bitwidth, needContigReg, + needContigLane)) { + return failure(); } - auto sharedEnc = - dyn_cast(srcTy.getEncoding()); - if (!sharedEnc) - return false; - - int rank = dstTy.getRank(); - const int kDim = dotEnc.getOpIdx() == 0 ? rank - 1 : rank - 2; - return kDim != sharedEnc.getOrder()[0]; + return lowerSharedToDotOperandTransLL(op, needContigReg, adaptor, + typeConverter, rewriter); } - bool checkKWidth(MemDescType srcTy, RankedTensorType dstTy) const { - // Single rate MFMA insts: - // fp16, bf16: mfma32x32x8, mfma16x16x16 - // fp8, bf8: mfma32x32x16, mfma16x16x32 - // int8: mfma32x32x16, mfma16x16x32 - // - // Double rate MFMA insts: - // fp16, bf16: mfma32x32x16, mfma16x16x32 - // fp8, bf8: mfma32x32x64, mfma16x16x128 - // int8: mfma32x32x32, mfma16x16x64 - // - // Check that kWidth of the dst dotOp layout is large enough to - // work with the transposed lds load instructions. - auto dotEnc = llvm::cast(dstTy.getEncoding()); - auto mfmaEnc = llvm::cast(dotEnc.getParent()); - - int rank = dstTy.getRank(); - auto bitwidth = this->typeConverter->convertType(dstTy.getElementType()) - .getIntOrFloatBitWidth(); - int32_t kWidth = dotEnc.getKWidth(); - const int32_t mDim = mfmaEnc.getInstrShape()[0]; - if (mDim != 32 && mDim != 16) +private: + bool checkLayoutProperties(MemDescType srcTy, RankedTensorType dstTy, + unsigned needContigReg, + unsigned needContigLane) const { + auto srcOrder = triton::gpu::getOrder(srcTy); + auto dstOrder = triton::gpu::getOrder(dstTy); + + // Check that the contiguity of srcTy and dstTy don't match + // this is because ds_read_tr will reshuffle the data to + // the opposite contiguity + if (dstOrder[0] == srcOrder[0]) return false; - const int kFactor = 16 / bitwidth; - const int kSizeDoubleRateMfma32 = 16 * kFactor; - const int kSizeDoubleRateMfma16 = 32 * kFactor; - int largeTileThreshold = - (mDim == 32) ? kSizeDoubleRateMfma32 : kSizeDoubleRateMfma16; - - // For FP8, wider MFMA instructions (scaled MFMA) have a k-dimension - // that is four times of regular MFMA instructions. - if (dstTy.getElementType().isFloat() && bitwidth == 8) { - largeTileThreshold *= 2; - } - - const auto shape = dstTy.getShape(); - const int kDim = dotEnc.getOpIdx() == 0 ? rank - 1 : rank - 2; - const bool isLargeTile = shape[kDim] >= largeTileThreshold; - - const int kWidthLargeTile = 8 * kFactor; - const int kWidthSmallTile = 4 * kFactor; - // For largeTile, i.e. double rated mfma is an option, it's accepted to - // have kWidth set for both double and single rated mfma - // For smallTile, it's only accepted to have kWidth set to single rate - // mfma. Smaller kWidth is not allowed to use transposed lds load. - return (isLargeTile && - llvm::is_contained({kWidthLargeTile, kWidthSmallTile}, kWidth)) || - (kWidth == kWidthSmallTile); + auto dstLL = triton::gpu::toLinearLayout(dstTy); + SmallVector outDimNames(dstLL.getOutDimNames()); + std::swap(outDimNames[0], outDimNames[1]); + auto dstTrLL = dstLL.transposeOuts(outDimNames); + + // Check the main requirements for the ds_read_tr instruction: contiguity + // of reg/lane. This is because ds_read_tr works on a block of 16 lanes + // with each holding 64 bits of data. Each lane will load 64 bits of + // contiguous data and then share it among the lane dimension. + // This means that there needs to be a check that each lane owns + // 64 bit of contig data and that the communicating lanes are contiguous. + // In order to do this, we use ll.getNumConsecutiveInOut() which + // can get the contiguity of the first component of the first + // dimension. + // Since the data might be dim0 or dim1 contiguous we need both the + // dstLL and the dstTrLL: one to check the register dimension + // contiguity and the other to check the lane dimension one. + bool dim1Contig = dstOrder[0] == 1; + auto dstLLDim0Contig = dim1Contig ? dstTrLL : dstLL; + auto dstLLDim1Contig = dim1Contig ? dstLL : dstTrLL; + int contigRegisters = dstLLDim0Contig.getNumConsecutiveInOut(); + + assert(dstLLDim0Contig.getBases().begin()->first == "register"); + SmallVector subLayoutInDims( + llvm::drop_begin(dstLLDim0Contig.getInDimNames())); + SmallVector subLayoutOutDims(dstLLDim0Contig.getOutDimNames()); + auto dstLLOnlyLaneWarp = + dstLLDim1Contig.sublayout(subLayoutInDims, subLayoutOutDims); + int contigLanes = dstLLOnlyLaneWarp.getNumConsecutiveInOut(); + + // Check that the tile size used by ds_read_tr (KxM/N = 4x16 for 16-bit + // elements) is contiguous both in terms of registers dimension and in + // terms of lane dimension. If that is the case then we can use ds_read_tr + return contigRegisters >= needContigReg && contigLanes >= needContigLane; } - bool checkCurrentLimitation(Operation *localLoad, - RankedTensorType dstTy) const { - - auto bitwidth = this->typeConverter->convertType(dstTy.getElementType()) - .getIntOrFloatBitWidth(); - + bool checkCurrentLimitation(unsigned bitwidth) const { // FP4 is represented as i8 and, when packed along K, can be // transposed using ds_read_tr8 which doesn't change packing. if (bitwidth != 16 && bitwidth != 8) { @@ -139,27 +111,22 @@ class TransLocalLoadOpConversion } bool canUseTransLoad(Operation *localLoad, MemDescType srcTy, - RankedTensorType dstTy) const { - auto bitwidth = this->typeConverter->convertType(dstTy.getElementType()) - .getIntOrFloatBitWidth(); - - // 1. Check GPU arch properties. - if (!targetInfo.canUseLDSTransLoad(bitwidth)) { - return false; + RankedTensorType dstTy, unsigned bitwidth, + unsigned needContigReg, unsigned needContigLane) const { + // Packed loads need to always map to ds_read_tr + if constexpr (isPackedLoad) { + return true; } - // 2. Check layout properties. - if (!checkLayoutProperties(srcTy, dstTy)) { + if (!targetInfo.canUseLDSTransLoad(bitwidth)) { return false; } - // 3. Check current limitations. - if (!checkCurrentLimitation(localLoad, dstTy)) { + if (!checkCurrentLimitation(bitwidth)) { return false; } - // 4. Check kWidth - if (!checkKWidth(srcTy, dstTy)) { + if (!checkLayoutProperties(srcTy, dstTy, needContigReg, needContigLane)) { return false; } @@ -167,7 +134,8 @@ class TransLocalLoadOpConversion } LogicalResult - lowerSharedToDotOperandTransLL(LocalLoadOpType op, OpAdaptor adaptor, + lowerSharedToDotOperandTransLL(LocalLoadOpType op, unsigned needContigReg, + OpAdaptor adaptor, const LLVMTypeConverter *typeConverter, ConversionPatternRewriter &rewriter) const { auto ctx = rewriter.getContext(); @@ -175,79 +143,90 @@ class TransLocalLoadOpConversion auto b = TritonLLVMOpBuilder(loc, rewriter); auto dstTy = cast(op.getType()); auto srcTy = cast(op.getSrc().getType()); - auto dotEnc = cast(dstTy.getEncoding()); - auto shape = isPackedLoad ? srcTy.getShape() : dstTy.getShape(); auto llvmElemTy = typeConverter->convertType(dstTy.getElementType()); - auto llBitwidth = isPackedLoad ? 4 : llvmElemTy.getIntOrFloatBitWidth(); auto bitwidth = llvmElemTy.getIntOrFloatBitWidth(); - auto ldsTransLayout = chooseDsReadB64TrLayout(dotEnc, shape, llBitwidth); auto smemObj = LLVM::getSharedMemoryObjectFromStruct(loc, adaptor.getSrc(), llvmElemTy, rewriter); - SmallVector outVals; - SmallVector elemsI32; mlir::Type retTy = dstTy; auto [laneId, warpId] = getLaneAndWarpId(rewriter, loc); - bool valid = emitTransferBetweenRegistersAndShared( - ldsTransLayout, srcTy, llvmElemTy, - /*maxVecElems=*/std::nullopt, smemObj, loc, rewriter, targetInfo, - laneId, warpId, [&](VectorType vecTy, Value vecAddr) { - if constexpr (isPackedLoad) { - assert(bitwidth == 8); - auto numElems = vecTy.getNumElements(); - auto numElemsI32 = (numElems * bitwidth / 32); - auto i32VecTy = VectorType::get(numElemsI32, i32_ty); - auto dsReadOp = - rewriter.create(loc, i32VecTy, vecAddr); - auto res = b.bitcast(dsReadOp.getResult(), vecTy); - Value vecVal = res.getResult(); - for (int v = 0; v < vecTy.getNumElements(); v++) { - outVals.push_back( - b.extract_element(llvmElemTy, vecVal, b.i32_val(v))); - } - } else if (bitwidth == 16) { - auto dsReadOp = - rewriter.create(loc, vecTy, vecAddr); - if constexpr (!isPackedLoad) { - AMD::addLocalLoadNoAliasScope(op, dsReadOp); - } - Value vecVal = dsReadOp.getResult(); - for (int v = 0; v < vecTy.getNumElements(); v++) { - outVals.push_back( - b.extract_element(llvmElemTy, vecVal, b.i32_val(v))); - } - } else { - // pack elements in i32 vectors - auto numElems = vecTy.getNumElements(); - auto numElemsI32 = (numElems * bitwidth / 32); - auto i32VecTy = VectorType::get(numElemsI32, i32_ty); - - auto dsReadOp = - rewriter.create(loc, i32VecTy, vecAddr); - if constexpr (!isPackedLoad) { - AMD::addLocalLoadNoAliasScope(op, dsReadOp); - } - Value vecVal = dsReadOp.getResult(); - for (auto i = 0; i < numElemsI32; ++i) { - elemsI32.push_back( - b.extract_element(i32_ty, vecVal, b.i32_val(i))); - } - } - }); - - // unpack i32 vectors and cast to native type - if (bitwidth != 16) { - auto numElemsPerVec = 32 / bitwidth; - auto vecTy = vec_ty(llvmElemTy, numElemsPerVec); - for (int v = 0; v < static_cast(elemsI32.size()); ++v) { - auto vec = b.bitcast(elemsI32[v], vecTy); - for (int i = 0; i < numElemsPerVec; ++i) - outVals.push_back(b.extract_element(llvmElemTy, vec, b.i32_val(i))); + auto affineOffset = smemObj.getShmemOffset(loc, rewriter, srcTy); + auto maskSpanAffineOffset = smemObj.getMaskSpanOffsets(srcTy); + auto calcPaddedOffset = [&](Value smemOffset) { + TritonLLVMOpBuilder b(loc, rewriter); + auto bitwidth = llvmElemTy.getIntOrFloatBitWidth(); + if (auto paddedLayout = dyn_cast( + srcTy.getEncoding())) { + // Apply the offset needed for padding. + Value padOffset = emitPadding(loc, rewriter, paddedLayout, bitwidth, + smemOffset, /*offsetInBytes=*/true); + smemOffset = b.add(smemOffset, padOffset); } + return smemOffset; + }; - retTy = LLVM::LLVMStructType::getLiteral( - ctx, SmallVector(outVals.size(), llvmElemTy)); + auto shape = srcTy.getShape(); + // FP4 are packed into i8 so the real bitwidth is different + auto llBitwidth = isPackedLoad ? 4 : llvmElemTy.getIntOrFloatBitWidth(); + auto ldsTransLayout = triton::gpu::chooseDsReadB64TrLayout( + dstTy.getEncoding(), shape, llBitwidth); + auto paddedEnc = + dyn_cast(srcTy.getEncoding()); + LinearLayout cvt = LinearLayout::empty(); + if (paddedEnc) { + const auto &sharedLL = paddedEnc.getLinearComponent(); + cvt = ldsTransLayout.invertAndCompose(sharedLL); + } else { + auto sharedLL = triton::gpu::toLinearLayout(srcTy); + cvt = ldsTransLayout.invertAndCompose(sharedLL); } - assert(valid && "Failed to emit LDS transpose load operations"); + // Check that we will be able to vectorize the load. + // Need to have exactly needContigReg, otherwise we can't use ds_read_tr + auto [elemsPerVec, permutation] = + largestVectorisation(ctx, cvt, bitwidth, needContigReg); + + if (paddedEnc) + elemsPerVec = std::min(elemsPerVec, paddedEnc.getMinInterval()); + + if (elemsPerVec != needContigReg) + return failure(); + + cvt = cvt.sublayout( + {str_attr("register"), str_attr("lane"), str_attr("warp")}, + {str_attr("offset")}); + auto lowerInst = [&](RewriterBase &rewriter, Location loc, + ArrayRef inVals, Value vecAddr, int idx, + VectorType vTy) { + Value dsReadTr; + if (bitwidth == 16) { + dsReadTr = rewriter.create(loc, vTy, vecAddr); + } else { + assert(bitwidth == 8); + auto numElems = vTy.getNumElements(); + auto numElemsI32 = (numElems * bitwidth / 32); + auto ty = VectorType::get(numElemsI32, i32_ty); + if (isPackedLoad) { + dsReadTr = rewriter.create(loc, ty, vecAddr); + } else { + dsReadTr = rewriter.create(loc, ty, vecAddr); + } + } + AMD::addLocalLoadNoAliasScope( + op, cast(dsReadTr.getDefiningOp())); + Value vecVal = b.bitcast(dsReadTr, vTy); + SmallVector loadedVals; + for (int v = 0; v < vTy.getNumElements(); v++) { + loadedVals.push_back( + b.extract_element(llvmElemTy, vecVal, b.i32_val(v))); + } + + return loadedVals; + }; + + SmallVector outVals = lowerLdSt( + loc, rewriter.getContext(), cvt, {}, // Input for store, output for load + llvmElemTy, smemObj.getBase(), calcPaddedOffset, affineOffset, + maskSpanAffineOffset, laneId, warpId, rewriter, targetInfo, + needContigReg, lowerInst); Value result = packLLElements(loc, typeConverter, outVals, rewriter, retTy); rewriter.replaceOp(op, result); return success(); diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/PatternTritonGPUOpToLLVM.h b/third_party/amd/lib/TritonAMDGPUToLLVM/PatternTritonGPUOpToLLVM.h index 0c0a514598..9465fce953 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/PatternTritonGPUOpToLLVM.h +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/PatternTritonGPUOpToLLVM.h @@ -25,6 +25,12 @@ void populateElementwiseOpToLLVMPatterns( LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, bool ftz, ModuleAxisInfoAnalysis &axisInfoAnalysis, ModuleAllocation &allocation, const TargetInfo &targetInfo, PatternBenefit benefit); + +// Manipulates with execution mode register which is per-wavefront one. +// The register controls execution of instructions - e.g., rounding modes, +// exception handling, etc. +void adjustModeRegister(ModuleOp mod, const TargetInfo &targetInfo); + void populateLoadStoreOpToLLVMPatterns(LLVMTypeConverter &typeConverter, const TargetInfo &targetInfo, RewritePatternSet &patterns, @@ -50,6 +56,9 @@ void populateFp4ToFpToLLVMPatterns(LLVMTypeConverter &typeConverter, void populateMaskedOpsToLLVMPatterns(RewritePatternSet &patterns); +void populateTensorPtrOpsToLLVMPatterns(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, + PatternBenefit benefit); } // namespace mlir::triton::AMD #endif // TRITON_THIRD_PARTY_AMD_LIB_TRITONAMDGPUTOLLVM_PATTERNTRITONGPUOPTOLLVM_H_ diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/TensorPtrOpsToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/TensorPtrOpsToLLVM.cpp new file mode 100644 index 0000000000..5523c1f7ca --- /dev/null +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/TensorPtrOpsToLLVM.cpp @@ -0,0 +1,45 @@ +#include "PatternTritonGPUOpToLLVM.h" +#include "Utility.h" +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/IR/BuiltinTypes.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Types.h" + +using namespace mlir; +using namespace mlir::triton; + +namespace { +struct MakeTensorDescOpConversion + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern< + triton::MakeTensorDescOp>::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(triton::MakeTensorDescOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + auto loc = op.getLoc(); + auto tensorShape = adaptor.getShape(); + auto tensorStride = adaptor.getStrides(); + auto basePtr = adaptor.getBase(); + auto result = op.getResult(); + + SmallVector elems; + elems.push_back(basePtr); + llvm::append_range(elems, tensorShape); + llvm::append_range(elems, tensorStride); + + auto newValue = packLLElements(op.getLoc(), getTypeConverter(), elems, + rewriter, result.getType()); + rewriter.replaceOp(op, newValue); + return success(); + } +}; +} // namespace + +void mlir::triton::AMD::populateTensorPtrOpsToLLVMPatterns( + LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + PatternBenefit benefit) { + patterns.add(typeConverter, benefit); + return; +} diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp index 24574d6349..fa5fdb5a9a 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp @@ -63,6 +63,36 @@ class TritonLLVMConversionTarget : public ConversionTarget { } }; +class TritonAMDGPUToLLVMTypeConverter : public TritonGPUToLLVMTypeConverter { +public: + TritonAMDGPUToLLVMTypeConverter(MLIRContext *ctx, + const LowerToLLVMOptions &options, + const TargetInfoBase &targetInfo, + const DataLayoutAnalysis *analysis = nullptr) + : TritonGPUToLLVMTypeConverter(ctx, options, targetInfo, analysis) { + addConversion([&](TensorDescType type) -> std::optional { + return convertTensorDescType(type); + }); + } + + Type convertTensorDescType(triton::TensorDescType type) { + auto ctx = type.getContext(); + + RankedTensorType rankedTensorType = type.getBlockType(); + auto eleType = rankedTensorType.getElementType(); + auto shape = rankedTensorType.getShape(); + SmallVector types; + // base ptr + types.push_back(LLVM::LLVMPointerType::get(ctx, 1)); + // 32 bit shapes + types.append(shape.size(), IntegerType::get(ctx, 32)); + // 64 bit strides + types.append(shape.size(), IntegerType::get(ctx, 64)); + + return LLVM::LLVMStructType::getLiteral(ctx, types); + } +}; + struct ConvertTritonAMDGPUToLLVM : public triton::impl::ConvertTritonAMDGPUToLLVMBase< ConvertTritonAMDGPUToLLVM> { @@ -90,7 +120,7 @@ struct ConvertTritonAMDGPUToLLVM mlir::LowerToLLVMOptions option(context); option.overrideIndexBitwidth(32); - TritonGPUToLLVMTypeConverter typeConverter(context, option, targetInfo); + TritonAMDGPUToLLVMTypeConverter typeConverter(context, option, targetInfo); TritonLLVMConversionTarget convTarget(*context); int numCTAs = triton::gpu::TritonGPUDialect::getNumCTAs(mod); @@ -175,6 +205,8 @@ struct ConvertTritonAMDGPUToLLVM AMD::populateLoadStoreOpToLLVMPatterns(typeConverter, targetInfo, patterns, axisInfoAnalysis, AMDBenefit); AMD::populateMaskedOpsToLLVMPatterns(patterns); + AMD::populateTensorPtrOpsToLLVMPatterns(typeConverter, patterns, + AMDBenefit); populatePatterns7(mlir::triton::populateReduceOpToLLVMPatterns, commonBenefit); @@ -234,6 +266,7 @@ struct ConvertTritonAMDGPUToLLVM return signalPassFailure(); } + AMD::adjustModeRegister(mod, targetInfo); fixUpLoopAnnotation(mod); } diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/CMakeLists.txt b/third_party/amd/lib/TritonAMDGPUTransforms/CMakeLists.txt index bb0d164431..98d39af2ba 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/CMakeLists.txt +++ b/third_party/amd/lib/TritonAMDGPUTransforms/CMakeLists.txt @@ -10,6 +10,7 @@ add_triton_library(TritonAMDGPUTransforms ReorderInstructions.cpp Pipeline.cpp ScheduleLoops.cpp + LowerLoops.cpp MfmaGroup.cpp WmmaGroup.cpp InThreadTranspose.cpp diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/CanonicalizePointers.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/CanonicalizePointers.cpp index f7bc6cfbed..9c4561ccbf 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/CanonicalizePointers.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/CanonicalizePointers.cpp @@ -1654,8 +1654,10 @@ static const std::string kInitFuncArgsRewritten = /// (ConvertUnimplementedOpUnrealizedCasts) if it wasn't DCEd (via a user /// extracting the tt.ptr and c0 operands). struct InitFuncPtrArgs : OpRewritePattern { - InitFuncPtrArgs(MLIRContext *context, FatPointers &fatPtrs) - : OpRewritePattern(context, 0), fatPtrs(fatPtrs) {} + InitFuncPtrArgs(MLIRContext *context, FatPointers &fatPtrs, + bool enableLargeTensorPtrCanon_) + : OpRewritePattern(context, 0), fatPtrs(fatPtrs), + enableLargeTensorPtrCanon(enableLargeTensorPtrCanon_) {} LogicalResult matchAndRewrite(tt::FuncOp newOp, PatternRewriter &rewriter) const override { @@ -1673,7 +1675,11 @@ struct InitFuncPtrArgs : OpRewritePattern { newOp.getArgAttrOfType(idx, "tt.pointer_range")) bitness = pointerRangeAttr.getInt(); - LDBG(idx << "-th argument: " << arg << ", bitness: " << bitness << "\n"); + LDBG(idx << "-th argument: " << arg << ", bitness: " << bitness); + if (!enableLargeTensorPtrCanon && (bitness == 64)) { + LDBG("Do not init argument of large-tensor pointer: " << arg); + continue; + } Value zeroOffset = rewriter.create(newOp.getLoc(), 0, bitness); @@ -1690,6 +1696,7 @@ struct InitFuncPtrArgs : OpRewritePattern { } FatPointers &fatPtrs; + bool enableLargeTensorPtrCanon; }; /// No-op to make conversion framework happy. @@ -1816,6 +1823,8 @@ class ConvertUnimplementedOpUnrealizedCasts class TritonAMDGPUCanonicalizePointersPass : public impl::TritonAMDGPUCanonicalizePointersBase< TritonAMDGPUCanonicalizePointersPass> { + using Base::Base; + public: void runOnOperation() override; }; @@ -1905,18 +1914,29 @@ void TritonAMDGPUCanonicalizePointersPass::runOnOperation() { FatPointers fatPrs; PatternRewriter rewriter(&getContext()); // Convert tt.func; %1 = unrealize_cast(%arg0: tt.ptr, c0: i32) -> tt.ptr - InitFuncPtrArgs pat(&getContext(), fatPrs); + InitFuncPtrArgs pat(&getContext(), fatPrs, enableLargeTensorPtrCanon); if (failed(pat.matchAndRewrite(func, rewriter))) return signalPassFailure(); llvm::SetVector opsToRewrite; - for (auto arg : func.getArguments()) { - if (llvm::isa(arg.getType())) { - // NB: reusing the same SetVector invalidates the topo order implied by - // getForwardSlice - for (auto &use : arg.getUses()) - getForwardSliceImpl(&use, use.getOwner(), &opsToRewrite); + for (auto [idx, arg] : llvm::enumerate(func.getArguments())) { + if (!llvm::isa(arg.getType())) + continue; + + int64_t bitness = 64; + if (auto pointerRangeAttr = + func.getArgAttrOfType(idx, "tt.pointer_range")) + bitness = pointerRangeAttr.getInt(); + + if (!enableLargeTensorPtrCanon && (bitness == 64)) { + LDBG("ignore " << idx << "-th argument of large-tensor ptr: " << arg); + continue; } + + // NB: reusing the same SetVector invalidates the topo order implied by + // getForwardSlice + for (auto &use : arg.getUses()) + getForwardSliceImpl(&use, use.getOwner(), &opsToRewrite); } ConversionConfig config; diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/LowerLoops.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/LowerLoops.cpp new file mode 100644 index 0000000000..7109d11c6c --- /dev/null +++ b/third_party/amd/lib/TritonAMDGPUTransforms/LowerLoops.cpp @@ -0,0 +1,754 @@ +#include "TritonAMDGPUTransforms/Passes.h" +#include "amd/lib/TritonAMDGPUToLLVM/AsyncUtility.h" +#include "amd/lib/TritonAMDGPUToLLVM/TargetInfo.h" +#include "amd/lib/TritonAMDGPUTransforms/PipelineUtility.h" +#include "triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "llvm/Support/Debug.h" +#include + +#define DEBUG_TYPE "tritonamdgpu-pipeline-lower-loops" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + +namespace tt = mlir::triton; +namespace ttg = mlir::triton::gpu; + +//===----------------------------------------------------------------------===// +// This file will conditionally allocate lds memory, create local/async load +// operations, and create schedule for these operations. After lowerLoops, +// schedule will be passed to expandLoops and eventually to PipelineExpander. +//===----------------------------------------------------------------------===// + +using mlir::triton::AMD::AttrBypassLDS; + +namespace mlir { +struct StreamCopyChainOps { + tt::LoadOp loadOp; + ttg::MemDescIndexOp subviewOp; + ttg::LocalStoreOp localStoreOp; + ttg::LocalLoadOp maybeLocalLoadOp; +}; + +struct AsyncCopyChainOps { + ttg::AsyncCopyGlobalToLocalOp copyOp; + ttg::AsyncCommitGroupOp commitOp; + ttg::AsyncWaitOp waitOp; + ttg::LocalLoadOp maybeLocalLoadOp; +}; + +using StreamOpVariant = std::variant; +using LoadToStreamOpMap = llvm::MapVector; + +AsyncCopyChainOps createAsyncCopy(tt::LoadOp loadOp, Value alloc, + Value extractIdx) { + OpBuilder builder(loadOp); + Location loc = loadOp.getLoc(); + + // Extract local subview from shared allocation + auto viewLoad = triton::createSingleBufferView(builder, alloc, extractIdx) + .getDefiningOp(); + + auto copyOp = builder.create( + loc, loadOp.getPtr(), viewLoad, loadOp.getMask(), loadOp.getOther(), + loadOp.getCache(), loadOp.getEvict(), loadOp.getIsVolatile()); + auto commitOp = + builder.create(loc, copyOp->getResult(0)); + ttg::AsyncWaitOp waitOp = + builder.create(loc, commitOp->getResult(0), 0); + + auto maybeSharedLoad = tt::replaceUsesWithLocalLoad( + builder, loadOp->getResult(0), viewLoad, waitOp); + + return {copyOp, commitOp, waitOp, maybeSharedLoad}; +} + +void scheduleLocalLoad(ttg::LocalLoadOp localLoadOp, + tt::CoarseSchedule &schedule, int stage, + const tt::CoarseSchedule::Cluster &cluster) { + schedule.insert(localLoadOp, stage, cluster); + // If its only user is a ConvertLayout, we place it into the same stage so + // it can be folded by a later pass + if (localLoadOp->hasOneUse()) { + auto cvt = *localLoadOp->getUsers().begin(); + if (isa(cvt)) { + schedule.insert(cvt, stage, cluster); + } + } +} + +StreamCopyChainOps createStreamCopy(tt::LoadOp loadOp, Value alloc, + Value extractIdx) { + OpBuilder builder(loadOp); + Location loc = loadOp.getLoc(); + + // Extract local subview from shared allocation + auto viewLoad = triton::createSingleBufferView(builder, alloc, extractIdx) + .getDefiningOp(); + + tt::LoadOp newLoadOp = cast(builder.clone(*loadOp)); + auto storeOp = builder.create(loc, newLoadOp, viewLoad); + auto maybeLocalLoad = + tt::replaceUsesWithLocalLoad(builder, loadOp->getResult(0), viewLoad); + + return {newLoadOp, viewLoad, storeOp, maybeLocalLoad}; +} + +// Returns the given |inputValue|'s dot user result encoding and updates |opIdx| +// and |vecSize| with which dot operand |inputValue| is fed into if possible. +ttg::AMDMfmaEncodingAttr getDotEncoding(Value inputValue, unsigned *opIdx, + unsigned *vecSize) { + if (!inputValue.hasOneUse()) + return nullptr; + + Operation *user = *inputValue.getUsers().begin(); + if (user->getNumResults() != 1 || + user->getBlock() != inputValue.getParentBlock()) + return nullptr; + + LDBG("getDotEncoding user: " << *user); + if (auto dotOp = dyn_cast(user)) { + OpOperand &use = *inputValue.getUses().begin(); + *opIdx = use.getOperandNumber(); + auto operandType = cast(inputValue.getType()); + *vecSize = ttg::toLinearLayout(operandType).getNumConsecutiveInOut(); + auto dotType = cast(dotOp->getResult(0).getType()); + return dyn_cast(dotType.getEncoding()); + } + + return getDotEncoding(user->getResult(0), opIdx, vecSize); +} + +// Adapted from +// lib/Dialect/TritonGPU/Transforms/Utility.cpp::getSharedEncIfAllUsersAreDotEnc +// to support AMDMfmaEncodingAttr. +// TODO(max): figure out how to refactor to use upstream +// +// If all the transitive uses of the given value have are used by a convert to +// the same dot operand encoding, return true and get the shared encoding that +// needs to be used to be compatible with users' layouts. +std::optional +getSharedEncIfAllUsersAreDotEnc(Value loadedValue) { + llvm::SmallVector sharedEncs; + for (Operation *user : loadedValue.getUsers()) { + LDBG(" getSharedEncIfAllUsersAreDotEnc current user: " << *user); + if (user->getNumResults() != 1) + return std::nullopt; + + ttg::SwizzledSharedEncodingAttr tempAttr; + Value userResult = user->getResult(0); + Type userResType = userResult.getType(); + if (auto memDesc = dyn_cast(userResType)) { + // First time we find a shared encoding in the chain, save it and try to + // use it if it is compatible with the other users. + tempAttr = cast(memDesc.getEncoding()); + // If the immediate user is ttg::LocalAllocOp, likely it's created in + // TritonAMDGPUOptimizeDotOperands. We should just respect it. + if (!getSharedEncIfAllUsersAreDotEnc(userResult).has_value() && + !isa(user)) { + return std::nullopt; + } + LDBG("Deduced shared encoding candidate from memDesc: " << tempAttr); + sharedEncs.push_back(tempAttr); + } else { + if (!(isa(user) || + user->hasTrait())) + return std::nullopt; + + auto srcTy = cast(loadedValue.getType()); + auto ctaLayout = ttg::getCTALayout(srcTy.getEncoding()); + auto order = getOrderForMemory(srcTy); + unsigned bitWidth = srcTy.getElementType().getIntOrFloatBitWidth(); + SmallVector sharedOrder; + int rank = order.size(); + // TODO rework this when shared -> dotOperand conversions support + // arbitrary shared memory ordering + if (rank == 3) { + // Move the batch dimension (dim #0) to be the last so that it will be + // the slowest varying dimension. + for (unsigned i = 0; i < rank; ++i) + if (order[i] != 0) + sharedOrder.emplace_back(order[i]); + sharedOrder.emplace_back(0); + } else { + sharedOrder = order; + } + + auto userResEnc = cast(userResType).getEncoding(); + if (auto dotOpEnc = dyn_cast(userResEnc)) { + tempAttr = ttg::SwizzledSharedEncodingAttr::get( + loadedValue.getContext(), dotOpEnc, srcTy.getShape(), sharedOrder, + ctaLayout, bitWidth, /*needTrans=*/false); + LDBG("Deduced shared encoding candidate from dot layout: " << tempAttr); + sharedEncs.push_back(tempAttr); + } else if (auto llEnc = dyn_cast(userResEnc)) { + // We use linear layout directly for scaled dot fp8 operands. For such + // cases, we need to look further down the def-use chain to find the dot + // op for the mfma layout to deduce operand index and other information. + unsigned opIdx; + unsigned vecSize; + if (auto mfmaEnc = getDotEncoding(userResult, &opIdx, &vecSize)) { + LDBG("deduced opIdx: " << opIdx << "; deduced vecSize: " << vecSize); + tempAttr = mfmaEnc.composeSharedLayoutForOperand( + ctaLayout, opIdx, srcTy.getShape(), order, vecSize, bitWidth, + /*needTrans=*/false); + LDBG("Deduced shared encoding candidate from mfma layout: " + << tempAttr); + sharedEncs.push_back(tempAttr); + } + } + } + } + + auto equalSharedEncIgnoreVec = [](ttg::SwizzledSharedEncodingAttr a, + ttg::SwizzledSharedEncodingAttr b) { + if (!a || !b) + return false; + return (a.getPerPhase() == b.getPerPhase() && + a.getMaxPhase() == b.getMaxPhase() && + a.getOrder() == b.getOrder() && + a.getCTALayout() == b.getCTALayout()); + }; + if (sharedEncs.empty() || !sharedEncs.front()) + return std::nullopt; + auto maxVecSharedEnc = sharedEncs.front(); + + for (auto sharedEnc : sharedEncs) { + if (!equalSharedEncIgnoreVec(sharedEnc, maxVecSharedEnc)) { + LDBG("Incompatible shared encodings"); + return std::nullopt; + } + if (sharedEnc.getVec() > maxVecSharedEnc.getVec()) { + maxVecSharedEnc = sharedEnc; + } + } + + LDBG("Deduced shared encoding: " << maxVecSharedEnc); + + return maxVecSharedEnc; +} + +bool canBeConvertedToAsyncLoad(unsigned numBuffers, tt::LoadOp loadOp, + Value alloc, + tt::ModuleAxisInfoAnalysis &axisInfoAnalysis, + const tt::AMD::TargetInfo &targetInfo) { + // If we have a single buffer we would require another barrier after the + // local_reads so instead we fall back to pipeline with registers + // Removing this check will create incorrect IR, see + // MembarUtility.h:membarFilter + if (numBuffers <= 1) + return false; + + // Compute the final vecSize we can use for the combination of sourceEncoding + // and sharedEncoding. We can only use AsyncCopy if the target supports the + // requested or a smaller vecSize because we cannot stride when loading + // directly to lds + auto srcTy = cast(loadOp.getPtr().getType()); + auto dstTy = cast(alloc.getType()); + auto regLayout = triton::gpu::toLinearLayout(srcTy); + // It's the allocation so we trim the multibuffer dimension + auto srcShape = dstTy.getShape().take_back(srcTy.getRank()); + auto sharedLayout = + triton::gpu::toLinearLayout(srcShape, dstTy.getEncoding()); + auto regToSharedLayout = regLayout.invertAndCompose(sharedLayout); + + unsigned vecSize = regToSharedLayout.getNumConsecutiveInOut(); + unsigned elemBitWidth = dstTy.getElementTypeBitWidth(); + + if (fitToValidDirectToLdsVecSize(vecSize, elemBitWidth, targetInfo) == 0) + return false; + + // Checks whether the global pointer's contiguity and mask alignment allows + // for at least 32 bit wide loads + return triton::canBeConvertedToAsyncLoad(loadOp, axisInfoAnalysis); +} + +// Convert load ops into shared memory allocation loads and apply +// multi-buffering based on the required number of buffers. +LoadToStreamOpMap +createStreamOps(const LoadToInfoMap &loadToInfo, scf::ForOp &forOp, + const int &numBuffers, bool useAsyncCopy, + tt::ModuleAxisInfoAnalysis &axisInfoAnalysis) { + IRRewriter builder(forOp); + Location loc = forOp.getLoc(); + Value minusOne = builder.create(loc, -1, 32); + Value zero = builder.create(loc, 0, 32); + Value one = builder.create(loc, 1, 32); + Value extractIdx = minusOne; + Value numBuffersVal = + builder.create(loc, numBuffers, 32); + + unsigned newOperandIndex = forOp.getBody()->getNumArguments(); + // Patch the loop to add the new loop carried dependency. + forOp = addIterArgsToLoop(builder, forOp, {extractIdx}); + + // Create one counter for the extract indices to avoid creating long + // live range. + extractIdx = forOp.getBody()->getArgument(newOperandIndex); + + builder.setInsertionPoint(forOp.getBody(), forOp.getBody()->begin()); + extractIdx = builder.create(loc, extractIdx, one); + Value cndExt = builder.create(loc, arith::CmpIPredicate::slt, + extractIdx, numBuffersVal); + extractIdx = builder.create(loc, cndExt, extractIdx, zero); + + // Patch the yield with the updated counter. + appendToForOpYield(forOp, {extractIdx}); + + LoadToStreamOpMap loadToStreamOp; + for (auto &[l, info] : loadToInfo) { + if (!info.sharedEncoding) + continue; + + auto loadOp = dyn_cast(l); + if (!loadOp) + continue; + + // Create an allocation that can hold distance number of loadOp shapes. + auto ty = cast(loadOp->getResultTypes()[0]); + Value alloc = triton::createAlloc(forOp, ty, loadOp->getLoc(), + info.sharedEncoding, numBuffers); + assert(alloc && "Failed to create alloc for the async load."); + auto arch = getAMDArch(loadOp->getParentOfType()); + triton::AMD::TargetInfo targetInfo(arch ? arch->str() : ""); + + // Replace the old load with multi-buffered loads + if (useAsyncCopy && + canBeConvertedToAsyncLoad(numBuffers, loadOp, alloc, axisInfoAnalysis, + targetInfo)) { + loadToStreamOp[loadOp] = createAsyncCopy(loadOp, alloc, extractIdx); + } else { + loadToStreamOp[loadOp] = createStreamCopy(loadOp, alloc, extractIdx); + } + } + + return loadToStreamOp; +} + +static void dumpSchedule(tt::CoarseSchedule &schedule, llvm::StringRef msg) { + LLVM_DEBUG({ + llvm::dbgs() << "\n"; + LDBG(msg); + schedule.dump(); + }); +}; + +namespace SingleDotSchedule { +using namespace mlir::SingleDotSchedule; +using ClusterMap = DenseMap; + +ClusterMap createClusterMap(tt::CoarseSchedule &schedule) { + DenseMap clusterMap; + for (auto &[op, stageAndCluster] : schedule.opToStageAndCluster) { + auto [stage, cluster] = stageAndCluster; + tt::CoarseSchedule::ClusterHash clusterHash = + tt::CoarseSchedule::hashCluster(cluster); + clusterMap[clusterHash] = *cluster; + } + + return clusterMap; +} + +// Remap global and compute clusters to the right place +void remapClusters(tt::CoarseSchedule &schedule, ClusterMap clusterMap, + Clusters &clusters) { + for (auto &[op, stageAndCluster] : schedule.opToStageAndCluster) { + auto [stage, cluster] = stageAndCluster; + tt::CoarseSchedule::ClusterHash clusterHash = + tt::CoarseSchedule::hashCluster(stageAndCluster.second); + int oldClusterId = clusterMap[clusterHash]; + if (oldClusterId == 0) { + stageAndCluster.second = clusters[SCHED_GLOBAL_LOAD]; + } else { + assert(oldClusterId == 1); + stageAndCluster.second = clusters[SCHED_COMPUTE]; + } + } +} + +// Init Schedule Config based on settings and loop characteristics. +// Create clusters in order of ops in loop. This can interleave ops +// from different stages in the same cluster to achieve better backend +// scheduling. +// WARNING: Changing the order of schedule.clusters.newAtBack() calls +// can cause invalid schedules to be produced. +LogicalResult initSchedule(int maxDist, Stages &stages, int numStages, + int &numBuffers, bool useAsyncCopy, bool waitAtTail, + Clusters &clusters, tt::CoarseSchedule &schedule) { + LDBG("Init SingleDotSchedule"); + int lastStage = numStages - 1; + stages[SCHED_GLOBAL_LOAD] = 0; + stages[SCHED_LOCAL_STORE] = 0; + stages[SCHED_LOCAL_LOAD] = lastStage; + stages[SCHED_COMPUTE] = lastStage; + stages[SCHED_ASYNC_WAIT] = stages[SCHED_LOCAL_LOAD]; + + bool pairedGlobalLoadLocalStore = stages[SCHED_LOCAL_STORE] == 0; + stages[SCHED_LOCAL_STORE] += maxDist; + if (waitAtTail) { + stages[SCHED_ASYNC_WAIT] = std::max(0, stages[SCHED_LOCAL_LOAD] - 1); + } + + LDBG( + "Stage schedule:" << " GLOBAL_LOAD stage = " << stages[SCHED_GLOBAL_LOAD] + << ", LOCAL_STORE stage = " << stages[SCHED_LOCAL_STORE] + << ", LOCAL_LOAD stage = " << stages[SCHED_LOCAL_LOAD] + << ", COMPUTE stage = " << stages[SCHED_COMPUTE] + << ", ASYNC_WAIT stage = " << stages[SCHED_ASYNC_WAIT] + << "; total = " << numStages); + + if (stages[SCHED_LOCAL_STORE] >= numStages || + stages[SCHED_LOCAL_STORE] > stages[SCHED_LOCAL_LOAD]) { + LDBG("Invalid stage schedule"); + return failure(); + } + + // Calculate the number of buffers needed for each load. + // TODO: Use the precise number of buffers needed by the particular load. + numBuffers = + std::max(1, stages[SCHED_LOCAL_LOAD] - stages[SCHED_LOCAL_STORE]); + // If we use AsyncCopy we need one more buffer since we are not using a + // register buffer + if (useAsyncCopy) { + numBuffers += 1; + } + + LDBG("deduced max shared memory buffer number = " << numBuffers); + + // We place async wait as the first cluster because we want to have it being + // the first in the main loop after pipelining. + // In case we use async_copy with pingpong, we need to place async_wait at + // the end of the previous iteration, so it can guarantee the correct + // dependency when warp0 and warp1 are pipelined. + int asyncWaitCluster = waitAtTail ? 4 : 0; + // If tt.load and ttg.local_store are in the same stage + // spread them apart to allow overlap with compute + // else + // Initiate ttg.local_store before tt.load + int globalLoadCluster = 1; + int localStoreCluster = 3; + if (!pairedGlobalLoadLocalStore) { + globalLoadCluster = 3; + localStoreCluster = 2; + } + + // If ttg.local_load and ttg.local_store are in the same stage + // spread them apart to allow overlap with compute + // else if they share the buffer + // ttg.local_load must come first + // else + // schedule ttg.local_load in the middle + int localLoadCluster = globalLoadCluster; + if (stages[SCHED_LOCAL_LOAD] == stages[SCHED_LOCAL_STORE]) { + localLoadCluster = std::max(3, localStoreCluster + 1); + } else if (numBuffers == 1 && localLoadCluster >= localStoreCluster) { + // For 1 buffer, ttg.local_load must occur before ttg.local_store + localLoadCluster = localStoreCluster - 1; + } + + // Schedule compute with ttg.local_load if paired + // otherwise, schedule in the middle + int computeCluster = 2; + if (stages[SCHED_LOCAL_LOAD] == stages[SCHED_COMPUTE]) { + computeCluster = localLoadCluster; + } + + // Create a hash map to associate cluster hash in old schedule with its + // clusterID + ClusterMap clusterMap = createClusterMap(schedule); + + // Make assignments + Clusters clusterVec; + schedule.clusters.clear(); + std::generate(clusterVec.begin(), clusterVec.end(), + [&]() { return schedule.clusters.newAtBack(); }); + + clusters[SCHED_GLOBAL_LOAD] = clusterVec[globalLoadCluster]; + clusters[SCHED_LOCAL_STORE] = clusterVec[localStoreCluster]; + clusters[SCHED_LOCAL_LOAD] = clusterVec[localLoadCluster]; + clusters[SCHED_COMPUTE] = clusterVec[computeCluster]; + clusters[SCHED_ASYNC_WAIT] = clusterVec[asyncWaitCluster]; + + remapClusters(schedule, clusterMap, clusters); + + LDBG("Cluster schedule:" << " GLOBAL_LOAD cluster = " << globalLoadCluster + << ", LOCAL_STORE cluster = " << localStoreCluster + << ", LOCAL_LOAD cluster = " << localLoadCluster + << ", COMPUTE cluster = " << computeCluster + << ", ASYNC_WAIT cluster = " << asyncWaitCluster + << "; total = " << SCHED_SIZE); + + return success(); +} + +void scheduleAsyncCopy(const AsyncCopyChainOps &asyncOps, tt::LoadOp loadOp, + tt::CoarseSchedule &schedule, const Stages &stages, + const Clusters &clusters) { + auto [copyOp, commitOp, waitOp, maybeLocalLoadOp] = asyncOps; + auto [loadStage, loadCluster] = schedule[loadOp]; + schedule.insert(copyOp, loadStage, loadCluster); + // Place ttg.async_commit_group op following AsyncCopyGlobalToLocal so the + // later UpdateAsyncWaitCount pass can deduce better waitcnts + schedule.insert(commitOp, loadStage, loadCluster); + // If the LocalLoads are scheduled to a later stage than AsyncCopy we need to + // place the AsyncCopy prefetches after the AsyncWaits which create a barrier + // to ensure all warps are finished reading the shared buffer we will write + // into. This is done by scheduling AsyncWait as the first cluster. + // If AsyncCopy and LocalLoads are in the same stage we do not assign a + // schdule so they are placed before the LocalLoads + if (loadStage != stages[SCHED_LOCAL_LOAD]) + schedule.insert(waitOp, stages[SCHED_ASYNC_WAIT], + clusters[SCHED_ASYNC_WAIT]); + + if (maybeLocalLoadOp && stages[SCHED_LOCAL_LOAD] != stages[SCHED_COMPUTE]) { + scheduleLocalLoad(maybeLocalLoadOp, schedule, stages[SCHED_LOCAL_LOAD], + clusters[SCHED_LOCAL_LOAD]); + } +} + +void scheduleStreamCopy(const StreamCopyChainOps &streamOps, + tt::LoadOp oldLoadOp, tt::CoarseSchedule &schedule, + const Stages &stages, const Clusters &clusters) { + auto [newLoadOp, subviewOp, localStoreOp, maybeLocalLoadOp] = streamOps; + auto [loadStage, loadCluster] = schedule[oldLoadOp]; + + schedule.insert(newLoadOp, loadStage, loadCluster); + schedule.insert(subviewOp, stages[SCHED_LOCAL_STORE], + clusters[SCHED_LOCAL_STORE]); + schedule.insert(localStoreOp, stages[SCHED_LOCAL_STORE], + clusters[SCHED_LOCAL_STORE]); + if (maybeLocalLoadOp && stages[SCHED_LOCAL_LOAD] != stages[SCHED_COMPUTE]) { + scheduleLocalLoad(maybeLocalLoadOp, schedule, stages[SCHED_LOCAL_LOAD], + clusters[SCHED_LOCAL_LOAD]); + } +} + +void scheduleStreamOps(const LoadToStreamOpMap &loadToStreamOp, + tt::CoarseSchedule &schedule, const Stages &stages, + const Clusters &clusters) { + for (auto [l, streamOps] : loadToStreamOp) { + auto loadOp = dyn_cast(l); + if (!loadOp) + continue; + + if (auto asyncOps = std::get_if(&streamOps)) { + scheduleAsyncCopy(*asyncOps, loadOp, schedule, stages, clusters); + } else if (auto sOps = std::get_if(&streamOps)) { + scheduleStreamCopy(*sOps, loadOp, schedule, stages, clusters); + } + } +} + +void updateSchedule(scf::ForOp &forOp, const LoadToInfoMap &loadToInfo, + tt::CoarseSchedule &schedule, + triton::AMD::ModuleAxisInfoAnalysis &axisInfoAnalysis, + int numStages, bool useAsyncCopy, bool waitAtTail) { + LDBG("SingleDotSchedule::updateSchedule"); + Stages stages; + Clusters clusters; + + int maxDist = 0; + for (auto &[l, info] : loadToInfo) { + maxDist = std::max(maxDist, info.distToUse); + } + + int numBuffers = 1; + if (failed(initSchedule(maxDist, stages, numStages, numBuffers, useAsyncCopy, + waitAtTail, clusters, schedule))) + return; + + // Convert the loads into shared memory allocations and loads from them. + auto loadToStreamOps = createStreamOps(loadToInfo, forOp, numBuffers, + useAsyncCopy, axisInfoAnalysis); + + scheduleStreamOps(loadToStreamOps, schedule, stages, clusters); + dumpSchedule(schedule, "Coarse schedule stream ops:"); + + scheduleDependencies(forOp, schedule); + dumpSchedule(schedule, "Coarse schedule with dependencies:"); + ttg::scheduleDistanceOneDependencies(forOp, schedule); + dumpSchedule(schedule, "Coarse schedule with dist 1:"); + tt::CoarseSchedule::Cluster computeCluster = clusters[SCHED_COMPUTE]; + ttg::scheduleRemainingToLastStage(forOp, schedule, computeCluster); + dumpSchedule(schedule, "Final coarse schedule:"); +} +} // namespace SingleDotSchedule + +namespace ChainedDotSchedule { +using namespace mlir::ChainedDotSchedule; + +void scheduleAsyncCopy(const AsyncCopyChainOps &asyncOps, tt::LoadOp loadOp, + tt::CoarseSchedule &schedule, + const ChainedDotClusters &clusters) { + auto [loadStage, loadCluster] = schedule[loadOp]; + auto [copyOp, commitOp, waitOp, maybeLocalLoadOp] = asyncOps; + + schedule.insert(copyOp, loadStage, loadCluster); + // Place ttg.async_commit_group op following AsyncCopyGlobalToLocal so the + // later UpdateAsyncWaitCount pass can deduce better waitcnts + schedule.insert(commitOp, loadStage, loadCluster); + + if (loadStage == STAGE_GLOBAL_LOAD_1) { + schedule.insert(waitOp, STAGE_LOCAL_LOAD_1, clusters[CLUSTER_ASYNC_WAIT_1]); + if (maybeLocalLoadOp) + scheduleLocalLoad(maybeLocalLoadOp, schedule, STAGE_LOCAL_LOAD_1, + clusters[CLUSTER_LOCAL_LOAD_1]); + } else { + schedule.insert(waitOp, STAGE_LOCAL_LOAD_2, clusters[CLUSTER_ASYNC_WAIT_2]); + if (maybeLocalLoadOp) + scheduleLocalLoad(maybeLocalLoadOp, schedule, STAGE_LOCAL_LOAD_2, + clusters[CLUSTER_LOCAL_LOAD_2]); + } +} + +void scheduleStreamCopy(const StreamCopyChainOps &streamOps, tt::LoadOp loadOp, + tt::CoarseSchedule &schedule, + const ChainedDotClusters &clusters) { + auto [loadStage, loadCluster] = schedule[loadOp]; + auto [copyOp, subviewOp, localStoreOp, maybeLocalLoadOp] = streamOps; + schedule.insert(copyOp, loadStage, loadCluster); + + if (loadStage == STAGE_GLOBAL_LOAD_1) { + schedule.insert(subviewOp, STAGE_LOCAL_WRITE_1, + clusters[CLUSTER_LOCAL_WRITE_1]); + schedule.insert(localStoreOp, STAGE_LOCAL_WRITE_1, + clusters[CLUSTER_LOCAL_WRITE_1]); + + if (maybeLocalLoadOp) + schedule.insert(maybeLocalLoadOp, STAGE_LOCAL_LOAD_1, + clusters[CLUSTER_LOCAL_LOAD_1]); + } else { + schedule.insert(subviewOp, STAGE_LOCAL_WRITE_2, + clusters[CLUSTER_LOCAL_WRITE_2]); + schedule.insert(localStoreOp, STAGE_LOCAL_WRITE_2, + clusters[CLUSTER_LOCAL_WRITE_2]); + if (maybeLocalLoadOp) + schedule.insert(maybeLocalLoadOp, STAGE_LOCAL_LOAD_2, + clusters[CLUSTER_LOCAL_LOAD_2]); + } + + if (maybeLocalLoadOp) { + if (auto cvt = dyn_cast( + *maybeLocalLoadOp->getUsers().begin())) { + auto [localLoadStage, localLoadCluster] = schedule[maybeLocalLoadOp]; + schedule.insert(cvt, localLoadStage, localLoadCluster); + } + } +} + +void scheduleStreamOps(const LoadToStreamOpMap &loadToStreamOp, + tt::CoarseSchedule &schedule, + const ChainedDotClusters &clusters) { + for (auto [l, streamOps] : loadToStreamOp) { + auto loadOp = dyn_cast(l); + if (!loadOp) + continue; + + if (auto asyncOps = std::get_if(&streamOps)) { + scheduleAsyncCopy(*asyncOps, loadOp, schedule, clusters); + } else if (auto sOps = std::get_if(&streamOps)) { + scheduleStreamCopy(*sOps, loadOp, schedule, clusters); + } + } +} + +void updateSchedule(scf::ForOp &forOp, const LoadToInfoMap &loadToInfo, + tt::CoarseSchedule &schedule, + triton::AMD::ModuleAxisInfoAnalysis &axisInfoAnalysis, + bool useAsyncCopy) { + LDBG("ChainedDotSchedule::updateSchedule"); + ChainedDotClusters clusters; + int cnt = clusters.size() - schedule.clusters.size(); + for (int i = 0; i < cnt; i++) { + schedule.clusters.newAtBack(); + } + auto it = schedule.clusters.begin(); + for (int i = 0; i < clusters.size(); i++, it++) { + clusters[i] = it; + } + + // Convert the loads into shared memory allocations and loads from them. + // TODO support different numBuffers + int numBuffers = useAsyncCopy ? 2 : 1; + auto loadToStreamOps = createStreamOps(loadToInfo, forOp, numBuffers, + useAsyncCopy, axisInfoAnalysis); + scheduleStreamOps(loadToStreamOps, schedule, clusters); + + for (auto [l, _] : loadToInfo) { + schedule.erase(l); + l->erase(); + } + + scheduleDependencies(forOp, schedule); + dumpSchedule(schedule, "Coarse schedule with dependencies:"); + + triton::gpu::scheduleDistanceOneDependencies(forOp, schedule); + dumpSchedule(schedule, "Coarse schedule with dist 1:"); + + tt::CoarseSchedule::Cluster lastCluster = clusters.back(); + triton::gpu::scheduleRemainingToLastStage(forOp, schedule, lastCluster); + dumpSchedule(schedule, "Final coarse schedule:"); +} +} // namespace ChainedDotSchedule + +void lowerLoop(scf::ForOp forOp, + triton::AMD::ModuleAxisInfoAnalysis &axisInfoAnalysis, + bool useAsyncCopy, bool usePingpong) { + tt::CoarseSchedule schedule; + if (failed(schedule.deSerialize(forOp, /*normalizeClusterId=*/false))) { + return; + } + + dumpSchedule(schedule, "[lowerLoops]deserialized schedule:"); + + int numStages = schedule.getNumStages(); + + // i.e., we can still disable `waitAtTail` by explicitly disabling + // pingpong, which is the only use case of this scheduling variant. + bool waitAtTail = usePingpong && (numStages == 3) && useAsyncCopy; + + llvm::MapVector> loadOpToIndLevel = + getIndirectLevel(axisInfoAnalysis, forOp, numStages); + + LoadToInfoMap loadToInfo; + for (const auto &[load, info] : loadOpToIndLevel) { + auto [distance, use] = info; + if (load->hasAttrOfType(AttrBypassLDS)) { + load->removeAttr(AttrBypassLDS); + loadToInfo[load] = {nullptr, distance, use}; + } else { + LDBG("Deduce shared encoding for: " << *load); + auto sharedEncoding = + getSharedEncIfAllUsersAreDotEnc(load->getResult(0)).value_or(nullptr); + loadToInfo[load] = {sharedEncoding, distance, use}; + LDBG("Populate loadInfo with shared encoding: " << sharedEncoding); + } + } + + if (succeeded(mlir::ChainedDotSchedule::checkPreconditions(forOp, numStages, + loadToInfo))) { + ChainedDotSchedule::updateSchedule(forOp, loadToInfo, schedule, + axisInfoAnalysis, useAsyncCopy); + } else { + SingleDotSchedule::updateSchedule(forOp, loadToInfo, schedule, + axisInfoAnalysis, numStages, useAsyncCopy, + waitAtTail); + } + + dumpSchedule(schedule, "[lowerLoops]updated schedule:"); + + schedule.serialize(forOp); +} + +void lowerLoops(ModuleOp moduleOp, bool useAsyncCopy, bool usePingpong) { + triton::AMD::ModuleAxisInfoAnalysis axisInfoAnalysis(moduleOp); + SmallVector loops; + moduleOp->walk([&](scf::ForOp forOp) { loops.push_back(forOp); }); + if (loops.empty()) + return; + for (auto forOp : loops) { + lowerLoop(forOp, axisInfoAnalysis, useAsyncCopy, usePingpong); + } +} + +} // namespace mlir diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/Pipeline.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/Pipeline.cpp index 7e856577f7..14ad8d3e22 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/Pipeline.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/Pipeline.cpp @@ -2,7 +2,7 @@ #include "amd/lib/TritonAMDGPUTransforms/PipelineUtility.h" #include "triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h" -#define DEBUG_TYPE "tritonamdgpu-pipeline" +#define DEBUG_TYPE "tritonamdgpu-pipeline-expand-loops" #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") @@ -33,7 +33,7 @@ Operation *streamPredication(RewriterBase &rewriter, Operation *op, return tt::wrapInMaskOp(rewriter, op, pred); } -void expandLoops(ModuleOp moduleOp, bool useAsyncCopy) { +void expandLoops(ModuleOp moduleOp) { SmallVector loops; moduleOp->walk([&](scf::ForOp forOp) { loops.push_back(forOp); }); for (scf::ForOp forOp : loops) { @@ -91,19 +91,11 @@ void expandLoops(ModuleOp moduleOp, bool useAsyncCopy) { if (failed(newForOp)) continue; - - forOp = *newForOp; } // NOTE: Leave empty for now, until we utilize customEpiloguePeeling DenseSet peeledMaskOps; tt::resolveMaskOp(moduleOp, peeledMaskOps); - - if (useAsyncCopy) { - llvm::SmallSetVector waitOps; - moduleOp.walk([&](ttg::AsyncWaitOp waitOp) { waitOps.insert(waitOp); }); - tt::combineRedundantWaitOps(waitOps); - } } } // namespace @@ -112,7 +104,14 @@ struct PipelinePass : impl::TritonAMDGPUPipelineBase { void runOnOperation() override { ModuleOp moduleOp = getOperation(); - expandLoops(moduleOp, useAsyncCopy); + lowerLoops(moduleOp, useAsyncCopy, usePingpong); + expandLoops(moduleOp); + + if (useAsyncCopy) { + llvm::SmallSetVector waitOps; + moduleOp.walk([&](ttg::AsyncWaitOp waitOp) { waitOps.insert(waitOp); }); + tt::combineRedundantWaitOps(waitOps); + } tt::removePipeliningAttributes(moduleOp); } diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/PipelineUtility.h b/third_party/amd/lib/TritonAMDGPUTransforms/PipelineUtility.h index c03a643f0c..19ff6398f3 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/PipelineUtility.h +++ b/third_party/amd/lib/TritonAMDGPUTransforms/PipelineUtility.h @@ -6,6 +6,21 @@ #include "triton/Dialect/TritonGPU/Transforms/Schedule.h" namespace mlir { + +namespace triton::AMD { +constexpr char AttrBypassLDS[] = "amdgpu.bypass_lds_load"; +} + +// This function will +// - deserialize schedule and numStages from IR. +// - calculate stages and clusters taking all factors into account, and remap +// symbolic clusters of global load and compute ops to their real clusters. +// - create lds alloc/dealloc/load/store or async load/commit/wait ops if +// possible. +// - schedule these new ops. +// - serialize schedule to IR for the next expandLoops function. +void lowerLoops(ModuleOp moduleOp, bool useAsyncCopy, bool usePingpong); + struct LoadInfo { // Shared layout is used for loads feeding into dot ops. triton::gpu::SwizzledSharedEncodingAttr sharedEncoding = nullptr; @@ -15,6 +30,13 @@ struct LoadInfo { }; using LoadToInfoMap = llvm::MapVector; +// A slim wrapper of ttg::loadOpsToIndirectionLevel, to get the indirection +// levels and final users of load ops. For details you can check the comment of +// ttg::loadOpsToIndirectionLevel. +llvm::MapVector> +getIndirectLevel(triton::AMD::ModuleAxisInfoAnalysis &axisInfoAnalysis, + scf::ForOp &forOp, int numStages); + namespace SingleDotSchedule { // Define categories of scheduling details per Operation types. // The SingleDotSchedule schedules 5 types of operations: diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/ScheduleLoops.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/ScheduleLoops.cpp index 44676bb606..b694925b45 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/ScheduleLoops.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/ScheduleLoops.cpp @@ -15,29 +15,6 @@ #include "llvm/Support/Debug.h" #include -//===----------------------------------------------------------------------===// -// This file will create a schedule that will be handed over to the pipeline -// expander. -// Software pipeliners are usually separated into two pieces, one that create a -// modulo schedule and an expander that rewrites the loop and emits a prologue -// and epilogue. This pass first calls a helper that will pre-process the IR -// to create stream operations and create a modulo schedule. Then we call the -// expander to generate the prologue and new loop and epilogue. -//===----------------------------------------------------------------------===// - -#define DEBUG_TYPE "tritonamdgpu-schedule-loops" -#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") -#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") - -namespace tt = mlir::triton; -namespace ttg = mlir::triton::gpu; - -namespace mlir { - -#define GEN_PASS_DEF_TRITONAMDGPUSCHEDULELOOPS -#include "TritonAMDGPUTransforms/Passes.h.inc" - -namespace { //===----------------------------------------------------------------------===// // Software pipelining generally works by anchoring on global load ops in the // main loop and rotating the loop to schedule global load ops for future loop @@ -51,26 +28,27 @@ namespace { // consists of multiple stages, where ops from different stages can overlap // executions because the dependencies are loop carried. // -// The general flow of this process is: +// The general flow of this process is(This is an overview. Some passes or +// functions are in other files): // // 1. The user provides a `num_stages` that specifies how many stages the // pipeline will have. The number of stages must be larger than the distance // from the first independent load to the compute in order to pipeline. -// 2. A schedule is created based on the distance between the global loads -// in the first stages and the compute that uses the loaded values in the -// last stage (num_stages - 1). Each operation will be clustered in the -// order to best overlap with other operations (see details below in the -// initSchedule methods). -// 3. When the compute is a tt.dot, the scheduler will insert a shared -// memory allocation between the global load and tt.dot. The global load -// value will be saved to shared memory, via ttg.local_store or via +// 2. In this pass, a schedule is created based on the distance between the +// global loads in the first stages and the compute that uses the loaded +// values in the last stage (num_stages - 1). Each operation will be +// clustered in the order to best overlap with other operations. +// 3. In lowerLoops, when the compute is a tt.dot, the scheduler will insert a +// shared memory allocation between the global load and tt.dot. The global +// load value will be saved to shared memory, via ttg.local_store or via // ttg.async_copy_global_to_local writing directly to shared memory, and the // ttg.local_load will load the relevant tiles for the tt.dot. These // operations will be scheduled according to various scheduling schemes -// outlined below in the initSchedule methods (see details there). -// 4. Finally the schedule will be passed to the PipelineExpander to rewrite -// accordingly. The new implementation will consist of: -// a. Prologue: containing the ramp-up of num_stages-1 stages for +// outlined in the initSchedule methods in LowerLoops.cpp (see details +// there). +// 4. Finally in TritonAMDGPUPipeline pass, the schedule will be passed to the +// PipelineExpander to rewrite accordingly. The new implementation will +// consist of: a. Prologue: containing the ramp-up of num_stages-1 stages for // iteratorions i=[0, num_stages-1). // b. New loop: ordered by cluster and iterated on each operation by // `i + (num_stages-op_stage)`. @@ -79,309 +57,82 @@ namespace { // bounds may be shorter than num_stages. In this case, the epilogue // iterations must align with the prologue. // +// +// This file implements the first stage of software pipelining. It builds a +// symbolic schedule for global memory access and compute operations. Certain +// optimizations (e.g. bypassLDS) are applied conditionally. +// +// Two additional stages follow: +// 1. lowerLoops in LowerLoops.cpp creates LDS alloc/load/store or async +// load/commit/await ops as needed and produces a schedule for them. +// 2. expandLoops in Pipeline.cpp invokes PipelineExpander to apply the schedule +// to the loops and then performs post-processing. +// +// These stages are connected via the schedule serialized in the IR. +//===----------------------------------------------------------------------===// -struct StreamCopyChainOps { - tt::LoadOp loadOp; - ttg::MemDescIndexOp subviewOp; - ttg::LocalStoreOp localStoreOp; - ttg::LocalLoadOp maybeLocalLoadOp; -}; - -struct AsyncCopyChainOps { - ttg::AsyncCopyGlobalToLocalOp copyOp; - ttg::AsyncCommitGroupOp commitOp; - ttg::AsyncWaitOp waitOp; - ttg::LocalLoadOp maybeLocalLoadOp; -}; - -using StreamOpVariant = std::variant; -using LoadToStreamOpMap = llvm::MapVector; - -AsyncCopyChainOps createAsyncCopy(tt::LoadOp loadOp, Value alloc, - Value extractIdx) { - OpBuilder builder(loadOp); - Location loc = loadOp.getLoc(); - - // Extract local subview from shared allocation - auto viewLoad = triton::createSingleBufferView(builder, alloc, extractIdx) - .getDefiningOp(); - - auto copyOp = builder.create( - loc, loadOp.getPtr(), viewLoad, loadOp.getMask(), loadOp.getOther(), - loadOp.getCache(), loadOp.getEvict(), loadOp.getIsVolatile()); - auto commitOp = - builder.create(loc, copyOp->getResult(0)); - ttg::AsyncWaitOp waitOp = - builder.create(loc, commitOp->getResult(0), 0); +#define DEBUG_TYPE "tritonamdgpu-schedule-loops" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") - auto maybeSharedLoad = tt::replaceUsesWithLocalLoad( - builder, loadOp->getResult(0), viewLoad, waitOp); +namespace tt = mlir::triton; +namespace ttg = mlir::triton::gpu; - return {copyOp, commitOp, waitOp, maybeSharedLoad}; -} +using mlir::triton::AMD::AttrBypassLDS; -void scheduleLocalLoad(ttg::LocalLoadOp localLoadOp, - tt::CoarseSchedule &schedule, int stage, - const tt::CoarseSchedule::Cluster &cluster) { - schedule.insert(localLoadOp, stage, cluster); - // If its only user is a ConvertLayout, we place it into the same stage so - // it can be folded by a later pass - if (localLoadOp->hasOneUse()) { - auto cvt = *localLoadOp->getUsers().begin(); - if (isa(cvt)) { - schedule.insert(cvt, stage, cluster); - } - } -} +namespace mlir { -StreamCopyChainOps createStreamCopy(tt::LoadOp loadOp, Value alloc, - Value extractIdx) { - OpBuilder builder(loadOp); - Location loc = loadOp.getLoc(); +#define GEN_PASS_DEF_TRITONAMDGPUSCHEDULELOOPS +#include "TritonAMDGPUTransforms/Passes.h.inc" - // Extract local subview from shared allocation - auto viewLoad = triton::createSingleBufferView(builder, alloc, extractIdx) - .getDefiningOp(); +llvm::MapVector> +getIndirectLevel(triton::AMD::ModuleAxisInfoAnalysis &axisInfoAnalysis, + scf::ForOp &forOp, int numStages) { + auto arch = getAMDArch(forOp->getParentOfType()); + triton::AMD::ISAFamily isaFamily = triton::AMD::ISAFamily::Unknown; + if (arch) + isaFamily = triton::AMD::deduceISAFamily(*arch); - tt::LoadOp newLoadOp = cast(builder.clone(*loadOp)); - auto storeOp = builder.create(loc, newLoadOp, viewLoad); - auto maybeLocalLoad = - tt::replaceUsesWithLocalLoad(builder, loadOp->getResult(0), viewLoad); + bool pipelineWithoutDot = forOp->hasAttr(mlir::triton::kNumStagesAttrName); + bool filterSmallVectors = + isaFamily != triton::AMD::ISAFamily::CDNA4 && !isRDNA(isaFamily); + llvm::MapVector> loadOpToIndLevel = + triton::gpu::loadOpsToIndirectionLevel(forOp, pipelineWithoutDot, + axisInfoAnalysis, numStages, + filterSmallVectors); - return {newLoadOp, viewLoad, storeOp, maybeLocalLoad}; + return loadOpToIndLevel; } -// Returns the given |inputValue|'s dot user result encoding and updates |opIdx| -// and |vecSize| with which dot operand |inputValue| is fed into if possible. -ttg::AMDMfmaEncodingAttr getDotEncoding(Value inputValue, unsigned *opIdx, - unsigned *vecSize) { - if (!inputValue.hasOneUse()) - return nullptr; - - Operation *user = *inputValue.getUsers().begin(); - if (user->getNumResults() != 1 || - user->getBlock() != inputValue.getParentBlock()) - return nullptr; - - LDBG("getDotEncoding user: " << *user); - if (auto dotOp = dyn_cast(user)) { - OpOperand &use = *inputValue.getUses().begin(); - *opIdx = use.getOperandNumber(); - auto operandType = cast(inputValue.getType()); - *vecSize = ttg::toLinearLayout(operandType).getNumConsecutiveInOut(); - auto dotType = cast(dotOp->getResult(0).getType()); - return dyn_cast(dotType.getEncoding()); - } - - return getDotEncoding(user->getResult(0), opIdx, vecSize); -} +LogicalResult +mlir::ChainedDotSchedule::checkPreconditions(scf::ForOp forOp, int numStages, + LoadToInfoMap loadToInfo) { + if (numStages != 4) + return failure(); -// Adapted from -// lib/Dialect/TritonGPU/Transforms/Utility.cpp::getSharedEncIfAllUsersAreDotEnc -// to support AMDMfmaEncodingAttr. -// TODO(max): figure out how to refactor to use upstream -// -// If all the transitive uses of the given value have are used by a convert to -// the same dot operand encoding, return true and get the shared encoding that -// needs to be used to be compatible with users' layouts. -std::optional -getSharedEncIfAllUsersAreDotEnc(Value loadedValue) { - llvm::SmallVector sharedEncs; - for (Operation *user : loadedValue.getUsers()) { - LDBG(" getSharedEncIfAllUsersAreDotEnc current user: " << *user); - if (user->getNumResults() != 1) - return std::nullopt; - - ttg::SwizzledSharedEncodingAttr tempAttr; - Value userResult = user->getResult(0); - Type userResType = userResult.getType(); - if (auto memDesc = dyn_cast(userResType)) { - // First time we find a shared encoding in the chain, save it and try to - // use it if it is compatible with the other users. - tempAttr = cast(memDesc.getEncoding()); - // If the immediate user is ttg::LocalAllocOp, likely it's created in - // TritonAMDGPUOptimizeDotOperands. We should just respect it. - if (!getSharedEncIfAllUsersAreDotEnc(userResult).has_value() && - !isa(user)) { - return std::nullopt; - } - LDBG("Deduced shared encoding candidate from memDesc: " << tempAttr); - sharedEncs.push_back(tempAttr); - } else { - if (!(isa(user) || - user->hasTrait())) - return std::nullopt; - - auto srcTy = cast(loadedValue.getType()); - auto ctaLayout = ttg::getCTALayout(srcTy.getEncoding()); - auto order = getOrderForMemory(srcTy); - unsigned bitWidth = srcTy.getElementType().getIntOrFloatBitWidth(); - SmallVector sharedOrder; - int rank = order.size(); - // TODO rework this when shared -> dotOperand conversions support - // arbitrary shared memory ordering - if (rank == 3) { - // Move the batch dimension (dim #0) to be the last so that it will be - // the slowest varying dimension. - for (unsigned i = 0; i < rank; ++i) - if (order[i] != 0) - sharedOrder.emplace_back(order[i]); - sharedOrder.emplace_back(0); - } else { - sharedOrder = order; - } + auto dotOps = llvm::to_vector(forOp.getBody()->getOps()); - auto userResEnc = cast(userResType).getEncoding(); - if (auto dotOpEnc = dyn_cast(userResEnc)) { - tempAttr = ttg::SwizzledSharedEncodingAttr::get( - loadedValue.getContext(), dotOpEnc, srcTy.getShape(), sharedOrder, - ctaLayout, bitWidth, /*needTrans=*/false); - LDBG("Deduced shared encoding candidate from dot layout: " << tempAttr); - sharedEncs.push_back(tempAttr); - } else if (auto llEnc = dyn_cast(userResEnc)) { - // We use linear layout directly for scaled dot fp8 operands. For such - // cases, we need to look further down the def-use chain to find the dot - // op for the mfma layout to deduce operand index and other information. - unsigned opIdx; - unsigned vecSize; - if (auto mfmaEnc = getDotEncoding(userResult, &opIdx, &vecSize)) { - LDBG("deduced opIdx: " << opIdx << "; deduced vecSize: " << vecSize); - tempAttr = mfmaEnc.composeSharedLayoutForOperand( - ctaLayout, opIdx, srcTy.getShape(), order, vecSize, bitWidth, - /*needTrans=*/false); - LDBG("Deduced shared encoding candidate from mfma layout: " - << tempAttr); - sharedEncs.push_back(tempAttr); - } - } - } - } + if (dotOps.size() != 2) + return failure(); - auto equalSharedEncIgnoreVec = [](ttg::SwizzledSharedEncodingAttr a, - ttg::SwizzledSharedEncodingAttr b) { - if (!a || !b) - return false; - return (a.getPerPhase() == b.getPerPhase() && - a.getMaxPhase() == b.getMaxPhase() && - a.getOrder() == b.getOrder() && - a.getCTALayout() == b.getCTALayout()); - }; - if (sharedEncs.empty() || !sharedEncs.front()) - return std::nullopt; - auto maxVecSharedEnc = sharedEncs.front(); - - for (auto sharedEnc : sharedEncs) { - if (!equalSharedEncIgnoreVec(sharedEnc, maxVecSharedEnc)) { - LDBG("Incompatible shared encodings"); - return std::nullopt; - } - if (sharedEnc.getVec() > maxVecSharedEnc.getVec()) { - maxVecSharedEnc = sharedEnc; - } + // Check that the first dot feeds into the second + SetVector slice; + getForwardSlice(dotOps[0]->getResult(0), &slice); + if (!slice.contains(dotOps[1])) { + return failure(); } - LDBG("Deduced shared encoding: " << maxVecSharedEnc); - - return maxVecSharedEnc; -} - -bool canBeConvertedToAsyncLoad(unsigned numBuffers, tt::LoadOp loadOp, - Value alloc, - tt::ModuleAxisInfoAnalysis &axisInfoAnalysis, - const tt::AMD::TargetInfo &targetInfo) { - // If we have a single buffer we would require another barrier after the - // local_reads so instead we fall back to pipeline with registers - // Removing this check will create incorrect IR, see - // MembarUtility.h:membarFilter - if (numBuffers <= 1) - return false; - - // Compute the final vecSize we can use for the combination of sourceEncoding - // and sharedEncoding. We can only use AsyncCopy if the target supports the - // requested or a smaller vecSize because we cannot stride when loading - // directly to lds - auto srcTy = cast(loadOp.getPtr().getType()); - auto dstTy = cast(alloc.getType()); - auto regLayout = triton::gpu::toLinearLayout(srcTy); - // It's the allocation so we trim the multibuffer dimension - auto srcShape = dstTy.getShape().take_back(srcTy.getRank()); - auto sharedLayout = - triton::gpu::toLinearLayout(srcShape, dstTy.getEncoding()); - auto regToSharedLayout = regLayout.invertAndCompose(sharedLayout); - - unsigned vecSize = regToSharedLayout.getNumConsecutiveInOut(); - unsigned elemBitWidth = dstTy.getElementTypeBitWidth(); - - if (fitToValidDirectToLdsVecSize(vecSize, elemBitWidth, targetInfo) == 0) - return false; - - // Checks whether the global pointer's contiguity and mask alignment allows - // for at least 32 bit wide loads - return triton::canBeConvertedToAsyncLoad(loadOp, axisInfoAnalysis); -} - -// Convert load ops into shared memory allocation loads and apply -// multi-buffering based on the required number of buffers. -LoadToStreamOpMap -createStreamOps(const LoadToInfoMap &loadToInfo, scf::ForOp &forOp, - const int &numBuffers, bool useAsyncCopy, - tt::ModuleAxisInfoAnalysis &axisInfoAnalysis) { - IRRewriter builder(forOp); - Location loc = forOp.getLoc(); - Value minusOne = builder.create(loc, -1, 32); - Value zero = builder.create(loc, 0, 32); - Value one = builder.create(loc, 1, 32); - Value extractIdx = minusOne; - Value numBuffersVal = - builder.create(loc, numBuffers, 32); - - unsigned newOperandIndex = forOp.getBody()->getNumArguments(); - // Patch the loop to add the new loop carried dependency. - forOp = addIterArgsToLoop(builder, forOp, {extractIdx}); - - // Create one counter for the extract indices to avoid creating long - // live range. - extractIdx = forOp.getBody()->getArgument(newOperandIndex); - - builder.setInsertionPoint(forOp.getBody(), forOp.getBody()->begin()); - extractIdx = builder.create(loc, extractIdx, one); - Value cndExt = builder.create(loc, arith::CmpIPredicate::slt, - extractIdx, numBuffersVal); - extractIdx = builder.create(loc, cndExt, extractIdx, zero); - - // Patch the yield with the updated counter. - appendToForOpYield(forOp, {extractIdx}); - - LoadToStreamOpMap loadToStreamOp; - for (auto &[l, info] : loadToInfo) { - if (!info.sharedEncoding) - continue; - - auto loadOp = dyn_cast(l); - if (!loadOp) - continue; - - // Create an allocation that can hold distance number of loadOp shapes. - auto ty = cast(loadOp->getResultTypes()[0]); - Value alloc = triton::createAlloc(forOp, ty, loadOp->getLoc(), - info.sharedEncoding, numBuffers); - assert(alloc && "Failed to create alloc for the async load."); - auto arch = getAMDArch(loadOp->getParentOfType()); - triton::AMD::TargetInfo targetInfo(arch ? arch->str() : ""); - - // Replace the old load with multi-buffered loads - if (useAsyncCopy && - canBeConvertedToAsyncLoad(numBuffers, loadOp, alloc, axisInfoAnalysis, - targetInfo)) { - loadToStreamOp[loadOp] = createAsyncCopy(loadOp, alloc, extractIdx); - } else { - loadToStreamOp[loadOp] = createStreamCopy(loadOp, alloc, extractIdx); - } + // Reject loops with indirect loads + // TODO support indirect loads + if (llvm::any_of(loadToInfo, + [](auto it) { return it.second.distToUse != 0; })) { + return failure(); } - return loadToStreamOp; + return success(); } +namespace { /// Returns true if for a given global load with loadType, loading instead with /// targetLLAttr maintains at least the same level of coalescing/vectorization /// with same amount of load ops. @@ -549,205 +300,13 @@ static Operation *bypassLDS(Operation *load, Operation *use) { // Finally, rewrite the load to use the inferred (better) encoding. auto newOp = convertDistributedOpEncoding(srcEnc, load); + newOp->setAttr(AttrBypassLDS, BoolAttr::get(newOp->getContext(), true)); return newOp; }; -LoadToInfoMap -preprocessLoop(triton::AMD::ModuleAxisInfoAnalysis &axisInfoAnalysis, - scf::ForOp &forOp, int numStages) { - auto arch = getAMDArch(forOp->getParentOfType()); - triton::AMD::ISAFamily isaFamily = triton::AMD::ISAFamily::Unknown; - if (arch) - isaFamily = triton::AMD::deduceISAFamily(*arch); - - bool pipelineWithoutDot = forOp->hasAttr(mlir::triton::kNumStagesAttrName); - bool filterSmallVectors = - isaFamily != triton::AMD::ISAFamily::CDNA4 && !isRDNA(isaFamily); - llvm::MapVector> loadOpToIndLevel = - triton::gpu::loadOpsToIndirectionLevel(forOp, pipelineWithoutDot, - axisInfoAnalysis, numStages, - filterSmallVectors); - - LLVM_DEBUG({ - LDBG("Found " << loadOpToIndLevel.size() << " loads to pipeline:"); - for (const auto &[l, i] : loadOpToIndLevel) { - LDBG(" - load: " << *l); - LDBG(" at distance: " << i.first); - LDBG(" used by op: " << *i.second); - } - }); - - LoadToInfoMap loadToInfo; - for (const auto &[load, info] : loadOpToIndLevel) { - auto [distance, use] = info; - auto newLoad = bypassLDS(load, use); - if (newLoad) { - loadToInfo[newLoad] = {nullptr, distance, use}; - } else { - LDBG("Deduce shared encoding for: " << *load); - auto sharedEncoding = - getSharedEncIfAllUsersAreDotEnc(load->getResult(0)).value_or(nullptr); - loadToInfo[load] = {sharedEncoding, distance, use}; - LDBG("Populate loadInfo with shared encoding: " << sharedEncoding); - } - } - - return loadToInfo; -} - namespace SingleDotSchedule { using namespace mlir::SingleDotSchedule; -// Init Schedule Config based on settings and loop characteristics. -// Create clusters in order of ops in loop. This can interleave ops -// from different stages in the same cluster to achieve better backend -// scheduling. -// WARNING: Changing the order of schedule.clusters.newAtBack() calls -// can cause invalid schedules to be produced. -LogicalResult initSchedule(int maxDist, Stages &stages, int numStages, - int &numBuffers, bool useAsyncCopy, bool waitAtTail, - Clusters &clusters, tt::CoarseSchedule &schedule) { - LDBG("Init SingleDotSchedule"); - int lastStage = numStages - 1; - stages[SCHED_GLOBAL_LOAD] = 0; - stages[SCHED_LOCAL_STORE] = 0; - stages[SCHED_LOCAL_LOAD] = lastStage; - stages[SCHED_COMPUTE] = lastStage; - stages[SCHED_ASYNC_WAIT] = stages[SCHED_LOCAL_LOAD]; - - bool pairedGlobalLoadLocalStore = stages[SCHED_LOCAL_STORE] == 0; - stages[SCHED_LOCAL_STORE] += maxDist; - if (waitAtTail) { - stages[SCHED_ASYNC_WAIT] = std::max(0, stages[SCHED_LOCAL_LOAD] - 1); - } - - LDBG( - "Stage schedule:" << " GLOBAL_LOAD stage = " << stages[SCHED_GLOBAL_LOAD] - << ", LOCAL_STORE stage = " << stages[SCHED_LOCAL_STORE] - << ", LOCAL_LOAD stage = " << stages[SCHED_LOCAL_LOAD] - << ", COMPUTE stage = " << stages[SCHED_COMPUTE] - << ", ASYNC_WAIT stage = " << stages[SCHED_ASYNC_WAIT] - << "; total = " << numStages); - - if (stages[SCHED_LOCAL_STORE] >= numStages || - stages[SCHED_LOCAL_STORE] > stages[SCHED_LOCAL_LOAD]) { - LDBG("Invalid stage schedule"); - return failure(); - } - - // Calculate the number of buffers needed for each load. - // TODO: Use the precise number of buffers needed by the particular load. - numBuffers = - std::max(1, stages[SCHED_LOCAL_LOAD] - stages[SCHED_LOCAL_STORE]); - // If we use AsyncCopy we need one more buffer since we are not using a - // register buffer - if (useAsyncCopy) { - numBuffers += 1; - } - - LDBG("deduced max shared memory buffer number = " << numBuffers); - - // We place async wait as the first cluster because we want to have it being - // the first in the main loop after pipelining. - // In case we use async_copy with pingpong, we need to place async_wait at - // the end of the previous iteration, so it can guarantee the correct - // dependency when warp0 and warp1 are pipelined. - int asyncWaitCluster = waitAtTail ? 4 : 0; - // If tt.load and ttg.local_store are in the same stage - // spread them apart to allow overlap with compute - // else - // Initiate ttg.local_store before tt.load - int globalLoadCluster = 1; - int localStoreCluster = 3; - if (!pairedGlobalLoadLocalStore) { - globalLoadCluster = 3; - localStoreCluster = 2; - } - - // If ttg.local_load and ttg.local_store are in the same stage - // spread them apart to allow overlap with compute - // else if they share the buffer - // ttg.local_load must come first - // else - // schedule ttg.local_load in the middle - int localLoadCluster = globalLoadCluster; - if (stages[SCHED_LOCAL_LOAD] == stages[SCHED_LOCAL_STORE]) { - localLoadCluster = std::max(3, localStoreCluster + 1); - } else if (numBuffers == 1 && localLoadCluster >= localStoreCluster) { - // For 1 buffer, ttg.local_load must occur before ttg.local_store - localLoadCluster = localStoreCluster - 1; - } - - // Schedule compute with ttg.local_load if paired - // otherwise, schedule in the middle - int computeCluster = 2; - if (stages[SCHED_LOCAL_LOAD] == stages[SCHED_COMPUTE]) { - computeCluster = localLoadCluster; - } - - // Make assignments - Clusters clusterVec; - std::generate(clusterVec.begin(), clusterVec.end(), - [&]() { return schedule.clusters.newAtBack(); }); - - clusters[SCHED_GLOBAL_LOAD] = clusterVec[globalLoadCluster]; - clusters[SCHED_LOCAL_STORE] = clusterVec[localStoreCluster]; - clusters[SCHED_LOCAL_LOAD] = clusterVec[localLoadCluster]; - clusters[SCHED_COMPUTE] = clusterVec[computeCluster]; - clusters[SCHED_ASYNC_WAIT] = clusterVec[asyncWaitCluster]; - - LDBG("Cluster schedule:" << " GLOBAL_LOAD cluster = " << globalLoadCluster - << ", LOCAL_STORE cluster = " << localStoreCluster - << ", LOCAL_LOAD cluster = " << localLoadCluster - << ", COMPUTE cluster = " << computeCluster - << ", ASYNC_WAIT cluster = " << asyncWaitCluster - << "; total = " << SCHED_SIZE); - - return success(); -} - -void scheduleAsyncCopy(const AsyncCopyChainOps &asyncOps, tt::LoadOp loadOp, - tt::CoarseSchedule &schedule, const Stages &stages, - const Clusters &clusters) { - auto [copyOp, commitOp, waitOp, maybeLocalLoadOp] = asyncOps; - auto [loadStage, loadCluster] = schedule[loadOp]; - schedule.insert(copyOp, loadStage, loadCluster); - // Place ttg.async_commit_group op following AsyncCopyGlobalToLocal so the - // later UpdateAsyncWaitCount pass can deduce better waitcnts - schedule.insert(commitOp, loadStage, loadCluster); - // If the LocalLoads are scheduled to a later stage than AsyncCopy we need to - // place the AsyncCopy prefetches after the AsyncWaits which create a barrier - // to ensure all warps are finished reading the shared buffer we will write - // into. This is done by scheduling AsyncWait as the first cluster. - // If AsyncCopy and LocalLoads are in the same stage we do not assign a - // schdule so they are placed before the LocalLoads - if (loadStage != stages[SCHED_LOCAL_LOAD]) - schedule.insert(waitOp, stages[SCHED_ASYNC_WAIT], - clusters[SCHED_ASYNC_WAIT]); - - if (maybeLocalLoadOp && stages[SCHED_LOCAL_LOAD] != stages[SCHED_COMPUTE]) { - scheduleLocalLoad(maybeLocalLoadOp, schedule, stages[SCHED_LOCAL_LOAD], - clusters[SCHED_LOCAL_LOAD]); - } -} - -void scheduleStreamCopy(const StreamCopyChainOps &streamOps, - tt::LoadOp oldLoadOp, tt::CoarseSchedule &schedule, - const Stages &stages, const Clusters &clusters) { - auto [newLoadOp, subviewOp, localStoreOp, maybeLocalLoadOp] = streamOps; - auto [loadStage, loadCluster] = schedule[oldLoadOp]; - - schedule.insert(newLoadOp, loadStage, loadCluster); - schedule.insert(subviewOp, stages[SCHED_LOCAL_STORE], - clusters[SCHED_LOCAL_STORE]); - schedule.insert(localStoreOp, stages[SCHED_LOCAL_STORE], - clusters[SCHED_LOCAL_STORE]); - if (maybeLocalLoadOp && stages[SCHED_LOCAL_LOAD] != stages[SCHED_COMPUTE]) { - scheduleLocalLoad(maybeLocalLoadOp, schedule, stages[SCHED_LOCAL_LOAD], - clusters[SCHED_LOCAL_LOAD]); - } -} - LogicalResult scheduleLoads(const LoadToInfoMap &loadToInfo, int maxDist, int numStages, const Stages &stages, const Clusters &clusters, @@ -775,25 +334,30 @@ LogicalResult scheduleLoads(const LoadToInfoMap &loadToInfo, int maxDist, return success(); } -void scheduleStreamOps(const LoadToStreamOpMap &loadToStreamOp, - tt::CoarseSchedule &schedule, const Stages &stages, - const Clusters &clusters) { - for (auto [l, streamOps] : loadToStreamOp) { - auto loadOp = dyn_cast(l); - if (!loadOp) - continue; +void initSymbolicSchedule(int maxDist, Stages &stages, int numStages, + Clusters &clusters, tt::CoarseSchedule &schedule) { + int lastStage = numStages - 1; + stages[SCHED_GLOBAL_LOAD] = 0; + stages[SCHED_LOCAL_STORE] = maxDist; + stages[SCHED_LOCAL_LOAD] = lastStage; + stages[SCHED_COMPUTE] = lastStage; + stages[SCHED_ASYNC_WAIT] = stages[SCHED_LOCAL_LOAD]; - if (auto asyncOps = std::get_if(&streamOps)) { - scheduleAsyncCopy(*asyncOps, loadOp, schedule, stages, clusters); - } else if (auto sOps = std::get_if(&streamOps)) { - scheduleStreamCopy(*sOps, loadOp, schedule, stages, clusters); - } - } + Clusters clusterVec; + std::generate(clusterVec.begin(), clusterVec.end(), + [&]() { return schedule.clusters.newAtBack(); }); + + // This is a symbolic cluster assignment. In this stage, we only focus on + // global load and compute ops. + int globalLoadCluster = 0; + int computeCluster = 1; + + clusters[SCHED_GLOBAL_LOAD] = clusterVec[globalLoadCluster]; + clusters[SCHED_COMPUTE] = clusterVec[computeCluster]; } tt::CoarseSchedule buildSchedule(scf::ForOp &forOp, int numStages, const LoadToInfoMap &loadToInfo, - bool useAsyncCopy, bool waitAtTail, triton::AMD::ModuleAxisInfoAnalysis &axisInfoAnalysis) { LDBG("Build SingleDotSchedule"); tt::CoarseSchedule schedule(numStages); @@ -814,31 +378,13 @@ buildSchedule(scf::ForOp &forOp, int numStages, const LoadToInfoMap &loadToInfo, } int numBuffers = 1; - if (failed(initSchedule(maxDist, stages, numStages, numBuffers, useAsyncCopy, - waitAtTail, clusters, schedule))) - return {}; + initSymbolicSchedule(maxDist, stages, numStages, clusters, schedule); if (failed(scheduleLoads(loadToInfo, maxDist, numStages, stages, clusters, schedule))) return {}; dumpSchedule("Coarse schedule loads only:"); - // Convert the loads into shared memory allocations and loads from them. - auto loadToStreamOp = createStreamOps(loadToInfo, forOp, numBuffers, - useAsyncCopy, axisInfoAnalysis); - scheduleStreamOps(loadToStreamOp, schedule, stages, clusters); - dumpSchedule("Coarse schedule stream ops:"); - - scheduleDependencies(forOp, schedule); - dumpSchedule("Coarse schedule with dependencies:"); - - triton::gpu::scheduleDistanceOneDependencies(forOp, schedule); - dumpSchedule("Coarse schedule with dist 1:"); - - tt::CoarseSchedule::Cluster computeCluster = clusters[SCHED_COMPUTE]; - triton::gpu::scheduleRemainingToLastStage(forOp, schedule, computeCluster); - dumpSchedule("Final coarse schedule:"); - return schedule; } } // namespace SingleDotSchedule @@ -855,34 +401,6 @@ buildSchedule(scf::ForOp &forOp, int numStages, const LoadToInfoMap &loadToInfo, // pipeliner is meant to be used in combination with pingpong namespace ChainedDotSchedule { using namespace mlir::ChainedDotSchedule; - -LogicalResult checkPreconditions(scf::ForOp forOp, int numStages, - LoadToInfoMap loadToInfo) { - if (numStages != 4) - return failure(); - - auto dotOps = llvm::to_vector(forOp.getBody()->getOps()); - - if (dotOps.size() != 2) - return failure(); - - // Check that the first dot feeds into the second - SetVector slice; - getForwardSlice(dotOps[0]->getResult(0), &slice); - if (!slice.contains(dotOps[1])) { - return failure(); - } - - // Reject loops with indirect loads - // TODO support indirect loads - if (llvm::any_of(loadToInfo, - [](auto it) { return it.second.distToUse != 0; })) { - return failure(); - } - - return success(); -} - // We schedule loads one stage in front of their dots LogicalResult scheduleLoads(std::array dotOps, @@ -968,84 +486,8 @@ LogicalResult scheduleOpsBetweenDots(scf::ForOp forOp, return success(); } -void scheduleAsyncCopy(const AsyncCopyChainOps &asyncOps, tt::LoadOp loadOp, - tt::CoarseSchedule &schedule, - const ChainedDotClusters &clusters) { - auto [loadStage, loadCluster] = schedule[loadOp]; - auto [copyOp, commitOp, waitOp, maybeLocalLoadOp] = asyncOps; - - schedule.insert(copyOp, loadStage, loadCluster); - // Place ttg.async_commit_group op following AsyncCopyGlobalToLocal so the - // later UpdateAsyncWaitCount pass can deduce better waitcnts - schedule.insert(commitOp, loadStage, loadCluster); - - if (loadStage == STAGE_GLOBAL_LOAD_1) { - schedule.insert(waitOp, STAGE_LOCAL_LOAD_1, clusters[CLUSTER_ASYNC_WAIT_1]); - if (maybeLocalLoadOp) - scheduleLocalLoad(maybeLocalLoadOp, schedule, STAGE_LOCAL_LOAD_1, - clusters[CLUSTER_LOCAL_LOAD_1]); - } else { - schedule.insert(waitOp, STAGE_LOCAL_LOAD_2, clusters[CLUSTER_ASYNC_WAIT_2]); - if (maybeLocalLoadOp) - scheduleLocalLoad(maybeLocalLoadOp, schedule, STAGE_LOCAL_LOAD_2, - clusters[CLUSTER_LOCAL_LOAD_2]); - } -} - -void scheduleStreamCopy(const StreamCopyChainOps &streamOps, tt::LoadOp loadOp, - tt::CoarseSchedule &schedule, - const ChainedDotClusters &clusters) { - auto [loadStage, loadCluster] = schedule[loadOp]; - auto [copyOp, subviewOp, localStoreOp, maybeLocalLoadOp] = streamOps; - schedule.insert(copyOp, loadStage, loadCluster); - - if (loadStage == STAGE_GLOBAL_LOAD_1) { - schedule.insert(subviewOp, STAGE_LOCAL_WRITE_1, - clusters[CLUSTER_LOCAL_WRITE_1]); - schedule.insert(localStoreOp, STAGE_LOCAL_WRITE_1, - clusters[CLUSTER_LOCAL_WRITE_1]); - - if (maybeLocalLoadOp) - schedule.insert(maybeLocalLoadOp, STAGE_LOCAL_LOAD_1, - clusters[CLUSTER_LOCAL_LOAD_1]); - } else { - schedule.insert(subviewOp, STAGE_LOCAL_WRITE_2, - clusters[CLUSTER_LOCAL_WRITE_2]); - schedule.insert(localStoreOp, STAGE_LOCAL_WRITE_2, - clusters[CLUSTER_LOCAL_WRITE_2]); - if (maybeLocalLoadOp) - schedule.insert(maybeLocalLoadOp, STAGE_LOCAL_LOAD_2, - clusters[CLUSTER_LOCAL_LOAD_2]); - } - - if (maybeLocalLoadOp) { - if (auto cvt = dyn_cast( - *maybeLocalLoadOp->getUsers().begin())) { - auto [localLoadStage, localLoadCluster] = schedule[maybeLocalLoadOp]; - schedule.insert(cvt, localLoadStage, localLoadCluster); - } - } -} - -void scheduleStreamOps(const LoadToStreamOpMap &loadToStreamOp, - tt::CoarseSchedule &schedule, - const ChainedDotClusters &clusters) { - for (auto [l, streamOps] : loadToStreamOp) { - auto loadOp = dyn_cast(l); - if (!loadOp) - continue; - - if (auto asyncOps = std::get_if(&streamOps)) { - scheduleAsyncCopy(*asyncOps, loadOp, schedule, clusters); - } else if (auto sOps = std::get_if(&streamOps)) { - scheduleStreamCopy(*sOps, loadOp, schedule, clusters); - } - } -} - tt::CoarseSchedule buildSchedule(scf::ForOp &forOp, int numStages, const LoadToInfoMap &loadToInfo, - bool useAsyncCopy, triton::AMD::ModuleAxisInfoAnalysis &axisInfoAnalysis) { LDBG("Build ChainedDotSchedule"); tt::CoarseSchedule schedule(numStages); @@ -1077,40 +519,36 @@ buildSchedule(scf::ForOp &forOp, int numStages, const LoadToInfoMap &loadToInfo, } dumpSchedule("Coarse schedule after schedule ops between dots:"); - // Convert the loads into shared memory allocations and loads from them. - // TODO support different numBuffers - int numBuffers = useAsyncCopy ? 2 : 1; - auto loadToStreamOps = - createStreamOps(loadToInfo, forOp, /*numBuffers=*/numBuffers, - useAsyncCopy, axisInfoAnalysis); - scheduleStreamOps(loadToStreamOps, schedule, clusters); - dumpSchedule("Coarse schedule stream ops:"); - - for (auto [l, _] : loadToInfo) { - schedule.erase(l); - l->erase(); - } - - scheduleDependencies(forOp, schedule); - dumpSchedule("Coarse schedule with dependencies:"); - - triton::gpu::scheduleDistanceOneDependencies(forOp, schedule); - dumpSchedule("Coarse schedule with dist 1:"); - - tt::CoarseSchedule::Cluster lastCluster = clusters.back(); - triton::gpu::scheduleRemainingToLastStage(forOp, schedule, lastCluster); - dumpSchedule("Final coarse schedule:"); - return schedule; } } // namespace ChainedDotSchedule -void pipelineLoop(scf::ForOp forOp, int numStages, bool useAsyncCopy, - bool waitAtTail) { +void pipelineLoop(scf::ForOp forOp, int numStages) { triton::AMD::ModuleAxisInfoAnalysis axisInfoAnalysis( forOp->getParentOfType()); - LoadToInfoMap loadToInfo = preprocessLoop(axisInfoAnalysis, forOp, numStages); + llvm::MapVector> loadOpToIndLevel = + getIndirectLevel(axisInfoAnalysis, forOp, numStages); + + LLVM_DEBUG({ + LDBG("Found " << loadOpToIndLevel.size() << " loads to pipeline:"); + for (const auto &[l, i] : loadOpToIndLevel) { + LDBG(" - load: " << *l); + LDBG(" at distance: " << i.first); + LDBG(" used by op: " << *i.second); + } + }); + + LoadToInfoMap loadToInfo; + for (const auto &[load, info] : loadOpToIndLevel) { + auto [distance, use] = info; + auto newLoad = bypassLDS(load, use); + if (newLoad) { + loadToInfo[newLoad] = {nullptr, distance, use}; + } else { + loadToInfo[load] = {nullptr, distance, use}; + } + } if (loadToInfo.empty()) { LDBG("couldn't find any pipeline-able loads:\n" << *forOp); @@ -1119,13 +557,12 @@ void pipelineLoop(scf::ForOp forOp, int numStages, bool useAsyncCopy, tt::CoarseSchedule schedule; - if (succeeded(ChainedDotSchedule::checkPreconditions(forOp, numStages, - loadToInfo))) { - schedule = ChainedDotSchedule::buildSchedule( - forOp, numStages, loadToInfo, useAsyncCopy, axisInfoAnalysis); + if (succeeded(mlir::ChainedDotSchedule::checkPreconditions(forOp, numStages, + loadToInfo))) { + schedule = ChainedDotSchedule::buildSchedule(forOp, numStages, loadToInfo, + axisInfoAnalysis); } else { schedule = SingleDotSchedule::buildSchedule(forOp, numStages, loadToInfo, - useAsyncCopy, waitAtTail, axisInfoAnalysis); } @@ -1155,11 +592,8 @@ struct ScheduleLoops : impl::TritonAMDGPUScheduleLoopsBase { LDBG("Loop not safe to pipeline:\n" << *forOp); continue; } - // i.e., we can still disable `waitAtTail` by explicitly disabling - // pingpong, which is the only use case of this scheduling variant. int numStagesThis = tt::getNumStagesOrDefault(forOp, numStages); - bool waitAtTail = usePingpong && (numStagesThis == 3) && useAsyncCopy; - pipelineLoop(forOp, numStagesThis, useAsyncCopy, waitAtTail); + pipelineLoop(forOp, numStagesThis); } } }; diff --git a/third_party/amd/python/test/test_gluon_gfx1250.py b/third_party/amd/python/test/test_gluon_gfx1250.py index ac3ff6f714..674ff8fc8f 100644 --- a/third_party/amd/python/test/test_gluon_gfx1250.py +++ b/third_party/amd/python/test/test_gluon_gfx1250.py @@ -13,6 +13,7 @@ from triton.backends.compiler import GPUTarget from triton._internal_testing import str_to_triton_dtype from triton._internal_testing import is_hip_gfx1250 +from triton.tools.mxfp import MXFP4Tensor, MXScaleTensor from triton.experimental import gluon import triton.experimental.gluon.language as ttgl @@ -309,45 +310,304 @@ def dot_mxfp_triton_kernel(a_base, stride_am, stride_ak, a_scale, b_base, stride out_ptr = out + tl.arange(0, BLOCK_M)[:, None] * BLOCK_N + tl.arange(0, BLOCK_N)[None, :] tl.store(out_ptr, c) + def torch_gemm_mxfp(a, b, a_scale, b_scale, scale_block, M, N, K): + a_scale_f32 = a_scale.to(torch.float32).repeat_interleave(scale_block, dim=1)[:M, :K] + b_scale_f32 = b_scale.to(torch.float32).repeat_interleave(scale_block, dim=1).T.contiguous()[:K, :N] + + a_f32 = a.to(torch.float32) + b_f32 = b.to(torch.float32) + + return torch.matmul(a_f32 * a_scale_f32, b_f32 * b_scale_f32).to(torch.float32) + torch.manual_seed(0) type_a = mxfp_type type_b = mxfp_type - DIV_FACTOR_A = 2 if type_a == "e2m1" else 1 - DIV_FACTOR_B = 2 if type_b == "e2m1" else 1 + a_mxfp4 = MXFP4Tensor(size=(BLOCK_M, BLOCK_K)).random() + b_mxfp4 = MXFP4Tensor(size=(BLOCK_K, BLOCK_N)).random() + + scale_a_size = (BLOCK_M, (BLOCK_K + 32 - 1) // 32) + scale_b_size = (BLOCK_N, (BLOCK_K + 32 - 1) // 32) + + if hasScale: + scale_a_mxfp4 = MXScaleTensor(size=scale_a_size).random(high=32.0) + scale_b_mxfp4 = MXScaleTensor(size=scale_b_size).random(high=32.0) + else: + scale_a_mxfp4 = torch.ones(scale_a_size, dtype=torch.float32) + scale_b_mxfp4 = torch.ones(scale_b_size, dtype=torch.float32) + + c_torch = torch_gemm_mxfp(a_mxfp4, b_mxfp4, scale_a_mxfp4, scale_b_mxfp4, 32, BLOCK_M, BLOCK_N, BLOCK_K) - x = torch.randint(20, 40, (BLOCK_M, BLOCK_K // DIV_FACTOR_A), dtype=torch.uint8).cuda() - y = torch.randint(20, 40, (BLOCK_K // DIV_FACTOR_B, BLOCK_N), dtype=torch.uint8).cuda() + a = a_mxfp4.to_packed_tensor(dim=1).data.contiguous().cuda() + b = b_mxfp4.to_packed_tensor(dim=0).data.contiguous().cuda() if hasScale: - min_scale, max_scale = (0, 142) - scale_x = torch.randint(min_scale, max_scale + 1, (BLOCK_M, BLOCK_K // 32), dtype=torch.uint8).cuda() - scale_y = torch.randint(min_scale, max_scale + 1, (BLOCK_N, BLOCK_K // 32), dtype=torch.uint8).cuda() + scale_a = scale_a_mxfp4.data.cuda() + scale_b = scale_b_mxfp4.data.cuda() else: - scale_x = None - scale_y = None - - def make_finite(x, dtype): - if dtype not in ("e5m2", "e4m3"): - return x - mask = 0x7C if dtype == "e5m2" else 0x7F - finite = torch.arange(x.numel(), dtype=torch.uint8).cuda().reshape_as(x) % mask - x_finite = torch.where(x & mask == mask, finite | (0x80 & x), x) - x.copy_(x_finite) - return x - - x = make_finite(x, type_a) - y = make_finite(y, type_b) - - z = torch.zeros((BLOCK_M, BLOCK_N), dtype=torch.float32).cuda() - pgm = dot_mxfp_gluon_kernel[(1, )](x, *x.stride(), scale_x, y, *y.stride(), scale_y, z, BLOCK_M, BLOCK_N, BLOCK_K, + scale_a = None + scale_b = None + + c = torch.zeros((BLOCK_M, BLOCK_N), dtype=torch.float32).cuda() + pgm = dot_mxfp_gluon_kernel[(1, )](a, *a.stride(), scale_a, b, *b.stride(), scale_b, c, BLOCK_M, BLOCK_N, BLOCK_K, type_a, type_b) assert "v_wmma_scale_f32_16x16x128_f8f6f4" in pgm.asm[ "amdgcn"], "The AMDGCN assembly does not contain the expected scaled WMMA instruction." - z_ref = torch.zeros((BLOCK_M, BLOCK_N), dtype=torch.float32).cuda() - dot_mxfp_triton_kernel[(1, )](x, *x.stride(), scale_x, y, *y.stride(), scale_y, z_ref, BLOCK_M, BLOCK_N, BLOCK_K, + c_ref = torch.zeros((BLOCK_M, BLOCK_N), dtype=torch.float32).cuda() + dot_mxfp_triton_kernel[(1, )](a, *a.stride(), scale_a, b, *b.stride(), scale_b, c_ref, BLOCK_M, BLOCK_N, BLOCK_K, type_a, type_b) - torch.testing.assert_close(z.cpu(), z_ref.cpu(), rtol=1e-5, atol=1e-5) + torch.testing.assert_close(c.cpu(), c_ref.cpu(), rtol=1e-5, atol=1e-5) + torch.testing.assert_close(c.cpu(), c_torch, rtol=1e-5, atol=1e-5) + + +@gluon.jit +def tensor_copy_kernel(a_ptr, b_ptr, # + M, N, # + BLOCK_M: ttgl.constexpr, BLOCK_N: ttgl.constexpr): + SHARED_LAYOUT: ttgl.constexpr = ttgl.PaddedSharedLayout.with_identity_for([[32, 4]], [BLOCK_M, BLOCK_N], [1, 0]) + BLOCKED_LAYOUT: ttgl.constexpr = ttgl.BlockedLayout([1, 8], [4, 8], [4, 1], [1, 0]) + + pid = ttgl.program_id(axis=0) + num_pid_m = ttgl.cdiv(M, BLOCK_M) + pid_m = pid % num_pid_m + pid_n = pid // num_pid_m + + a_desc = ttgl.amd.gfx1250.tdm.make_tensor_descriptor(base=a_ptr, shape=(M, N), strides=(N, 1), + block_shape=(BLOCK_M, BLOCK_N), layout=SHARED_LAYOUT) + + a_buffer = ttgl.allocate_shared_memory(a_desc.dtype, shape=a_desc.block_shape, layout=a_desc.layout) + ttgl.amd.gfx1250.tdm.async_load(a_desc, [pid_m * BLOCK_M, pid_n * BLOCK_N], a_buffer) + + ttgl.amd.gfx1250.tdm.async_wait(0) + a = a_buffer.load(layout=BLOCKED_LAYOUT) + + b_offsets = (pid_m * BLOCK_M + ttgl.arange(0, BLOCK_M, layout=ttgl.SliceLayout(1, BLOCKED_LAYOUT)))[:, None] * N + \ + (pid_n * BLOCK_N + ttgl.arange(0, BLOCK_N, layout=ttgl.SliceLayout(0, BLOCKED_LAYOUT)))[None, :] + ttgl.store(b_ptr + b_offsets, a) + + +@pytest.mark.parametrize("BLOCK_M,BLOCK_N", [(32, 32), (32, 64), (64, 64)]) +def test_compile_tensor_copy(BLOCK_M, BLOCK_N): + k = triton.compile( + gluon._runtime.GluonASTSource( + fn=tensor_copy_kernel, signature={ + "a_ptr": "*bf16", "b_ptr": "*bf16", "M": "i32", "N": "i32", "BLOCK_M": "constexpr", "BLOCK_N": + "constexpr" + }, constexprs={"BLOCK_M": BLOCK_M, "BLOCK_N": BLOCK_N}), target=GPUTarget("hip", 'gfx1250', 32)) + + amdgcn = k.asm["amdgcn"] + + tensor_pattern = r"tensor_load_to_lds" + assert re.search(tensor_pattern, amdgcn) + + wait_pattern = r"s_wait_tensorcnt 0x0" + assert re.search(wait_pattern, amdgcn) + + +@pytest.mark.parametrize("BLOCK_M,BLOCK_N", [(32, 32), (32, 64), (64, 64)]) +def test_runtime_tensor_copy(BLOCK_M, BLOCK_N): + M, N = 1024, 1024 + + torch.manual_seed(42) + a = torch.randint(0x0, 0xFFFF, (M, N), dtype=torch.uint16) + b = torch.zeros_like(a) + + a_device = a.cuda() + b_device = b.cuda() + grid = (triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N), 1) + tensor_copy_kernel[grid](a_device, b_device, M, N, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N) + + b_triton = b_device.cpu() + assert torch.equal(b_triton, a) + + +@gluon.jit +def mxgemm_kernel(a_ptr, b_ptr, c_ptr, a_scale, b_scale, M, N, K, stride_am, stride_ak, stride_bk, stride_bn, stride_cm, + stride_cn, stride_scale, DTYPE_A: ttgl.constexpr, DTYPE_B: ttgl.constexpr, + SCALE_BLOCK: ttgl.constexpr, BLOCK_M: ttgl.constexpr, BLOCK_N: ttgl.constexpr, + BLOCK_K: ttgl.constexpr, GROUP_SIZE_M: ttgl.constexpr): + DIV_FACTOR_A: ttgl.constexpr = 2 if DTYPE_A == "e2m1" else 1 + DIV_FACTOR_B: ttgl.constexpr = 2 if DTYPE_B == "e2m1" else 1 + BLOCK_K_SCALE: ttgl.constexpr = BLOCK_K // SCALE_BLOCK + BLOCK_K_PACKED_A: ttgl.constexpr = BLOCK_K // DIV_FACTOR_A + BLOCK_K_PACKED_B: ttgl.constexpr = BLOCK_K // DIV_FACTOR_B + + BLOCKED_LAYOUT: ttgl.constexpr = ttgl.BlockedLayout([1, 1], [8, 4], [4, 1], [1, 0]) + A_BLOCKED_LAYOUT: ttgl.constexpr = ttgl.BlockedLayout([1, 16], [8, 4], [4, 1], [1, 0]) + B_BLOCKED_LAYOUT: ttgl.constexpr = ttgl.BlockedLayout([1, 16], [16, 2], [4, 1], [1, 0]) + + WMMA_LAYOUT: ttgl.constexpr = ttgl.amd.AMDWMMALayout(3, transposed=True, warps_per_cta=[2, 2], + instr_shape=[16, 16, 128]) + WMMA_LAYOUT_PACKED: ttgl.constexpr = ttgl.amd.AMDWMMALayout(3, transposed=True, warps_per_cta=[2, 2], + instr_shape=[16, 16, 64]) + A_SCALE_LINEAR_LAYOUT: ttgl.constexpr = ttgl.DistributedLinearLayout( + reg_bases=[[0, 1], [0, 2]], lane_bases=[[1, 0], [2, 0], [4, 0], [8, 0], [0, 0]], warp_bases=[[0, 0], [16, 0]], + block_bases=[], shape=[32, 4]) + B_SCALE_LINEAR_LAYOUT: ttgl.constexpr = ttgl.DistributedLinearLayout( + reg_bases=[[0, 1], [0, 2]], lane_bases=[[1, 0], [2, 0], [4, 0], [8, 0], [0, 0]], warp_bases=[[16, 0], [0, 0]], + block_bases=[], shape=[32, 4]) + + DOT_LAYOUT_A: ttgl.constexpr = ttgl.DotOperandLayout( + operand_index=0, parent=WMMA_LAYOUT_PACKED if DTYPE_A == "e2m1" else WMMA_LAYOUT, k_width=16) + DOT_LAYOUT_B: ttgl.constexpr = ttgl.DotOperandLayout( + operand_index=1, parent=WMMA_LAYOUT_PACKED if DTYPE_B == "e2m1" else WMMA_LAYOUT, k_width=16) + + pid = ttgl.program_id(axis=0) + num_pid_m = ttgl.cdiv(M, BLOCK_M) + num_pid_n = ttgl.cdiv(N, BLOCK_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + offs_am = (pid_m * BLOCK_M + ttgl.arange(0, BLOCK_M, layout=ttgl.SliceLayout(1, A_BLOCKED_LAYOUT))) % M + offs_ak = ttgl.arange(0, BLOCK_K_PACKED_A, layout=ttgl.SliceLayout(0, A_BLOCKED_LAYOUT)) + offs_bk = ttgl.arange(0, BLOCK_K_PACKED_B, layout=ttgl.SliceLayout(1, B_BLOCKED_LAYOUT)) + offs_bn = (pid_n * BLOCK_N + ttgl.arange(0, BLOCK_N, layout=ttgl.SliceLayout(0, B_BLOCKED_LAYOUT))) % N + + offs_scale_am = (pid_m * BLOCK_M + ttgl.arange(0, BLOCK_M, layout=ttgl.SliceLayout(1, BLOCKED_LAYOUT))) % M + offs_scale_ak = ttgl.arange(0, BLOCK_K_SCALE, layout=ttgl.SliceLayout(0, BLOCKED_LAYOUT)) + offs_scale_bn = (pid_n * BLOCK_N + ttgl.arange(0, BLOCK_N, layout=ttgl.SliceLayout(1, BLOCKED_LAYOUT))) % N + offs_scale_bk = ttgl.arange(0, BLOCK_K_SCALE, layout=ttgl.SliceLayout(0, BLOCKED_LAYOUT)) + + a_scale_ptr = a_scale + offs_scale_am[:, None] * stride_scale + offs_scale_ak[None, :] + b_scale_ptr = b_scale + offs_scale_bn[:, None] * stride_scale + offs_scale_bk[None, :] + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_ak[None, :] * stride_ak) + b_ptrs = b_ptr + (offs_bk[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + + accumulator = ttgl.zeros((BLOCK_M, BLOCK_N), dtype=ttgl.float32, layout=WMMA_LAYOUT) + for k in range(0, ttgl.cdiv(K, BLOCK_K)): + k_remaining_a = K - k * BLOCK_K_PACKED_A + k_remaining_b = K - k * BLOCK_K_PACKED_B + valid_k_a = offs_ak < k_remaining_a + valid_k_b = offs_bk < k_remaining_b + + scale_a = ttgl.load(a_scale_ptr) + scale_b = ttgl.load(b_scale_ptr) + scale_a = ttgl.convert_layout(scale_a, A_SCALE_LINEAR_LAYOUT) + scale_b = ttgl.convert_layout(scale_b, B_SCALE_LINEAR_LAYOUT) + + a = ttgl.load(a_ptrs, mask=valid_k_a[None, :], other=0.0) + b = ttgl.load(b_ptrs, mask=valid_k_b[:, None], other=0.0) + a = ttgl.convert_layout(a, DOT_LAYOUT_A) + b = ttgl.convert_layout(b, DOT_LAYOUT_B) + + accumulator = ttgl.amd.gfx1250.wmma_scaled(a, scale_a, DTYPE_A, b, scale_b, DTYPE_B, accumulator) + + a_ptrs += BLOCK_K_PACKED_A * stride_ak + b_ptrs += BLOCK_K_PACKED_B * stride_bk + + a_scale_ptr += BLOCK_K_SCALE + b_scale_ptr += BLOCK_K_SCALE + + offs_cm = pid_m * BLOCK_M + ttgl.arange(0, BLOCK_M, layout=ttgl.SliceLayout(1, WMMA_LAYOUT)) + offs_cn = pid_n * BLOCK_N + ttgl.arange(0, BLOCK_N, layout=ttgl.SliceLayout(0, WMMA_LAYOUT)) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + ttgl.store(c_ptrs, accumulator, mask=c_mask) + + +@pytest.mark.parametrize("BLOCK_M, BLOCK_N, BLOCK_K", [(32, 32, 64), (32, 32, 128)]) +@pytest.mark.parametrize("DTYPE_A", ["float8_e5m2", "float8_e4m3", "float4"]) +@pytest.mark.parametrize("DTYPE_B", ["float8_e5m2", "float8_e4m3", "float4"]) +def test_compile_mxgemm(BLOCK_M, BLOCK_N, BLOCK_K, DTYPE_A, DTYPE_B): + scale_block = 32 + + if BLOCK_K < 128: + pytest.skip("NYI: don't support block shape smaller than instr shape") + + triton_dtype_converter = {'float8_e5m2': "fp8e5", "float8_e4m3": "fp8e4nv", "float4": "u8"} + dot_scaled_dtype_converter = {'float8_e5m2': "e5m2", "float8_e4m3": "e4m3", "float4": "e2m1"} + + k = triton.compile( + gluon._runtime.GluonASTSource( + fn=mxgemm_kernel, signature={ + "a_ptr": f"*{triton_dtype_converter[DTYPE_A]}", "b_ptr": f"*{triton_dtype_converter[DTYPE_B]}", "c_ptr": + "*fp32", "a_scale": "*u8", "b_scale": "*u8", "M": "i32", "N": "i32", "K": "i32", "stride_am": "i32", + "stride_ak": "i32", "stride_bk": "i32", "stride_bn": "i32", "stride_cm": "i32", "stride_cn": "i32", + "stride_scale": "i32", "DTYPE_A": "constexpr", "DTYPE_B": "constexpr", "SCALE_BLOCK": "constexpr", + "BLOCK_M": "constexpr", "BLOCK_N": "constexpr", "BLOCK_K": "constexpr", "GROUP_SIZE_M": "constexpr" + }, constexprs={ + "DTYPE_A": dot_scaled_dtype_converter[DTYPE_A], "DTYPE_B": dot_scaled_dtype_converter[DTYPE_B], + "SCALE_BLOCK": scale_block, "BLOCK_M": BLOCK_M, "BLOCK_N": BLOCK_N, "BLOCK_K": BLOCK_K, "GROUP_SIZE_M": + 1 + }), target=GPUTarget("hip", 'gfx1250', 32)) + + amdgcn = k.asm["amdgcn"] + pattern = "v_wmma_scale_f32_16x16x128_f8f6f4" + assert re.search(pattern, amdgcn), f"Can't find instruction {pattern} in AMDGCN assembly" + + +@pytest.mark.parametrize("M, N, K", [(32, 32, 128), (128, 128, 512), (1, 8192, 512)]) +@pytest.mark.parametrize("BLOCK_M, BLOCK_N, BLOCK_K", [(32, 32, 128), (64, 64, 128)]) +@pytest.mark.parametrize("DTYPE_A", ["float8_e5m2", "float8_e4m3", "float4"]) +@pytest.mark.parametrize("DTYPE_B", ["float8_e5m2", "float8_e4m3", "float4"]) +def test_runtime_mxgemm(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, DTYPE_A, DTYPE_B): + scale_block = 32 + + torch.manual_seed(0) + + def torch_gemm_mxfp(a, b, a_scale, b_scale, scale_block, M, N, K): + a_scale_f32 = a_scale.to(torch.float32).repeat_interleave(scale_block, dim=1)[:M, :K] + b_scale_f32 = b_scale.to(torch.float32).repeat_interleave(scale_block, dim=1).T.contiguous()[:K, :N] + + a_f32 = a.to(torch.float32) + b_f32 = b.to(torch.float32) + + return torch.matmul(a_f32 * a_scale_f32, b_f32 * b_scale_f32).to(torch.float32) + + def init_data(dtype, d0: int, d1: int): + if dtype == 'float4': + return MXFP4Tensor(size=(d0, d1)).random() + elif dtype == "float8_e5m2": + return torch.randint(20, 40, (d0, d1), dtype=torch.uint8).view(torch.float8_e5m2) + elif dtype == "float8_e4m3": + return torch.randint(20, 40, (d0, d1), dtype=torch.uint8).view(torch.float8_e4m3fn) + else: + raise NotImplementedError(f"NYI: unsupported dtype: {dtype}") + + a = init_data(DTYPE_A, M, K) + b = init_data(DTYPE_B, K, N) + a_size = (M, (K + scale_block - 1) // scale_block) + b_size = (N, (K + scale_block - 1) // scale_block) + a_scale = MXScaleTensor(size=a_size).random(low=1.0, high=32.0) + b_scale = MXScaleTensor(size=b_size).random(low=1.0, high=32.0) + + c_ref = torch_gemm_mxfp(a, b, a_scale, b_scale, scale_block, M, N, K) + + a_scale = a_scale.data + b_scale = b_scale.data + + # mxfp4 input needs packed along the k dim, i.e., two mxfp4 are packed in one uint8 + if DTYPE_A in ['float4', 'float6_e2m3', 'float6_e3m2']: + a = a.to_packed_tensor(dim=1) + if DTYPE_B in ['float4', 'float6_e2m3', 'float6_e3m2']: + b = b.to_packed_tensor(dim=0) + + c_d = torch.zeros(M, N, dtype=torch.float32).cuda() + a_d = a.data.contiguous().cuda() + b_d = b.data.contiguous().cuda() + a_scale_d = a_scale.cuda() + b_scale_d = b_scale.cuda() + + stride_am, stride_ak = a_d.stride(0), a_d.stride(1) + stride_bk, stride_bn = b_d.stride(0), b_d.stride(1) + stride_cm, stride_cn = c_d.stride(0), c_d.stride(1) + stride_scale = a_scale_d.stride(0) + + numBlocks = triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N) + grid = [numBlocks, 1, 1] + group_size_m = 1 + + dtype_converter = {'float8_e5m2': "e5m2", "float8_e4m3": "e4m3", "float4": "e2m1"} + + mxgemm_kernel[grid](a_d, b_d, c_d, a_scale_d, b_scale_d, M, N, K, stride_am, stride_ak, stride_bk, stride_bn, + stride_cm, stride_cn, stride_scale, dtype_converter[DTYPE_A], dtype_converter[DTYPE_B], + scale_block, BLOCK_M, BLOCK_N, BLOCK_K, group_size_m, num_warps=4, num_ctas=1) + + torch.testing.assert_close(c_d.cpu(), c_ref.cpu(), rtol=1e-5, atol=1e-8) diff --git a/third_party/amd/python/triton_amd.cc b/third_party/amd/python/triton_amd.cc index d963fba4d1..fb3880c571 100644 --- a/third_party/amd/python/triton_amd.cc +++ b/third_party/amd/python/triton_amd.cc @@ -91,11 +91,10 @@ void init_triton_amd_passes_ttgpuir(py::module &&m) { ADD_PASS_WRAPPER_0("add_fold_true_cmpi", mlir::createTritonAMDFoldTrueCmpI); ADD_PASS_OPTION_WRAPPER_1("add_block_pingpong", mlir::createTritonAMDGPUBlockPingpong, int32_t); - ADD_PASS_OPTION_WRAPPER_3("add_schedule_loops", - mlir::createTritonAMDGPUScheduleLoops, int, bool, - bool); - ADD_PASS_OPTION_WRAPPER_1("add_pipeline", mlir::createTritonAMDGPUPipeline, - bool); + ADD_PASS_OPTION_WRAPPER_1("add_schedule_loops", + mlir::createTritonAMDGPUScheduleLoops, int); + ADD_PASS_OPTION_WRAPPER_2("add_pipeline", mlir::createTritonAMDGPUPipeline, + bool, bool); ADD_PASS_OPTION_WRAPPER_1("add_coalesce_async_copy", mlir::createTritonAMDGPUCoalesceAsyncCopy, std::string); diff --git a/third_party/nvidia/backend/compiler.py b/third_party/nvidia/backend/compiler.py index 38dad8d3c3..b2c1d77a1e 100644 --- a/third_party/nvidia/backend/compiler.py +++ b/third_party/nvidia/backend/compiler.py @@ -170,6 +170,10 @@ def __init__(self, target: GPUTarget) -> None: self.binary_ext = "cubin" def parse_options(self, opts) -> Any: + # Enable debug mode for ConSan, so device-side assertions are not optimized out + if "instrumentation_mode" in opts and opts["instrumentation_mode"] == "consan": + opts["debug"] = True + args = {'arch': knobs.runtime.override_arch or f"sm{self.target.arch}"} args.update({k: opts[k] for k in CUDAOptions.__dataclass_fields__.keys() if k in opts if opts[k] is not None}) capability = int(self._parse_arch(args["arch"])) @@ -353,7 +357,7 @@ def make_llir(self, src, metadata, options, capability): passes.gluon.add_inliner(pm) nvidia.passes.ttgpuir.add_allocate_shared_memory_nv(pm, capability, ptx_version) nvidia.passes.ttnvgpuir.add_allocate_tensor_memory(pm) - if knobs.compilation.enable_experimental_consan: + if knobs.compilation.instrumentation_mode == "consan": # Call ConcurrencySanitizerPass here, before allocating global scratch memory but after allocating tensor and shared passes.ttgpuir.add_concurrency_sanitizer(pm) passes.ttgpuir.add_allocate_global_scratch_memory(pm) diff --git a/third_party/nvidia/lib/Dialect/NVWS/Transforms/AssignStagePhase.cpp b/third_party/nvidia/lib/Dialect/NVWS/Transforms/AssignStagePhase.cpp index d0e48e97e9..2d01d55858 100644 --- a/third_party/nvidia/lib/Dialect/NVWS/Transforms/AssignStagePhase.cpp +++ b/third_party/nvidia/lib/Dialect/NVWS/Transforms/AssignStagePhase.cpp @@ -240,12 +240,8 @@ template struct AssignStagePhase { auto createInto = [&](auto opTy, auto... args) { using ty = decltype(opTy); - auto ids = partitionIds; - if (ids) { - ids->insert(0); - } return triton::gpu::createInto( - builder, builder.getLoc(), ids, stageCluster, + builder, builder.getLoc(), partitionIds, stageCluster, std::forward(args)...); }; @@ -356,7 +352,46 @@ template struct AssignStagePhase { } }; -static LogicalResult assignStagePhase(triton::FuncOp funcOp) { +void visitBackwardSlice(scf::ForOp wsLoop, Value value, + std::function callback, + DenseSet &visited) { + if (!visited.insert(value).second) + return; + + if (auto blockArg = dyn_cast(value)) { + if (auto forOp = dyn_cast(blockArg.getOwner()->getParentOp())) { + if (forOp->hasAttr(kWarpSpecializeAttrName)) + return; + auto pos = findValuePosInRange(forOp.getRegionIterArgs(), value); + assert(pos); + visitBackwardSlice(wsLoop, forOp.getInitArgs()[*pos], callback, visited); + } + } else if (auto defOp = value.getDefiningOp(); + isa(defOp)) { + auto pos = findValuePosInRange(defOp->getResults(), value); + assert(pos); + if (auto ifOp = dyn_cast(defOp)) { + visitBackwardSlice(wsLoop, ifOp.thenYield()->getOperand(*pos), callback, + visited); + if (ifOp.elseBlock()) + visitBackwardSlice(wsLoop, ifOp.elseYield()->getOperand(*pos), callback, + visited); + visitBackwardSlice(wsLoop, ifOp.getCondition(), callback, visited); + } else { + auto forOp = cast(defOp); + visitBackwardSlice(wsLoop, + forOp.getBody()->getTerminator()->getOperand(*pos), + callback, visited); + } + } else if (wsLoop.getBody()->findAncestorOpInBlock(*defOp)) { + callback(defOp); + for (auto operand : defOp->getOperands()) { + visitBackwardSlice(wsLoop, operand, callback, visited); + } + } +} + +LogicalResult assignStagePhase(triton::FuncOp funcOp) { SmallVector arefOps; funcOp.walk([&](ArefCreateOp arefOp) { arefOps.push_back(arefOp); }); for (auto arefOp : arefOps) { @@ -365,6 +400,31 @@ static LogicalResult assignStagePhase(triton::FuncOp funcOp) { if (failed(AssignStagePhase::run(arefOp))) return failure(); } + + auto callback = [&](Operation *op) { + if (!isa(op)) { + auto partitionIds = getPartitionIds(op); + assert(partitionIds); + partitionIds->insert(0); + setPartition(op, *partitionIds); + } + }; + + funcOp.walk([&](scf::ForOp forOp) { + DenseSet visited; + if (forOp->hasAttr(kWarpSpecializeAttrName)) { + for (auto result : forOp.getResults()) { + // if result is of scalar type and is used outside of for-op, visit + // all dependencies and assign default partition to them + if (isa(result.getType()) && + !result.use_empty()) { + auto arg = forOp.getBody()->getTerminator()->getOperand( + result.getResultNumber()); + visitBackwardSlice(forOp, arg, callback, visited); + } + } + } + }); return success(); } diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAHelpers.h b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAHelpers.h index 7e0dd3af0d..e8d08e6dae 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAHelpers.h +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAHelpers.h @@ -1,5 +1,6 @@ #include "Utility.h" #include "mlir/Support/LLVM.h" +#include "triton/Tools/LayoutUtils.h" namespace mlir { namespace triton { @@ -7,7 +8,7 @@ namespace NVIDIA { // The descriptor format is described in the spec: // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-matrix-shared-memory-layout-matrix-descriptor -// Unnamed fieids are not used +// Unnamed fields are not used union SMEMDescriptor { uint64_t descriptor; struct { @@ -23,6 +24,14 @@ union SMEMDescriptor { }; }; +struct MMASMEMDescriptor { + SMEMDescriptor descriptor; + int32_t swizzlingByteWidth; + int32_t bitwidth; + bool transposed; + bool fp4Padded; +}; + struct MemDescOperand { Value base; std::optional offset; @@ -37,56 +46,299 @@ class DotOpMmaMemLoader { Location loc) const = 0; }; -// Helper class to load shared memory slices following MMAv3 layout. -class DotOpMmaV3SmemLoader : public DotOpMmaMemLoader { +class DotOpMmaSmemLoader : public DotOpMmaMemLoader { public: - DotOpMmaV3SmemLoader() {} - DotOpMmaV3SmemLoader(Value tensor, Value base, SmallVector shape, - ArrayRef allocSwizzleShape, Value warpId, - unsigned int dimWpt, bool trans, - SmallVector instrShape, - int64_t elementBitwidth, - ConversionPatternRewriter &rewriter, Location loc); - // Return a descriptor pointing to the shared memory slice at coordinates (a, - // b) - Value smemLoad(int a, int b, ConversionPatternRewriter &rewriter, - Location loc) const; + DotOpMmaSmemLoader() = default; + + DotOpMmaSmemLoader(MMASMEMDescriptor desc, Value baseb128, LinearLayout llInv, + ArrayRef instrShape) + : desc(desc), baseb128(baseb128), ll(std::move(llInv)), + instrShape(instrShape) {} + + static DotOpMmaSmemLoader + build(Location loc, RewriterBase &rewriter, gpu::MemDescType memTy, + Value smemBase, ArrayRef instrShape, int mmaVersion, + bool isFp4 = false, + std::optional mmaTy = std::nullopt, + std::optional MNdim = std::nullopt) { + auto ctx = rewriter.getContext(); + auto kOffset = str_attr("offset"); + // The handling of subviews is not as fine as it could be + // We could compose with the identity of the memTy.getShape() + // (at the moment llInv will be of allocShape), but then + // we would need to handle the getReps part more carefuly + // This way we could support more subviews that we don't + // We can implement this generalisation in the future if needed + auto llInv = toLinearLayout(memTy).pseudoinvert(); + auto bitwidth = memTy.getElementType().getIntOrFloatBitWidth(); + if (isFp4) { + // hacky but well + auto dims = to_vector(llInv.getInDimNames()); + auto trans = llInv.getBasis(dims[0], 0, kOffset) == 1; + llInv = LinearLayout::identity1D(2, dims[trans ? 0 : 1], kOffset) * llInv; + bitwidth /= 2; + // The instr_shape comes in number of elements already + } + return build(loc, rewriter, llInv, bitwidth, smemBase, instrShape, + mmaVersion, mmaTy, MNdim); + } + + static DotOpMmaSmemLoader + build(Location loc, RewriterBase &rewriter, const LinearLayout &ll, + int bitwidth, Value smemBase, ArrayRef instrShapeArray, + int mmaVersion, std::optional mmaTy = std::nullopt, + std::optional MNdim = std::nullopt) { + // ll is a map from two dimensions (dim0, dim1) or (row, col) into offsets + // and blocks + auto ctx = rewriter.getContext(); + auto kOffset = str_attr("offset"); + auto kBlock = str_attr("block"); + assert(ll.getNumOutDims() == 2); + assert(ll.hasOutDim(kOffset) && ll.hasOutDim(kBlock)); + + assert(mmaVersion == 3 || mmaVersion == 5); + // Just needed for MMAv3 + assert(mmaTy.has_value() == (mmaVersion == 3)); + assert(MNdim.has_value() == (mmaVersion == 3)); + if (mmaVersion == 3) { + assert(MNdim.value() < 2); + } + auto instrShape = to_vector(instrShapeArray); + assert(instrShape.size() == 2); + auto b = TritonLLVMOpBuilder(loc, rewriter); + + // Due to having a 16B alignment, we can compute the offsets in 128b + // elements + // TODO We should assert in the verifier that the alignment is at least 16B + smemBase = b.ptrtoint(i32_ty, smemBase); + Value baseSrcb128 = b.lshr(smemBase, b.i32_val(4)); + if (mmaVersion == 3) { + auto mndim = MNdim.value(); + auto mmaLl = gpu::toLinearLayout(mmaTy.value()); + auto outDims = to_vector(mmaLl.getOutDimNames()); + auto kWarp = str_attr("warp"); + // Map from warps into the MN dimension + auto mmaWarps = mmaLl.sublayout({kWarp}, {outDims[mndim]}) * + LinearLayout::identity1D(1, kWarp, outDims[1 - mndim]); + // Map from warps to offsets in bitwidth elements + auto warpToOffset = mmaWarps.compose(ll); + // Map from warps to offsets in 128b elements + auto maybeWarpToOffsetb128 = + divideLeft(warpToOffset, + LinearLayout::zeros1D(1, kWarp, kOffset, 128 / bitwidth)); + assert(maybeWarpToOffsetb128.has_value()); + // zero out the first two warp bases to have a warpgroup to offset map + auto bases = maybeWarpToOffsetb128->getBases(); + assert(maybeWarpToOffsetb128->getNumOutDims() == 2); + bases[kWarp][0] = {0, 0}; + bases[kWarp][1] = {0, 0}; + auto warpGroupToOffsetb128 = LinearLayout( + bases, warpToOffset.getOutDims(), /*requireSurjective=*/false); + Value warpId = rewriter.create(loc); + Value warpStrideb128 = + applyLinearLayout(loc, rewriter, warpGroupToOffsetb128, + {{kWarp, warpId}})[0] + .second; + baseSrcb128 = b.add(baseSrcb128, warpStrideb128); + // Increase the instruction shape to describe the size at a block level + // as the input just describes it at a warp level + int logwgAlongMN = 0; + for (int i = 0; i < warpGroupToOffsetb128.getInDimSizeLog2(kWarp); i++) { + if (warpGroupToOffsetb128.getBasis(kWarp, i, kOffset) != 0) { + logwgAlongMN++; + } + } + instrShape[mndim] *= (1 << logwgAlongMN); + } + + for (auto [dim, instrSize] : llvm::zip(ll.getInDimNames(), instrShape)) { + assert(instrSize <= ll.getInDimSize(dim) && + "Instruction shape is too large for the layout"); + } + + auto desc = getDescriptor(ll, instrShape, bitwidth, mmaVersion); + + Value baseb128 = b.zext(i64_ty, b.and_(baseSrcb128, b.i32_val(0x3FFF))); + return {desc, baseb128, ll, instrShape}; + } + + Value smemLoad(int a, int b, ConversionPatternRewriter &rewriter, + Location loc) const { + auto *ctx = loc.getContext(); + auto tb = TritonLLVMOpBuilder(loc, rewriter); + auto dims = to_vector(ll.getInDimNames()); + assert((a + 1) * instrShape[0] <= ll.getInDimSize(dims[0])); + assert((b + 1) * instrShape[1] <= ll.getInDimSize(dims[1])); + assert(to_vector(ll.getOutDimNames()) == + llvm::to_vector( + ArrayRef{str_attr("offset"), str_attr("block")})); + int32_t totalOffElems = ll.apply({{dims[0], a * instrShape[0]}, + {dims[1], b * instrShape[1]}})[0] + .second; + int32_t smemByteOffsetb8 = totalOffElems * desc.bitwidth / 8; + auto currDesc = desc.descriptor; + // Take the next 0/1/2/3 bits after the 128b tile + uint32_t mask = (desc.swizzlingByteWidth >> 4) - 1; + currDesc.matrixBaseOffset = (smemByteOffsetb8 / 128) & mask; + int32_t smemByteOffsetb128 = smemByteOffsetb8 >> 4; + Value descValBase = + tb.int_val(64, currDesc.descriptor + smemByteOffsetb128); + // Add the base address to the descriptor + Value descVal = tb.add(descValBase, baseb128); + return descVal; + } MemDescOperand memLoad(int a, int b, ConversionPatternRewriter &rewriter, Location loc) const override { return {smemLoad(a, b, rewriter, loc), std::nullopt}; } + MMASMEMDescriptor &getDescriptor() { return desc; } + private: - Value base; - SmallVector shape; - SmallVector allocSwizzleShape; - Value warpId; - int dimWpt; - bool trans; - int fastMovingDim; - Value elemsPerSwizzlingRowVal; - SmallVector instrShape; - int elemsPerSwizzlingRow; - int64_t elemBits; - Value descriptor; -}; + MMASMEMDescriptor desc; + Value baseb128; + LinearLayout ll; + SmallVector instrShape; -// Helper class to load shared memory slices following MMAv5 layout. -class DotOpMmaV5SmemLoader : public DotOpMmaV3SmemLoader { -public: - using DotOpMmaV3SmemLoader::DotOpMmaV3SmemLoader; + static MMASMEMDescriptor getDescriptor(const LinearLayout &ll, + ArrayRef instrShape, + int bitwidth, int mmaVersion) { + // ll is a map from allocShape into offsets and blocks + auto dims = to_vector(ll.getInDimNames()); + auto ctx = dims[0].getContext(); + auto kOffset = str_attr("offset"); - // Return a descriptor pointing to the shared memory slice at coordinates (a, - // b), with bit 46 set. - Value smemLoad(int a, int b, ConversionPatternRewriter &rewriter, - Location loc) const { - auto tb = TritonLLVMOpBuilder(loc, rewriter); - Value desc = DotOpMmaV3SmemLoader::smemLoad(a, b, rewriter, loc); - // Set bit 46 as per - // https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-shared-memory-descriptor - Value mask = tb.int_val(64, 1ULL << 46); - return tb.or_(desc, mask, /*disjoint*/ true); + // Any CTALayout, it's not really used within getCoreMatrixLinearLayout + auto CTALayout = triton::gpu::CTALayoutAttr::getDefault(ctx, 2); + + for (bool fp4Padded : (bitwidth == 4 ? SmallVector({false, true}) + : SmallVector({false}))) { + for (auto transposed : {false, true}) { + for (int swizzling : {0, 32, 64, 128}) { + // FIXME: getCoreMatrixLinearLayout does not accept bitwidth < 8 + auto shmemEnc = triton::gpu::NVMMASharedEncodingAttr::get( + ctx, swizzling, transposed, std::max(8, bitwidth), fp4Padded, + CTALayout); + auto shmemTile = + getCoreMatrixLinearLayout(shmemEnc, /*disableSwizzle=*/false); + // Rename out dims to match the original layout (in case the dims were + // (row, col)) + auto outDims = to_vector(shmemTile.getOutDims()); + outDims[0].first = dims[0]; + outDims[1].first = dims[1]; + shmemTile = LinearLayout(shmemTile.getBases(), outDims, + /*requireSurjective=*/false); + // unpack the fp4 layout + if (bitwidth == 4) { + shmemTile = + LinearLayout::identity1D(2, kOffset, dims[1]) * shmemTile; + } + + // getCoreMatrixLinearLayout gives the k-contiguous tile + // shmemTile is a layout onto a matrix with shape + // If swizzling != 0: 8 x (8 * swizzling / bitwidth) + // If swizzling == 0: 8 x (8 * 16 / bitwidth) + assert(shmemTile.getOutDimSize(dims[0]) == 8); + // Multiply by 2 if fp4Padded as the matrix has half the core + // matrix has half the number of elements + assert(shmemTile.getOutDimSize(dims[1]) * (fp4Padded ? 2 : 1) == + 8 * std::max(16, swizzling) / bitwidth); + + if (transposed) { + shmemTile = transposeLinearLayout(shmemTile, {1, 0}); + } + // Pseudoinvert as fp4 may have padding + auto shmemTileInv = shmemTile.pseudoinvert(); + + // The PTX docs are wrong in a number of ways: + // 1) LBO can be specified for !transposed && swizzled != 0 + // PTX says it's assumed to be 1, but we can in fact use it + // 2) LBO / SBO are swapped also for !transposed && swizzled == 0 + // PTX just reports this for the transposed case + // EVEN MORE the computation we do here is conceptually correct + // and it agrees with the tensor descriptors for wgmma or + // tcgen05.mma but not for tcgen05.cp! tcgen05.cp follows the PTX + // docs! + int lbo = 0, sbo = 0; + int leadingDim = transposed ? 0 : 1; + int stridedDim = transposed ? 1 : 0; + // The lbo / sbo is defined wrt. the 128 tile, so this makes their + // definition change for swizzling == 0 lol + if (swizzling == 0) { + std::swap(leadingDim, stridedDim); + } + auto log2RowsTile = shmemTileInv.getInDimSizeLog2(dims[leadingDim]); + if (llvm::Log2_32(instrShape[leadingDim]) > log2RowsTile) { + lbo = ll.getBasis(dims[leadingDim], log2RowsTile, kOffset); + } + + auto log2ColsTile = shmemTileInv.getInDimSizeLog2(dims[stridedDim]); + if (llvm::Log2_32(instrShape[stridedDim]) > log2ColsTile) { + sbo = ll.getBasis(dims[stridedDim], log2ColsTile, kOffset); + } + + // Pad the tile up to the full instruction shape with the relevant + // stride if the instruction shape is larger than the tile + auto bases = shmemTileInv.getBases(); + for (int d : {0, 1}) { + // 'tile' with the atom tile according to the lbo/sbo rules + for (int i = 1; + i < instrShape[d] / shmemTileInv.getInDimSize(dims[d]); + i *= 2) { + auto stride = ll.getBasis( + dims[d], shmemTileInv.getInDimSizeLog2(dims[d]), kOffset); + bases[dims[d]].push_back({stride * i}); + } + } + auto maxBasis = 0; + for (auto dimBases : llvm::make_second_range(bases)) { + for (auto basis : dimBases) { + maxBasis = std::max(maxBasis, basis[0]); + } + } + // Multiply by 2 or round up to the next power of 2 + shmemTileInv = + LinearLayout(bases, {{kOffset, llvm::NextPowerOf2(maxBasis)}}, + /*requireSurjective=*/false); + // Add a trivial block dimension as getReps expects both layouts to + // have the same outdims + shmemTileInv *= + LinearLayout::identity1D(1, dims[0], str_attr("block")); + + auto reps = getReps(ll, shmemTileInv); + if (reps.has_value()) { + SMEMDescriptor desc; + desc.descriptor = mmaVersion == 5 ? 1ULL << 46 : 0ULL; + // The lbo / sbo is defined wrt. the 128b elements + desc.leadDimensionBaseOffset = (lbo * bitwidth / 8) >> 4; + desc.strideDimensionBaseOffset = (sbo * bitwidth / 8) >> 4; + switch (swizzling) { + case 0: + desc.swizzlingMode = 0; + break; + case 32: + desc.swizzlingMode = 3; + break; + case 64: + desc.swizzlingMode = 2; + break; + case 128: + desc.swizzlingMode = 1; + break; + default: + llvm_unreachable("Unsupported swizzling size."); + } + return {/* .descriptor = */ desc, + /* .swizzlingByteWidth = */ swizzling, + /* .bitwidth = */ bitwidth, + /* .transposed = */ transposed, + /* .fp4Padded = */ fp4Padded}; + } + } + } + } + llvm::report_fatal_error("Failed to find a valid layout"); } }; diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv5.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv5.cpp index 3a5d932fd2..fa95db4ed4 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv5.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv5.cpp @@ -453,16 +453,24 @@ void convertDotImpl(const LLVMTypeConverter &typeConverter, aLoader = std::make_unique(a, baseA, aOperandShape, interleaved, transA); } else { - auto allocShapeA = getAllocShape(aTensorTy, 1); - aLoader = std::make_unique( - a, baseA, shapeA, allocShapeA, zero, 1, transA, aOperandShape, - op.numBitsPerElementA, rewriter, loc); + auto isFp4a = op.numBitsPerElementA == 4; + aLoader = std::make_unique(DotOpMmaSmemLoader::build( + loc, rewriter, aTensorTy, baseA, aOperandShape, 5, isFp4a)); } + auto isFp4b = op.numBitsPerElementB == 4; auto allocShapeB = getAllocShape(bTensorTy, 0); - DotOpMmaV5SmemLoader bLoader = DotOpMmaV5SmemLoader( - b, baseB, shapeB, allocShapeB, zero, 1, transB, {mmaSizeN, mmaSizeK}, - op.numBitsPerElementB, rewriter, loc); + // [Instr shape twoCTAs] + // This division by 2 in 2CTA mode a bit subtle: + // The issue here is that in 2CTA you multiply in one instruction a tensor + // of shape MNK = 256, K, N, and you put it into TMEM of shape 128, K, N*2. + // So to compute the shapePerCTA, on the lhs we can look at the TMEM shape, + // but to compute the shapePerCTA on the rhs, we need to divide by 2. + // Something similar happens when we multiply by 2 the mmaSizeM when creating + // It's a massive code smell tho + DotOpMmaSmemLoader bLoader = DotOpMmaSmemLoader::build( + loc, rewriter, bTensorTy, baseB, {mmaSizeK, mmaSizeN / (twoCTAs ? 2 : 1)}, + 5, isFp4b); DotConversion::InstDesc desc{mmaSizeM, mmaSizeN, {numRepM, numRepN, numRepK}, transA, transB, interleaved, @@ -473,7 +481,7 @@ void convertDotImpl(const LLVMTypeConverter &typeConverter, MemDescOperand accAddress = op.getAccAddress(rewriter, loc, m, n, desc); for (int k = 0; k < numRepK; k++) { MemDescOperand a = aLoader->memLoad(m, k, rewriter, loc); - Value b = bLoader.smemLoad(n, k, rewriter, loc); + Value b = bLoader.smemLoad(k, n, rewriter, loc); op.createMMAInst(rewriter, loc, accAddress, a, b, elect, useInitAcc, desc, m, n, k); useInitAcc = tb.i1_val(1); @@ -524,6 +532,8 @@ void convertDot(const LLVMTypeConverter &typeConverter, Value pred, Value useInitAcc, const DotConversion::InstDesc &desc, int m, int n, int k) { + // To understand this multiplication by 2, see the comment + // [Instr shape twoCTAs] Value instDescriptor = createInstDescriptor( rewriter, op, twoCTAs ? desc.mmaSizeM * 2 : desc.mmaSizeM, desc.mmaSizeN, desc.transA, desc.transB); @@ -596,10 +606,8 @@ void convertScaledDot(const LLVMTypeConverter &typeConverter, dot.shapeB[0] *= 2; } - dot.numBitsPerElementA = opKindIsMXFP4 ? getFormatBitSize(op.getAType()) - : aTensorTy.getElementTypeBitWidth(); - dot.numBitsPerElementB = opKindIsMXFP4 ? getFormatBitSize(op.getBType()) - : bTensorTy.getElementTypeBitWidth(); + dot.numBitsPerElementA = getFormatBitSize(op.getAType()); + dot.numBitsPerElementB = getFormatBitSize(op.getBType()); TritonLLVMOpBuilder tb(loc, rewriter); Value baseD = tb.ptrtoint(i32_ty, adaptor.getD()); diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp index eafc404375..e6f2e0da7b 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp @@ -68,180 +68,6 @@ triton::nvgpu::WGMMAEltType getMmaOperandType(Value a, bool allowTF32) { } } -int64_t getSwizzlingFromLayout(const NVMMASharedEncodingAttr &layout, - uint32_t widthInByte) { - uint32_t swizzlingByteWidth = layout.getSwizzlingByteWidth(); - // TODO[biaow]: remove it once we support swizzling size larger than matrix - // width, which requires padding the matrix width to the swizzling size when - // allocating shared memory. - assert(swizzlingByteWidth <= widthInByte && - "swizzling size larger than matrix width is not supported."); - return swizzlingByteWidth; -} - -static Value createDescriptor(ConversionPatternRewriter &rewriter, Location loc, - int64_t swizzling, uint32_t stride) { - auto b = TritonLLVMOpBuilder(loc, rewriter); - static_assert(sizeof(SMEMDescriptor) == 8, - "Descriptor size should be 64 bits."); - SMEMDescriptor desc; - desc.descriptor = 0; - switch (swizzling) { - case 0: - desc.swizzlingMode = 0; - break; - case 32: - desc.swizzlingMode = 3; - break; - case 64: - desc.swizzlingMode = 2; - break; - case 128: - desc.swizzlingMode = 1; - break; - default: - llvm::report_fatal_error("Unsupported swizzling size."); - } - if (swizzling == 0) { - // Because the descriptor normalizes spacing to 128-bit units, the - // normalized per-element stride is 16 bytes and LBO is defined as 8×that, - // i.e. 128 bytes. - desc.leadDimensionBaseOffset = 128 >> 4; - // Offset from first row to second row 16x16 bytes. - desc.strideDimensionBaseOffset = 256 >> 4; - } else { - desc.leadDimensionBaseOffset = (swizzling * stride) >> 4; - desc.strideDimensionBaseOffset = swizzling >> 1; - } - return b.int_val(64, desc.descriptor); -} - -mlir::triton::NVIDIA::DotOpMmaV3SmemLoader::DotOpMmaV3SmemLoader( - Value tensor, Value base, SmallVector shape, - ArrayRef allocSwizzleShape, Value warpId, unsigned int dimWpt, - bool trans, SmallVector instrShape, int64_t elementBitwidth, - ConversionPatternRewriter &rewriter, Location loc) - : base(base), shape(shape), allocSwizzleShape(allocSwizzleShape), - warpId(warpId), dimWpt(dimWpt), trans(trans), instrShape(instrShape), - elemBits(elementBitwidth) { - auto b = TritonLLVMOpBuilder(loc, rewriter); - auto ty = cast(tensor.getType()); - auto sharedLayout = cast(ty.getEncoding()); - fastMovingDim = sharedLayout.getTransposed() ? 0 : 1; - const int swizzlingByteWidth = sharedLayout.getSwizzlingByteWidth(); - elemsPerSwizzlingRow = (swizzlingByteWidth * 8) / elemBits; - elemsPerSwizzlingRowVal = b.i32_val(elemsPerSwizzlingRow); - - uint32_t widthInByte = allocSwizzleShape[fastMovingDim] * elemBits / 8; - int64_t swizzling = getSwizzlingFromLayout(sharedLayout, widthInByte); - - descriptor = createDescriptor(rewriter, loc, swizzling, - allocSwizzleShape[1 - fastMovingDim]); -} - -Value mlir::triton::NVIDIA::DotOpMmaV3SmemLoader::smemLoad( - int a, int b, ConversionPatternRewriter &rewriter, Location loc) const { - auto tb = TritonLLVMOpBuilder(loc, rewriter); - Value k = tb.i32_val(b * instrShape[1]); - Value m = tb.add(tb.i32_val(a * dimWpt * instrShape[0]), - tb.mul(warpId, tb.i32_val(instrShape[0]))); - if (trans) { - std::swap(k, m); - } - Value off1; - if (elemsPerSwizzlingRow > 0) { - Value leading_offset = - tb.mul(tb.udiv(k, elemsPerSwizzlingRowVal), - tb.i32_val(shape[1 - fastMovingDim] * elemsPerSwizzlingRow)); - Value stride_offset = tb.mul(m, elemsPerSwizzlingRowVal); - Value offset = tb.add(tb.add(leading_offset, stride_offset), - tb.urem(k, elemsPerSwizzlingRowVal)); - // Avoid the runtime udiv if we know the elements are byte multiples - if (elemBits % 8) { - off1 = tb.udiv(tb.mul(tb.i32_val(elemBits), offset), tb.i32_val(8)); - } else { - off1 = tb.mul(tb.i32_val(elemBits / 8), offset); - } - } else { - assert(a == 0 && instrShape[0] * elemBits == 16 * 8 && - "Currently expect that unswizzled case only happens for " - "rhs cases and the inner dimension is 16bytes."); - off1 = tb.i32_val(512 * b); - } - Value smemBase = tb.ptrtoint(i32_ty, base); - smemBase = tb.add(smemBase, off1); - smemBase = tb.lshr(tb.and_(smemBase, tb.i32_val(0x3FFFF)), tb.i32_val(4)); - Value loadDesc = tb.add(descriptor, tb.zext(i64_ty, smemBase)); - return loadDesc; -} - -DotOpMmaV3SmemLoader loadA(const LLVMTypeConverter *typeConverter, - ConversionPatternRewriter &rewriter, Location loc, - const NvidiaMmaEncodingAttr &mmaEncoding, - Value tensor, Value smemObjBase, Value thread) { - auto b = TritonLLVMOpBuilder(loc, rewriter); - auto aTy = cast(tensor.getType()); - auto aSharedLayout = dyn_cast(aTy.getEncoding()); - assert(aSharedLayout && "only support load dot operand from shared."); - auto instrShape = mmaEncoding.getInstrShape(); - auto wpt = mmaEncoding.getWarpsPerCTA(); - bool transA = aSharedLayout.getTransposed(); - auto shapePerCTA = getShapePerCTA(aTy); - auto allocSwizzleShape = aTy.getAllocShape().take_back(shapePerCTA.size()); - - // The descriptor should be calculated based on the first warp of the - // warpgroup. - Value warp = - b.and_(rewriter.create(loc), b.i32_val(0xFFFFFFFC)); - Value warpM = b.urem(warp, b.i32_val(wpt[0])); - Value warpId = b.urem(warpM, b.i32_val(shapePerCTA[0] / instrShape[0])); - - return {tensor, - smemObjBase, - shapePerCTA, - allocSwizzleShape, - warpId, - wpt[0], - transA, - {instrShape[0], instrShape[2]}, - aTy.getElementTypeBitWidth(), - rewriter, - loc}; -} - -DotOpMmaV3SmemLoader loadB(const LLVMTypeConverter *typeConverter, - ConversionPatternRewriter &rewriter, Location loc, - NvidiaMmaEncodingAttr &mmaEncoding, Value tensor, - Value base, Value thread) { - auto b = TritonLLVMOpBuilder(loc, rewriter); - auto bTy = cast(tensor.getType()); - auto bSharedLayout = cast(bTy.getEncoding()); - assert(bSharedLayout && "only support load B from shared."); - auto instrShape = mmaEncoding.getInstrShape(); - auto wpt = mmaEncoding.getWarpsPerCTA(); - bool transB = !bSharedLayout.getTransposed(); - auto shapePerCTA = triton::gpu::getShapePerCTA(bTy); - auto allocSwizzleShape = bTy.getAllocShape().take_back(shapePerCTA.size()); - - Value warp = - b.and_(rewriter.create(loc), b.i32_val(0xFFFFFFFC)); - Value warpMN = b.udiv(warp, b.i32_val(wpt[0])); - Value warpN = b.urem(warpMN, b.i32_val(wpt[1])); - Value warpId = b.urem(warpN, b.i32_val(shapePerCTA[1] / instrShape[1])); - - return {tensor, - base, - shapePerCTA, - allocSwizzleShape, - warpId, - wpt[1], - transB, - {instrShape[1], instrShape[2]}, - bTy.getElementTypeBitWidth(), - rewriter, - loc}; -} - // Return a vector of Value of the accumulator start at startIndex and pack the // values into 32bits in case the accumulator is fp16. // @@ -385,9 +211,9 @@ LogicalResult convertDot(const LLVMTypeConverter *typeConverter, auto dShapePerCTA = getShapePerCTA(dTensorTy); auto instrShape = mmaEncoding.getInstrShape(); auto accSize = 2 * (instrShape[1] / 4); - int M = 4 * instrShape[0]; - int N = instrShape[1]; - int K = instrShape[2]; + unsigned M = 4 * instrShape[0]; + unsigned N = instrShape[1]; + unsigned K = instrShape[2]; bool zeroAcc = isZeroConst(c); auto instrMNK = mmaEncoding.getInstrShape(); auto warpSize = mmaEncoding.getWarpsPerCTA(); @@ -396,16 +222,18 @@ LogicalResult convertDot(const LLVMTypeConverter *typeConverter, int numRepM = ceil(dShapePerCTA[0], shapePerCTATile[0]); int numRepN = ceil(dShapePerCTA[1], shapePerCTATile[1]); int numRepK = ceil(aTensorTy.getShape()[1], instrShape[2]); - DotOpMmaV3SmemLoader aLoader; + DotOpMmaSmemLoader aLoader; SmallVector structA; + auto warpGroups = {warpSize[0] / 4, warpSize[1]}; if (aSharedLayout) { aLoader = - loadA(typeConverter, rewriter, loc, mmaEncoding, a, baseA, thread); + DotOpMmaSmemLoader::build(loc, rewriter, cast(aTensorTy), + baseA, {M, K}, 3, false, dTensorTy, 0); } else { structA = unpackLLElements(loc, loadedA, rewriter); } - DotOpMmaV3SmemLoader bLoader = - loadB(typeConverter, rewriter, loc, mmaEncoding, b, baseB, thread); + DotOpMmaSmemLoader bLoader = DotOpMmaSmemLoader::build( + loc, rewriter, bTensorTy, baseB, {K, N}, 3, false, dTensorTy, 1); auto fc = unpackLLElements(loc, loadedC, rewriter); @@ -460,7 +288,7 @@ LogicalResult convertDot(const LLVMTypeConverter *typeConverter, SmallVector(regA.size(), regA[0].getType())); a = packLLElements(loc, typeConverter, regA, rewriter, regATy); } - auto b = bLoader.smemLoad(n, k, rewriter, loc); + auto b = bLoader.smemLoad(k, n, rewriter, loc); numLowPrecisionAcc += K; // If using native accumulation would cause use to do more low precion // accumulation than allowed do a separate allocation. diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TensorMemoryToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TensorMemoryToLLVM.cpp index 6cbd036d22..3485f153d3 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TensorMemoryToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TensorMemoryToLLVM.cpp @@ -75,10 +75,11 @@ TMemCopyAtom getTMemCopyAtom(const LinearLayout &cvt, int bitwidth) { auto S = [&](StringRef str) { return StringAttr::get(ctx, str); }; auto kRow = S("row"); auto kCol = S("col"); + auto kOffset = S("offset"); assert(cvt.getInDimSize(kRow) == 128); auto multicastBit = [&](int i) { assert(i == 0 || i == 1); - return cvt.getBasis(kRow, llvm::Log2_32(32) + i) == ArrayRef{0}; + return cvt.getBasis(kRow, llvm::Log2_32(32) + i, kOffset) == 0; }; auto multicast = multicastBit(0) | multicastBit(1) << 1; if (multicast == 0) { @@ -100,121 +101,6 @@ TMemCopyAtom getTMemCopyAtom(const LinearLayout &cvt, int bitwidth) { } } -std::optional getReps(const LinearLayout &cvt, - const LinearLayout &tile) { - // Close cousin of doing zerosLike(tile) * divideLeft(cvt, tile) - // This one is a tad more general in the sense that it allows to divide - // cvt: - // - register=1 -> (0, 1) - // register=2 -> (8, 0) - // register=4 -> (0, 8) - // register=8 -> (0, 16) - // register=16 -> (0, 32) - // register=32 -> (0, 64) - // register=64 -> (16, 0) - // - lane=1 -> (0, 2) - // lane=2 -> (0, 4) - // lane=4 -> (1, 0) - // lane=8 -> (2, 0) - // lane=16 -> (4, 0) - // - warp=1 -> (32, 0) - // warp=2 -> (64, 0) - // - block is a size 1 dimension - // where out dims are: [row (size 128), col (size 128)] - // tile: - // - register=1 -> (0, 1) - // register=2 -> (8, 0) - // - lane=1 -> (0, 2) - // lane=2 -> (0, 4) - // lane=4 -> (1, 0) - // lane=8 -> (2, 0) - // lane=16 -> (4, 0) - // - warp=1 -> (32, 0) - // warp=2 -> (64, 0) - // where out dims are: [row (size 128), col (size 8)] - // which would not be possible to lower via the divideLeft approach as we - // cannot divide by the tile given the `register=64 -> (16, 0)` basis. - - // Ensure tile out-dims are subset of cvt out-dims. - for (auto od : tile.getOutDimNames()) - assert(cvt.hasOutDim(od) && "tile out-dims must be contained in cvt"); - - // Precompute tile out-dim bit-widths. - llvm::SmallDenseMap outBLog2; - for (StringAttr od : cvt.getOutDimNames()) - outBLog2[od] = tile.hasOutDim(od) ? tile.getOutDimSizeLog2(od) : 0; - - // Build a per-out-dimension mask by OR-ing all tile bases that touch it. - llvm::SmallDenseMap tileMaskPerOutDim; - for (StringAttr od : cvt.getOutDimNames()) - tileMaskPerOutDim[od] = 0; - for (auto &[inDim, inBases] : tile.getBases()) { - (void)inDim; - for (auto &basis : inBases) { - int idx = 0; - for (StringAttr od : tile.getOutDimNames()) { - tileMaskPerOutDim[od] |= basis[idx++]; - } - } - } - - // Build reps with the same in/out dims as cvt, but zeroing out the leading - // inB bases (per in-dim) and keeping the remainder bases unchanged from cvt. - LinearLayout::BasesT repsBases; - for (StringAttr id : cvt.getInDimNames()) { - int inA = cvt.getInDimSizeLog2(id); - int inB = tile.hasInDim(id) ? tile.getInDimSizeLog2(id) : 0; - assert(inB <= inA && "tile has more in-bits than cvt for a given in-dim"); - - std::vector> basesForDim; - basesForDim.reserve(inA); - - // 1) Validate the starting bases match exactly. - for (int i = 0; i < inB; ++i) { - for (StringAttr od : cvt.getOutDimNames()) { - int a = cvt.getBasis(id, i, od); - int b = tile.getBasis(id, i, od); - if (a != b) { - return std::nullopt; - } - } - } - - // 2) Validate no overlap: the remaining cvt bases must have zeros in all - // tile-bit positions (computed as OR of all tile bases) for each - // out-dim. - for (int i = inB; i < inA; ++i) { - for (StringAttr od : cvt.getOutDimNames()) { - int32_t mask = tileMaskPerOutDim.lookup(od); - if (mask == 0) - continue; - int v = cvt.getBasis(id, i, od); - if ((v & mask) != 0) { - return std::nullopt; - } - } - } - - // 3) Emit reps bases: first inB as all-zeros; remainder copied from cvt. - for (int i = 0; i < inB; ++i) { - std::vector zero(cvt.getNumOutDims(), 0); - basesForDim.push_back(std::move(zero)); - } - for (int i = inB; i < inA; ++i) { - std::vector keep; - keep.reserve(cvt.getNumOutDims()); - for (StringAttr od : cvt.getOutDimNames()) - keep.push_back(cvt.getBasis(id, i, od)); - basesForDim.push_back(std::move(keep)); - } - - repsBases[id] = std::move(basesForDim); - } - - return LinearLayout(std::move(repsBases), cvt.getOutDims(), - /*requireSurjective=*/false); -} - // Similar to largestVectorisation in TritonGPUToLLVM/Utility.cpp std::optional> getVec(const LinearLayout &cvt, const LinearLayout &tile, int maxnreg) { @@ -960,116 +846,6 @@ static void createTcgen05Cp(ConversionPatternRewriter &rewriter, Location loc, ptxBuilder.launch(rewriter, loc, void_ty(rewriter.getContext())); } -static std::optional, int32_t, int32_t>> -getSwizzling(MemDescType shmemTy, MemDescType tmemTy, TMemCopyAtom atom) { - // cvt is a map from Tmem to Shmem - auto tmemLl = toLinearLayout(tmemTy); - auto shmemLl = toLinearLayout(shmemTy); - auto inDimNames = to_vector(tmemLl.getInDimNames()); - auto *ctx = inDimNames[0].getContext(); - assert(shmemLl.getInDimSize(str_attr("block")) == 1 && "NYI"); - auto kOffset = str_attr("offset"); - auto kRow = str_attr("row"); - auto kCol = str_attr("col"); - shmemLl = shmemLl.sublayout({kOffset}, to_vector(shmemLl.getOutDimNames())); - auto cvt = tmemLl.invertAndCompose(shmemLl); - - int32_t bitwidth = tmemTy.getElementType().getIntOrFloatBitWidth(); - - // Check if the layout is large enough as to check SBO - // TODO Move to the verifier - if (shmemLl.getOutDimSizeLog2(str_attr("dim0")) < 4) { - return std::nullopt; - } - // TODO We may need to be careful here if we ever want to support fp4 padded - // layouts - if (!shmemLl.isInvertible()) { - return std::nullopt; - } - - // This will be SBO for k-Contiguous layouts (like the ones used in - // tcgen05.cp) - auto sbo = - shmemLl.invert().getBasis(str_attr("dim0"), /*log2(8)=*/3, kOffset); - - const SmallVector instrShape = {atom.nRow, atom.bCol / bitwidth}; - // TODO Move to the verifier perhaps - // Can we move the tile? - for (auto [inDimName, instrSize] : llvm::zip(inDimNames, instrShape)) { - if (cvt.getInDimSize(inDimName) < instrSize) { - return std::nullopt; - } - } - - auto CTALayout = getCTALayout(shmemTy.getEncoding()); - - for (int swizzling : {0, 32, 64, 128}) { - // r = 0, 1, 2, 3 - auto shmemEnc = - NVMMASharedEncodingAttr::get(ctx, swizzling, /*transposed=*/false, - bitwidth, /*fp4Padded=*/false, CTALayout); - auto shmemTile = - getCoreMatrixLinearLayout(shmemEnc, /*disableSwizzle=*/false); - // getCoreMatrixLinearLayout gives the k-contiguous tile - // shmemTile is a layout onto a matrix with shape - // If swizzling != 0: 8 x (8 * swizzling / bitwidth) - // If swizzling == 0: 8 x (8 * 16 / bitwidth) - assert(shmemTile.getOutDimSize(str_attr("dim0")) == 8); - assert(shmemTile.getOutDimSize(str_attr("dim1")) == - 8 * std::max(16, swizzling) / bitwidth); - // The shmemTile is mapped identically into the tmem, so we just need to - // rename the outDims in shmemTile from dim0, dim1 to row, col - auto cvtTileInverted = - LinearLayout(shmemTile.getBases(), {str_attr("row"), str_attr("col")}); - // The tile should be invertible, so we consider it as a map from row, col - // to offset - // nb. Working with the map from row, col to offset is important to handle - // the tcgen05.cp instructions that do broadcasting - auto cvtTile = cvtTileInverted.invert(); - // The sbo stride shall not touch the core tile - if (sbo < cvtTile.getOutDimSize(kOffset)) - continue; - - // As we are copying instrShape[0] columns in one go, to be able to - // represent this in the descriptor, we need to have a constant "stride" - // along the row dimension from row=8 until the last row. - auto bases = cvtTile.getBases(); - for (int i = 1; i < instrShape[0] / 8; i *= 2) { - bases[kRow].push_back({sbo * i}); - } - // Broadcast - for (int i = instrShape[0]; i < 128; i *= 2) { - bases[kRow].push_back({0}); - } - // If we multicast as warpx2::02_13, we need to swap the last two bases - if (atom.multicast == 1) { - auto n = bases[kRow].size(); - std::swap(bases[kRow][n - 1], bases[kRow][n - 2]); - } - cvtTile = LinearLayout(bases, {{kOffset, sbo * (instrShape[0] / 8)}}, - /*requireSurjective=*/false); - - auto quot = divideLeft(cvt, cvtTile); - if (quot.has_value()) { - if (auto nvmma = dyn_cast(shmemEnc)) { - assert(nvmma.getSwizzlingByteWidth() == swizzling); - } - auto lbo = 0; - if (swizzling == 0) { - auto dim1 = str_attr("dim1"); - auto endTile = shmemTile.getOutDimSizeLog2(dim1); - auto shmemInv = shmemLl.invert(); - if (shmemInv.getInDimSizeLog2(dim1) > endTile) { - lbo = shmemInv.getBasis(dim1, endTile, kOffset); - } - } - return std::make_tuple(swizzling, *quot, cvtTile, instrShape, lbo, sbo); - } - } - return std::nullopt; -} - static void copySharedToTmem(ConversionPatternRewriter &rewriter, Location loc, const TypeConverter *typeConverter, triton::nvidia_gpu::TMEMCopyOp op, Value src, @@ -1082,84 +858,60 @@ static void copySharedToTmem(ConversionPatternRewriter &rewriter, Location loc, MemDescType srcTy = op.getSrc().getType(); MemDescType dstTy = op.getDst().getType(); - - auto sharedLl = toLinearLayout(srcTy); - sharedLl = - sharedLl.sublayout({kOffset}, to_vector(sharedLl.getOutDimNames())); + auto shmemLl = toLinearLayout(srcTy); auto tmemLl = toLinearLayout(dstTy); - auto cvt = tmemLl.invertAndCompose(sharedLl); - auto bitwidth = srcTy.getElementType().getIntOrFloatBitWidth(); - auto atom = getTMemCopyAtom(cvt, bitwidth); - // Need to find the shmem tile that matches - auto maybeSwizzling = getSwizzling(srcTy, dstTy, atom); - assert(maybeSwizzling.has_value()); - auto [swizzling, quot, tile, tileShape, lbo, sbo] = - std::move(*maybeSwizzling); - - auto reps = zerosLike(tile) * quot; + // This subtlely handles subviews + auto cvt = tmemLl.invertAndCompose(shmemLl); + auto bitwidth = srcTy.getElementType().getIntOrFloatBitWidth(); + auto atom = getTMemCopyAtom(cvt, bitwidth); // Get shmem ptr - // TODO We should not allow splitting along the swizzling pattern Type elemTy = typeConverter->convertType(srcTy.getElementType()); auto smemObj = LLVM::getSharedMemoryObjectFromStruct(loc, src, elemTy, rewriter); - Value baseSrcInt = - b.ptrtoint(i32_ty, smemObj.getShmemAffineBase(loc, rewriter, srcTy)); - // We checked in the verifier that the alignment is at least 16 - Value baseSrcIntShr4 = b.lshr(baseSrcInt, b.i32_val(4)); - Value baseSrcDesc = b.zext(i64_ty, b.and_(baseSrcIntShr4, b.i32_val(0x3FFF))); - - // Set common fields in the SMEMDescriptor - SMEMDescriptor desc; - // https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-shared-memory-descriptor - desc.descriptor = 1ULL << 46; - desc.baseAddress = 0; - desc.leadDimensionBaseOffset = lbo != 0 ? (lbo * (bitwidth / 8)) >> 4 : 1; - // SBO is in elements and we have to pass it to bits and right shift by 4 - desc.strideDimensionBaseOffset = ((sbo * (bitwidth / 8)) >> 4); - desc.matrixBaseOffset = 0; - switch (swizzling) { - case 0: - desc.swizzlingMode = 0; - break; - case 32: - desc.swizzlingMode = 3; - break; - case 64: - desc.swizzlingMode = 2; - break; - case 128: - desc.swizzlingMode = 1; - break; - default: - llvm::report_fatal_error("Unsupported swizzling size."); + auto smemBase = smemObj.getShmemAffineBase(loc, rewriter, srcTy); + + // We handle the multicast (the last 2 bits) after the descriptor + // once we have access to the lbo/sbo + const SmallVector instrShape = {32, atom.bCol / bitwidth}; + auto kWarp = str_attr("warp"); + auto cvtWarp = + cvt.reshapeIns({{kRow, 32}, {kWarp, 4}, {kCol, cvt.getInDimSize(kCol)}}) + .sublayout({kRow, kCol}, to_vector(cvt.getOutDimNames())); + + auto loader = DotOpMmaSmemLoader::build(loc, rewriter, cvtWarp, bitwidth, + smemBase, instrShape, 5); + assert(!loader.getDescriptor().transposed); + + // The lbo/sbo are swapped for swizzling == 0 when passing a descriptor to + // tcgen05.cp vs passing it to wgmma/tcgen05.mma!! + auto &descData = loader.getDescriptor(); + if (descData.swizzlingByteWidth == 0) { + auto lbo = descData.descriptor.leadDimensionBaseOffset; + auto sbo = descData.descriptor.strideDimensionBaseOffset; + descData.descriptor.leadDimensionBaseOffset = sbo; + descData.descriptor.strideDimensionBaseOffset = lbo; + } + // Check correct lbo/sbo along the multicast + auto strideRow = cvt.getBasis(kRow, llvm::Log2_32(8), kOffset); + if ((atom.multicast & 1) == 0) { + assert(cvt.getBasis(kRow, llvm::Log2_32(32), kOffset) == + strideRow * (32 / 8)); + } + if ((atom.multicast & 2) == 0) { + assert(cvt.getBasis(kRow, llvm::Log2_32(64), kOffset) == + strideRow * (64 / 8)); } - // Make sure we don't have to iterate along the rows - assert(tile.getInDimSize(kRow) == cvt.getInDimSize(kRow) && "NYI"); - assert(tileShape[1] <= tile.getInDimSize(kCol) && "NYI"); - int elementBytes = bitwidth / 8; - for (int col = 0; col < reps.getInDimSize(kCol); - col += tile.getInDimSize(kCol)) { - // Compute base offset for the swizzling pattern - int32_t off = reps.apply({{kRow, 0}, {kCol, col}})[0].second; - desc.matrixBaseOffset = (off * elementBytes / 128) & 0x7; - for (int offset = 0; offset < tile.getInDimSize(kCol); - offset += tileShape[1]) { - // Compute total offset of the current message - int32_t totalOffElems = - cvt.apply({{kRow, 0}, {kCol, col + offset}})[0].second; - int32_t smemByteOffset = totalOffElems * elementBytes; - int32_t smemByteOffsetShr4 = smemByteOffset >> 4; - Value descValBase = b.int_val(64, desc.descriptor + smemByteOffsetShr4); - // Add the base address to the descriptor - Value descVal = b.or_(descValBase, baseSrcDesc, /*disjoint=*/true); - auto tmemAddr = b.or_(b.ptrtoint(i32_ty, baseDst), - b.i32_val((col + offset) * elementBytes / 4), - /*disjoint=*/true); - createTcgen05Cp(rewriter, loc, tmemAddr, descVal, pred, atom); - } + for (int col = 0; col < cvt.getInDimSize(kCol); col += instrShape[1]) { + // smemLoad takes the colRep. It'd be nice to change this but we would need + // to change the wgmma and mmav5 lowering + auto desc = loader.smemLoad(0, col / instrShape[1], rewriter, loc); + auto tmemAddr = + b.or_(b.ptrtoint(i32_ty, baseDst), b.i32_val(col * bitwidth / 32), + /*disjoint=*/true); + createTcgen05Cp(rewriter, loc, tmemAddr, desc, pred, atom); } } diff --git a/third_party/proton/README.md b/third_party/proton/README.md index 6ad5821084..d0b5cd8a40 100644 --- a/third_party/proton/README.md +++ b/third_party/proton/README.md @@ -118,7 +118,7 @@ proton.start(name="profile_name", context="shadow", backend="cupti", mode="pcsam #### Instrumentation The instrumentation backend allows for detailed, fine-grained profiling of intra-kernel behavior, generating trace or tree views similar to those produced by coarse-grained profiling. -By default, if no `mode` is specified, Proton profiles kernel cycles, which may require shared memory. If there is insufficient shared memory, profiling will abort and a warning will be displayed. Future releases will introduce additional instrumentation modes. +By default, if no `mode` is specified, Proton profiles kernel cycles, which may require shared memory. If there is insufficient shared memory, profiling will abort and a warning will be displayed. Future releases will introduce additional instrumentation modes. See the [tutorial](tutorials/intra_kernel) for more detailed information and examples. **Host-side usage:** @@ -138,7 +138,7 @@ import triton.profiler.mode as pmode proton.start( name="profile_name", backend="instrumentation", - mode=pmode.Default(granularity="warp_2") # collect metrics from every 2 warps + mode=pmode.Default() # collect metrics from every warp ) ``` @@ -167,7 +167,7 @@ def kernel(...): gl.load(...) ``` -Advanced users can instrument either the `ttir` or `ttgir` intermediate representations for even finer-grained measurement. The relevant IR instructions are `proton.record start` and `proton.record end`. This can be combined with the environment variable `TRITON_KERNEL_OVERRIDE=1` for custom kernel overrides. For detailed steps, refer to the Triton [documentation](https://github.com/triton-lang/triton?tab=readme-ov-file#tips-for-hacking) under the **Kernel Override Steps** section. We have also assembled a [tutorial](tutorials/ttgir_override) that demonstrates how to use the IR-based instrumentation. +Advanced users can instrument either the `ttir` or `ttgir` intermediate representations for even finer-grained measurement. The relevant IR instructions are `proton.record start` and `proton.record end`. This can be combined with the environment variable `TRITON_KERNEL_OVERRIDE=1` for custom kernel overrides. For detailed steps, refer to the Triton [documentation](https://github.com/triton-lang/triton?tab=readme-ov-file#tips-for-hacking) under the **Kernel Override Steps** section. We have also assembled a [tutorial](tutorials/intra_kernel) that demonstrates how to use the IR-based instrumentation approach and the proton DSL approach. ### Hook diff --git a/third_party/proton/proton/flags.py b/third_party/proton/proton/flags.py index 4f24f479e5..bef7621014 100644 --- a/third_party/proton/proton/flags.py +++ b/third_party/proton/proton/flags.py @@ -17,7 +17,7 @@ @dataclass class ProfilerFlags: - # Whether to enable profiling. Default is False. + # Whether profiling is enabled. Default is False. profiling_on: bool = False # Whether instrumentation is enabled. Default is False. instrumentation_on: bool = False diff --git a/third_party/proton/proton/hooks/instrumentation.py b/third_party/proton/proton/hooks/instrumentation.py index 186152f150..b03026150c 100644 --- a/third_party/proton/proton/hooks/instrumentation.py +++ b/third_party/proton/proton/hooks/instrumentation.py @@ -1,4 +1,3 @@ -import functools from typing import Dict, Optional, Union, Any import triton @@ -9,7 +8,6 @@ from triton._C.libtriton import passes as triton_passes from triton._C.libproton import proton as libproton from triton.compiler import LazyDict -from triton.runtime.jit import JITFunction from triton.runtime._allocation import set_profile_allocator, NullAllocator from triton.backends import backends @@ -195,16 +193,8 @@ def to_llvm_passes(pm): # Set up the profiling allocator set_profile_allocator(self.allocator) - original_run = JITFunction.run - - original_mode = self.mode - - @functools.wraps(original_run) - def instrumented_run(self, *args, **kwargs): - kwargs["instrumentation_mode"] = str(original_mode) - return original_run(self, *args, **kwargs) - - JITFunction.run = instrumented_run + # Set the instrumentation mode + triton.knobs.compilation.instrumentation_mode = str(self.mode) def deactivate(self): if InstrumentationHook.active_count == 0: @@ -220,16 +210,14 @@ def deactivate(self): # No runtime instrumentation hook is active anymore flags.instrumentation_on = False - # Restore original JIT function run method - if hasattr(JITFunction.run, "__wrapped__"): - JITFunction.run = JITFunction.run.__wrapped__ + # Restore the instrumentation mode + triton.knobs.compilation.instrumentation_mode = "" # Reset profile allocator set_profile_allocator(NullAllocator()) # Reset host memory for external processing - if InstrumentationHook.enable_host_buffer: - InstrumentationHook.host_buffer = None + InstrumentationHook.host_buffer = None # Reset the buffer reference self.buffer = None diff --git a/third_party/proton/proton/profile.py b/third_party/proton/proton/profile.py index 499c151468..3e1f54e4aa 100644 --- a/third_party/proton/proton/profile.py +++ b/third_party/proton/proton/profile.py @@ -99,8 +99,8 @@ def start( Returns: session (int): The session ID of the profiling session. """ - if flags.command_line: - # Ignore the start() call if the script is run from the command line. + if flags.command_line or triton.knobs.proton.disable: + # Ignore the start() call if the script is run from the command line or profiling is disabled. return flags.profiling_on = True diff --git a/third_party/proton/proton/proton.py b/third_party/proton/proton/proton.py index 76ba1fae99..3dd02b8206 100644 --- a/third_party/proton/proton/proton.py +++ b/third_party/proton/proton/proton.py @@ -19,7 +19,7 @@ def parse_arguments(): choices=["shadow", "python"]) parser.add_argument("-m", "--mode", type=str, help="Profiling mode", default=None) parser.add_argument("-d", "--data", type=str, help="Profiling data", default="tree", choices=["tree", "trace"]) - parser.add_argument("-k", "--hook", type=str, help="Profiling hook", default=None, choices=[None, "launch"]) + parser.add_argument("-k", "--hook", type=str, help="Profiling hook", default=None, choices=[None, "triton"]) parser.add_argument('target_args', nargs=argparse.REMAINDER, help='Subcommand and its arguments') args = parser.parse_args() return args, args.target_args diff --git a/third_party/proton/test/test_api.py b/third_party/proton/test/test_api.py index 6f1bb77a65..1b8060f962 100644 --- a/third_party/proton/test/test_api.py +++ b/third_party/proton/test/test_api.py @@ -4,6 +4,7 @@ Profile correctness tests involving GPU kernels should be placed in `test_profile.py`. """ +import pytest import json import triton.profiler as proton import pathlib @@ -392,3 +393,17 @@ def test_throw(tmp_path: pathlib.Path): finally: proton.finalize() assert "Session has not been initialized: " + str(session_id + 1) in deactivate_error + + +@pytest.mark.parametrize("disable", [True, False]) +def test_profile_disable(disable, fresh_knobs, tmp_path: pathlib.Path): + fresh_knobs.proton.disable = disable + temp_file = tmp_path / "test_profile_disable.hatchet" + proton.start(str(temp_file.with_suffix(""))) + proton.enter_scope("test0") + proton.exit_scope() + proton.finalize() + if disable: + assert not temp_file.exists() + else: + assert temp_file.exists() diff --git a/third_party/proton/tutorials/intra_kernel/README.md b/third_party/proton/tutorials/intra_kernel/README.md new file mode 100644 index 0000000000..de5395137b --- /dev/null +++ b/third_party/proton/tutorials/intra_kernel/README.md @@ -0,0 +1,118 @@ +# Proton Intra-Kernel Profiler Tutorial + +A comprehensive tutorial demonstrating how to use the Proton intra-kernel profiler for detailed performance analysis of GPU kernels written in Triton DSL and Gluon DSL. + +## Overview + +The Proton intra-kernel profiler captures fine-grained timing information within GPU kernels, enabling performance bottleneck identification and optimization opportunities. This tutorial provides two distinct profiling approaches: + +- **TTGIR Override Approach** - For profiling existing Triton DSL kernels by injecting instrumentation +- **Proton DSL Approach** - For native integration with Triton and Gluon DSL kernels using embedded profiling scopes + +## Examples + +### 1. TTGIR Override Approach (`example_override.py`) + +**Use Case**: Profile existing Triton DSL kernels without modifying source code + +**Example**: Vector addition kernel with external instrumentation injection + +**Workflow**: +1. **Generate TTGIR dump files**: + ```bash + ../../scripts/dump_ttgir.sh python3 example_override.py --increase-accuracy + ``` + Creates original TTGIR files in `ttgir_dump/` directory + +2. **Insert profiling instrumentation**: + ```bash + ./insert_proton_records + ``` + Modifies TTGIR files by adding `proton.record` operators at profiling points + +3. **Execute with TTGIR override**: + ```bash + TRITON_ALWAYS_COMPILE=1 TRITON_KERNEL_OVERRIDE=1 TRITON_OVERRIDE_DIR=ttgir_dump python3 example_override.py --increase-accuracy + ``` + - `TRITON_ALWAYS_COMPILE=1`: Forces recompilation on each run + - `TRITON_KERNEL_OVERRIDE=1`: Enables TTGIR override mechanism + - `TRITON_OVERRIDE_DIR=ttgir_dump`: Specifies directory with modified TTGIR files + +### 2. Proton DSL Approach (`example_dsl.py`) + +**Use Case**: Native profiling DSL integration for Triton and Gluon DSL kernels + +**Example**: Triton vector-add and Gluon matrix multiplication using NVIDIA Hopper architecture features (WGMMA, TMA) + + +**Command Line Options**: +```bash +# Timeline trace mode (default) +python3 example_dsl.py + +# Operation measurement mode +python3 example_dsl.py --op-measure + +# Enable warp sampling with specific warp IDs +python3 example_dsl.py --warp-sampling --warp-ids "0,1,2,3" + +# High accuracy profiling +python3 example_dsl.py --increase-accuracy +``` + +## Understanding Timeline Traces + +### Time Representation +- **Scope Duration**: Displayed in cycles for precise measurement +- **Threadblock Start Times**: Measured in nanoseconds using global timing +- **Chrome Trace Format**: Assumes 1GHz GPU frequency for consistent time units (ns) + +### Circular Buffer System +- **Backend Storage**: Uses circular buffer for runtime profiling on each CTA +- **Buffer Overflow**: When full, earlier events are dropped with warnings in trace generation +- **Event Window**: Displays sliding window (the latest window) of recorded events in timeline + +### Finalize Time Measurement +- **Definition**: Captures `Finalize Time` when kernel execution completes +- **Meaning**: Shows overhead of dumping profiling data from buffer to global memory (appears as a field in Chrome trace viewer tab) + +## Configuration Options + +### Profiling Accuracy + +| Option | Description | Use Case | +|--------|-------------|----------| +| `clock32` | Records events in 32-bit clock format for lower overhead | normal kernels (<4 seconds @ 1GHz) | +| `time_shift` | Deducts constant profiling overhead from timeline trace | Mitigate Proton runtime overhead for cleaner traces | +| `sched_stores` | Provides more cycle-accurate operation latency measurement | Accurate single operation latency measure | +| `sched_barriers` | Constrains AMD instruction scheduling within proton scopes | AMD GPU profiling | + +### Buffer Configuration + +| Parameter | Options | Description | +|-----------|---------|-------------| +| `buffer_type` | `shared_mem`| Storage location for profiling buffer | +| `buffer_size` | `N` | Byte size of the profiling buffer (default: infer a small fraction of shared memory) | + +### Sampling Configuration + +| Parameter | Options | Description | +|-----------|---------|-------------| +| `sampling_strategy` | `selective`, `none` | Sampling approach for profiling data collection | +| `sampling_options` | Comma-separated warp IDs | Specific warps to profile (e.g., "0,1,2,3") | + +**Sampling Benefits**: Warp sampling captures more events within the same buffer size constraint by focusing on specific warps of interest. + +## Output Formats + +### Timeline Traces +- **Format**: Chrome trace format (`.chrome_trace` files) +- **Viewer**: Chrome browser at `chrome://tracing` or [`Perfetto`](https://ui.perfetto.dev/) +- **Content**: Detailed timeline with scope durations + +### Operation Measurements +- **Format**: Hatchet format (`.hatchet` files) +- **Viewer**: `proton-viewer -m normalized_cycles .hatchet` +(with `-m cycles` showing sum of all cycles across the GPU, `normalized_cycles` for per-warp averaged cycles) +- **Content**: Scope-level performance metrics and statistics +- **Note**: Cycle counts are averaged across warps/CTAs diff --git a/third_party/proton/tutorials/intra_kernel/example_dsl.py b/third_party/proton/tutorials/intra_kernel/example_dsl.py new file mode 100644 index 0000000000..f8ba0a3ebb --- /dev/null +++ b/third_party/proton/tutorials/intra_kernel/example_dsl.py @@ -0,0 +1,305 @@ +""" +Intra-Kernel Profiling Examples using Proton DSL for Triton and Gluon Kernels +""" + +import argparse + +import torch +import triton +import triton.language as tl +import triton.profiler as proton +import triton.profiler.language as pl +from triton.experimental import gluon +from triton.experimental.gluon import language as gl +from triton.experimental.gluon.language.nvidia.hopper import ( + fence_async_shared, + mbarrier, + tma, + warpgroup_mma, + warpgroup_mma_init, + warpgroup_mma_wait, +) + +from triton.experimental.gluon.nvidia.hopper import TensorDescriptor + +DEVICE = triton.runtime.driver.active.get_active_torch_device() + +NUM_WARPS = 8 + + +def is_hopper(): + target = triton.runtime.driver.active.get_current_target() + return target.backend == "cuda" and torch.cuda.get_device_capability()[0] == 9 + + +def config_helper(description: str): + # Configure command line arguments for profiling options + parser = argparse.ArgumentParser(description=description) + parser.add_argument( + "--op-measure", + action="store_true", + default=False, + help="Enable operation measurement. Otherwise, we profile timeline trace. (default: False)", + ) + parser.add_argument( + "--warp-sampling", + action="store_true", + default=False, + help="Enable warp sampling during profiling (default: False)", + ) + parser.add_argument( + "--increase-accuracy", + action="store_true", + default=False, + help="Enable increased-accuracy during profiling (default: False).", + ) + parser.add_argument( + "--warp-ids", + type=str, + default="0, 2", + help="Comma-separated list of warp IDs for warp sampling (default: '0, 2')", + ) + + args = parser.parse_args() + + # Configure profiling options based on accuracy requirements + # Default uses clock_64 for long-running kernels with higher overhead + opts = "" + # `clock_32` provides lower overhead per record, `time_shift`` post-processes to reduce noise + if args.increase_accuracy: + opts = "clock32,time_shift" + + # Set up profiling mode based on warp sampling preferences + if args.warp_sampling: + # Selective warp sampling allows capturing more events within buffer constraints + # by only profiling specified warps (e.g. "0,1,2,3") + mode = proton.mode.Default( + optimizations=opts, + sampling_strategy="selective", + sampling_options=args.warp_ids, + ) + else: + # Profile all warps - provides complete picture but uses more buffer space + mode = proton.mode.Default(optimizations=opts) + + return args.op_measure, mode + + +@triton.jit +def add_kernel(x_ptr, # *Pointer* to first input vector. + y_ptr, # *Pointer* to second input vector. + output_ptr, # *Pointer* to output vector. + n_elements, # Size of the vector. + BLOCK_SIZE: tl.constexpr, # Number of elements each program should process. + # NOTE: `constexpr` so it can be used as a shape value. + ): + pl.enter_scope("kernel") + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + with pl.scope("load_and_add"): + with pl.scope("load_x_issue"): + x = tl.load(x_ptr + offsets, mask=mask) + with pl.scope("load_y_issue"): + y = tl.load(y_ptr + offsets, mask=mask) + output = x + y + tl.store(output_ptr + offsets, output, mask=mask) + pl.exit_scope("kernel") + + +def add(x: torch.Tensor, y: torch.Tensor): + output = torch.empty_like(x) + assert x.device == DEVICE and y.device == DEVICE and output.device == DEVICE + n_elements = output.numel() + grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]), ) + add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024, num_warps=NUM_WARPS) + return output + + +if __name__ == "__main__": + description = "Triton Vector Add with Proton Intra-Kernel Profiling" + print(description) + + # Explicit Proton DSL enablement for Triton kernels. + # Be careful NOT to insert proton ops in loops (use the ttgir override approach instead). + pl.enable_semantic("triton") + + op_measure, mode = config_helper(description) + + # Start profiling with appropriate backend and output format + if op_measure: + # Operation measurement mode generates scope-level metrics + # View results with: proton-viewer -m normalized_cycles vector-add.hatchet + # Note: cycles are averaged across all warps/CTAs - adjust for warp specialization + proton.start("vector-add", backend="instrumentation", mode=mode) + else: + # Timeline trace mode generates Chrome trace format for visualization + # Output file: vector-add.chrome_trace + proton.start("vector-add", data="trace", backend="instrumentation", mode=mode) + + torch.manual_seed(0) + size = 98432 + x = torch.rand(size, device=DEVICE) + y = torch.rand(size, device=DEVICE) + output_torch = x + y + output_triton = add(x, y) + torch.testing.assert_close(output_torch, output_triton, rtol=1e-3, atol=1e-1) + proton.finalize() + + +# This decorator allows us to invoke the function from a Gluon constexpr. +@gluon.constexpr_function +def get_warps_per_cta(BLOCK_M, BLOCK_N, num_warps): + warps_per_cta = [4, 1] + m = 16 + # Tile the atom until we have enough warps. + while warps_per_cta[0] * warps_per_cta[1] != num_warps: + # Tile along M only if it would not cause broadcasting. + if BLOCK_M > m * warps_per_cta[0]: + warps_per_cta[0] *= 2 + else: + warps_per_cta[1] *= 2 + return warps_per_cta + + +@gluon.constexpr_function +def get_instr_shape_n(BLOCK_M, BLOCK_N, num_warps): + m = 16 + mReps = triton.cdiv(BLOCK_M, m) + nReps = triton.cdiv(num_warps, mReps) + maxN = max(BLOCK_N // nReps, 8) + n = 256 + while n > maxN or BLOCK_N % n != 0: + n -= 8 + assert n >= 8, "expected to find a valid n" + return n + + +@gluon.constexpr_function +def pick_wgmma_layout(dtype, BLOCK_M, BLOCK_N, num_warps): + m = 16 + k = 256 // dtype.primitive_bitwidth + n = get_instr_shape_n(BLOCK_M, BLOCK_N, num_warps) + warps_per_cta = get_warps_per_cta(BLOCK_M, BLOCK_N, num_warps) + return gl.NVMMADistributedLayout( + version=[3, 0], + warps_per_cta=warps_per_cta, + instr_shape=[m, n, k], + ) + + +@gluon.jit +def blocked_matmul_pipelined_kernel(a_desc, b_desc, c_desc, num_warps: gl.constexpr): + BLOCK_M: gl.constexpr = c_desc.block_type.shape[0] + BLOCK_N: gl.constexpr = c_desc.block_type.shape[1] + BLOCK_K: gl.constexpr = a_desc.block_type.shape[1] + dtype: gl.constexpr = a_desc.dtype + K = a_desc.shape[1] + + pl.enter_scope("blocked_matmul_pipelined_kernel") + + # Allocate 2 buffers for each A and B. + a_smem = gl.allocate_shared_memory(dtype, [2] + a_desc.block_type.shape, a_desc.layout) + b_smem = gl.allocate_shared_memory(dtype, [2] + b_desc.block_type.shape, b_desc.layout) + index = 0 + + pid_m = gl.program_id(axis=0) + pid_n = gl.program_id(axis=1) + off_m = pid_m * BLOCK_M + off_n = pid_n * BLOCK_N + + mma_layout: gl.constexpr = pick_wgmma_layout(dtype, BLOCK_M, BLOCK_N, num_warps) + acc = warpgroup_mma_init(gl.zeros((BLOCK_M, BLOCK_N), dtype=gl.float32, layout=mma_layout)) + + bar = gl.allocate_shared_memory(gl.int64, [1], mbarrier.MBarrierLayout()) + mbarrier.init(bar, count=1) + phase = 0 + + for k in range(0, K, BLOCK_K): + a = a_smem.index(index) + b = b_smem.index(index) + + mbarrier.expect(bar, a_desc.block_type.nbytes + b_desc.block_type.nbytes) + + with pl.scope("tma_loads_issue"): + tma.async_copy_global_to_shared(a_desc, [off_m, k], bar, a) + tma.async_copy_global_to_shared(b_desc, [k, off_n], bar, b) + + with pl.scope("tma_loads_wait"): + mbarrier.wait(bar, phase=phase) + phase ^= 1 + + # Since `warpgroup_mma_wait` is a no-op when there are no WGMMAs in + # flight, we can overlap the WGMMA by waiting first, then issuing the + # async WGMMA. + with pl.scope("wgmma_wait"): + acc = warpgroup_mma_wait(num_outstanding=0, deps=(acc, )) + + with pl.scope("wgmma_issue"): + acc = warpgroup_mma(a, b, acc, is_async=True) + + # Move to the next buffer. The TMA load will start while the WGMMA is + # still running. + index ^= 1 + + # Wait for the last WGMMA to complete. + with pl.scope("wgmma_last_wait"): + acc = warpgroup_mma_wait(num_outstanding=0, deps=(acc, )) + + mbarrier.invalidate(bar) + + c_smem = gl.allocate_shared_memory(dtype, c_desc.block_type.shape, c_desc.layout) + c_smem.store(acc.to(dtype)) + fence_async_shared() + tma.async_copy_shared_to_global(c_desc, [off_m, off_n], c_smem) + tma.store_wait(pendings=0) + + pl.exit_scope("blocked_matmul_pipelined_kernel") + + +def blocked_matmul_pipelined(A, B, C, BLOCK_M, BLOCK_N, BLOCK_K, num_warps): + M, N = C.shape + + a_layout = gl.NVMMASharedLayout.get_default_for([BLOCK_M, BLOCK_K], gl.float16) + b_layout = gl.NVMMASharedLayout.get_default_for([BLOCK_K, BLOCK_N], gl.float16) + c_layout = gl.NVMMASharedLayout.get_default_for([BLOCK_M, BLOCK_N], gl.float16) + a_desc = TensorDescriptor.from_tensor(A, [BLOCK_M, BLOCK_K], a_layout) + b_desc = TensorDescriptor.from_tensor(B, [BLOCK_K, BLOCK_N], b_layout) + c_desc = TensorDescriptor.from_tensor(C, [BLOCK_M, BLOCK_N], c_layout) + + grid = (triton.cdiv(M, BLOCK_M), triton.cdiv(N, BLOCK_N)) + blocked_matmul_pipelined_kernel[grid](a_desc, b_desc, c_desc, num_warps=num_warps) + + +if __name__ == "__main__": + if not is_hopper(): + raise RuntimeError("This tutorial requires a Hopper NVIDIA GPU") + + description = "Gluon Matrix Multiplication with Proton Intra-Kernel Profiling" + print(description) + + M, N, K = 512, 512, 1024 + A = torch.randn(M, K, device="cuda", dtype=torch.float16) + B = torch.randn(K, N, device="cuda", dtype=torch.float16) + C = torch.empty(M, N, device="cuda", dtype=torch.float16) + BLOCK_M, BLOCK_N, BLOCK_K = 128, 128, 128 + + op_measure, mode = config_helper(description) + + # Start profiling with appropriate backend and output format + if op_measure: + # Operation measurement mode generates scope-level metrics + # View results with: proton-viewer -m normalized_cycles gemm.hatchet + # Note: cycles are averaged across all warps/CTAs - adjust for warp specialization + proton.start("gemm", backend="instrumentation", mode=mode) + else: + # Timeline trace mode generates Chrome trace format for visualization + # Output file: gemm.chrome_trace + proton.start("gemm", data="trace", backend="instrumentation", mode=mode) + + blocked_matmul_pipelined(A, B, C, BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS) + torch.testing.assert_close(A @ B, C, rtol=1e-3, atol=1e-1) + + # Complete profiling and write output files + proton.finalize() diff --git a/third_party/proton/tutorials/intra_kernel/example_override.py b/third_party/proton/tutorials/intra_kernel/example_override.py new file mode 100644 index 0000000000..a87e8e8868 --- /dev/null +++ b/third_party/proton/tutorials/intra_kernel/example_override.py @@ -0,0 +1,98 @@ +""" +Vector Addition with Triton Intra-Kernel Profiling using TTGIR Override + +This tutorial demonstrates how to use Triton's TTGIR override mechanism +to enable intra-kernel profiling with Proton. The workflow involves generating, +modifying, and overriding the kernel's intermediate representation to insert +profiling hooks. + +Workflow: +1. Generate TTGIR dump files: + + This creates the original TTGIR files in the `ttgir_dump/` directory: + + ../../scripts/dump_ttgir.sh python3 example_override.py --increase-accuracy + +2. Insert profiling instrumentation: + + Modify the generated TTGIR files by adding proton.record operators at desired + profiling points. Example script that adds proton ops in the above ttgir: + + ./insert_proton_records + +3. Execute with TTGIR override: + + TRITON_ALWAYS_COMPILE=1 TRITON_KERNEL_OVERRIDE=1 TRITON_OVERRIDE_DIR=ttgir_dump python3 example_override.py --increase-accuracy + + - TRITON_ALWAYS_COMPILE=1: Forces recompilation on each run + - TRITON_KERNEL_OVERRIDE=1: Enables TTGIR override mechanism + - TRITON_OVERRIDE_DIR=ttgir_dump: Specifies directory containing modified TTGIR files +""" + +import argparse + +import torch +import triton +import triton.language as tl +import triton.profiler as proton +from triton.profiler.mode import Default + +DEVICE = triton.runtime.driver.active.get_active_torch_device() + + +@triton.jit +def add_kernel(x_ptr, # *Pointer* to first input vector. + y_ptr, # *Pointer* to second input vector. + output_ptr, # *Pointer* to output vector. + n_elements, # Size of the vector. + BLOCK_SIZE: tl.constexpr, # Number of elements each program should process. + # NOTE: `constexpr` so it can be used as a shape value. + ): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(x_ptr + offsets, mask=mask) + y = tl.load(y_ptr + offsets, mask=mask) + output = x + y + tl.store(output_ptr + offsets, output, mask=mask) + + +def add(x: torch.Tensor, y: torch.Tensor): + parser = argparse.ArgumentParser(description="TTGIR override example with Triton intra kernel profiling") + parser.add_argument( + "--increase-accuracy", + action="store_true", + default=False, + help="Enable increased-accuracy during profiling (default: False)", + ) + args = parser.parse_args() + + output = torch.empty_like(x) + assert x.device == DEVICE and y.device == DEVICE and output.device == DEVICE + n_elements = output.numel() + grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]), ) + + if args.increase_accuracy: + proton.start( + "add", + data="trace", + backend="instrumentation", + mode=Default(optimizations="clock32,time_shift"), + ) + else: + proton.start("add", data="trace", backend="instrumentation") + + add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024) + + proton.finalize() + return output + + +torch.manual_seed(0) +size = 98432 +x = torch.rand(size, device=DEVICE) +y = torch.rand(size, device=DEVICE) +output_torch = x + y +output_triton = add(x, y) +torch.testing.assert_close(output_torch, output_triton, rtol=1e-3, atol=1e-1) diff --git a/third_party/proton/tutorials/intra_kernel/insert_proton_records b/third_party/proton/tutorials/intra_kernel/insert_proton_records new file mode 100755 index 0000000000..e98435e06e --- /dev/null +++ b/third_party/proton/tutorials/intra_kernel/insert_proton_records @@ -0,0 +1,112 @@ +#!/usr/bin/env python3 +""" +Script to automatically add proton.record statements to the examplar vector-add ttgir. +""" + +import glob +import os +import re +import sys + + +def add_proton_records(input_file): + """Add proton.record statements to a ttgir file.""" + + with open(input_file, "r") as f: + content = f.read() + lines = f.readlines() + + # Assert no proton.record already exists + if "proton.record" in content: + raise AssertionError("File already contains `proton.record` statements! Please clean-up.") + + # Reset file pointer and read lines again + with open(input_file, "r") as f: + lines = f.readlines() + + result_lines = [] + load_and_add_started = False + + for i, line in enumerate(lines): + # Add kernel record start after function declaration + if "tt.func public @" in line and "{" in line: + result_lines.append(line) + result_lines.append(' proton.record start "kernel"\n') + continue + + # Add load_and_add record start before first load + if "tt.load" in line and not load_and_add_started: + result_lines.append(' proton.record start "load_and_add"\n') + load_and_add_started = True + + # Add individual load records + if "tt.load" in line: + # Extract variable name (x, y, etc.) - just the letters before '_' + match = re.search(r"%(\w+)_\d+\s*=\s*tt\.load", line) + if match: + var_name = match.group(1) + result_lines.append(f' proton.record start "load_{var_name}_issue"\n') + result_lines.append(line) + result_lines.append(f' proton.record end "load_{var_name}_issue"\n') + continue + + # Add load_and_add record end after arithmetic operation + if "arith.addf" in line and load_and_add_started: + result_lines.append(line) + result_lines.append(' proton.record end "load_and_add"\n') + load_and_add_started = False + continue + + # Add kernel record end before return + if "tt.return" in line: + result_lines.append(' proton.record end "kernel"\n') + result_lines.append(line) + continue + + # Default: just add the line + result_lines.append(line) + + # Write output in-place + with open(input_file, "w") as f: + f.writelines(result_lines) + + print(f"Added proton records to {input_file}") + + +def find_and_process_ttgir(): + """Find all ttgir files in ttgir_dump directory and process them.""" + + # Find ttgir_dump directory + ttgir_dump_path = None + for root, dirs, files in os.walk("."): + if "ttgir_dump" in dirs: + ttgir_dump_path = os.path.join(root, "ttgir_dump") + break + + if not ttgir_dump_path: + print("Error: ttgir_dump directory not found!") + sys.exit(1) + + # Process the ttgir file + ttgir_files = glob.glob(os.path.join(ttgir_dump_path, "**", "*.ttgir"), recursive=True) + + if not ttgir_files: + print(f"No ttgir files found in {ttgir_dump_path}") + return + + if len(ttgir_files) > 1: + print(f"Warning: Found {len(ttgir_files)} ttgir files, expected at most 1") + + ttgir_file = ttgir_files[0] # Take the first (and expected only) file + try: + print(f"Processing {ttgir_file}...") + add_proton_records(ttgir_file) + print("Successfully processed ttgir file") + except AssertionError as e: + print(f"Skipping {ttgir_file}: {e}") + except Exception as e: + print(f"Error processing {ttgir_file}: {e}") + + +if __name__ == "__main__": + find_and_process_ttgir() diff --git a/third_party/proton/tutorials/ttgir_override/ttgir_instrumentation.sh b/third_party/proton/tutorials/ttgir_override/ttgir_instrumentation.sh deleted file mode 100755 index f68230dbc5..0000000000 --- a/third_party/proton/tutorials/ttgir_override/ttgir_instrumentation.sh +++ /dev/null @@ -1,44 +0,0 @@ -#!/bin/bash - -# This script is used to demonstrate the use of the TTGIR instrumentation feature in Proton. -# the IR may change over time so to make this script future proof we will -# instrument at the DSL level first and then use that IR to override the uninstrumented kernel -DUMP_DIR="$PWD/ttgir_dump" - -if [ -e "$DUMP_DIR" ]; - then rm -rf "$DUMP_DIR" ; -fi - -mkdir -p "$DUMP_DIR" - -TRITON_ALWAYS_COMPILE=1 TRITON_KERNEL_DUMP=1 TRITON_DUMP_DIR=$DUMP_DIR python vector-add-instrumented.py -# Iterate over all subdirectories in $DUMP_DIR and remove all except the .ttgir files -for dir in "$DUMP_DIR"/*; do - if [ -d "$dir" ]; then - find "$dir" -type f ! -name "*.ttgir" -delete - #Save off the actual hash directory (this will change across kernel/Triton/etc. versions) - TTGIR_DIR="$dir" - fi -done - -echo "TTGIR files dumped to $TTGIR_DIR" - -# Save the add_kernel.ttgir file from the DSL level instrumentation to the current directory temporarily -cp $TTGIR_DIR/add_kernel.ttgir $PWD -rm -rf "$DUMP_DIR" - -# Now run the uninstrumented kernel and overwrite the add_kernel.ttgir file from the DSL level instrumentation -TRITON_ALWAYS_COMPILE=1 TRITON_KERNEL_DUMP=1 TRITON_DUMP_DIR=$DUMP_DIR python vector-add.py - -for dir in "$DUMP_DIR"/*; do - if [ -d "$dir" ]; then - find "$dir" -type f ! -name "*.ttgir" -delete - TTGIR_DIR="$dir" - fi -done - -mv add_kernel.ttgir $TTGIR_DIR/add_kernel.ttgir - -TRITON_ALWAYS_COMPILE=1 TRITON_KERNEL_OVERRIDE=1 TRITON_OVERRIDE_DIR=$DUMP_DIR python vector-add.py - -echo "Now run `proton-viewer -m normalized_cycles vector-add.hatchet` to see the output" diff --git a/third_party/proton/tutorials/ttgir_override/vector-add-instrumented.py b/third_party/proton/tutorials/ttgir_override/vector-add-instrumented.py deleted file mode 100644 index 539c45d12d..0000000000 --- a/third_party/proton/tutorials/ttgir_override/vector-add-instrumented.py +++ /dev/null @@ -1,62 +0,0 @@ -import torch - -import triton -import triton.language as tl -import triton.profiler.language as pl -import triton.profiler as proton -import pathlib -import os - -from typing import NamedTuple - -DEVICE = triton.runtime.driver.active.get_active_torch_device() - -pl.enable_semantic("triton") - - -def metadata_fn(grid: tuple, metadata: NamedTuple, args: dict): - BLOCK_SIZE = args["BLOCK_SIZE"] - return {"name": f"add_{BLOCK_SIZE}"} - - -@triton.jit(launch_metadata=metadata_fn) -def add_kernel(x_ptr, # *Pointer* to first input vector. - y_ptr, # *Pointer* to second input vector. - output_ptr, # *Pointer* to output vector. - n_elements, # Size of the vector. - BLOCK_SIZE: tl.constexpr, # Number of elements each program should process. - # NOTE: `constexpr` so it can be used as a shape value. - ): - with pl.scope("kernel"): - pid = tl.program_id(axis=0) - block_start = pid * BLOCK_SIZE - offsets = block_start + tl.arange(0, BLOCK_SIZE) - mask = offsets < n_elements - with pl.scope("load_ops"): - with pl.scope("load_x"): - x = tl.load(x_ptr + offsets, mask=mask) - with pl.scope("load_y"): - y = tl.load(y_ptr + offsets, mask=mask) - output = x + y - tl.store(output_ptr + offsets, output, mask=mask) - - -def add(x: torch.Tensor, y: torch.Tensor): - output = torch.empty_like(x) - assert x.device == DEVICE and y.device == DEVICE and output.device == DEVICE - n_elements = output.numel() - grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]), ) - tmp_path = pathlib.Path(os.getcwd()) - temp_file = tmp_path / "vector-add.hatchet" - proton.start(str(temp_file.with_suffix("")), backend="instrumentation") - add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024, num_warps=1) - proton.finalize() - return output - - -torch.manual_seed(0) -size = 98432 -x = torch.rand(size, device=DEVICE) -y = torch.rand(size, device=DEVICE) -output_torch = x + y -output_triton = add(x, y) diff --git a/third_party/proton/tutorials/ttgir_override/vector-add.py b/third_party/proton/tutorials/ttgir_override/vector-add.py deleted file mode 100644 index 63e67bdadc..0000000000 --- a/third_party/proton/tutorials/ttgir_override/vector-add.py +++ /dev/null @@ -1,55 +0,0 @@ -import torch - -import triton -import triton.language as tl -import triton.profiler as proton -import pathlib -import os - -from typing import NamedTuple - -DEVICE = triton.runtime.driver.active.get_active_torch_device() - - -def metadata_fn(grid: tuple, metadata: NamedTuple, args: dict): - BLOCK_SIZE = args["BLOCK_SIZE"] - return {"name": f"add_{BLOCK_SIZE}"} - - -@triton.jit(launch_metadata=metadata_fn) -def add_kernel(x_ptr, # *Pointer* to first input vector. - y_ptr, # *Pointer* to second input vector. - output_ptr, # *Pointer* to output vector. - n_elements, # Size of the vector. - BLOCK_SIZE: tl.constexpr, # Number of elements each program should process. - # NOTE: `constexpr` so it can be used as a shape value. - ): - pid = tl.program_id(axis=0) - block_start = pid * BLOCK_SIZE - offsets = block_start + tl.arange(0, BLOCK_SIZE) - mask = offsets < n_elements - x = tl.load(x_ptr + offsets, mask=mask) - y = tl.load(y_ptr + offsets, mask=mask) - output = x + y - tl.store(output_ptr + offsets, output, mask=mask) - - -def add(x: torch.Tensor, y: torch.Tensor): - output = torch.empty_like(x) - assert x.device == DEVICE and y.device == DEVICE and output.device == DEVICE - n_elements = output.numel() - grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), ) - tmp_path = pathlib.Path(os.getcwd()) - temp_file = tmp_path / "vector-add.hatchet" - proton.start(str(temp_file.with_suffix("")), backend="instrumentation") - add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024, num_warps=1) - proton.finalize() - return output - - -torch.manual_seed(0) -size = 98432 -x = torch.rand(size, device=DEVICE) -y = torch.rand(size, device=DEVICE) -output_torch = x + y -output_triton = add(x, y)