Skip to content

Commit afe2cc0

Browse files
committed
[intel] 2Dblock runtime HW checks
1 parent 4249232 commit afe2cc0

File tree

3 files changed

+155
-0
lines changed

3 files changed

+155
-0
lines changed

python/test/unit/intel/test_block_load.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,3 +207,47 @@ def triton_mm(X, Y, b=None, transpose_x=False, transpose_y=False):
207207
result_tor = fn_tor()
208208
result_tri = fn_tri()
209209
torch.testing.assert_close(result_tri, result_tor, atol=1e-2, rtol=1e-3)
210+
211+
def test_block_load_asserts(M, N, dtype_str, transpose, device, tmp_path: pathlib.Path):
212+
TTGIR = r"""
213+
module attributes {
214+
ttg.target = "xpu",
215+
"ttg.num-warps" = 32 : i32,
216+
"ttg.num-ctas" = 1 : i32,
217+
"ttg.threads-per-warp" = 16 : i32
218+
} {
219+
tt.func @dyn_block(
220+
%iptr : i64, %base_width : i32,
221+
%base_height : i32, %base_pitch : i32,
222+
%x : i32, %y : i32) {
223+
%p0 = llvm.inttoptr %iptr : i64 to !llvm.ptr
224+
225+
%0 = triton_gen.2Dblockload %p0, %base_width, %base_height,
226+
%base_pitch, %x, %y
227+
{ elem_size_in_bits = 8, tile_width = 8, tile_height = 8,
228+
v_blocks = 1, transpose = false,
229+
vnni_transform = false, cache_control = Default }
230+
: (!llvm.ptr, i32, i32, i32, i32, i32)
231+
-> vector<2xi16>
232+
tt.return
233+
}
234+
}
235+
"""
236+
237+
tmp = pathlib.Path("test_block_load_asserts.ttgir").resolve()
238+
tmp.write_text(TTGIR)
239+
240+
path = str(tmp)
241+
kernel = triton.compile(path)
242+
243+
a = torch.randn((256, 64), dtype=torch.float32, device="xpu")
244+
245+
addr = int(a.data_ptr())
246+
247+
import ctypes
248+
addr64 = ctypes.c_int64(addr).value
249+
250+
print(kernel.asm['spvdis'])
251+
252+
# TODO catch the assert from __assert_fail
253+
kernel[(1,1,1)](addr64, 64, 64, 1, 0, 0)
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
// RUN: env TRITON_INTEL_2DBLOCK_ASSERT=1 triton-opt -convert-tritongen-to-llvm -split-input-file %s | FileCheck %s
2+
3+
module attributes {"ttg.threads-per-warp" = 16 : i32} {
4+
llvm.func @triton_gen.2Dblockload(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) {
5+
// CHECK: "llvm.intr.trap"() : () -> ()
6+
%0 = triton_gen.2Dblockload %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=8, tile_width=8, tile_height=8, v_blocks=1, transpose=false, vnni_transform=false, cache_control=Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<2xi16>
7+
llvm.return
8+
}
9+
}
10+
11+
// -----
12+
13+
llvm.func @triton_gen.2Dblockprefetch(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) {
14+
// CHECK: "llvm.intr.trap"() : () -> ()
15+
triton_gen.2Dblockprefetch %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=8, tile_width=32, tile_height=16, v_blocks=1, cache_control=Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32)
16+
llvm.return
17+
}
18+
19+
// -----
20+
21+
module attributes {"ttg.threads-per-warp" = 16 : i32} {
22+
llvm.func @triton_gen.2Dblockstore(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32, %stored_val : vector<8xi16>) {
23+
// CHECK: "llvm.intr.trap"() : () -> ()
24+
triton_gen.2Dblockstore %ptr, %base_width, %base_height, %base_pitch, %x, %y, %stored_val {elem_size_in_bits=8, tile_width=32, tile_height=8, v_blocks=1, cache_control=Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32, vector<8xi16>)
25+
llvm.return
26+
}
27+
}
28+
29+
// TODO: change checks to use __assert_fail

third_party/intel/lib/TritonGENToLLVM/TritonGENToLLVMPass.cpp

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -503,6 +503,82 @@ createGenISA2DBlockPrefetch(TritonGEN::Matrix2DBlockPrefetchOp op,
503503
intel::noUnwindWillReturnAttrs);
504504
}
505505

