From 234897feaf459afe6080afa16d4bbf55d78fdecd Mon Sep 17 00:00:00 2001 From: Alex Baden Date: Fri, 20 Jun 2025 15:05:48 +0000 Subject: [PATCH] Store transpose attribute in Subgroup2DBlockIO layouts --- .../optimize-block-io-encoding.mlir | 74 ++++++++++++++++++- .../IR/TritonIntelGPUAttrDefs.td | 4 +- .../lib/Dialect/TritonIntelGPU/IR/Dialect.cpp | 46 ++++++++---- .../IR/LinearLayoutConversions.cpp | 6 +- .../LoadStoreOpToLLVM.cpp | 3 +- .../OptimizeBlockIOEncoding.cpp | 6 +- .../LinearLayoutConversionsTest.cpp | 53 ++++++++++--- 7 files changed, 156 insertions(+), 36 deletions(-) diff --git a/test/TritonIntelGPU/optimize-block-io-encoding.mlir b/test/TritonIntelGPU/optimize-block-io-encoding.mlir index 8a9b6f184a..0e135a956e 100644 --- a/test/TritonIntelGPU/optimize-block-io-encoding.mlir +++ b/test/TritonIntelGPU/optimize-block-io-encoding.mlir @@ -3,8 +3,8 @@ #blocked = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 16], warpsPerCTA = [8, 4], order = [1, 0]}> #blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 4], warpsPerCTA = [32, 1], order = [1, 0]}> #blocked2 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [1, 16], warpsPerCTA = [16, 2], order = [1, 0]}> -// CHECK: #mma = #ttig.subgroup_2d_block<{warpsPerCTA = [8, 4], instrShape = [8, 16], numBlocks=2, order=[1, 0], kWidth=1, threadsPerWarp=16}> -// CHECK: #mma1 = #ttig.subgroup_2d_block<{warpsPerCTA = [8, 4], instrShape = [16, 16], numBlocks=2, order=[0, 1], kWidth=2, threadsPerWarp=16}> +// CHECK: #mma = #ttig.subgroup_2d_block<{warpsPerCTA = [8, 4], instrShape = [8, 16], numBlocks = 2, isTransposed = false, order = [1, 0], kWidth = 1, threadsPerWarp = 16} +// CHECK: #mma1 = #ttig.subgroup_2d_block<{warpsPerCTA = [8, 4], instrShape = [16, 16], numBlocks = 2, isTransposed = false, order = [0, 1], kWidth = 2, threadsPerWarp = 16} // CHECK: #mma2 = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}> #mma = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}> module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32, ttg.target = "xpu", "ttg.threads-per-warp" = 16 : i32, ttig.min_sg_size = 16 : i32, ttig.support_bf16_conversion, ttig.support_dpas, ttig.support_sg_2d_block, ttig.target_arch = "spir64"} { @@ -66,11 +66,79 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32, ttg.tar // ----- +// COM: Dot Operand B transpose is supported +#blocked = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 16], warpsPerCTA = [8, 4], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 4], warpsPerCTA = [32, 1], order = [1, 0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [4, 4], warpsPerCTA = [1, 32], order = [0, 1]}> +#blocked3 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [1, 16], warpsPerCTA = [16, 2], order = [1, 0]}> +// CHECK: #mma = #ttig.subgroup_2d_block<{warpsPerCTA = [8, 4], instrShape = [8, 16], numBlocks = 2, isTransposed = false, order = [1, 0], kWidth = 1, threadsPerWarp = 16}> +// CHECK: #mma1 = #ttig.subgroup_2d_block<{warpsPerCTA = [8, 4], instrShape = [16, 8], numBlocks = 1, isTransposed = true, order = [0, 1], kWidth = 2, threadsPerWarp = 16}> +// CHECK: #mma2 = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}> +#mma = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32, ttg.target = "xpu", "ttg.threads-per-warp" = 16 : i32, ttig.min_sg_size = 16 : i32, ttig.support_bf16_conversion, ttig.support_dpas, ttig.support_sg_2d_block, ttig.target_arch = "spir64"} { + tt.func public @matmul_kernel_with_block_pointers(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %c4_i32 = arith.constant 4 : i32 + %c256_i32 = arith.constant 256 : i32 + %c1024_i64 = arith.constant 1024 : i64 + %c5120_i64 = arith.constant 5120 : i64 + %c1_i64 = arith.constant 1 : i64 + %c0_i32 = arith.constant 0 : i32 + %c4096_i64 = arith.constant 4096 : i64 + %c32_i32 = arith.constant 32 : i32 + %c64_i32 = arith.constant 64 : i32 + %c5120_i32 = arith.constant 5120 : i32 + %cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #blocked> + %0 = tt.get_program_id x : i32 + %1 = arith.divsi %0, %c64_i32 : i32 + %2 = arith.muli %1, %c4_i32 : i32 + %3 = arith.subi %c4_i32, %2 : i32 + %4 = arith.minsi %3, %c4_i32 : i32 + %5 = arith.remsi %0, %4 : i32 + %6 = arith.addi %2, %5 : i32 + %7 = arith.remsi %0, %c64_i32 : i32 + %8 = arith.divsi %7, %4 : i32 + %9 = arith.muli %6, %c256_i32 : i32 + // CHECK: tt.make_tensor_ptr {{.*}} : > + %10 = tt.make_tensor_ptr %arg0, [%c1024_i64, %c5120_i64], [%c5120_i64, %c1_i64], [%9, %c0_i32] {order = array} : > + %11 = arith.muli %8, %c256_i32 : i32 + // CHECK: tt.make_tensor_ptr {{.*}} : > + %12 = tt.make_tensor_ptr %arg1, [%c5120_i64, %c4096_i64], [%c1_i64, %c5120_i64], [%c0_i32, %11] {order = array} : > + %13:3 = scf.for %arg3 = %c0_i32 to %c5120_i32 step %c32_i32 iter_args(%arg4 = %cst, %arg5 = %10, %arg6 = %12) -> (tensor<256x256xf32, #blocked>, !tt.ptr>, !tt.ptr>) : i32 { + // CHECK: %[[A_LOAD:.*]] = tt.load %arg5 {boundaryCheck = array, ttig.block_io = "row_major"} : !tt.ptr> + // CHECK: {{.*}} = ttg.convert_layout %[[A_LOAD]] : tensor<256x32xf16, #mma> -> tensor<256x32xf16, #blocked1> + %17 = tt.load %arg5 {boundaryCheck = array, ttig.block_io = "row_major"} : !tt.ptr> + // CHECK: %[[B_LOAD:.*]] = tt.load %arg6 {boundaryCheck = array, ttig.block_io = "column_major"} : !tt.ptr> + // CHECK: {{.*}} = ttg.convert_layout %[[B_LOAD]] : tensor<32x256xf16, #mma1> -> tensor<32x256xf16, #blocked2> + %18 = tt.load %arg6 {boundaryCheck = array, ttig.block_io = "column_major"} : !tt.ptr> + %19 = ttg.convert_layout %17 : tensor<256x32xf16, #blocked1> -> tensor<256x32xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> + %20 = ttg.convert_layout %18 : tensor<32x256xf16, #blocked2> -> tensor<32x256xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> + %21 = ttg.convert_layout %arg4 : tensor<256x256xf32, #blocked> -> tensor<256x256xf32, #mma> + %22 = ttg.convert_layout %19 : tensor<256x32xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> -> tensor<256x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> + %23 = ttg.convert_layout %20 : tensor<32x256xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<32x256xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + // CHECK: tt.dot {{.*}} : tensor<256x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma2, kWidth = 1}>> * tensor<32x256xf16, #ttg.dot_op<{opIdx = 1, parent = #mma2, kWidth = 2}>> -> tensor<256x256xf32, #mma2> + %24 = tt.dot %22, %23, %21, inputPrecision = tf32 : tensor<256x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x256xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<256x256xf32, #mma> + %25 = ttg.convert_layout %24 : tensor<256x256xf32, #mma> -> tensor<256x256xf32, #blocked> + // CHECK: tt.advance {{.*}} : > + %26 = tt.advance %arg5, [%c0_i32, %c32_i32] : > + // CHECK: tt.advance {{.*}} : > + %27 = tt.advance %arg6, [%c32_i32, %c0_i32] : > + scf.yield %25, %26, %27 : tensor<256x256xf32, #blocked>, !tt.ptr>, !tt.ptr> + } + %14 = tt.make_tensor_ptr %arg2, [%c1024_i64, %c4096_i64], [%c4096_i64, %c1_i64], [%9, %11] {order = array} : > + %15 = arith.truncf %13#0 : tensor<256x256xf32, #blocked> to tensor<256x256xf16, #blocked> + %16 = ttg.convert_layout %15 : tensor<256x256xf16, #blocked> -> tensor<256x256xf16, #blocked3> + tt.store %14, %16 {boundaryCheck = array} : !tt.ptr> + tt.return + } +} + +// ----- + // COM: Dot operand A transpose currently not supported by subgroup 2d block io encoding #blocked = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 16], warpsPerCTA = [8, 4], order = [1, 0]}> #blocked1 = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [16, 1], warpsPerCTA = [2, 16], order = [0, 1]}> #blocked2 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [1, 16], warpsPerCTA = [16, 2], order = [1, 0]}> -// CHECK: #mma = #ttig.subgroup_2d_block<{warpsPerCTA = [8, 4], instrShape = [16, 16], numBlocks=2, order=[0, 1], kWidth=2, threadsPerWarp=16}> +// CHECK: #mma = #ttig.subgroup_2d_block<{warpsPerCTA = [8, 4], instrShape = [16, 16], numBlocks = 2, isTransposed = false, order = [0, 1], kWidth = 2, threadsPerWarp = 16}> // CHECK: #mma1 = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}> // CHECK-NOT: #mma2 #mma = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}> diff --git a/third_party/intel/include/Dialect/TritonIntelGPU/IR/TritonIntelGPUAttrDefs.td b/third_party/intel/include/Dialect/TritonIntelGPU/IR/TritonIntelGPUAttrDefs.td index 254ea42b47..bb3198c3ce 100644 --- a/third_party/intel/include/Dialect/TritonIntelGPU/IR/TritonIntelGPUAttrDefs.td +++ b/third_party/intel/include/Dialect/TritonIntelGPU/IR/TritonIntelGPUAttrDefs.td @@ -297,6 +297,7 @@ def Subgroup2DBlockEncodingAttr : DistributedEncoding<"Subgroup2DBlockEncoding", For the layout, the following parameters are required: - `instrShape` : contains the (height, width) block parameters for the block io operation - `numBlocks` : the block count parameter allows a single load to load multiple blocks in row-major order (useful for increasing cache line utilization) + - `isTransposed` : indicates whether the data should be transposed post-load. The `instrShape` describes the shape of the data to load pre-transpose, i.e. if this is true then the output from the instruction (load + tranpose) will be the transposed `instrShape`. - `threadsPerWarp` : currently a scalar, this parameter allows us to support different subgroup / warp configurations. Because the 2d block io operation is a subgroup operation, the size of the subgroup is important in determining the ordering of the loaded tensor. - `warpsPerCTA` : the number of warps per block / subgroups per workgroup and their distribution - `order` : The order within the block, used to determine along which dimension to broadcast. @@ -310,6 +311,7 @@ def Subgroup2DBlockEncodingAttr : DistributedEncoding<"Subgroup2DBlockEncoding", "CTALayoutAttr":$CTALayout, ArrayRefParameter<"unsigned">:$instrShape, "unsigned":$numBlocks, + "bool":$isTransposed, ArrayRefParameter<"unsigned">:$order, "unsigned":$kWidth, "unsigned":$threadsPerWarp @@ -317,7 +319,7 @@ def Subgroup2DBlockEncodingAttr : DistributedEncoding<"Subgroup2DBlockEncoding", let extraClassDeclaration = extraDistributedDeclaration # [{ SmallVector getRepOrderForOperand(int opIdx) const; - static SmallVector getInstrShapeForLayout(DistributedEncodingTrait layout, ArrayRef shape, bool memoryRowMajor, unsigned kWidth, MLIRContext* context); + static SmallVector getInstrShapeForLayout(DistributedEncodingTrait layout, ArrayRef shape, bool memoryRowMajor, bool isTransposed, unsigned kWidth, MLIRContext* context); }]; let hasCustomAssemblyFormat = 1; diff --git a/third_party/intel/lib/Dialect/TritonIntelGPU/IR/Dialect.cpp b/third_party/intel/lib/Dialect/TritonIntelGPU/IR/Dialect.cpp index 138fccf6c0..a0d5996398 100644 --- a/third_party/intel/lib/Dialect/TritonIntelGPU/IR/Dialect.cpp +++ b/third_party/intel/lib/Dialect/TritonIntelGPU/IR/Dialect.cpp @@ -59,6 +59,17 @@ static LogicalResult parseIntAttrValue(AsmParser &parser, Attribute attr, return success(); } +static LogicalResult parseBoolAttrValue(AsmParser &parser, Attribute attr, + bool &value, StringRef desc) { + auto boolAttr = mlir::dyn_cast(attr); + if (!boolAttr) { + parser.emitError(parser.getNameLoc(), "expected a bool type in ") << desc; + return failure(); + } + value = boolAttr.getValue(); + return success(); +} + // parse an array of integers static LogicalResult parseIntArrayAttr(AsmParser &parser, const NamedAttribute &attr, @@ -83,6 +94,11 @@ static LogicalResult parseUInt(AsmParser &parser, const NamedAttribute &attr, return parseIntAttrValue(parser, attr.getValue(), value, desc); }; +static LogicalResult parseBool(AsmParser &parser, const NamedAttribute &attr, + bool &value, StringRef desc) { + return parseBoolAttrValue(parser, attr.getValue(), value, desc); +}; + //===----------------------------------------------------------------------===// // Attribute methods //===----------------------------------------------------------------------===// @@ -531,8 +547,8 @@ void maybePrintCTALayout(mlir::MLIRContext *context, mlir::AsmPrinter &printer, LogicalResult Subgroup2DBlockEncodingAttr::verify( function_ref emitError, ArrayRef warpsPerCTA, CTALayoutAttr CTALayout, - ArrayRef instrShape, unsigned numBlocks, ArrayRef order, - unsigned kWidth, unsigned threadsPerWarp) { + ArrayRef instrShape, unsigned numBlocks, bool isTransposed, + ArrayRef order, unsigned kWidth, unsigned threadsPerWarp) { if (instrShape.size() != 2) { return emitError() << "instrShape must be rank 2 but was: " << instrShape.size(); @@ -569,6 +585,7 @@ Attribute Subgroup2DBlockEncodingAttr::parse(AsmParser &parser, Type type) { std::optional> CTAOrder; SmallVector instrShape; unsigned numBlocks = 0; + bool isTransposed = false; SmallVector order; unsigned kWidth = 0; unsigned threadsPerWarp = 0; @@ -601,6 +618,10 @@ Attribute Subgroup2DBlockEncodingAttr::parse(AsmParser &parser, Type type) { if (parseUInt(parser, attr, numBlocks, "numBlocks").failed()) return {}; } + if (attr.getName() == "isTransposed") { + if (parseBool(parser, attr, isTransposed, "isTransposed").failed()) + return {}; + } if (attr.getName() == "order") { if (parseIntArrayAttr(parser, attr, order, "order").failed()) return {}; @@ -622,7 +643,7 @@ Attribute Subgroup2DBlockEncodingAttr::parse(AsmParser &parser, Type type) { return parser.getChecked( parser.getContext(), warpsPerCTA, *CTALayout, instrShape, numBlocks, - order, kWidth, threadsPerWarp); + isTransposed, order, kWidth, threadsPerWarp); } SmallVector Subgroup2DBlockEncodingAttr::getRepOrder() const { @@ -652,9 +673,10 @@ void Subgroup2DBlockEncodingAttr::print(AsmPrinter &printer) const { maybePrintCTALayout(getContext(), printer, getCTALayout(), getRank()); printer << ", instrShape = [" << getInstrShape() - << "], numBlocks=" << getNumBlocks() << ", order=[" << getOrder() - << "], kWidth=" << getKWidth() - << ", threadsPerWarp=" << getThreadsPerWarp() << "}>"; + << "], numBlocks = " << getNumBlocks() + << ", isTransposed = " << getIsTransposed() << ", order = [" + << getOrder() << "], kWidth = " << getKWidth() + << ", threadsPerWarp = " << getThreadsPerWarp() << "}>"; } LinearLayout @@ -664,7 +686,8 @@ Subgroup2DBlockEncodingAttr::toLinearLayout(ArrayRef shape) const { SmallVector Subgroup2DBlockEncodingAttr::getInstrShapeForLayout( DistributedEncodingTrait layout, ArrayRef tensorShape, - bool memoryRowMajor, unsigned kWidth, MLIRContext *context) { + bool memoryRowMajor, bool isTransposed, unsigned kWidth, + MLIRContext *context) { const auto rank = tensorShape.size(); std::optional llEncoding = layout.toLinearLayout(tensorShape); @@ -672,13 +695,6 @@ SmallVector Subgroup2DBlockEncodingAttr::getInstrShapeForLayout( LinearEncodingAttr llAttr = LinearEncodingAttr::get(context, *llEncoding); SmallVector threadOrder = llAttr.getThreadOrder(); - const bool valueRowMajor = - (threadOrder[rank - 2] == 1 && threadOrder[rank - 1] == 0); - assert((valueRowMajor || - (threadOrder[rank - 2] == 0 && threadOrder[rank - 1] == 1)) && - "Only row_major or column_major is allowed"); - const bool isTransposeRequired = valueRowMajor ^ memoryRowMajor; - auto dotEncodingAttr = dyn_cast(layout); const unsigned opIdx = dotEncodingAttr ? dotEncodingAttr.getOpIdx() : 2; @@ -725,7 +741,7 @@ SmallVector Subgroup2DBlockEncodingAttr::getInstrShapeForLayout( unsigned dpasOperandsPerTileY = isOperandA ? numReps[2] : repCluster[dimOuter]; - if (isTransposeRequired) { + if (isTransposed) { std::swap(tileWidth, tileHeight); const unsigned threadsPerWarp = dpasLayout.getThreadsPerWarp(); diff --git a/third_party/intel/lib/Dialect/TritonIntelGPU/IR/LinearLayoutConversions.cpp b/third_party/intel/lib/Dialect/TritonIntelGPU/IR/LinearLayoutConversions.cpp index 64cc423629..ddf46f5f0a 100644 --- a/third_party/intel/lib/Dialect/TritonIntelGPU/IR/LinearLayoutConversions.cpp +++ b/third_party/intel/lib/Dialect/TritonIntelGPU/IR/LinearLayoutConversions.cpp @@ -602,11 +602,15 @@ subgroup2DBlockToLinearLayout(ArrayRef blockShape, assert(rank == layout.getRank() && "unexpected block shape rank, layout rank " "and block shape rank must be equal"); auto dimNames = standardOutDimNames(ctx, rank); - auto loadTileSize = layout.getInstrShape(); + auto loadTileSize = SmallVector(layout.getInstrShape()); + assert(loadTileSize.size() == 2); StringAttr kRegister = S("register"); StringAttr kLane = S("lane"); StringAttr kWarp = S("warp"); + if (layout.getIsTransposed()) + std::swap(loadTileSize[0], loadTileSize[1]); + // Start by creating register/lane bases corresponding to the desired load // tile size auto [regBases, laneBases] = createRegisterLaneBases( diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp index 7fcc5eb253..7ca9a2822c 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -1467,7 +1467,8 @@ struct LoadOpConversion } else { auto tileParams = Subgroup2DBlockEncodingAttr::getInstrShapeForLayout( cast(encoding), tensorType.getShape(), - memoryRowMajor, elemSizeInBits / 8, rewriter.getContext()); + memoryRowMajor, isTransposeRequired, elemSizeInBits / 8, + rewriter.getContext()); return std::make_tuple(tileParams[0], tileParams[1], tileParams[2]); } }; diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeBlockIOEncoding.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeBlockIOEncoding.cpp index db912b34ba..effd387df6 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeBlockIOEncoding.cpp +++ b/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeBlockIOEncoding.cpp @@ -287,14 +287,14 @@ class TritonIntelGPUOptimizeBlockIOEncodingPass auto tileParams = Subgroup2DBlockEncodingAttr::getInstrShapeForLayout( cast(dotOperandEncoding), - oldTensorType.getShape(), memoryRowMajor, elemSizeInBits / 8, - &getContext()); + oldTensorType.getShape(), memoryRowMajor, isTransposeRequired, + elemSizeInBits / 8, &getContext()); SmallVector instrShape{tileParams[0], tileParams[1]}; const unsigned vBlocks = tileParams[2]; auto subgroup2DBlockEncoding = Subgroup2DBlockEncodingAttr::get( &getContext(), dpasLayout.getWarpsPerCTA(), CTALayout, instrShape, - tileParams[2], + tileParams[2], isTransposeRequired, getOrderForDotOperand(dotOperandEncoding.getOpIdx(), /*rank*/ rank, /*kContig*/ true), kWidth, dpasLayout.getThreadsPerWarp()); diff --git a/third_party/intel/unittest/Dialect/TritonIntelGPU/LinearLayoutConversionsTest.cpp b/third_party/intel/unittest/Dialect/TritonIntelGPU/LinearLayoutConversionsTest.cpp index 2e1eee2adf..70ecacf335 100644 --- a/third_party/intel/unittest/Dialect/TritonIntelGPU/LinearLayoutConversionsTest.cpp +++ b/third_party/intel/unittest/Dialect/TritonIntelGPU/LinearLayoutConversionsTest.cpp @@ -21,9 +21,10 @@ class LinearLayoutConversionsTest : public ::testing::Test { // Create a Subgroup2DBlockEncoding layout based on a DPAS layout Subgroup2DBlockEncodingAttr - sdb(ArrayRef instrShape, unsigned numBlocks, unsigned kWidth, - ArrayRef warpsPerCTA, ArrayRef repCluster, - ArrayRef blockShape, unsigned opsPerChannel, unsigned opIdx) { + sdb(ArrayRef instrShape, unsigned numBlocks, bool isTransposed, + unsigned kWidth, ArrayRef warpsPerCTA, + ArrayRef repCluster, ArrayRef blockShape, + unsigned opsPerChannel, unsigned opIdx) { auto dpasLayout = DpasEncodingAttr::get( &ctx, /*repeatCount=*/8, /*systolicDepth=*/8, /*executionSize=*/16, opsPerChannel, warpsPerCTA, repCluster, @@ -35,7 +36,7 @@ class LinearLayoutConversionsTest : public ::testing::Test { CTALayoutAttr::get( &ctx, dpasLayout.getCTAsPerCGA(), // TODO: add to DpasLayout? dpasLayout.getCTASplitNum(), dpasLayout.getCTAOrder()), - instrShape, numBlocks, + instrShape, numBlocks, isTransposed, getOrderForDotOperand(opIdx, /*rank*/ 2, /*kContig*/ true), kWidth, dpasLayout.getThreadsPerWarp()); return layout; @@ -51,7 +52,8 @@ TEST_F(LinearLayoutConversionsTest, FP32_32x8x2_M256_N128_K32_A) { EXPECT_EQ( subgroup2DBlockToLinearLayout( /*blockShape*/ {256, 32}, - sdb(/*instrShape*/ {32, 8}, /*numBlocks*/ 2, /*kWidth*/ 4, + sdb(/*instrShape*/ {32, 8}, /*numBlocks*/ 2, /*isTransposed*/ false, + /*kWidth*/ 4, /*warpsPerCTA*/ {8, 4}, /*repCluster*/ {4, 1}, /*blockShape*/ {256, 32}, /*opsPerChannel*/ 1, /*opIdx*/ 0), /*kWidth*/ 4), @@ -67,7 +69,8 @@ TEST_F(LinearLayoutConversionsTest, FP32_32x16x1_M256_N128_K32_B) { EXPECT_EQ( subgroup2DBlockToLinearLayout( /*blockShape*/ {32, 128}, - sdb(/*instrShape*/ {32, 16}, /*numBlocks*/ 1, /*kWidth*/ 4, + sdb(/*instrShape*/ {32, 16}, /*numBlocks*/ 1, /*isTransposed*/ false, + /*kWidth*/ 4, /*warpsPerCTA*/ {8, 4}, /*repCluster*/ {4, 1}, /*blockShape*/ {32, 128}, /*opsPerChannel*/ 1, /*opIdx*/ 1), /*kWidth*/ 4), @@ -83,7 +86,8 @@ TEST_F(LinearLayoutConversionsTest, FP16_32x32x1_M256_N32_K32_A) { EXPECT_EQ( subgroup2DBlockToLinearLayout( /*blockShape*/ {256, 32}, - sdb(/*instrShape*/ {32, 32}, /*numBlocks*/ 1, /*kWidth*/ 2, + sdb(/*instrShape*/ {32, 32}, /*numBlocks*/ 1, /*isTransposed*/ false, + /*kWidth*/ 2, /*warpsPerCTA*/ {8, 4}, /*repCluster*/ {4, 2}, /*blockShape*/ {256, 32}, /*opsPerChannel*/ 2, /*opIdx*/ 0), /*kWidth*/ 2), @@ -99,7 +103,8 @@ TEST_F(LinearLayoutConversionsTest, FP16_32x16x2_M256_N32_K32_A) { EXPECT_EQ( subgroup2DBlockToLinearLayout( /*blockShape*/ {256, 32}, - sdb(/*instrShape*/ {32, 16}, /*numBlocks*/ 2, /*kWidth*/ 2, + sdb(/*instrShape*/ {32, 16}, /*numBlocks*/ 2, /*isTransposed*/ false, + /*kWidth*/ 2, /*warpsPerCTA*/ {8, 4}, /*repCluster*/ {4, 2}, /*blockShape*/ {256, 32}, /*opsPerChannel*/ 2, /*opIdx*/ 0), /*kWidth*/ 2), @@ -114,7 +119,8 @@ TEST_F(LinearLayoutConversionsTest, FP16_32x16x2_M256_N32_K32_A) { TEST_F(LinearLayoutConversionsTest, FP16_32x16x2_M256_N32_K32_B) { EXPECT_EQ(subgroup2DBlockToLinearLayout( /*shape*/ {32, 256}, - sdb(/*instrShape*/ {32, 16}, /*numBlocks*/ 2, /*kWidth*/ 2, + sdb(/*instrShape*/ {32, 16}, /*numBlocks*/ 2, + /*isTransposed*/ false, /*kWidth*/ 2, /*warpsPerCTA*/ {8, 4}, /*repCluster*/ {4, 2}, /*blockShape*/ {32, 256}, /*opsPerChannel*/ 2, /*opIdx*/ 1), @@ -131,7 +137,8 @@ TEST_F(LinearLayoutConversionsTest, FP16_32x16x2_M256_N32_K32_B) { TEST_F(LinearLayoutConversionsTest, FP16_16x16x2_M256_N32_K32_B) { EXPECT_EQ(subgroup2DBlockToLinearLayout( /*shape*/ {32, 256}, - sdb(/*instrShape*/ {16, 16}, /*numBlocks*/ 2, /*kWidth*/ 2, + sdb(/*instrShape*/ {16, 16}, /*numBlocks*/ 2, + /*isTransposed*/ false, /*kWidth*/ 2, /*warpsPerCTA*/ {8, 4}, /*repCluster*/ {4, 2}, /*blockShape*/ {32, 256}, /*opsPerChannel*/ 2, /*opIdx*/ 1), @@ -145,11 +152,32 @@ TEST_F(LinearLayoutConversionsTest, FP16_16x16x2_M256_N32_K32_B) { {S("dim0"), S("dim1")})); } +TEST_F(LinearLayoutConversionsTest, FP16_32x16x1_M256_N32_K32_TRANSPOSE_B) { + // Note that the instrShape is pre-transpose + EXPECT_EQ( + subgroup2DBlockToLinearLayout( + /*shape*/ {32, 256}, + sdb(/*instrShape*/ {32, 16}, /*numBlocks*/ 1, /*isTransposed*/ true, + /*kWidth*/ 2, + /*warpsPerCTA*/ {8, 4}, /*repCluster*/ {4, 2}, + /*blockShape*/ {256, 32}, /*opsPerChannel*/ 2, + /*opIdx*/ 1), + /*kWidth*/ 2), + LinearLayout( + {{S("register"), + {{0, 1}, {1, 0}, {2, 0}, {4, 0}, {8, 0}, {16, 0}, {0, 128}}}, + {S("lane"), {{0, 2}, {0, 4}, {0, 8}, {0, 16}}}, + {S("warp"), {{0, 32}, {0, 64}, {0, 0}, {0, 0}, {0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); +} + TEST_F(LinearLayoutConversionsTest, I8_16x32x1_M64_N128_K32_A) { EXPECT_EQ( subgroup2DBlockToLinearLayout( /*shape*/ {64, 32}, - sdb(/*instrShape*/ {16, 32}, /*numBlocks*/ 1, /*kWidth*/ 1, + sdb(/*instrShape*/ {16, 32}, /*numBlocks*/ 1, /*isTransposed*/ false, + /*kWidth*/ 1, /*warpsPerCTA*/ {4, 8}, /*repCluster*/ {2, 1}, /*blockShape*/ {64, 32}, /*opsPerChannel*/ 4, /*opIdx*/ 0), @@ -165,7 +193,8 @@ TEST_F(LinearLayoutConversionsTest, I8_32x32x1_M64_N128_K32_B) { EXPECT_EQ( subgroup2DBlockToLinearLayout( /*shape*/ {32, 128}, - sdb(/*instrShape*/ {32, 16}, /*numBlocks*/ 1, /*kWidth*/ 1, + sdb(/*instrShape*/ {32, 16}, /*numBlocks*/ 1, /*isTransposed*/ false, + /*kWidth*/ 1, /*warpsPerCTA*/ {4, 8}, /*repCluster*/ {2, 1}, /*blockShape*/ {32, 128}, /*opsPerChannel*/ 4, /*opIdx*/ 1),