diff --git a/test/TritonIntelGPU/tensor-pointer-load-block-2d.mlir b/test/TritonIntelGPU/tensor-pointer-load-block-2d.mlir index 658f83db75..eb951c0baa 100644 --- a/test/TritonIntelGPU/tensor-pointer-load-block-2d.mlir +++ b/test/TritonIntelGPU/tensor-pointer-load-block-2d.mlir @@ -57,7 +57,7 @@ module attributes {ttig.min_sg_size = 16 : i32, ttig.support_bf16_conversion, tt %65 = tt.splat %64 : i32 -> tensor<1x64xi32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> %66 = arith.cmpi slt, %38, %65 : tensor<1x64xi32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> %67 = tt.broadcast %66 : tensor<1x64xi1, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> -> tensor<128x64xi1, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> - // CHECK-COUNT-16: triton_gen.2Dblockload {{.*}} {elem_size_in_bits = 16, tile_width = 16, tile_height = 8, v_blocks = 2 + // CHECK-COUNT-8: triton_gen.2Dblockload {{.*}} {elem_size_in_bits = 16, tile_width = 16, tile_height = 32, v_blocks = 1 %68 = tt.load %60, %67, %cst_3 {ttig.block_io = "row_major"} : tensor<128x64x!tt.ptr, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> %74 = tt.addptr %60, %cst_0 : tensor<128x64x!tt.ptr, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>>, tensor<128x64xi32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> %76 = arith.addi %58, %c1_i32 : i32 @@ -69,72 +69,6 @@ module attributes {ttig.min_sg_size = 16 : i32, ttig.support_bf16_conversion, tt // ----- -#mma = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [2, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}> -module attributes {ttig.min_sg_size = 16 : i32, ttig.support_bf16_conversion, ttig.support_dpas, ttig.support_sg_2d_block, ttig.target_arch = "spir64", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.shared = 33280 : i32, ttg.target = "xpu", "ttg.threads-per-warp" = 16 : i32} { - tt.func public @matmul_tensor_pointer_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}, %arg9: !llvm.ptr<3>) attributes {noinline = false} { - %c63_i32 = arith.constant 63 : i32 - %c255_i32 = arith.constant 255 : i32 - %c127_i32 = arith.constant 127 : i32 - %c1_i32 = arith.constant 1 : i32 - %c0_i32 = arith.constant 0 : i32 - %c64_i32 = arith.constant 64 : i32 - %c8_i32 = arith.constant 8 : i32 - %c128_i32 = arith.constant 128 : i32 - %c256_i32 = arith.constant 256 : i32 - %cst_1 = arith.constant dense<0> : tensor<256xi32, #ttg.slice<{dim = 0, parent = #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>}>> - %cst_4 = arith.constant dense<0.000000e+00> : tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> - %0 = tt.get_program_id x : i32 - %1 = arith.addi %arg3, %c127_i32 : i32 - %2 = arith.divsi %1, %c128_i32 : i32 - %3 = arith.addi %arg4, %c255_i32 : i32 - %4 = arith.divsi %3, %c256_i32 : i32 - %5 = arith.muli %4, %c8_i32 : i32 - %6 = arith.divsi %0, %5 : i32 - %7 = arith.muli %6, %c8_i32 : i32 - %8 = arith.subi %2, %7 : i32 - %9 = arith.minsi %8, %c8_i32 : i32 - %12 = arith.remsi %0, %5 : i32 - %13 = arith.divsi %12, %9 : i32 - %15 = arith.muli %13, %c256_i32 : i32 - %22 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>}>> - %24 = tt.splat %15 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>}>> - %26 = arith.addi %24, %22 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>}>>%31 = tt.splat %arg4 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>}>> - %44 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>}>> - %45 = tt.expand_dims %44 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>}>> -> tensor<64x1xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> - %cst_2 = arith.constant dense<512> : tensor<64x1xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> - %47 = arith.muli %45, %cst_2 : tensor<64x1xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> - %48 = tt.expand_dims %26 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>}>> -> tensor<1x256xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> - %49 = tt.broadcast %47 : tensor<64x1xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<64x256xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> - %50 = tt.broadcast %48 : tensor<1x256xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<64x256xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> - %51 = arith.addi %49, %50 : tensor<64x256xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> - %52 = tt.splat %arg1 : !tt.ptr -> tensor<64x256x!tt.ptr, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> - %53 = tt.addptr %52, %51 : tensor<64x256x!tt.ptr, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>, tensor<64x256xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> - %54 = arith.addi %arg5, %c63_i32 : i32 - %55 = arith.divsi %54, %c64_i32 : i32 - %56 = arith.muli %arg7, %c64_i32 : i32 - %57 = tt.splat %56 : i32 -> tensor<64x256xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> - cf.br ^bb1(%c0_i32, %53 : i32, tensor<64x256x!tt.ptr, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>) - ^bb1(%58: i32, %61: tensor<64x256x!tt.ptr, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>): // 2 preds: ^bb0, ^bb2 - %62 = arith.cmpi slt, %58, %55 : i32 - cf.cond_br %62, ^bb2, ^bb3 - ^bb2: // pred: ^bb1 - %63 = arith.muli %58, %c64_i32 : i32 - %64 = arith.subi %arg5, %63 : i32 - %69 = tt.splat %64 : i32 -> tensor<64x1xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> - %70 = arith.cmpi slt, %45, %69 : tensor<64x1xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> - %71 = tt.broadcast %70 : tensor<64x1xi1, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<64x256xi1, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> - // CHECK-COUNT-8: triton_gen.2Dblockload {{.*}} {elem_size_in_bits = 16, tile_width = 16, tile_height = 32, v_blocks = 1 - %72 = tt.load %61, %71, %cst_4 {ttig.block_io = "row_major"} : tensor<64x256x!tt.ptr, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> - %75 = tt.addptr %61, %57 : tensor<64x256x!tt.ptr, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>, tensor<64x256xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> - %76 = arith.addi %58, %c1_i32 : i32 - cf.br ^bb1(%76, %75 : i32, tensor<64x256x!tt.ptr, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>) - ^bb3: // pred: ^bb1 - tt.return - } -} - -// ----- - #mma = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 1], repCluster = [2, 2]}> #mma_1 = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [4, 2], repCluster = [1, 1]}> #mma_2 = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 1], repCluster = [4, 2]}> @@ -259,6 +193,7 @@ module attributes {ttig.support_sg_2d_block, "ttg.num-warps" = 8 : i32, "ttg.thr // CHECK: %[[BLOCK_SHAPE_Y:.*]] = llvm.mlir.constant(16 : i32) : i32 + // CHECK-COUNT-2: llvm.mlir.constant(0 : i32) : i32 // CHECK: %[[CST0_0:.*]] = llvm.mlir.constant(0 : i32) : i32 // CHECK: %[[CST0_1:.*]] = llvm.mlir.constant(0 : i32) : i32 // CHECK: %[[TOP_LEFT_MASK_0:.*]] = llvm.zext %[[TOP_LEFT_MASK_BOOL_0]] : i1 to i8 @@ -267,22 +202,25 @@ module attributes {ttig.support_sg_2d_block, "ttg.num-warps" = 8 : i32, "ttg.thr // CHECK: %[[BASE_Y_0:.*]] = llvm.select %[[PRED_BOOL]], %[[CST0_0]], %[[BLOCK_SHAPE_Y]] : i1, i32 // CHECK: %[[LOAD_0:.*]] = triton_gen.2Dblockload {{.*}}, %[[BASE_Y_0]] {elem_size_in_bits = 16, tile_width = 16, tile_height = 16, v_blocks = 2 + // CHECK-COUNT-3: llvm.mlir.constant(0 : i32) : i32 // CHECK: %[[CST0_0:.*]] = llvm.mlir.constant(0 : i32) : i32 // CHECK: %[[CST0_1:.*]] = llvm.mlir.constant(0 : i32) : i32 - // CHECK: %[[TOP_LEFT_MASK_1:.*]] = llvm.zext %[[TOP_LEFT_MASK_BOOL_64]] : i1 to i8 + // CHECK: %[[TOP_LEFT_MASK_1:.*]] = llvm.zext %[[TOP_LEFT_MASK_BOOL_32]] : i1 to i8 // CHECK: %[[PRED:.*]] = llvm.call spir_funccc @_Z17sub_group_shufflecj(%[[TOP_LEFT_MASK_1]], %[[CST0_1]]) // CHECK: %[[PRED_BOOL:.*]] = llvm.trunc %[[PRED]] : i8 to i1 // CHECK: %[[BASE_Y_1:.*]] = llvm.select %[[PRED_BOOL]], %[[CST0_0]], %[[BLOCK_SHAPE_Y]] : i1, i32 // CHECK: %[[LOAD_1:.*]] = triton_gen.2Dblockload {{.*}}, %[[BASE_Y_1]] {elem_size_in_bits = 16, tile_width = 16, tile_height = 16, v_blocks = 2 + // CHECK-COUNT-3: llvm.mlir.constant(0 : i32) : i32 // CHECK: %[[CST0_0:.*]] = llvm.mlir.constant(0 : i32) : i32 // CHECK: %[[CST0_1:.*]] = llvm.mlir.constant(0 : i32) : i32 - // CHECK: %[[TOP_LEFT_MASK_2:.*]] = llvm.zext %[[TOP_LEFT_MASK_BOOL_32]] : i1 to i8 + // CHECK: %[[TOP_LEFT_MASK_2:.*]] = llvm.zext %[[TOP_LEFT_MASK_BOOL_64]] : i1 to i8 // CHECK: %[[PRED:.*]] = llvm.call spir_funccc @_Z17sub_group_shufflecj(%[[TOP_LEFT_MASK_2]], %[[CST0_1]]) // CHECK: %[[PRED_BOOL:.*]] = llvm.trunc %[[PRED]] : i8 to i1 // CHECK: %[[BASE_Y_2:.*]] = llvm.select %[[PRED_BOOL]], %[[CST0_0]], %[[BLOCK_SHAPE_Y]] : i1, i32 // CHECK: %[[LOAD_2:.*]] = triton_gen.2Dblockload {{.*}}, %[[BASE_Y_2]] {elem_size_in_bits = 16, tile_width = 16, tile_height = 16, v_blocks = 2 + // CHECK-COUNT-3: llvm.mlir.constant(0 : i32) : i32 // CHECK: %[[CST0_0:.*]] = llvm.mlir.constant(0 : i32) : i32 // CHECK: %[[CST0_1:.*]] = llvm.mlir.constant(0 : i32) : i32 // CHECK: %[[TOP_LEFT_MASK_3:.*]] = llvm.zext %[[TOP_LEFT_MASK_BOOL_96]] : i1 to i8 diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp index fc874fa357..df0b2618d7 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -2458,18 +2458,6 @@ struct LoadOpToBlockIOConversion if (op.getPadding() && op.getPadding() == PaddingOption::PAD_NAN) return failure(); - // 2D block io lowering steps: - // 1. Get the 2 dims for 2D block io: one of the dimension chosen correspond - // to the dimension where the access pattern has stride one. The other - // dimension should be the one with the constancy stride. - // 2. Check the DPAS layout, fallback to gather IO lowering if the block - // IO is not supported for the layout. TODO: to generalize the code for - // different layout with the linear layout. - // 3. Compute the maximum tile size for the 2D block io from the layout - // information. - // 4. Generates the 2D block IO instructions. - // 5. Unpacked the loaded values into expected order required by the layout. - Value ptr = op.getPtr(); if (isTensorPointerType(ptr.getType())) return rewriteTensorPointerLoad(op, adaptor, rewriter); @@ -2525,159 +2513,15 @@ struct LoadOpToBlockIOConversion Location loc = op.getLoc(); auto b = TritonLLVMOpBuilder(loc, rewriter); MLIRContext *ctx = rewriter.getContext(); - Value warpId = rewriter.create( - loc, i32_ty, - rewriter.create(loc, /*upperBound=*/nullptr)); - - Value llPtr = adaptor.getPtr(); - unsigned numElems = getTotalElemsPerThread(resultType); // Get the LLVM values for pointers + Value llPtr = adaptor.getPtr(); SmallVector ptrElems = unpackLLElements(loc, llPtr, rewriter); + unsigned numElems = getTotalElemsPerThread(resultType); assert(ptrElems.size() == numElems && "the number of pointer values is not matched with the number of " "elements"); - // Step 2: Right now we only support DPAS related layout to simplify the - // lowering. - DpasEncodingAttr::OpIdx opIdx = getOpIdx(tensorType); - DpasEncodingAttr dpasLayout = getDpasLayout(tensorType); - const ArrayRef tensorShape = tensorType.getShape(); - SmallVector repetitons = - dpasLayout.getDPASRepetitions(tensorShape, opIdx); - assert(repetitons.size() == 3 && - "getDPASRepetitions always return rank 3 size"); - assert(repetitons[0] == 1 && "Only supports rank of 2 for now"); - SmallVector numReps{repetitons[1], repetitons[2]}; - ArrayRef warpsPerCTA = dpasLayout.getWarpsPerCTA(); - SmallVector dpasWarpsOrder = - getMatrixOrder(warpsPerCTA.size(), /*rowMajor*/ true); - unsigned threadsPerWarp = - product(getThreadsPerWarp(dpasLayout, tensorShape)); - - SmallVector multiDimWarpId = - delinearize(rewriter, loc, warpId, warpsPerCTA, dpasWarpsOrder); - - // By default, use the unpacked type for the 2D load result type. - Type loadResultElemType = typeConverter->convertType(eltTy); - bool usePackedType = false; - unsigned packedElemsNum = 1; - // The tensor values are distributed as DotOp layout of DPAS. - // If the element size of the tensor matches the DPAS packed layout, then - // use the packed type for the 2D load result type. For example, - // The intermediate ops generated by ConvertTritonGPUToLLVM: - // %0 = load_2d %ptr : vector<8 x i32> - // %1 = bitcast %0 : vector<8 x i32> -> vector<16 x f16> - // %2 = bitcast %1 : vector<16 x f16> -> vector<8 x i32> - // %3 = dpas %2 - // And the LLVM dialect optimization pass can eliminate the duplicated - // bitcast. Then there is a shortcut to use the load result directly as the - // input operands to DPAS. - // TODO: add support for int4 and int2. - - // OperandA: outer dim -> M, inner dim -> K. - // OperandB: outer dim -> N, inner dim -> K. - // OperandC: outer dim -> M, inner dim -> N. - // Round the warp id fit into the tensor shape. - unsigned dimOuter; - unsigned dimInner; - SmallVector repCluster(dpasLayout.getRepCluster()); - SmallVector warpShape; - SmallVector dpasInstShape; - auto llAttr = LinearEncodingAttr::get(rewriter.getContext(), *llEncoding); - SmallVector threadOrder(llAttr.getThreadOrder()); - size_t rank = threadOrder.size(); - - switch (opIdx) { - case DpasEncodingAttr::OpIdx::OperandA: { - warpShape = std::move(dpasLayout.getShapeA()); - dpasInstShape = std::move(dpasLayout.getDPASInstShapeA()); - dimOuter = rank - 2; - dimInner = rank - 1; - repCluster[dimInner] = 1; - - unsigned opsPerChannel = dpasLayout.getOpsPerChannel(); - if ((opsPerChannel == 4 && elemSizeInBits == 8) || - (opsPerChannel == 2 && elemSizeInBits == 16) || - (opsPerChannel == 1 && elemSizeInBits == 32)) { - loadResultElemType = elemSizeInBits == 32 ? i32_ty : i16_ty; - packedElemsNum = opsPerChannel == 4 ? 2 : 1; - usePackedType = true; - } else if (opsPerChannel == 4) { - packedElemsNum = 2; - unsigned packedBitWidht = elemSizeInBits * packedElemsNum; - if (packedBitWidht > 64) { - // Be conservative to avoid the packed type exceeds 64 bits. - return failure(); - } - // Need to pack two column into one to work around vectorization - // limitation. - loadResultElemType = int_ty(packedBitWidht); - usePackedType = true; - } - } break; - case DpasEncodingAttr::OpIdx::OperandB: { - warpShape = std::move(dpasLayout.getShapeB()); - dpasInstShape = std::move(dpasLayout.getDPASInstShapeB()); - dimOuter = rank - 1; - dimInner = rank - 2; - repCluster[dimInner] = 1; - - unsigned opsPerChannel = dpasLayout.getOpsPerChannel(); - if ((opsPerChannel == 4 && elemSizeInBits == 8) || - (opsPerChannel == 2 && elemSizeInBits == 16) || - (opsPerChannel == 1 && elemSizeInBits == 32)) { - loadResultElemType = i32_ty; - packedElemsNum = opsPerChannel; - usePackedType = true; - } - } break; - case DpasEncodingAttr::OpIdx::OperandC: - warpShape = std::move(dpasLayout.getShapeC()); - dpasInstShape = std::move(dpasLayout.getDPASInstShapeC()); - dimOuter = rank - 2; - dimInner = rank - 1; - usePackedType = false; - break; - default: - llvm_unreachable("unknown DPAS operands index type."); - break; - } - unsigned elemsPerLanePerDPASInst = - product(dpasInstShape) / threadsPerWarp; - LLVMTypeConverter *typeConverter = getTypeConverter(); - Type unpackedDPASOperandType = LLVM::getVectorType( - typeConverter->convertType(eltTy), elemsPerLanePerDPASInst); - - unsigned packedElemsPerLanePerDPASInst = - elemsPerLanePerDPASInst / packedElemsNum; - Type packedDPASOperandType = - LLVM::getVectorType(loadResultElemType, packedElemsPerLanePerDPASInst); - - unsigned outerDimTileNum = - mlir::ceil(tensorShape[dimOuter], warpShape[dimOuter]); - unsigned outerDimWarpNum = - std::min(warpsPerCTA[dimOuter], outerDimTileNum); - Value outerDimWarpId = - b.urem(multiDimWarpId[dimOuter], b.i32_val(outerDimWarpNum)); - unsigned innerDimRequiredWarpNum = - mlir::ceil(tensorShape[dimInner], warpShape[dimInner]); - unsigned innerDimWarpNum = - std::min(warpsPerCTA[dimInner], innerDimRequiredWarpNum); - - // Step 3: Get the tile size of load. - tileWidth = dpasInstShape[threadOrder[rank - 2]]; - tileHeight = dpasInstShape[threadOrder[rank - 1]]; - vBlocks = 1; - unsigned numOperandsOuterDimPerLoad = 1; - unsigned numOperandsInnerDimPerLoad = 1; - unsigned maskConstancyHor = 1, maskConstancyVer = 1; - unsigned instWidth = dpasInstShape[threadOrder[rank - 2]]; - unsigned instHeight = dpasInstShape[threadOrder[rank - 1]]; - - std::map, Value> ptrs; - std::map, Value> masks; - std::map, Value> others; SmallVector maskElems; Value llMask = adaptor.getMask(); // Get the LLVM values for mask @@ -2690,9 +2534,10 @@ struct LoadOpToBlockIOConversion auto axisInfo = const_cast(axisAnalysisPass) .getAxisInfo(mask); + unsigned maskConstancyHor = 1, maskConstancyVer = 1; if (axisInfo) { - maskConstancyHor = axisInfo->getConstancy(rank - 1); - maskConstancyVer = axisInfo->getConstancy(rank - 2); + maskConstancyHor = axisInfo->getConstancy(colDim); + maskConstancyVer = axisInfo->getConstancy(rowDim); // The mask constancy has to be power of 2 for block IO. if (!llvm::isPowerOf2_64(maskConstancyHor) || !llvm::isPowerOf2_64(maskConstancyVer)) @@ -2700,12 +2545,14 @@ struct LoadOpToBlockIOConversion } // Check the constancy of the mask support to load the memory in 2D block. - if (!(maskConstancyHor >= instWidth && maskConstancyVer >= instHeight)) + if (!(maskConstancyHor >= (tileWidth * numPackedVals) && + maskConstancyVer >= tileHeight)) return failure(); - } else { - // no mask - maskConstancyHor = std::numeric_limits::max(); - maskConstancyVer = std::numeric_limits::max(); + + // Adjust vBlock to fit the constancy of mask. + vBlocks = std::min(vBlocks, mlir::ceil(maskConstancyHor, + tileWidth * numPackedVals)); + assert(llvm::isPowerOf2_64(vBlocks) && "vBlocks has to be power of 2"); } // Get the LLVM values for `other` @@ -2735,109 +2582,14 @@ struct LoadOpToBlockIOConversion otherElems = unpackLLElements(loc, llOther, rewriter); } - // re-arrange the ptrs and masks to for large 2D block IO. - // Layout is unrelated to the scalar type. - SmallVector> offsets = - mlir::emitOffsetForLayout(encoding, tensorType); - for (size_t i = 0; i < ptrElems.size(); ++i) { - SmallVector offset = offsets[i]; - ptrs[offset] = ptrElems[i]; - if (llMask) - masks[offset] = maskElems[i]; - if (otherElems.size()) - others[offset] = otherElems[i]; - } - - unsigned numOperandsPer2DLoadM, numOperandsPer2DLoadN; - assert(!isTransposeRequired && "Expected no transpose requirement"); - // Set the number of operands per 2D load to the maximum number from - // layout information. The number will be adjusted to fit the - // tensor pointers's shape, constancy and contiguity. - switch (opIdx) { - case DpasEncodingAttr::OpIdx::OperandA: - numOperandsPer2DLoadM = repCluster[dimOuter]; - numOperandsPer2DLoadN = numReps[dimInner]; - break; - case DpasEncodingAttr::OpIdx::OperandB: - numOperandsPer2DLoadM = numReps[dimInner]; - numOperandsPer2DLoadN = repCluster[dimOuter]; - break; - case DpasEncodingAttr::OpIdx::OperandC: - numOperandsPer2DLoadM = repCluster[dimOuter]; - numOperandsPer2DLoadN = repCluster[dimInner]; - break; - default: - llvm_unreachable("unknown DPAS operands index type."); - break; - } - - // adjust the mask constancy to fit the 2D load. - numOperandsPer2DLoadM = - std::min(numOperandsPer2DLoadM, maskConstancyHor / instWidth); - numOperandsPer2DLoadN = - std::min(numOperandsPer2DLoadN, maskConstancyVer / instHeight); - - // PVC 2D load supports 32 rows at most. Load multiple dot operands in by - // enlarging the tileHeight. - numOperandsPer2DLoadM = - std::min(numOperandsPer2DLoadM, - static_cast(MAX_TILE_HEIGHT / tileHeight)); - - // PVC 2D load supports 64 bytes per row at most. Load multiple dot operands - // by enlarging the vBlocks. - unsigned totalBytesPerRowPerDPASOp = tileWidth * elemSizeInBits / 8; - if (totalBytesPerRowPerDPASOp > MAX_WIDTH) - return failure(); - numOperandsPer2DLoadN = - std::min(numOperandsPer2DLoadN, MAX_WIDTH / totalBytesPerRowPerDPASOp); - // vBlocks has HW limitation of 4. - numOperandsPer2DLoadN = std::min(numOperandsPer2DLoadN, 4u); - - tileHeight = instHeight * numOperandsPer2DLoadM; - tileWidth = instWidth; - vBlocks = numOperandsPer2DLoadN; - - numOperandsOuterDimPerLoad = opIdx != DpasEncodingAttr::OpIdx::OperandB - ? numOperandsPer2DLoadM - : numOperandsPer2DLoadN; - numOperandsInnerDimPerLoad = opIdx != DpasEncodingAttr::OpIdx::OperandB - ? numOperandsPer2DLoadN - : numOperandsPer2DLoadM; - - unsigned numLoadPerOutRepCluster = - mlir::ceil(repCluster[dimOuter], numOperandsOuterDimPerLoad); - unsigned numLoadPerInnerRepCluster = - mlir::ceil(repCluster[dimInner], numOperandsInnerDimPerLoad); - - unsigned numValuesPerLoad = packedElemsPerLanePerDPASInst * - numOperandsOuterDimPerLoad * - numOperandsInnerDimPerLoad; - Type load2DGenXType = - LLVM::getVectorType(loadResultElemType, numValuesPerLoad); - - // Step 4: Generates the load instruction. - // The stride for the tile replicates. - unsigned numRepOuter; - unsigned numRepInner; - unsigned repOuterStride = warpShape[dimOuter] * outerDimWarpNum; - unsigned repInnerStride; - switch (opIdx) { - case DpasEncodingAttr::OpIdx::OperandA: - case DpasEncodingAttr::OpIdx::OperandB: - numRepOuter = numReps[dimOuter]; - numRepInner = - mlir::ceil(numReps[dimInner], numOperandsInnerDimPerLoad); - repInnerStride = warpShape[dimInner] * numOperandsInnerDimPerLoad; - break; - case DpasEncodingAttr::OpIdx::OperandC: - numRepOuter = numReps[dimOuter]; - numRepInner = numReps[dimInner]; - repInnerStride = warpShape[dimInner] * innerDimWarpNum; - break; - default: - llvm_unreachable("unknown DPAS operands index type."); - break; - } + unsigned threadsPerWarp = + TritonGPUDialect::getThreadsPerWarp(op->getParentOfType()); + int64_t numElemsPerLoad = mlir::ceil( + tileHeight * tileWidth * numPackedVals * vBlocks, (int)threadsPerWarp); + unsigned numValuesPerLoad = mlir::ceil((int)numElemsPerLoad, numPackedVals); + Type packedType = IntegerType::get(ctx, packedElemSizeInBits); + Type load2DGenXType = LLVM::getVectorType(packedType, numValuesPerLoad); + Type unpackedType = LLVM::getVectorType(eltTy, numElemsPerLoad); Value pitch = getPitch(rewriter, ptr, elemSizeInBits, memoryRowMajor ? 0 : 1); @@ -2848,8 +2600,8 @@ struct LoadOpToBlockIOConversion int stride = getStride(ptr, memoryRowMajor ? 0 : 1); unsigned baseHeightInt = (stride == 0 ? 1 : tileHeight); Value baseHeight = b.i32_val(baseHeightInt); - Value baseWidth = - b.i32_val(std::max(64u, vBlocks * tileWidth * (elemSizeInBits / 8))); + Value baseWidth = b.i32_val( + std::max(64u, vBlocks * tileWidth * (packedElemSizeInBits / 8))); StringAttr kRegister = str_attr("register"); StringAttr kLane = str_attr("lane"); @@ -2865,276 +2617,171 @@ struct LoadOpToBlockIOConversion {{kRegister, llEncoding->getInDimSize(kRegister)}}, /*requireSurjective=*/true); - const unsigned originalElemBits = elemSizeInBits; - - LDBG("Block io tile shape: [" - << tileHeight << ", " << tileWidth << "], vblocks: " << vBlocks - << ", numOperandsPerLoad: [" - << (opIdx != DpasEncodingAttr::OpIdx::OperandB - ? numOperandsOuterDimPerLoad - : numOperandsInnerDimPerLoad) - << ", " - << (opIdx != DpasEncodingAttr::OpIdx::OperandB - ? numOperandsInnerDimPerLoad - : numOperandsOuterDimPerLoad) - << "], number loads per repCluster: [" - << (opIdx != DpasEncodingAttr::OpIdx::OperandB - ? numLoadPerOutRepCluster - : numLoadPerInnerRepCluster) - << ", " - << (opIdx != DpasEncodingAttr::OpIdx::OperandB - ? numLoadPerInnerRepCluster - : numLoadPerOutRepCluster) - << "], number repCluster: [" - << (opIdx != DpasEncodingAttr::OpIdx::OperandB ? numRepOuter - : numRepInner) - << ", " - << (opIdx != DpasEncodingAttr::OpIdx::OperandB ? numRepInner - : numRepOuter) - << "]"); - - ValueTable loadVals; - for (int inner = 0; inner < numRepInner; ++inner) { - for (int outer = 0; outer < numRepOuter; ++outer) { - for (int loadInner = 0; loadInner < numLoadPerInnerRepCluster; - ++loadInner) { - for (int loadOuter = 0; loadOuter < numLoadPerOutRepCluster; - ++loadOuter) { - unsigned offsetOuter = - outer * repOuterStride + loadOuter * dpasInstShape[dimOuter] * - numOperandsOuterDimPerLoad; - unsigned offsetInner = - inner * repInnerStride + loadInner * dpasInstShape[dimInner] * - numOperandsInnerDimPerLoad; - unsigned offsetM = - (opIdx != DpasEncodingAttr::OpIdx::OperandB ? offsetOuter - : offsetInner); - unsigned offsetN = - (opIdx != DpasEncodingAttr::OpIdx::OperandB ? offsetInner - : offsetOuter); - - LDBG("Block load iterator: inner: " - << inner << ", outer:" << outer << ", loadInner:" << loadInner - << ", loadOuter:" << loadOuter << " offset: [" << offsetM - << ", " << offsetN << "]"); - - Value offsetY = b.i32_val(0); - Value pred; - if (llMask) { - assert(masks.size() && "Invalid size of the masks."); - pred = targetInfo.shuffleIdx(rewriter, loc, - masks[{offsetM, offsetN}], 0); - // We leverage the GPU block I/O hardware out-of-bound protection - // feature by setting the offset to an invalid value when 'pred' - // is false (the HW will not read out-of-bounds values). Later on, - // after issuing the 2d block read operation, we will select the - // result of the load only if the mask evaluate to true, otherwise - // we will use 'other'. - offsetY = b.select(pred, offsetY, baseHeight); - } - - // Use the top-left address of the block to load the data. - Value addrElem = - b.bitcast(ptrs[{offsetM, offsetN}], ptr_ty(ctx, 1 /*global*/)); - addrElem = targetInfo.shuffleIdx(rewriter, loc, addrElem, 0); - - Value ret = rewriter.create( - loc, load2DGenXType, - /*ptr*/ addrElem, - /*base_width*/ baseWidth, - /*base_height*/ baseHeight, - /*base_pitch*/ pitch, - /*x*/ b.i32_val(0), - /*y*/ offsetY, - /*elem_size_in_bits*/ elemSizeInBits, - /*tile_width*/ tileWidth, - /*tile_height*/ tileHeight, - /*v_blocks*/ vBlocks, - /*transpose*/ false, - /*vnni_transform*/ - (usePackedType && opIdx == DpasEncodingAttr::OpIdx::OperandB && - !isTransposeRequired && originalElemBits != 32)); - - // When strides[0] is 0, we only want to load the first row, so we - // set the base height to be 1. If tile height is bigger than 1, - // then only the first row contain valid data. To ensure the entire - // tile is filled with valid data, we must replicate the first row - // throughout the tile. - if (baseHeightInt < tileHeight && baseHeightInt == 1) { - unsigned numIndicesPerMatrix = numValuesPerLoad / vBlocks; - SmallVector shuffleIndices(numValuesPerLoad); - - // Create a vector to store the data of the first index of each - // matrix. - VectorType vecTy = vec_ty(loadResultElemType, vBlocks); - Value firstIndexVec = b.undef(vecTy); - - for (unsigned valueIndex = 0; valueIndex < numValuesPerLoad; - ++valueIndex) { - unsigned firstIndexVecIdx = valueIndex / numIndicesPerMatrix; - // Handle case where an index spans two rows. - if (valueIndex % numIndicesPerMatrix == 0) { - Value oldVal = b.extract_element(ret, b.i32_val(valueIndex)); - Value newVal = oldVal; - if (tileWidth < threadsPerWarp) { - assert(tileWidth * 2 == threadsPerWarp && - "Expecting tileWidth to be 2x threadsPerWarp"); - Value threadId = getThreadId(rewriter, loc); - newVal = targetInfo.shuffleIdx( - rewriter, loc, oldVal, - b.urem(threadId, b.i32_val(tileWidth))); - } - firstIndexVec = - b.insert_element(firstIndexVec.getType(), firstIndexVec, - newVal, b.i32_val(firstIndexVecIdx)); - } + bool useVNNIFormat = false; + Type packedDPASOperandType; + if (hasDotDpasEncoding(tensorType)) { + DpasEncodingAttr::OpIdx opIdx = getOpIdx(tensorType); + auto dpasLayout = getDpasLayout(tensorType); + if (opIdx == DpasEncodingAttr::OpIdx::OperandB) { + unsigned elemsPerLanePerDPASInst = + product(dpasLayout.getDPASInstShapeB()) / threadsPerWarp; + // Block 2D contain at least one DotOp B. + if (numElemsPerLoad >= elemsPerLanePerDPASInst) { + unsigned opsPerChannel = dpasLayout.getOpsPerChannel(); + unsigned sysDepth = dpasLayout.getSystolicDepth(); + if (tileHeight >= (opsPerChannel * sysDepth) && + ((opsPerChannel == 4 && elemSizeInBits == 8) || + (opsPerChannel == 2 && elemSizeInBits == 16))) { + // Use the VNNI packing format for DotOp B layout. + numValuesPerLoad = numElemsPerLoad / opsPerChannel; + packedType = i32_ty; + load2DGenXType = LLVM::getVectorType(packedType, numValuesPerLoad); + packedDPASOperandType = LLVM::getVectorType( + packedType, elemsPerLanePerDPASInst / opsPerChannel); + useVNNIFormat = true; + } else { + packedDPASOperandType = + LLVM::getVectorType(IntegerType::get(ctx, packedElemSizeInBits), + elemsPerLanePerDPASInst / numPackedVals); + } + unpackedType = LLVM::getVectorType(eltTy, elemsPerLanePerDPASInst); + } + } + } + SmallVector unpackedLoadedVals(numElems); + for (size_t elemIdx = 0; elemIdx < numElems; elemIdx += numElemsPerLoad) { + unsigned registerIdx = regMapping.apply({{kRegister, elemIdx}})[0].second; - shuffleIndices[valueIndex] = firstIndexVecIdx; - } - DenseI32ArrayAttr attr = - rewriter.getDenseI32ArrayAttr(shuffleIndices); - ret = rewriter.create( - loc, load2DGenXType, firstIndexVec, firstIndexVec, attr); - } + // Use the top-left address of the block to load the data. + Value addrElem = ptrElems[registerIdx]; + addrElem = targetInfo.shuffleIdx(rewriter, loc, addrElem, 0); - if (others.size()) { - assert(masks.size() == others.size() && - "The mask value has to be provided when " - "the other value is provided."); - VectorType vecTy = - vec_ty(eltTy, numValuesPerLoad * packedElemsNum); + Value offsetX = b.i32_val(0); + Value offsetY = b.i32_val(0); + Value pred; + if (maskElems.size()) { + pred = targetInfo.shuffleIdx(rewriter, loc, maskElems[registerIdx], 0); + // We leverage the GPU block I/O hardware out-of-bound protection + // feature by setting the offset to an invalid value when 'pred' + // is false (the HW will not read out-of-bounds values). Later on, + // after issuing the 2d block read operation, we will select the + // result of the load only if the mask evaluate to true, otherwise + // we will use 'other'. + offsetY = b.select(pred, offsetY, baseHeight); + } - Value v = b.undef(vecTy); - unsigned nWords = 0; - for (int vblk = 0; vblk < vBlocks; ++vblk) - for (int i = 0; i < tileHeight; ++i) { - unsigned numColPerPackedValue = - opIdx == DpasEncodingAttr::OpIdx::OperandA - ? packedElemsNum - : 1; - unsigned numPackedValuesPerRow = mlir::ceil( - (tileWidth / numColPerPackedValue), threadsPerWarp); - for (int col = 0; col < numPackedValuesPerRow; ++col) { - for (int packedCol = 0; packedCol < numColPerPackedValue; - ++packedCol) { - unsigned N = packedCol + - col * threadsPerWarp * numColPerPackedValue + - vblk * tileWidth + offsetN; - unsigned M = i + offsetM; - Value falseVal = others[{M, N}]; - Value sVal = createIndexAttrConstant( - rewriter, loc, typeConverter->getIndexType(), - nWords++); - v = b.insert_element(vecTy, v, falseVal, sVal); - } - } - } - Value others = b.bitcast(v, load2DGenXType); - ret = b.select(pred, ret, others); + Value ret = rewriter.create( + loc, load2DGenXType, + /*ptr*/ addrElem, + /*base_width*/ baseWidth, + /*base_height*/ baseHeight, + /*base_pitch*/ pitch, + /*x*/ offsetX, + /*y*/ offsetY, + /*elem_size_in_bits*/ packedElemSizeInBits, + /*tile_width*/ tileWidth, + /*tile_height*/ tileHeight, + /*v_blocks*/ vBlocks, + /*transpose*/ false, + /*vnni_transform*/ useVNNIFormat); + + // When strides[0] is 0, we only want to load the first row, so we + // set the base height to be 1. If tile height is bigger than 1, + // then only the first row contain valid data. To ensure the entire + // tile is filled with valid data, we must replicate the first row + // throughout the tile. + if (baseHeightInt < tileHeight && baseHeightInt == 1) { + unsigned numIndicesPerMatrix = numValuesPerLoad / vBlocks; + SmallVector shuffleIndices(numValuesPerLoad); + + // Create a vector to store the data of the first index of each + // matrix. + VectorType vecTy = vec_ty(packedType, vBlocks); + Value firstIndexVec = b.undef(vecTy); + + for (unsigned valueIndex = 0; valueIndex < numValuesPerLoad; + ++valueIndex) { + unsigned firstIndexVecIdx = valueIndex / numIndicesPerMatrix; + // Handle case where an index spans two rows. + if (valueIndex % numIndicesPerMatrix == 0) { + Value oldVal = b.extract_element(ret, b.i32_val(valueIndex)); + Value newVal = oldVal; + if (tileWidth < threadsPerWarp) { + assert(tileWidth * 2 == threadsPerWarp && + "Expecting tileWidth to be 2x threadsPerWarp"); + Value threadId = getThreadId(rewriter, loc); + newVal = + targetInfo.shuffleIdx(rewriter, loc, oldVal, + b.urem(threadId, b.i32_val(tileWidth))); } + firstIndexVec = + b.insert_element(firstIndexVec.getType(), firstIndexVec, newVal, + b.i32_val(firstIndexVecIdx)); + } - unsigned numOperandsM = opIdx != DpasEncodingAttr::OpIdx::OperandB - ? numOperandsOuterDimPerLoad - : numOperandsInnerDimPerLoad; - unsigned numOperandsN = opIdx != DpasEncodingAttr::OpIdx::OperandB - ? numOperandsInnerDimPerLoad - : numOperandsOuterDimPerLoad; - - // Split the return matrix by large 2d block io size into multiple - // DPAS operands. - assert(numOperandsN >= vBlocks && - "numOperandsN has to be >= vBlocks"); - unsigned numOperandsPerVBlockN = numOperandsN / vBlocks; - for (int vblk = 0; vblk < vBlocks; ++vblk) - for (int row = 0; row < numOperandsM; ++row) - for (int col = 0; col < numOperandsPerVBlockN; ++col) { - - unsigned operandStartOffset = (vblk * numOperandsM + row) * - numOperandsPerVBlockN * - packedElemsPerLanePerDPASInst; - - SmallVector indices(packedElemsPerLanePerDPASInst); - for (int elemIdx = 0; elemIdx < packedElemsPerLanePerDPASInst; - ++elemIdx) { - indices[elemIdx] = operandStartOffset + - elemIdx * numOperandsPerVBlockN + col; - } + shuffleIndices[valueIndex] = firstIndexVecIdx; + } + DenseI32ArrayAttr attr = rewriter.getDenseI32ArrayAttr(shuffleIndices); + ret = rewriter.create( + loc, load2DGenXType, firstIndexVec, firstIndexVec, attr); + } - LLVM_DEBUG({ - DBGS() << "shuffle idx: ["; - for (int elemIdx = 0; - elemIdx < packedElemsPerLanePerDPASInst; ++elemIdx) { - llvm::dbgs() << indices[elemIdx] << ", "; - } - llvm::dbgs() << "]\n"; - }); + unsigned numElemsPerUnpackedType = + LLVM::getVectorNumElements(unpackedType).getKnownMinValue(); + unsigned numValsPerDPASOperand = + packedDPASOperandType + ? LLVM::getVectorNumElements(packedDPASOperandType) + .getKnownMinValue() + : numValuesPerLoad; + unsigned numOperandsPerLoad = numValuesPerLoad / numValsPerDPASOperand; + for (size_t opsIdx = 0; opsIdx < numOperandsPerLoad; ++opsIdx) { + Value unpackedVal; + if (numValsPerDPASOperand != numValuesPerLoad) { + // Decompose the return value to multiple DPAS operands. + SmallVector indices(numValsPerDPASOperand); + for (int i = 0; i < numValsPerDPASOperand; ++i) { + indices[i] = opsIdx * numValsPerDPASOperand + i; + } + DenseI32ArrayAttr attr = rewriter.getDenseI32ArrayAttr(indices); + Value dpasOperand = rewriter.create( + loc, packedDPASOperandType, ret, ret, attr); - DenseI32ArrayAttr attr = - rewriter.getDenseI32ArrayAttr(indices); - Value loadVal = rewriter.create( - loc, packedDPASOperandType, ret, ret, attr); + unpackedVal = b.bitcast(dpasOperand, unpackedType); - // Save the decomposed vals to the map; - switch (opIdx) { - case DpasEncodingAttr::OpIdx::OperandC: - case DpasEncodingAttr::OpIdx::OperandA: { - unsigned o = outer * numLoadPerOutRepCluster * - numOperandsOuterDimPerLoad + - loadOuter * numOperandsOuterDimPerLoad + row; - unsigned i = inner * numLoadPerInnerRepCluster * - numOperandsInnerDimPerLoad + - loadInner * numOperandsInnerDimPerLoad + - vblk * numOperandsPerVBlockN + col; + } else { + unpackedVal = b.bitcast(ret, unpackedType); + } - LDBG("insert: [" << o << ", " << i << "]"); - loadVals[{o, i}] = - b.bitcast(loadVal, unpackedDPASOperandType); - } break; - case DpasEncodingAttr::OpIdx::OperandB: { - unsigned o = outer * numLoadPerOutRepCluster * - numOperandsOuterDimPerLoad + - loadOuter * numOperandsOuterDimPerLoad + - vblk * numOperandsPerVBlockN + col; - unsigned i = inner * numOperandsInnerDimPerLoad + row; - LDBG("insert: [" << o << ", " << i << "]"); - loadVals[{o, i}] = - b.bitcast(loadVal, unpackedDPASOperandType); - } break; - default: { - llvm_unreachable("unknown DPAS operands index type."); - } break; - } - } + if (otherElems.size()) { + assert(maskElems.size() == otherElems.size() && + "Invalid size of the masks"); + Value other = b.undef(unpackedType); + for (size_t i = 0; i < numElemsPerUnpackedType; ++i) { + unsigned registerIdx = + regMapping + .apply( + {{kRegister, + elemIdx + opsIdx * numElemsPerUnpackedType + i}})[0] + .second; + Value falseVal = otherElems[registerIdx]; + other = b.insert_element(other, falseVal, b.i32_val(i)); } + unpackedVal = b.select(pred, unpackedVal, other); } - } - } - // Step 5: Unpack the load values. - // Extract the value returned by the load ops. And put the values in the - // expected order for the layout. - SmallVector unpackedLoadedVals; - for (int outer = 0; outer < numReps[dimOuter]; ++outer) { - for (int inner = 0; inner < numReps[dimInner]; ++inner) { - for (int repOuter = 0; repOuter < repCluster[dimOuter]; ++repOuter) { - for (int repInner = 0; repInner < repCluster[dimInner]; ++repInner) { - unsigned o = outer * repCluster[dimOuter] + repOuter; - unsigned i = inner * repCluster[dimInner] + repInner; - LDBG("extract: [" << o << ", " << i << "]"); - Value loadVal = loadVals.at({o, i}); - VectorType loadTy = cast(loadVal.getType()); - for (int i = 0; i < loadTy.getNumElements(); ++i) { - auto val = b.extract_element(loadVal, b.i32_val(i)); - unpackedLoadedVals.push_back(val); - } - loadVals.erase({o, i}); - } + for (int i = 0; i < numElemsPerUnpackedType; ++i) { + unsigned registerIdx = + regMapping + .apply({{kRegister, + elemIdx + opsIdx * numElemsPerUnpackedType + i}})[0] + .second; + unpackedLoadedVals[registerIdx] = + b.extract_element(unpackedVal, b.i32_val(i)); } } } - assert(loadVals.empty() && "not all loaded values is unpacked."); - + auto typeConverter = getTypeConverter(); Type llvmResultStructTy = typeConverter->convertType(op.getType()); Value resultStruct = packLLElements(loc, typeConverter, unpackedLoadedVals, rewriter, llvmResultStructTy);