@@ -503,6 +503,82 @@ createGenISA2DBlockPrefetch(TritonGEN::Matrix2DBlockPrefetchOp op,
503
503
intel::noUnwindWillReturnAttrs);
504
504
}
505
505
506
+ template <typename OpTy>
507
+ static void validateMatrix2DBlockParameters (OpTy op,
508
+ mlir::ConversionPatternRewriter &rewriter) {
509
+ using namespace mlir ;
510
+ using namespace mlir ::LLVM;
511
+
512
+ Location loc = op->getLoc ();
513
+ auto b = TritonLLVMOpBuilder (loc, rewriter);
514
+ MLIRContext *ctx = rewriter.getContext ();
515
+
516
+ Value baseWidth = op.getBaseWidth ();
517
+ Value baseHeight = op.getBaseHeight ();
518
+ Value basePitch = op.getBasePitch ();
519
+ Value x = op.getX ();
520
+ unsigned elemSize = op.getElemSizeInBits () / 8 ;
521
+
522
+ if (!baseWidth.getType ().isInteger (32 ))
523
+ baseWidth = rewriter.create <ZExtOp>(loc, rewriter.getI32Type (), baseWidth);
524
+ if (!baseHeight.getType ().isInteger (32 ))
525
+ baseHeight = rewriter.create <ZExtOp>(loc, rewriter.getI32Type (), baseHeight);
526
+ if (!basePitch.getType ().isInteger (32 ))
527
+ basePitch = rewriter.create <ZExtOp>(loc, rewriter.getI32Type (), basePitch);
528
+ if (!x.getType ().isInteger (32 ))
529
+ x = rewriter.create <ZExtOp>(loc, rewriter.getI32Type (), x);
530
+
531
+ Value c0 = b.i32_val (0 );
532
+ Value c4 = b.i32_val (4 );
533
+ Value c64 = b.i32_val (64 );
534
+ Value c24m1 = b.i32_val ((1u <<24 ) - 1 ); // 2^24 - 1
535
+ Value cElemSize = b.i32_val (elemSize);
536
+
537
+ // ===== validation predicates =====
538
+
539
+ // width!=0 && width<2^24 && width%4==0
540
+ Value wZero = rewriter.create <ICmpOp>(loc, ICmpPredicate::eq, baseWidth, c0);
541
+ Value wTooLarge = rewriter.create <ICmpOp>(loc, ICmpPredicate::ugt, baseWidth, c24m1);
542
+ Value wRem = rewriter.create <URemOp>(loc, baseWidth, c4);
543
+ Value wNotAligned = rewriter.create <ICmpOp>(loc, ICmpPredicate::ne, wRem, c0);
544
+ Value badWidth = rewriter.create <OrOp>(loc, wZero,
545
+ rewriter.create <OrOp>(loc, wTooLarge, wNotAligned));
546
+
547
+ // height!=0 && height<2^24
548
+ Value hZero = rewriter.create <ICmpOp>(loc, ICmpPredicate::eq, baseHeight, c0);
549
+ Value hTooLarge = rewriter.create <ICmpOp>(loc, ICmpPredicate::ugt, baseHeight, c24m1);
550
+ Value badHeight = rewriter.create <OrOp>(loc, hZero, hTooLarge);
551
+
552
+ // pitch >= 64
553
+ Value badPitch = rewriter.create <ICmpOp>(loc, ICmpPredicate::ult, basePitch, c64);
554
+
555
+ // x*elemSize % 4 == 0
556
+ Value offsetBytes = rewriter.create <MulOp>(loc, x, cElemSize);
557
+ Value offsetRem = rewriter.create <URemOp>(loc, offsetBytes, c4);
558
+ Value badOffset = rewriter.create <ICmpOp>(loc, ICmpPredicate::ne, offsetRem, c0);
559
+
560
+ // assert on any
561
+ Value anyBad = rewriter.create <OrOp>(loc, badWidth,
562
+ rewriter.create <OrOp>(loc, badHeight,
563
+ rewriter.create <OrOp>(loc, badPitch, badOffset)));
564
+
565
+ Block *curBlock = rewriter.getBlock ();
566
+ auto ip = rewriter.getInsertionPoint ();
567
+ Block *contBlock = rewriter.splitBlock (curBlock, ip);
568
+ Region *region = contBlock->getParent ();
569
+ Block *trapBlock = rewriter.createBlock (region, Region::iterator (contBlock));
570
+
571
+ // TODO: use __assert_fail instead of llvm.intr.trap
572
+ rewriter.setInsertionPointToStart (trapBlock);
573
+ rewriter.create <Trap>(loc);
574
+ rewriter.create <UnreachableOp>(loc);
575
+
576
+ rewriter.setInsertionPointToEnd (curBlock);
577
+ rewriter.create <CondBrOp>(loc, anyBad, trapBlock, ValueRange{}, contBlock, ValueRange{});
578
+
579
+ rewriter.setInsertionPointToStart (contBlock);
580
+ }
581
+
506
582
namespace {
507
583
508
584
// ===----------------------------------------------------------------------===//
@@ -636,6 +712,8 @@ struct TritonMatrix2DBlockLoadLowering
636
712
LogicalResult
637
713
matchAndRewrite (TritonGEN::Matrix2DBlockLoadOp op, OpAdaptor adaptor,
638
714
ConversionPatternRewriter &rewriter) const override {
715
+ validateMatrix2DBlockParameters (op, rewriter);
716
+
639
717
if (!isSPVBuiltinAvailable (op)) {
640
718
// Fallback to GenISA interface.
641
719
rewriter.replaceOp (op, createGenISA2DBlockRead (op, rewriter));
@@ -711,6 +789,8 @@ struct TritonMatrix2DBlockStoreLowering
711
789
LogicalResult
712
790
matchAndRewrite (TritonGEN::Matrix2DBlockStoreOp op, OpAdaptor adaptor,
713
791
ConversionPatternRewriter &rewriter) const override {
792
+ validateMatrix2DBlockParameters (op, rewriter);
793
+
714
794
if (!isSPVBuiltinAvailable (op)) {
715
795
// Fallback to GenISA interface.
716
796
rewriter.replaceOp (op, createGenISA2DBlockWrite (op, rewriter));
@@ -785,6 +865,8 @@ struct TritonMatrix2DBlockPrefetchLowering
785
865
LogicalResult
786
866
matchAndRewrite (TritonGEN::Matrix2DBlockPrefetchOp op, OpAdaptor adaptor,
787
867
ConversionPatternRewriter &rewriter) const override {
868
+ validateMatrix2DBlockParameters (op, rewriter);
869
+
788
870
if (!isSPVBuiltinAvailable (op)) {
789
871
// Fallback to GenISA interface.
790
872
rewriter.replaceOp (op, createGenISA2DBlockPrefetch (op, rewriter));
0 commit comments