diff --git a/python/triton/compiler/compiler.py b/python/triton/compiler/compiler.py index d33d3f5b48..4c0cc63d9c 100644 --- a/python/triton/compiler/compiler.py +++ b/python/triton/compiler/compiler.py @@ -138,7 +138,7 @@ def parse(full_name, ext, context): return module if ext == "llir" or ext == "ptx" or ext == "amdgcn": return Path(full_name).read_text() - if ext == "cubin" or ext == "hsaco": + if ext == "cubin" or ext == "hsaco" or ext == "zebin": return Path(full_name).read_bytes() if ext == "spv": return Path(full_name).read_bytes() @@ -332,7 +332,7 @@ def compile(src, target=None, options=None, _env_vars=None): print(f"\nOverriding kernel with file {full_name}") next_module = parse(full_name, ext, context) # If TRITON_STORE_BINARY_ONLY is 1, only store cubin/hsaco/json - if (not store_only_binary) or (ext in ("cubin", "hsaco", "json", "spv")): + if (not store_only_binary) or (ext in ("cubin", "hsaco", "zebin", "json", "spv")): metadata_group[ir_filename] = fn_cache_manager.put(next_module, ir_filename) if fn_dump_manager is not None: fn_dump_manager.put(next_module, ir_filename) @@ -433,11 +433,15 @@ def __init__(self, src, metadata_group, hash): self.name = self.metadata.name # stores the text of each level of IR that was generated during compilation asm_files = [Path(p) for c, p in metadata_group.items() if not c.endswith(".json")] + + def read_file(path): + try: + return path.read_text() + except UnicodeDecodeError: + return path.read_bytes() + + self.asm = AsmDict({file.suffix[1:]: read_file(file) for file in asm_files}) binary_ext = backend.binary_ext - self.asm = AsmDict({ - file.suffix[1:]: file.read_bytes() if file.suffix[1:] == binary_ext else file.read_text() - for file in asm_files - }) self.metadata_group = metadata_group self.kernel = self.asm[binary_ext] # binaries are lazily initialized @@ -477,8 +481,7 @@ def raise_(err): knobs.runtime.kernel_load_start_hook(self.module, self.function, self.name, self.metadata_group, self.hash) # TODO: n_regs, n_spills should be metadata generated when calling `ptxas` self.module, self.function, self.n_regs, self.n_spills, self.n_max_threads = driver.active.utils.load_binary( - self.name, self.kernel, self.metadata.shared, self.metadata.build_flags, - not self.metadata.generate_native_code, device) + self.name, self.kernel, self.metadata.shared, self.metadata.build_flags, False, device) if hasattr(self.metadata, "threads_per_warp"): warp_size = self.metadata.threads_per_warp else: diff --git a/third_party/intel/backend/compiler.py b/third_party/intel/backend/compiler.py index 397553519b..3e9f4045df 100644 --- a/third_party/intel/backend/compiler.py +++ b/third_party/intel/backend/compiler.py @@ -120,7 +120,7 @@ def __init__(self, target: tuple) -> None: mod = compile_module_from_src(src=Path(os.path.join(dirname, "arch_parser.c")).read_text(), name="arch_utils") self.device_arch = knobs.intel.device_arch or mod.parse_device_arch(target.arch.get('architecture', 0)) self.properties = self.parse_target(target.arch) - self.binary_ext = "spv" + self.binary_ext = "zebin" def get_target_name(self, options) -> str: return f"xpu:{self.device_arch}" @@ -374,6 +374,10 @@ def make_llir(src, metadata, options): def make_spv(src, metadata, options, device_arch): spirv, name = intel.translate_to_spirv(src) metadata["name"] = name + return spirv + + @staticmethod + def make_zebin(src, metadata, options, device_arch): if options.grf_mode == 'small': metadata["build_flags"] = "-cl-intel-128-GRF-per-thread" elif options.grf_mode == 'large': @@ -392,50 +396,49 @@ def make_spv(src, metadata, options, device_arch): if knobs.intel.dump_shader_info: # The IGC (Intel Graphic Compiler) only parses the options at first time in JIT-ing the binary per process. # Have to use the `ocloc` to generate the binary in sub-process to work around the limitation. - assert options.generate_native_code, "Only support native code generation with shader dump" + # assert options.generate_native_code, "Only support native code generation with shader dump" shader_dump_opt = f" -igc_opts ',DumpToCustomDir={metadata['cache_dir']},ShaderDumpEnable=1'" - metadata["generate_native_code"] = options.generate_native_code - - if options.generate_native_code: - with tempfile.TemporaryDirectory() as temp_dir: - with tempfile.NamedTemporaryFile(mode='wb', suffix='.spv', dir=temp_dir, delete=False) as fsrc: - fsrc.write(spirv) - fbin = fsrc.name + '.o' - - ocloc_cmd = [ - 'ocloc', 'compile', '-file', fsrc.name, '-o', fbin, '-spirv_input', '-device', device_arch, - '-options', metadata["build_flags"] + shader_dump_opt - ] - - try: - output = subprocess.check_output(ocloc_cmd, stderr=subprocess.STDOUT, text=True) - if 'spilled' in output and metadata["build_flags"].find("-cl-intel-256-GRF-per-thread") == -1: - """ - The exact message is something like: - warning: kernel matmul_kernel compiled SIMD16 allocated 128 regs and spilled around 217 - is "spilled" enough for now? - """ - metadata["build_flags"] += " -cl-intel-256-GRF-per-thread" - # re-run with new build flags - ocloc_cmd[-1] = metadata["build_flags"] + shader_dump_opt - subprocess.check_output(ocloc_cmd, stderr=subprocess.STDOUT, text=True) - except subprocess.CalledProcessError as e: - if e.returncode == 255: - error = 'Internal Triton ZEBIN codegen error' - elif e.returncode == 128 + signal.SIGSEGV: - error = '`ocloc` raised SIGSEGV' - else: - error = f'`ocloc` failed with error code {e.returncode}' - - raise RuntimeError(f'{error}\n' - f'`ocloc` stderr:\n{e.output}\n' - f'Repro command: {ocloc_cmd}\n') from e - - with open(fbin, 'rb') as f: - zebin = f.read() - return zebin - return spirv + # metadata["generate_native_code"] = options.generate_native_code + + # if options.generate_native_code: + with tempfile.TemporaryDirectory() as temp_dir: + with tempfile.NamedTemporaryFile(mode='wb', suffix='.spv', dir=temp_dir, delete=False) as fsrc: + fsrc.write(src) + fbin = fsrc.name + '.o' + + ocloc_cmd = [ + 'ocloc', 'compile', '-file', fsrc.name, '-o', fbin, '-spirv_input', '-device', device_arch, '-options', + metadata["build_flags"] + shader_dump_opt + ] + + try: + output = subprocess.check_output(ocloc_cmd, stderr=subprocess.STDOUT, text=True) + if 'spilled' in output and metadata["build_flags"].find("-cl-intel-256-GRF-per-thread") == -1: + """ + The exact message is something like: + warning: kernel matmul_kernel compiled SIMD16 allocated 128 regs and spilled around 217 + is "spilled" enough for now? + """ + metadata["build_flags"] += " -cl-intel-256-GRF-per-thread" + # re-run with new build flags + ocloc_cmd[-1] = metadata["build_flags"] + shader_dump_opt + subprocess.check_output(ocloc_cmd, stderr=subprocess.STDOUT, text=True) + except subprocess.CalledProcessError as e: + if e.returncode == 255: + error = 'Internal Triton ZEBIN codegen error' + elif e.returncode == 128 + signal.SIGSEGV: + error = '`ocloc` raised SIGSEGV' + else: + error = f'`ocloc` failed with error code {e.returncode}' + + raise RuntimeError(f'{error}\n' + f'`ocloc` stderr:\n{e.output}\n' + f'Repro command: {ocloc_cmd}\n') from e + + with open(fbin, 'rb') as f: + zebin = f.read() + return zebin def add_stages(self, stages, options, language): if language == Language.TRITON: @@ -445,6 +448,7 @@ def add_stages(self, stages, options, language): stages["ttgir"] = lambda src, metadata: self.gluon_to_ttgir(src, metadata, options) stages["llir"] = lambda src, metadata: self.make_llir(src, metadata, options) stages["spv"] = lambda src, metadata: self.make_spv(src, metadata, options, self.device_arch) + stages["zebin"] = lambda src, metadata: self.make_zebin(src, metadata, options, self.device_arch) @functools.lru_cache() def hash(self):