Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions include/triton/Tools/Sys/GetEnv.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ inline const std::set<std::string> CACHE_INVALIDATING_ENV_VARS = {
"TRITON_F32_DEFAULT",
"TRITON_PREFER_TMEM_16x256_LAYOUT",
"TRITON_ENABLE_EXPERIMENTAL_CONSAN",
"TRITON_INTEL_2DBLOCK_ASSERT",
"TRITON_INTEL_AGGRESSIVE_DPAS_REUSE",
"TRITON_INTEL_ENABLE_BLOCK_IO_ALL_LAYOUTS",
"TRITON_INTEL_ENABLE_DPAS_FOR_WARP_SIZE_32",
Expand Down
55 changes: 55 additions & 0 deletions python/test/unit/intel/block_load_helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import torch
import triton

import ctypes
import sys


def run_load_ir(temp_file, elem_size, *args):
out_type = f"i{int(elem_size) * 4}"
ir = f"""
module attributes {{
ttg.target = "xpu",
"ttg.num-warps" = 32 : i32,
"ttg.num-ctas" = 1 : i32,
"ttg.threads-per-warp" = 16 : i32
}} {{
tt.func @dyn_block(
%iptr : i64, %base_width : i32,
%base_height : i32, %base_pitch : i32,
%x : i32, %y : i32) {{
%p0 = llvm.inttoptr %iptr : i64 to !llvm.ptr

%v = triton_gen.2Dblockload %p0, %base_width, %base_height,
%base_pitch, %x, %y
{{ elem_size_in_bits = {elem_size}, tile_width = 8, tile_height = 8,
v_blocks = 1, transpose = false,
vnni_transform = false, cache_control = Default }}
: (!llvm.ptr, i32, i32, i32, i32, i32)
-> vector<1x{out_type}>

// prevent GluonInline
%v_cast = llvm.bitcast %v : vector<1x{out_type}> to {out_type}
llvm.inline_asm has_side_effects asm_dialect = att
"", "r" %v_cast : ({out_type}) -> ()

tt.return
}}
}}
"""

with open(temp_file, "w", encoding="utf-8") as f:
f.write(ir)

kernel = triton.compile(temp_file)

a = torch.zeros((256, 64), dtype=torch.float32, device="xpu")

addr = ctypes.c_int64(a.data_ptr()).value

kernel[(1, 1, 1)](addr, *map(int, args), 0)


if __name__ == "__main__":
fn = globals()[sys.argv[1]]
fn(*sys.argv[2:])
47 changes: 47 additions & 0 deletions python/test/unit/intel/test_block_load.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
import pytest
import torch

import os
import signal
import subprocess
import sys
import pathlib
from functools import partial

Expand Down Expand Up @@ -207,3 +212,45 @@ def triton_mm(X, Y, b=None, transpose_x=False, transpose_y=False):
result_tor = fn_tor()
result_tri = fn_tri()
torch.testing.assert_close(result_tri, result_tor, atol=1e-2, rtol=1e-3)


@pytest.mark.parametrize("elem_size, width, height, pitch, x",
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Runtime checks to see if asserts actually work.

[[8, 16777216, 64, 16777216, 0], # width <= 24 bits
[8, 32, 64, 128, 0], # width >= 64
[8, 66, 64, 128, 0], # width % max(4,elemSize) == 0
[8, 128, 16777216, 128, 0], # height <= 24 bits
[8, 128, 64, 16777216, 0], # pitch <= 24 bits
[8, 128, 64, 32, 0], # pitch >= 64
[8, 128, 64, 70, 0], # pitch % 16 == 0
[8, 128, 64, 120, 0], # pitch >= width
[8, 128, 64, 128, 1], # x*elemSize % 4 == 0 (alignment for 8-bit)
[16, 128, 64, 128, 1], # x*elemSize % 4 == 0 (alignment for 16-bit)
])
@pytest.mark.skipif(not is_xpu(), reason="Block load tests are specific to the XPU backend")
@pytest.mark.xfail(
not (torch.xpu.get_device_capability()['has_subgroup_2d_block_io']
and torch.xpu.get_device_capability()['has_subgroup_matrix_multiply_accumulate']),
reason="Block loads and/or DPAS not supported on this architecture", run=False)
def test_block_load_asserts(elem_size, width, height, pitch, x, monkeypatch, tmp_path: pathlib.Path):
monkeypatch.setenv("TRITON_INTEL_2DBLOCK_ASSERT", "1")

dir_path = os.path.dirname(os.path.realpath(__file__))
helper_path = os.path.join(dir_path, "block_load_helper.py")

temp_file = tmp_path / "test_block_load_asserts.ttgir"

proc = subprocess.run(
[
sys.executable, helper_path, "run_load_ir",
str(temp_file),
str(elem_size),
str(width),
str(height),
str(pitch),
str(x)
],
capture_output=True,
)

rc = proc.returncode
assert rc == -signal.SIGABRT
33 changes: 33 additions & 0 deletions test/TritonGEN/tritongen-2Dblockload-to-llvm-asserts.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
// RUN: env TRITON_INTEL_2DBLOCK_ASSERT=1 triton-opt -convert-tritongen-to-llvm -split-input-file %s | FileCheck %s --check-prefix=ASSERT
// RUN: triton-opt -convert-tritongen-to-llvm -split-input-file %s | FileCheck %s --check-prefix=NOASSERT

module attributes {"ttg.threads-per-warp" = 16 : i32} {
llvm.func @triton_gen.2Dblockload(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) {
// ASSERT: llvm.call spir_funccc @__assert_fail
// NOASSERT-NOT: __assert_fail
%0 = triton_gen.2Dblockload %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=8, tile_width=8, tile_height=8, v_blocks=1, transpose=false, vnni_transform=false, cache_control=Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<2xi16>
llvm.return
}
}

// -----

module attributes {"ttg.threads-per-warp" = 16 : i32} {
llvm.func @triton_gen.2Dblockprefetch(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) {
// ASSERT: llvm.call spir_funccc @__assert_fail
// NOASSERT-NOT: __assert_fail
triton_gen.2Dblockprefetch %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=8, tile_width=32, tile_height=16, v_blocks=1, cache_control=Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32)
llvm.return
}
}

// -----

module attributes {"ttg.threads-per-warp" = 16 : i32} {
llvm.func @triton_gen.2Dblockstore(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32, %stored_val : vector<8xi16>) {
// ASSERT: llvm.call spir_funccc @__assert_fail
// NOASSERT-NOT: __assert_fail
triton_gen.2Dblockstore %ptr, %base_width, %base_height, %base_pitch, %x, %y, %stored_val {elem_size_in_bits=8, tile_width=32, tile_height=8, v_blocks=1, cache_control=Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32, vector<8xi16>)
llvm.return
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,13 @@ namespace triton {
#define GEN_PASS_DECL
#include "intel/include/TritonGENToLLVM/Passes.h.inc"

void populateTritonGENToLLVMConversionPatterns(LLVMTypeConverter &converter,
RewritePatternSet &patterns);
namespace gpu::intel {
class LibCallEmitter;
} // namespace gpu::intel

void populateTritonGENToLLVMConversionPatterns(
LLVMTypeConverter &converter, RewritePatternSet &patterns,
const mlir::triton::gpu::intel::LibCallEmitter &emitter);

void registerConvertTritonGENToLLVMInterface(DialectRegistry &registry);

Expand Down
Loading
Loading