Skip to content

Commit 6cf28c1

Browse files
committed
[intel] 2Dblock runtime HW checks
1 parent 4249232 commit 6cf28c1

File tree

3 files changed

+158
-0
lines changed

3 files changed

+158
-0
lines changed

python/test/unit/intel/test_block_load.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,3 +207,42 @@ def triton_mm(X, Y, b=None, transpose_x=False, transpose_y=False):
207207
result_tor = fn_tor()
208208
result_tri = fn_tri()
209209
torch.testing.assert_close(result_tri, result_tor, atol=1e-2, rtol=1e-3)
210+
211+
212+
def test_block_load_asserts(tmp_path: pathlib.Path):
213+
ir = r"""
214+
module attributes {
215+
ttg.target = "xpu",
216+
"ttg.num-warps" = 32 : i32,
217+
"ttg.num-ctas" = 1 : i32,
218+
"ttg.threads-per-warp" = 16 : i32
219+
} {
220+
tt.func @dyn_block(
221+
%iptr : i64, %base_width : i32,
222+
%base_height : i32, %base_pitch : i32,
223+
%x : i32, %y : i32) {
224+
%p0 = llvm.inttoptr %iptr : i64 to !llvm.ptr
225+
226+
%0 = triton_gen.2Dblockload %p0, %base_width, %base_height,
227+
%base_pitch, %x, %y
228+
{ elem_size_in_bits = 8, tile_width = 8, tile_height = 8,
229+
v_blocks = 1, transpose = false,
230+
vnni_transform = false, cache_control = Default }
231+
: (!llvm.ptr, i32, i32, i32, i32, i32)
232+
-> vector<2xi16>
233+
tt.return
234+
}
235+
}
236+
"""
237+
238+
temp_file = tmp_path / "test_block_load_asserts.ttgir"
239+
temp_file.write_text(ir)
240+
kernel = triton.compile(str(temp_file))
241+
242+
a = torch.randn((256, 64), dtype=torch.float32, device="xpu")
243+
244+
import ctypes
245+
addr = ctypes.c_int64(a.data_ptr()).value
246+
247+
# TODO catch the assert from __assert_fail
248+
kernel[(1, 1, 1)](addr, 64, 64, 1, 0, 0)
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
// RUN: env TRITON_INTEL_2DBLOCK_ASSERT=1 triton-opt -convert-tritongen-to-llvm -split-input-file %s | FileCheck %s
2+
3+
module attributes {"ttg.threads-per-warp" = 16 : i32} {
4+
llvm.func @triton_gen.2Dblockload(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) {
5+
// CHECK: "llvm.intr.trap"() : () -> ()
6+
%0 = triton_gen.2Dblockload %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=8, tile_width=8, tile_height=8, v_blocks=1, transpose=false, vnni_transform=false, cache_control=Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<2xi16>
7+
llvm.return
8+
}
9+
}
10+
11+
// -----
12+
13+
llvm.func @triton_gen.2Dblockprefetch(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) {
14+
// CHECK: "llvm.intr.trap"() : () -> ()
15+
triton_gen.2Dblockprefetch %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=8, tile_width=32, tile_height=16, v_blocks=1, cache_control=Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32)
16+
llvm.return
17+
}
18+
19+
// -----
20+
21+
module attributes {"ttg.threads-per-warp" = 16 : i32} {
22+
llvm.func @triton_gen.2Dblockstore(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32, %stored_val : vector<8xi16>) {
23+
// CHECK: "llvm.intr.trap"() : () -> ()
24+
triton_gen.2Dblockstore %ptr, %base_width, %base_height, %base_pitch, %x, %y, %stored_val {elem_size_in_bits=8, tile_width=32, tile_height=8, v_blocks=1, cache_control=Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32, vector<8xi16>)
25+
llvm.return
26+
}
27+
}
28+
29+
// TODO: change checks to use __assert_fail

third_party/intel/lib/TritonGENToLLVM/TritonGENToLLVMPass.cpp

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -503,6 +503,90 @@ createGenISA2DBlockPrefetch(TritonGEN::Matrix2DBlockPrefetchOp op,
503503
intel::noUnwindWillReturnAttrs);
504504
}
505505

