@@ -503,6 +503,90 @@ createGenISA2DBlockPrefetch(TritonGEN::Matrix2DBlockPrefetchOp op,
503
503
intel::noUnwindWillReturnAttrs);
504
504
}
505
505
506
+ template <typename OpTy>
507
+ static void
508
+ validateMatrix2DBlockParameters (OpTy op,
509
+ mlir::ConversionPatternRewriter &rewriter) {
510
+ using namespace mlir ;
511
+ using namespace mlir ::LLVM;
512
+
513
+ Location loc = op->getLoc ();
514
+ auto b = TritonLLVMOpBuilder (loc, rewriter);
515
+ MLIRContext *ctx = rewriter.getContext ();
516
+
517
+ Value baseWidth = op.getBaseWidth ();
518
+ Value baseHeight = op.getBaseHeight ();
519
+ Value basePitch = op.getBasePitch ();
520
+ Value x = op.getX ();
521
+ unsigned elemSize = op.getElemSizeInBits () / 8 ;
522
+
523
+ if (!baseWidth.getType ().isInteger (32 ))
524
+ baseWidth = rewriter.create <ZExtOp>(loc, rewriter.getI32Type (), baseWidth);
525
+ if (!baseHeight.getType ().isInteger (32 ))
526
+ baseHeight =
527
+ rewriter.create <ZExtOp>(loc, rewriter.getI32Type (), baseHeight);
528
+ if (!basePitch.getType ().isInteger (32 ))
529
+ basePitch = rewriter.create <ZExtOp>(loc, rewriter.getI32Type (), basePitch);
530
+ if (!x.getType ().isInteger (32 ))
531
+ x = rewriter.create <ZExtOp>(loc, rewriter.getI32Type (), x);
532
+
533
+ Value c0 = b.i32_val (0 );
534
+ Value c4 = b.i32_val (4 );
535
+ Value c64 = b.i32_val (64 );
536
+ Value c24m1 = b.i32_val ((1u << 24 ) - 1 ); // 2^24 - 1
537
+ Value cElemSize = b.i32_val (elemSize);
538
+
539
+ // ===== validation predicates =====
540
+
541
+ // width!=0 && width<2^24 && width%4==0
542
+ Value wZero = rewriter.create <ICmpOp>(loc, ICmpPredicate::eq, baseWidth, c0);
543
+ Value wTooLarge =
544
+ rewriter.create <ICmpOp>(loc, ICmpPredicate::ugt, baseWidth, c24m1);
545
+ Value wRem = rewriter.create <URemOp>(loc, baseWidth, c4);
546
+ Value wNotAligned = rewriter.create <ICmpOp>(loc, ICmpPredicate::ne, wRem, c0);
547
+ Value badWidth = rewriter.create <OrOp>(
548
+ loc, wZero, rewriter.create <OrOp>(loc, wTooLarge, wNotAligned));
549
+
550
+ // height!=0 && height<2^24
551
+ Value hZero = rewriter.create <ICmpOp>(loc, ICmpPredicate::eq, baseHeight, c0);
552
+ Value hTooLarge =
553
+ rewriter.create <ICmpOp>(loc, ICmpPredicate::ugt, baseHeight, c24m1);
554
+ Value badHeight = rewriter.create <OrOp>(loc, hZero, hTooLarge);
555
+
556
+ // pitch >= 64
557
+ Value badPitch =
558
+ rewriter.create <ICmpOp>(loc, ICmpPredicate::ult, basePitch, c64);
559
+
560
+ // x*elemSize % 4 == 0
561
+ Value offsetBytes = rewriter.create <MulOp>(loc, x, cElemSize);
562
+ Value offsetRem = rewriter.create <URemOp>(loc, offsetBytes, c4);
563
+ Value badOffset =
564
+ rewriter.create <ICmpOp>(loc, ICmpPredicate::ne, offsetRem, c0);
565
+
566
+ // assert on any
567
+ Value anyBad = rewriter.create <OrOp>(
568
+ loc, badWidth,
569
+ rewriter.create <OrOp>(loc, badHeight,
570
+ rewriter.create <OrOp>(loc, badPitch, badOffset)));
571
+
572
+ Block *curBlock = rewriter.getBlock ();
573
+ auto ip = rewriter.getInsertionPoint ();
574
+ Block *contBlock = rewriter.splitBlock (curBlock, ip);
575
+ Region *region = contBlock->getParent ();
576
+ Block *trapBlock = rewriter.createBlock (region, Region::iterator (contBlock));
577
+
578
+ // TODO: use __assert_fail instead of llvm.intr.trap
579
+ rewriter.setInsertionPointToStart (trapBlock);
580
+ rewriter.create <Trap>(loc);
581
+ rewriter.create <UnreachableOp>(loc);
582
+
583
+ rewriter.setInsertionPointToEnd (curBlock);
584
+ rewriter.create <CondBrOp>(loc, anyBad, trapBlock, ValueRange{}, contBlock,
585
+ ValueRange{});
586
+
587
+ rewriter.setInsertionPointToStart (contBlock);
588
+ }
589
+
506
590
namespace {
507
591
508
592
// ===----------------------------------------------------------------------===//
@@ -636,6 +720,8 @@ struct TritonMatrix2DBlockLoadLowering
636
720
LogicalResult
637
721
matchAndRewrite (TritonGEN::Matrix2DBlockLoadOp op, OpAdaptor adaptor,
638
722
ConversionPatternRewriter &rewriter) const override {
723
+ validateMatrix2DBlockParameters (op, rewriter);
724
+
639
725
if (!isSPVBuiltinAvailable (op)) {
640
726
// Fallback to GenISA interface.
641
727
rewriter.replaceOp (op, createGenISA2DBlockRead (op, rewriter));
@@ -711,6 +797,8 @@ struct TritonMatrix2DBlockStoreLowering
711
797
LogicalResult
712
798
matchAndRewrite (TritonGEN::Matrix2DBlockStoreOp op, OpAdaptor adaptor,
713
799
ConversionPatternRewriter &rewriter) const override {
800
+ validateMatrix2DBlockParameters (op, rewriter);
801
+
714
802
if (!isSPVBuiltinAvailable (op)) {
715
803
// Fallback to GenISA interface.
716
804
rewriter.replaceOp (op, createGenISA2DBlockWrite (op, rewriter));
@@ -785,6 +873,8 @@ struct TritonMatrix2DBlockPrefetchLowering
785
873
LogicalResult
786
874
matchAndRewrite (TritonGEN::Matrix2DBlockPrefetchOp op, OpAdaptor adaptor,
787
875
ConversionPatternRewriter &rewriter) const override {
876
+ validateMatrix2DBlockParameters (op, rewriter);
877
+
788
878
if (!isSPVBuiltinAvailable (op)) {
789
879
// Fallback to GenISA interface.
790
880
rewriter.replaceOp (op, createGenISA2DBlockPrefetch (op, rewriter));
0 commit comments