-
Notifications
You must be signed in to change notification settings - Fork 74
[Draft] Add a new stage to generate zebin
to align CUDA stages in triton.compile
#5189
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -120,7 +120,7 @@ def __init__(self, target: tuple) -> None: | |
mod = compile_module_from_src(src=Path(os.path.join(dirname, "arch_parser.c")).read_text(), name="arch_utils") | ||
self.device_arch = knobs.intel.device_arch or mod.parse_device_arch(target.arch.get('architecture', 0)) | ||
self.properties = self.parse_target(target.arch) | ||
self.binary_ext = "spv" | ||
self.binary_ext = "zebin" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why are we changing from spv to zebin, we want to generate a spv file still, will this affect the generation of the SPIRV binary ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The SPIRV file is still generated and saved in the Triton cache director. |
||
|
||
def get_target_name(self, options) -> str: | ||
return f"xpu:{self.device_arch}" | ||
|
@@ -374,6 +374,10 @@ def make_llir(src, metadata, options): | |
def make_spv(src, metadata, options, device_arch): | ||
spirv, name = intel.translate_to_spirv(src) | ||
metadata["name"] = name | ||
return spirv | ||
|
||
@staticmethod | ||
def make_zebin(src, metadata, options, device_arch): | ||
if options.grf_mode == 'small': | ||
metadata["build_flags"] = "-cl-intel-128-GRF-per-thread" | ||
elif options.grf_mode == 'large': | ||
|
@@ -392,50 +396,49 @@ def make_spv(src, metadata, options, device_arch): | |
if knobs.intel.dump_shader_info: | ||
# The IGC (Intel Graphic Compiler) only parses the options at first time in JIT-ing the binary per process. | ||
# Have to use the `ocloc` to generate the binary in sub-process to work around the limitation. | ||
assert options.generate_native_code, "Only support native code generation with shader dump" | ||
# assert options.generate_native_code, "Only support native code generation with shader dump" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. remove commented out code |
||
shader_dump_opt = f" -igc_opts ',DumpToCustomDir={metadata['cache_dir']},ShaderDumpEnable=1'" | ||
|
||
metadata["generate_native_code"] = options.generate_native_code | ||
|
||
if options.generate_native_code: | ||
with tempfile.TemporaryDirectory() as temp_dir: | ||
with tempfile.NamedTemporaryFile(mode='wb', suffix='.spv', dir=temp_dir, delete=False) as fsrc: | ||
fsrc.write(spirv) | ||
fbin = fsrc.name + '.o' | ||
|
||
ocloc_cmd = [ | ||
'ocloc', 'compile', '-file', fsrc.name, '-o', fbin, '-spirv_input', '-device', device_arch, | ||
'-options', metadata["build_flags"] + shader_dump_opt | ||
] | ||
|
||
try: | ||
output = subprocess.check_output(ocloc_cmd, stderr=subprocess.STDOUT, text=True) | ||
if 'spilled' in output and metadata["build_flags"].find("-cl-intel-256-GRF-per-thread") == -1: | ||
""" | ||
The exact message is something like: | ||
warning: kernel matmul_kernel compiled SIMD16 allocated 128 regs and spilled around 217 | ||
is "spilled" enough for now? | ||
""" | ||
metadata["build_flags"] += " -cl-intel-256-GRF-per-thread" | ||
# re-run with new build flags | ||
ocloc_cmd[-1] = metadata["build_flags"] + shader_dump_opt | ||
subprocess.check_output(ocloc_cmd, stderr=subprocess.STDOUT, text=True) | ||
except subprocess.CalledProcessError as e: | ||
if e.returncode == 255: | ||
error = 'Internal Triton ZEBIN codegen error' | ||
elif e.returncode == 128 + signal.SIGSEGV: | ||
error = '`ocloc` raised SIGSEGV' | ||
else: | ||
error = f'`ocloc` failed with error code {e.returncode}' | ||
|
||
raise RuntimeError(f'{error}\n' | ||
f'`ocloc` stderr:\n{e.output}\n' | ||
f'Repro command: {ocloc_cmd}\n') from e | ||
|
||
with open(fbin, 'rb') as f: | ||
zebin = f.read() | ||
return zebin | ||
return spirv | ||
# metadata["generate_native_code"] = options.generate_native_code | ||
|
||
# if options.generate_native_code: | ||
with tempfile.TemporaryDirectory() as temp_dir: | ||
with tempfile.NamedTemporaryFile(mode='wb', suffix='.spv', dir=temp_dir, delete=False) as fsrc: | ||
Comment on lines
+399
to
+406
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These commented-out lines should be removed rather than left as commented code. If this logic is no longer needed, clean it up completely. Copilot uses AI. Check for mistakes. Positive FeedbackNegative Feedback |
||
fsrc.write(src) | ||
fbin = fsrc.name + '.o' | ||
|
||
ocloc_cmd = [ | ||
'ocloc', 'compile', '-file', fsrc.name, '-o', fbin, '-spirv_input', '-device', device_arch, '-options', | ||
metadata["build_flags"] + shader_dump_opt | ||
] | ||
|
||
try: | ||
output = subprocess.check_output(ocloc_cmd, stderr=subprocess.STDOUT, text=True) | ||
if 'spilled' in output and metadata["build_flags"].find("-cl-intel-256-GRF-per-thread") == -1: | ||
""" | ||
The exact message is something like: | ||
warning: kernel matmul_kernel compiled SIMD16 allocated 128 regs and spilled around 217 | ||
is "spilled" enough for now? | ||
""" | ||
metadata["build_flags"] += " -cl-intel-256-GRF-per-thread" | ||
# re-run with new build flags | ||
ocloc_cmd[-1] = metadata["build_flags"] + shader_dump_opt | ||
subprocess.check_output(ocloc_cmd, stderr=subprocess.STDOUT, text=True) | ||
except subprocess.CalledProcessError as e: | ||
if e.returncode == 255: | ||
error = 'Internal Triton ZEBIN codegen error' | ||
elif e.returncode == 128 + signal.SIGSEGV: | ||
error = '`ocloc` raised SIGSEGV' | ||
else: | ||
error = f'`ocloc` failed with error code {e.returncode}' | ||
|
||
raise RuntimeError(f'{error}\n' | ||
f'`ocloc` stderr:\n{e.output}\n' | ||
f'Repro command: {ocloc_cmd}\n') from e | ||
|
||
with open(fbin, 'rb') as f: | ||
zebin = f.read() | ||
return zebin | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The unreachable Copilot uses AI. Check for mistakes. Positive FeedbackNegative Feedback |
||
|
||
def add_stages(self, stages, options, language): | ||
if language == Language.TRITON: | ||
|
@@ -445,6 +448,7 @@ def add_stages(self, stages, options, language): | |
stages["ttgir"] = lambda src, metadata: self.gluon_to_ttgir(src, metadata, options) | ||
stages["llir"] = lambda src, metadata: self.make_llir(src, metadata, options) | ||
stages["spv"] = lambda src, metadata: self.make_spv(src, metadata, options, self.device_arch) | ||
stages["zebin"] = lambda src, metadata: self.make_zebin(src, metadata, options, self.device_arch) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can't make this step mandatory yet (due to #5153 (comment)), but if we make it optional using |
||
|
||
@functools.lru_cache() | ||
def hash(self): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The hardcoded
False
value replacesnot self.metadata.generate_native_code
. This magic boolean should be documented or use a named constant to clarify its purpose.Copilot uses AI. Check for mistakes.