506+
template <typename OpTy>
507+
static void validateMatrix2DBlockParameters(OpTy op,
508+
mlir::ConversionPatternRewriter &rewriter) {
509+
using namespace mlir;
510+
using namespace mlir::LLVM;
511+
512+
Location loc = op->getLoc();
513+
auto b = TritonLLVMOpBuilder(loc, rewriter);
514+
MLIRContext *ctx = rewriter.getContext();
515+
516+
Value baseWidth = op.getBaseWidth();
517+
Value baseHeight = op.getBaseHeight();
518+
Value basePitch = op.getBasePitch();
519+
Value x = op.getX();
520+
unsigned elemSize = op.getElemSizeInBits() / 8;
521+
522+
if (!baseWidth.getType().isInteger(32))
523+
baseWidth = rewriter.create<ZExtOp>(loc, rewriter.getI32Type(), baseWidth);
524+
if (!baseHeight.getType().isInteger(32))
525+
baseHeight = rewriter.create<ZExtOp>(loc, rewriter.getI32Type(), baseHeight);
526+
if (!basePitch.getType().isInteger(32))
527+
basePitch = rewriter.create<ZExtOp>(loc, rewriter.getI32Type(), basePitch);
528+
if (!x.getType().isInteger(32))
529+
x = rewriter.create<ZExtOp>(loc, rewriter.getI32Type(), x);
530+
531+
Value c0 = b.i32_val(0);
532+
Value c4 = b.i32_val(4);
533+
Value c64 = b.i32_val(64);
534+
Value c24m1 = b.i32_val((1u<<24) - 1); // 2^24 - 1
535+
Value cElemSize = b.i32_val(elemSize);
536+
537+
// ===== validation predicates =====
538+
539+
// width!=0 && width<2^24 && width%4==0
540+
Value wZero = rewriter.create<ICmpOp>(loc, ICmpPredicate::eq, baseWidth, c0);
541+
Value wTooLarge = rewriter.create<ICmpOp>(loc, ICmpPredicate::ugt, baseWidth, c24m1);
542+
Value wRem = rewriter.create<URemOp>(loc, baseWidth, c4);
543+
Value wNotAligned = rewriter.create<ICmpOp>(loc, ICmpPredicate::ne, wRem, c0);
544+
Value badWidth = rewriter.create<OrOp>(loc, wZero,
545+
rewriter.create<OrOp>(loc, wTooLarge, wNotAligned));
546+
547+
// height!=0 && height<2^24
548+
Value hZero = rewriter.create<ICmpOp>(loc, ICmpPredicate::eq, baseHeight, c0);
549+
Value hTooLarge = rewriter.create<ICmpOp>(loc, ICmpPredicate::ugt, baseHeight, c24m1);
550+
Value badHeight = rewriter.create<OrOp>(loc, hZero, hTooLarge);
551+
552+
// pitch >= 64
553+
Value badPitch = rewriter.create<ICmpOp>(loc, ICmpPredicate::ult, basePitch, c64);
554+
555+
// x*elemSize % 4 == 0
556+
Value offsetBytes = rewriter.create<MulOp>(loc, x, cElemSize);
557+
Value offsetRem = rewriter.create<URemOp>(loc, offsetBytes, c4);
558+
Value badOffset = rewriter.create<ICmpOp>(loc, ICmpPredicate::ne, offsetRem, c0);
559+
560+
// assert on any
561+
Value anyBad = rewriter.create<OrOp>(loc, badWidth,
562+
rewriter.create<OrOp>(loc, badHeight,
563+
rewriter.create<OrOp>(loc, badPitch, badOffset)));
564+
565+
Block *curBlock = rewriter.getBlock();
566+
auto ip = rewriter.getInsertionPoint();
567+
Block *contBlock = rewriter.splitBlock(curBlock, ip);
568+
Region *region = contBlock->getParent();
569+
Block *trapBlock = rewriter.createBlock(region, Region::iterator(contBlock));
570+
571+
// TODO: use __assert_fail instead of llvm.intr.trap
572+
rewriter.setInsertionPointToStart(trapBlock);
573+
rewriter.create<Trap>(loc);
574+
rewriter.create<UnreachableOp>(loc);
575+
576+
rewriter.setInsertionPointToEnd(curBlock);
577+
rewriter.create<CondBrOp>(loc, anyBad, trapBlock, ValueRange{}, contBlock, ValueRange{});
578+
579+
rewriter.setInsertionPointToStart(contBlock);
580+
}
581+
506582
namespace {
507583

508584
//===----------------------------------------------------------------------===//
@@ -636,6 +712,8 @@ struct TritonMatrix2DBlockLoadLowering
636712
LogicalResult
637713
matchAndRewrite(TritonGEN::Matrix2DBlockLoadOp op, OpAdaptor adaptor,
638714
ConversionPatternRewriter &rewriter) const override {
715+
validateMatrix2DBlockParameters(op, rewriter);
716+
639717
if (!isSPVBuiltinAvailable(op)) {
640718
// Fallback to GenISA interface.
641719
rewriter.replaceOp(op, createGenISA2DBlockRead(op, rewriter));
@@ -711,6 +789,8 @@ struct TritonMatrix2DBlockStoreLowering
711789
LogicalResult
712790
matchAndRewrite(TritonGEN::Matrix2DBlockStoreOp op, OpAdaptor adaptor,
713791
ConversionPatternRewriter &rewriter) const override {
792+
validateMatrix2DBlockParameters(op, rewriter);
793+
714794
if (!isSPVBuiltinAvailable(op)) {
715795
// Fallback to GenISA interface.
716796
rewriter.replaceOp(op, createGenISA2DBlockWrite(op, rewriter));
@@ -785,6 +865,8 @@ struct TritonMatrix2DBlockPrefetchLowering
785865
LogicalResult
786866
matchAndRewrite(TritonGEN::Matrix2DBlockPrefetchOp op, OpAdaptor adaptor,
787867
ConversionPatternRewriter &rewriter) const override {
868+
validateMatrix2DBlockParameters(op, rewriter);
869+
788870
if (!isSPVBuiltinAvailable(op)) {
789871
// Fallback to GenISA interface.
790872
rewriter.replaceOp(op, createGenISA2DBlockPrefetch(op, rewriter));

0 commit comments

Comments
 (0)