From a7e1e4ef33ed4d48a6e2cf607b1885331a8e073a Mon Sep 17 00:00:00 2001 From: Vladimir Loncar Date: Mon, 4 Aug 2025 20:50:13 +0200 Subject: [PATCH 1/3] Support for parsing ONNX Pad node --- hls4ml/converters/onnx/reshape.py | 65 ++++++++++++++ hls4ml/converters/pytorch/reshape.py | 4 + .../passes/convert_to_channels_last.py | 5 +- test/pytest/test_pytorch_constpadmapping.py | 44 ---------- test/pytest/test_zeropadding_pytorch_onnx.py | 86 +++++++++++++++++++ 5 files changed, 158 insertions(+), 46 deletions(-) delete mode 100644 test/pytest/test_pytorch_constpadmapping.py create mode 100644 test/pytest/test_zeropadding_pytorch_onnx.py diff --git a/hls4ml/converters/onnx/reshape.py b/hls4ml/converters/onnx/reshape.py index f11796b6db..c27d40cfde 100644 --- a/hls4ml/converters/onnx/reshape.py +++ b/hls4ml/converters/onnx/reshape.py @@ -58,3 +58,68 @@ def parse_resize_layer(node, input_names, input_shapes, graph): ) return layer + + +@onnx_handler('Pad') +def parse_pad_layer(node, input_names, input_shapes, graph): + layer = {} + layer['name'] = node.name + layer['class_name'] = 'ZeroPadding' + layer['inputs'] = input_names + layer['outputs'] = list(node.output) + layer['data_format'] = ( + 'channels_last' if any(node.domain == 'qonnx.custom_op.channels_last' for node in graph.node) else 'channels_first' + ) + + mode = get_onnx_attribute(node, 'mode') + if mode is not None and mode != 'constant': + raise RuntimeError(f'Unsupported padding mode: {mode} in node {node.name}') + + pads = get_onnx_attribute(node, 'pads') + + dim = 0 + if len(input_shapes[0]) == 3: + dim = 1 # 2D input (batch, channels, width), will use ZeroPadding1D + if layer['data_format'] == 'channels_first': + _, channels, width = input_shapes[0] + pad_left, pad_right = pads[2], pads[5] + else: + _, width, channels = input_shapes[0] + pad_left, pad_right = pads[1], pads[4] + out_width = width + pad_left + pad_right + + layer['n_chan'] = channels + layer['in_width'] = width + layer['out_width'] = out_width + + layer['pad_left'] = pad_left + layer['pad_right'] = pad_right + elif len(input_shapes[0]) == 4: + dim = 2 # 3D input (batch, channels, height, width), will use ZeroPadding2D + if layer['data_format'] == 'channels_first': + _, channels, height, width = input_shapes[0] + pad_top, pad_bottom = pads[2], pads[6] + pad_left, pad_right = pads[3], pads[7] + else: + _, height, width, channels = input_shapes[0] + pad_top, pad_bottom = pads[1], pads[5] + pad_left, pad_right = pads[2], pads[6] + out_height = height + pad_top + pad_bottom + out_width = width + pad_left + pad_right + + layer['n_chan'] = channels + layer['in_height'] = height + layer['in_width'] = width + layer['out_height'] = out_height + layer['out_width'] = out_width + + layer['pad_top'] = pad_top + layer['pad_bottom'] = pad_bottom + layer['pad_left'] = pad_left + layer['pad_right'] = pad_right + else: + raise RuntimeError(f'Unsupported input shape: {input_shapes[0]} for Pad node {node.name}') + + layer['class_name'] += str(dim) + 'D' + + return layer diff --git a/hls4ml/converters/pytorch/reshape.py b/hls4ml/converters/pytorch/reshape.py index 64b60c97b9..f0ac34c122 100644 --- a/hls4ml/converters/pytorch/reshape.py +++ b/hls4ml/converters/pytorch/reshape.py @@ -204,6 +204,8 @@ def parse_constantpad2d_layer(operation, layer_name, input_names, input_shapes, layer['out_height'] = out_height layer['out_width'] = out_width + layer['data_format'] = 'channels_first' # Default data format in PyTorch + return layer, output_shape @@ -243,4 +245,6 @@ def parse_constantpad1d_layer(operation, layer_name, input_names, input_shapes, layer['in_width'] = width layer['out_width'] = out_width + layer['data_format'] = 'channels_first' # Default data format in PyTorch + return layer, output_shape diff --git a/hls4ml/model/optimizer/passes/convert_to_channels_last.py b/hls4ml/model/optimizer/passes/convert_to_channels_last.py index 6511a6967b..4cc7fc3844 100644 --- a/hls4ml/model/optimizer/passes/convert_to_channels_last.py +++ b/hls4ml/model/optimizer/passes/convert_to_channels_last.py @@ -13,8 +13,9 @@ class ChannelsLastConverter(OptimizerPass): def match(self, node): # If this parameter has not been set, this model does not need to be converted - if 'ChannelsLastConversion' not in node.model.config.config['HLSConfig']['Model']: - return False # No littering of unused property + do_convert = node.model.config.config['HLSConfig']['Model'].get('ChannelsLastConversion', 'off') + if do_convert == 'off': + return False if not hasattr(node, 'channels_last_converted'): return True diff --git a/test/pytest/test_pytorch_constpadmapping.py b/test/pytest/test_pytorch_constpadmapping.py deleted file mode 100644 index b4f602d711..0000000000 --- a/test/pytest/test_pytorch_constpadmapping.py +++ /dev/null @@ -1,44 +0,0 @@ -import torch.nn as nn - -from hls4ml.converters import convert_from_pytorch_model -from hls4ml.utils.config import config_from_pytorch_model - - -def test_pytorch_constantpad_1d_2d(): - class Pad1DModel(nn.Module): - def __init__(self): - super().__init__() - self.pad = nn.ConstantPad1d((2, 3), 0) # pad 2 left, 3 right - - def forward(self, x): - return self.pad(x) - - class Pad2DModel(nn.Module): - def __init__(self): - super().__init__() - self.pad = nn.ConstantPad2d((1, 2, 3, 4), 0) # left, right, top, bottom - - def forward(self, x): - return self.pad(x) - - # 1D test: batch=1, channels=2, width=4, values 1,2,3,4 - model_1d = Pad1DModel() - model_1d.eval() - config_1d = config_from_pytorch_model(model_1d, (2, 4)) - hls_model_1d = convert_from_pytorch_model(model_1d, hls_config=config_1d) - print("1D Padding Model Layers:") - for layer in hls_model_1d.get_layers(): - print(f"{layer.name}: {layer.class_name}") - - # 2D test: batch=1, channels=1, height=2, width=4, values 1,2,3,4,5,6,7,8 - model_2d = Pad2DModel() - model_2d.eval() - config_2d = config_from_pytorch_model(model_2d, (1, 2, 4)) - hls_model_2d = convert_from_pytorch_model(model_2d, hls_config=config_2d) - print("2D Padding Model Layers:") - for layer in hls_model_2d.get_layers(): - print(f"{layer.name}: {layer.class_name}") - - # Write the HLS projects, cannot compile on Windows - hls_model_1d.write() - hls_model_2d.write() diff --git a/test/pytest/test_zeropadding_pytorch_onnx.py b/test/pytest/test_zeropadding_pytorch_onnx.py new file mode 100644 index 0000000000..5d76274b93 --- /dev/null +++ b/test/pytest/test_zeropadding_pytorch_onnx.py @@ -0,0 +1,86 @@ +from pathlib import Path + +import numpy as np +import qonnx.util.cleanup +import torch +import torch.nn as nn +from qonnx.core.modelwrapper import ModelWrapper + +from hls4ml.converters import convert_from_onnx_model, convert_from_pytorch_model +from hls4ml.utils.config import config_from_onnx_model, config_from_pytorch_model + +test_root_path = Path(__file__).parent + + +def test_constantpad_1d(): + class Pad1DModel(nn.Module): + def __init__(self): + super().__init__() + self.pad = nn.ConstantPad1d((2, 3), 0) # pad 2 left, 3 right + + def forward(self, x): + return self.pad(x) + + model = Pad1DModel() + model.eval() + config_pytorch = config_from_pytorch_model(model, (2, 4), channels_last_conversion='off') + hls_model_pytorch = convert_from_pytorch_model( + model, output_dir=str(test_root_path / 'hls4mlprj_constpad_1d/pytorch'), hls_config=config_pytorch + ) + + hls_model_pytorch.compile() + + onnx_path = str(test_root_path / 'hls4mlprj_constpad_1d/pad1d.onnx') + torch.onnx.export(model, torch.randn(1, 2, 4), onnx_path, opset_version=10) + qonnx.util.cleanup.cleanup(onnx_path, out_file=onnx_path) + pad1d_onnx = ModelWrapper(onnx_path) + + config_onnx = config_from_onnx_model(pad1d_onnx) + hls_model_onnx = convert_from_onnx_model( + pad1d_onnx, output_dir=str(test_root_path / 'hls4mlprj_constpad_1d/onnx'), hls_config=config_onnx + ) + + hls_model_onnx.compile() + + input_data = np.random.randn(10, 2, 4) + pred_pytorch = hls_model_pytorch.predict(input_data) + pred_onnx = hls_model_onnx.predict(input_data) + + np.testing.assert_allclose(pred_pytorch, pred_onnx, rtol=0, atol=1e-5) + + +def test_constantpad_2d(): + class Pad2DModel(nn.Module): + def __init__(self): + super().__init__() + self.pad = nn.ConstantPad2d((1, 2, 3, 4), 0) # left, right, top, bottom + + def forward(self, x): + return self.pad(x) + + model = Pad2DModel() + model.eval() + config_pytorch = config_from_pytorch_model(model, (2, 3, 4), channels_last_conversion='off') + hls_model_pytorch = convert_from_pytorch_model( + model, output_dir=str(test_root_path / 'hls4mlprj_constpad_2d/pytorch'), hls_config=config_pytorch + ) + + hls_model_pytorch.compile() + + onnx_path = str(test_root_path / 'hls4mlprj_constpad_2d/pad2d.onnx') + torch.onnx.export(model, torch.randn(1, 2, 3, 4), onnx_path, opset_version=10) + qonnx.util.cleanup.cleanup(onnx_path, out_file=onnx_path) + pad2d_onnx = ModelWrapper(onnx_path) + + config_onnx = config_from_onnx_model(pad2d_onnx) + hls_model_onnx = convert_from_onnx_model( + pad2d_onnx, output_dir=str(test_root_path / 'hls4mlprj_constpad_2d/onnx'), hls_config=config_onnx + ) + + hls_model_onnx.compile() + + input_data = np.random.randn(10, 2, 3, 4) + pred_pytorch = hls_model_pytorch.predict(input_data) + pred_onnx = hls_model_onnx.predict(input_data) + + np.testing.assert_allclose(pred_pytorch, pred_onnx, rtol=0, atol=1e-5) From 562fd49cf00bf5a4d0c834e6c2c00dbaa8043b3d Mon Sep 17 00:00:00 2001 From: Vladimir Loncar Date: Wed, 6 Aug 2025 19:02:33 +0200 Subject: [PATCH 2/3] Use dynamo onnx export --- test/pytest/test_zeropadding_pytorch_onnx.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/pytest/test_zeropadding_pytorch_onnx.py b/test/pytest/test_zeropadding_pytorch_onnx.py index 5d76274b93..c084d5981f 100644 --- a/test/pytest/test_zeropadding_pytorch_onnx.py +++ b/test/pytest/test_zeropadding_pytorch_onnx.py @@ -31,7 +31,7 @@ def forward(self, x): hls_model_pytorch.compile() onnx_path = str(test_root_path / 'hls4mlprj_constpad_1d/pad1d.onnx') - torch.onnx.export(model, torch.randn(1, 2, 4), onnx_path, opset_version=10) + torch.onnx.export(model, torch.randn(1, 2, 4), onnx_path, dynamo=True) qonnx.util.cleanup.cleanup(onnx_path, out_file=onnx_path) pad1d_onnx = ModelWrapper(onnx_path) @@ -68,7 +68,7 @@ def forward(self, x): hls_model_pytorch.compile() onnx_path = str(test_root_path / 'hls4mlprj_constpad_2d/pad2d.onnx') - torch.onnx.export(model, torch.randn(1, 2, 3, 4), onnx_path, opset_version=10) + torch.onnx.export(model, torch.randn(1, 2, 3, 4), onnx_path, dynamo=True) qonnx.util.cleanup.cleanup(onnx_path, out_file=onnx_path) pad2d_onnx = ModelWrapper(onnx_path) From 8c4b669b31d5369cb5eb120e871ab8acae0680f4 Mon Sep 17 00:00:00 2001 From: Vladimir Loncar Date: Wed, 10 Sep 2025 01:04:39 +0200 Subject: [PATCH 3/3] Parse Pad node with ONNX opset >= 11 --- hls4ml/converters/onnx/reshape.py | 11 +++++-- hls4ml/model/optimizer/__init__.py | 1 + hls4ml/model/optimizer/passes/pad_const.py | 37 ++++++++++++++++++++++ 3 files changed, 47 insertions(+), 2 deletions(-) create mode 100644 hls4ml/model/optimizer/passes/pad_const.py diff --git a/hls4ml/converters/onnx/reshape.py b/hls4ml/converters/onnx/reshape.py index c27d40cfde..d9057d930a 100644 --- a/hls4ml/converters/onnx/reshape.py +++ b/hls4ml/converters/onnx/reshape.py @@ -1,4 +1,4 @@ -from hls4ml.converters.onnx_to_hls import get_onnx_attribute, onnx_handler +from hls4ml.converters.onnx_to_hls import get_constant_value, get_onnx_attribute, onnx_handler @onnx_handler('Transpose') @@ -75,7 +75,14 @@ def parse_pad_layer(node, input_names, input_shapes, graph): if mode is not None and mode != 'constant': raise RuntimeError(f'Unsupported padding mode: {mode} in node {node.name}') - pads = get_onnx_attribute(node, 'pads') + pads = get_constant_value(graph, node.input[1]) + if len(input_names) > 2: + const_val = get_constant_value(graph, node.input[2]) + if const_val != 0: + raise RuntimeError(f'Only constant value of 0 supported for Pad node {node.name}, got {const_val}') + + if len(input_names) > 3: + raise RuntimeError(f'Parsing axes input of Pad node {node.name} is not supported.') dim = 0 if len(input_shapes[0]) == 3: diff --git a/hls4ml/model/optimizer/__init__.py b/hls4ml/model/optimizer/__init__.py index c0a99a66c5..9b3a5dce90 100644 --- a/hls4ml/model/optimizer/__init__.py +++ b/hls4ml/model/optimizer/__init__.py @@ -34,6 +34,7 @@ 'parse_qonnx', [ 'reshape_constant', + 'padding_constant', 'resize_remove_constants', 'quant_constant_parameters', 'bipolar_quant_constant_parameters', diff --git a/hls4ml/model/optimizer/passes/pad_const.py b/hls4ml/model/optimizer/passes/pad_const.py new file mode 100644 index 0000000000..76f59f8ef5 --- /dev/null +++ b/hls4ml/model/optimizer/passes/pad_const.py @@ -0,0 +1,37 @@ +from hls4ml.model.layers import Constant, ZeroPadding1D, ZeroPadding2D +from hls4ml.model.optimizer import OptimizerPass + + +class PaddingConstant(OptimizerPass): + """ + ONNX has the padding come as an input, not a parameter. This removes the Constant node from the input. + The constant value was already used; this is just a cleanup uptimization. + """ + + def match(self, node): + is_match = ( + isinstance(node, (ZeroPadding1D, ZeroPadding2D)) + and len(node.inputs) > 1 + and isinstance(node.get_input_node(node.inputs[1]), Constant) + ) + + return is_match + + def transform(self, model, node): + """ + Remove Constant node(s) from the graph. Note, padding is already present in ZeroPadding node. + """ + if len(node.inputs) > 2: + const_val_node = node.get_input_node(node.inputs[2]) + if not isinstance(const_val_node, Constant): + raise RuntimeError(f'Non-constant padding inputs are not currently supported ({node.name})') + model.remove_node(const_val_node) + node.inputs.pop(2) + + pad_node = node.get_input_node(node.inputs[1]) + if not isinstance(pad_node, Constant): + raise RuntimeError(f'Non-constant padding inputs are not currently supported ({node.name})') + model.remove_node(pad_node) + node.inputs.pop(1) + + return True