506+
template <typename OpTy>
507+
static void
508+
validateMatrix2DBlockParameters(OpTy op,
509+
mlir::ConversionPatternRewriter &rewriter) {
510+
using namespace mlir;
511+
using namespace mlir::LLVM;
512+
513+
Location loc = op->getLoc();
514+
auto b = TritonLLVMOpBuilder(loc, rewriter);
515+
MLIRContext *ctx = rewriter.getContext();
516+
517+
Value baseWidth = op.getBaseWidth();
518+
Value baseHeight = op.getBaseHeight();
519+
Value basePitch = op.getBasePitch();
520+
Value x = op.getX();
521+
unsigned elemSize = op.getElemSizeInBits() / 8;
522+
523+
if (!baseWidth.getType().isInteger(32))
524+
baseWidth = rewriter.create<ZExtOp>(loc, rewriter.getI32Type(), baseWidth);
525+
if (!baseHeight.getType().isInteger(32))
526+
baseHeight =
527+
rewriter.create<ZExtOp>(loc, rewriter.getI32Type(), baseHeight);
528+
if (!basePitch.getType().isInteger(32))
529+
basePitch = rewriter.create<ZExtOp>(loc, rewriter.getI32Type(), basePitch);
530+
if (!x.getType().isInteger(32))
531+
x = rewriter.create<ZExtOp>(loc, rewriter.getI32Type(), x);
532+
533+
Value c0 = b.i32_val(0);
534+
Value c4 = b.i32_val(4);
535+
Value c64 = b.i32_val(64);
536+
Value c24m1 = b.i32_val((1u << 24) - 1); // 2^24 - 1
537+
Value cElemSize = b.i32_val(elemSize);
538+
539+
// ===== validation predicates =====
540+
541+
// width!=0 && width<2^24 && width%4==0
542+
Value wZero = rewriter.create<ICmpOp>(loc, ICmpPredicate::eq, baseWidth, c0);
543+
Value wTooLarge =
544+
rewriter.create<ICmpOp>(loc, ICmpPredicate::ugt, baseWidth, c24m1);
545+
Value wRem = rewriter.create<URemOp>(loc, baseWidth, c4);
546+
Value wNotAligned = rewriter.create<ICmpOp>(loc, ICmpPredicate::ne, wRem, c0);
547+
Value badWidth = rewriter.create<OrOp>(
548+
loc, wZero, rewriter.create<OrOp>(loc, wTooLarge, wNotAligned));
549+
550+
// height!=0 && height<2^24
551+
Value hZero = rewriter.create<ICmpOp>(loc, ICmpPredicate::eq, baseHeight, c0);
552+
Value hTooLarge =
553+
rewriter.create<ICmpOp>(loc, ICmpPredicate::ugt, baseHeight, c24m1);
554+
Value badHeight = rewriter.create<OrOp>(loc, hZero, hTooLarge);
555+
556+
// pitch >= 64
557+
Value badPitch =
558+
rewriter.create<ICmpOp>(loc, ICmpPredicate::ult, basePitch, c64);
559+
560+
// x*elemSize % 4 == 0
561+
Value offsetBytes = rewriter.create<MulOp>(loc, x, cElemSize);
562+
Value offsetRem = rewriter.create<URemOp>(loc, offsetBytes, c4);
563+
Value badOffset =
564+
rewriter.create<ICmpOp>(loc, ICmpPredicate::ne, offsetRem, c0);
565+
566+
// assert on any
567+
Value anyBad = rewriter.create<OrOp>(
568+
loc, badWidth,
569+
rewriter.create<OrOp>(loc, badHeight,
570+
rewriter.create<OrOp>(loc, badPitch, badOffset)));
571+
572+
Block *curBlock = rewriter.getBlock();
573+
auto ip = rewriter.getInsertionPoint();
574+
Block *contBlock = rewriter.splitBlock(curBlock, ip);
575+
Region *region = contBlock->getParent();
576+
Block *trapBlock = rewriter.createBlock(region, Region::iterator(contBlock));
577+
578+
// TODO: use __assert_fail instead of llvm.intr.trap
579+
rewriter.setInsertionPointToStart(trapBlock);
580+
rewriter.create<Trap>(loc);
581+
rewriter.create<UnreachableOp>(loc);
582+
583+
rewriter.setInsertionPointToEnd(curBlock);
584+
rewriter.create<CondBrOp>(loc, anyBad, trapBlock, ValueRange{}, contBlock,
585+
ValueRange{});
586+
587+
rewriter.setInsertionPointToStart(contBlock);
588+
}
589+
506590
namespace {
507591

508592
//===----------------------------------------------------------------------===//
@@ -636,6 +720,8 @@ struct TritonMatrix2DBlockLoadLowering
636720
LogicalResult
637721
matchAndRewrite(TritonGEN::Matrix2DBlockLoadOp op, OpAdaptor adaptor,
638722
ConversionPatternRewriter &rewriter) const override {
723+
validateMatrix2DBlockParameters(op, rewriter);
724+
639725
if (!isSPVBuiltinAvailable(op)) {
640726
// Fallback to GenISA interface.
641727
rewriter.replaceOp(op, createGenISA2DBlockRead(op, rewriter));
@@ -711,6 +797,8 @@ struct TritonMatrix2DBlockStoreLowering
711797
LogicalResult
712798
matchAndRewrite(TritonGEN::Matrix2DBlockStoreOp op, OpAdaptor adaptor,
713799
ConversionPatternRewriter &rewriter) const override {
800+
validateMatrix2DBlockParameters(op, rewriter);
801+
714802
if (!isSPVBuiltinAvailable(op)) {
715803
// Fallback to GenISA interface.
716804
rewriter.replaceOp(op, createGenISA2DBlockWrite(op, rewriter));
@@ -785,6 +873,8 @@ struct TritonMatrix2DBlockPrefetchLowering
785873
LogicalResult
786874
matchAndRewrite(TritonGEN::Matrix2DBlockPrefetchOp op, OpAdaptor adaptor,
787875
ConversionPatternRewriter &rewriter) const override {
876+
validateMatrix2DBlockParameters(op, rewriter);
877+
788878
if (!isSPVBuiltinAvailable(op)) {
789879
// Fallback to GenISA interface.
790880
rewriter.replaceOp(op, createGenISA2DBlockPrefetch(op, rewriter));

0 commit comments

Comments
 (0)