From 7dfc4b8676b1fc7a7af122be7f97e108dd07e98f Mon Sep 17 00:00:00 2001 From: Yaman Umuroglu Date: Thu, 25 Sep 2025 11:31:20 +0200 Subject: [PATCH 01/20] [Lint] rerun linter, fix errors --- src/qonnx/core/datatype.py | 2 +- src/qonnx/core/modelwrapper.py | 58 +-- src/qonnx/transformation/fixedpt_quantize.py | 20 +- src/qonnx/transformation/general.py | 3 +- tests/core/test_datatypes.py | 378 ++++++++---------- tests/core/test_subgraph_traversal.py | 62 ++- tests/transformation/test_fixedpt_quantize.py | 69 +--- tests/transformation/test_sort_graph.py | 1 + 8 files changed, 292 insertions(+), 301 deletions(-) diff --git a/src/qonnx/core/datatype.py b/src/qonnx/core/datatype.py index 5b8d0459..e32d30c0 100644 --- a/src/qonnx/core/datatype.py +++ b/src/qonnx/core/datatype.py @@ -288,7 +288,7 @@ def max(self): return signed_max if self._signed else unsigned_max def allowed(self, value): - value_is_integer = (np.round(value) == value) + value_is_integer = np.round(value) == value value_is_bounded = np.logical_and(self.min() <= value, value <= self.max()) return np.logical_and(value_is_integer, value_is_bounded) diff --git a/src/qonnx/core/modelwrapper.py b/src/qonnx/core/modelwrapper.py index 7566ca07..77866f4c 100644 --- a/src/qonnx/core/modelwrapper.py +++ b/src/qonnx/core/modelwrapper.py @@ -128,13 +128,17 @@ def save(self, filename): def analysis(self, analysis_fxn, apply_to_subgraphs=False): """Runs given anaylsis_fxn on this model and return resulting dict.""" - if apply_to_subgraphs == True: - assert "apply_to_subgraphs" in inspect.signature(analysis_fxn), "analysis_fxn must have 'apply_to_subgraphs' argument when apply_to_subgraphs == True" + if apply_to_subgraphs: + assert "apply_to_subgraphs" in inspect.signature( + analysis_fxn + ), "analysis_fxn must have 'apply_to_subgraphs' argument when apply_to_subgraphs == True" return analysis_fxn(self, apply_to_subgraphs) else: return analysis_fxn(self) - def transform_subgraphs(self, transformation, make_deepcopy=True, cleanup=True, apply_to_subgraphs=False, use_preorder_traversal=True): + def transform_subgraphs( + self, transformation, make_deepcopy=True, cleanup=True, apply_to_subgraphs=False, use_preorder_traversal=True + ): """Applies given Transformation to all subgraphs of this ModelWrapper instance. - make_deepcopy : operates on a new (deep)copy of model. @@ -144,23 +148,27 @@ def transform_subgraphs(self, transformation, make_deepcopy=True, cleanup=True, otherwise postorder traversal is used. """ for node in self.model.graph.node: - transformed_subgraph_attrs = [] - for idx, attr in enumerate(node.attribute): - if attr.type == onnx.AttributeProto.GRAPH: - # this is a subgraph, add it to the list - subgraph = self.make_subgraph_modelwrapper(attr.g) - # apply the transformation to the subgraph - subgraph = subgraph.transform(transformation, make_deepcopy, cleanup, apply_to_subgraphs, use_preorder_traversal) - # update the new subgraph in the attrubute - transformed_subgraph_attrs.append((idx, onnx.helper.make_attribute(attr.name, subgraph.model.graph))) - # replace the attributes in the node with the transformed subgraph attributes - for idx, new_attr in transformed_subgraph_attrs: - # remove the old attribute - node.attribute.pop(idx) - # add the new attribute - node.attribute.insert(idx, new_attr) - - def transform(self, transformation, make_deepcopy=True, cleanup=True, apply_to_subgraphs=False, use_preorder_traversal=True): + transformed_subgraph_attrs = [] + for idx, attr in enumerate(node.attribute): + if attr.type == onnx.AttributeProto.GRAPH: + # this is a subgraph, add it to the list + subgraph = self.make_subgraph_modelwrapper(attr.g) + # apply the transformation to the subgraph + subgraph = subgraph.transform( + transformation, make_deepcopy, cleanup, apply_to_subgraphs, use_preorder_traversal + ) + # update the new subgraph in the attrubute + transformed_subgraph_attrs.append((idx, onnx.helper.make_attribute(attr.name, subgraph.model.graph))) + # replace the attributes in the node with the transformed subgraph attributes + for idx, new_attr in transformed_subgraph_attrs: + # remove the old attribute + node.attribute.pop(idx) + # add the new attribute + node.attribute.insert(idx, new_attr) + + def transform( + self, transformation, make_deepcopy=True, cleanup=True, apply_to_subgraphs=False, use_preorder_traversal=True + ): """Applies given Transformation repeatedly until no more changes can be made and returns a transformed ModelWrapper instance. @@ -174,8 +182,10 @@ def transform(self, transformation, make_deepcopy=True, cleanup=True, apply_to_s if self.fix_float64: (transformed_model, model_was_changed) = DoubleToSingleFloat().apply(transformed_model) - if apply_to_subgraphs and use_preorder_traversal == False: - transformed_model.transform_subgraphs(transformation, make_deepcopy, cleanup, apply_to_subgraphs, use_preorder_traversal) + if apply_to_subgraphs and (use_preorder_traversal is False): + transformed_model.transform_subgraphs( + transformation, make_deepcopy, cleanup, apply_to_subgraphs, use_preorder_traversal + ) model_was_changed = True while model_was_changed: @@ -184,7 +194,9 @@ def transform(self, transformation, make_deepcopy=True, cleanup=True, apply_to_s transformed_model.cleanup() if apply_to_subgraphs and use_preorder_traversal: - transformed_model.transform_subgraphs(transformation, make_deepcopy, cleanup, apply_to_subgraphs, use_preorder_traversal) + transformed_model.transform_subgraphs( + transformation, make_deepcopy, cleanup, apply_to_subgraphs, use_preorder_traversal + ) return transformed_model diff --git a/src/qonnx/transformation/fixedpt_quantize.py b/src/qonnx/transformation/fixedpt_quantize.py index 0b21c591..f9225719 100644 --- a/src/qonnx/transformation/fixedpt_quantize.py +++ b/src/qonnx/transformation/fixedpt_quantize.py @@ -29,10 +29,10 @@ import numpy as np from warnings import warn +from qonnx.core.datatype import DataType from qonnx.core.modelwrapper import ModelWrapper -from qonnx.transformation.base import Transformation from qonnx.custom_op.general.intquant import resolve_rounding_mode -from qonnx.core.datatype import DataType +from qonnx.transformation.base import Transformation def default_op_filter(op): @@ -44,10 +44,12 @@ class FixedPointQuantizeParamsFromDict(Transformation): Quantize model parameters to a given fixed-point representation. The self.max_err dictionary stores the maximum error for each quantized input after calling. Parameters: - fixedpt_dict: Dictionary containing tensor names and their corresponding target fixed-point data type or its canonical name + fixedpt_dict: Dictionary containing tensor names and their corresponding target fixed-point + data type or its canonical name rounding_mode: Rounding mode used for conversion into fixed point. Default is "ROUND", - possible values: ["ROUND", "HALF_EVEN", "CEIL", "FLOOR", "UP", "DOWN", "HALF_UP", "HALF_DOWN"] + possible values: ["ROUND", "HALF_EVEN", "CEIL", "FLOOR", "UP", "DOWN", + "HALF_UP", "HALF_DOWN"] """ def __init__(self, fixedpt_dict, rounding_mode="ROUND"): @@ -63,13 +65,17 @@ def apply(self, model: ModelWrapper): tdtype = DataType[tdtype] current_dtype = model.get_tensor_datatype(tname) if current_dtype.is_fixed_point(): - warn(f"Tensor {tname} is already a {current_dtype.get_canonical_name()} type. Recasting to {tdtype.get_canonical_name()}") + warn( + f"Tensor {tname} is already a {current_dtype.get_canonical_name()} type. " + "Recasting to {tdtype.get_canonical_name()}" + ) in1_t_new = self.round_func(in1_t.astype(np.float32) / tdtype.scale_factor()) * tdtype.scale_factor() if (in1_t_new.max() > tdtype.max()) or (in1_t_new.min() < tdtype.min()): warn( f"Range of {tname} [{in1_t_new.min():.3f}, {in1_t_new.max():.3f}] greater than" - f"{tdtype.get_canonical_name()} [{tdtype.min():.3f}, {tdtype:.max():.3f}], clipping.") + f"{tdtype.get_canonical_name()} [{tdtype.min():.3f}, {tdtype:.max():.3f}], clipping." + ) in1_t_new = np.clip(in1_t_new, tdtype.min(), tdtype.max()) model.set_tensor_datatype(tname, tdtype) model.set_initializer(tname, in1_t_new) @@ -78,6 +84,7 @@ def apply(self, model: ModelWrapper): return (model, False) + class FixedPointQuantizeParams(Transformation): """ Quantize model parameters to a given fixed-point representation. @@ -93,6 +100,7 @@ class FixedPointQuantizeParams(Transformation): Default is "ROUND", possible values: ["ROUND", "HALF_EVEN", "CEIL", "FLOOR", "UP", "DOWN", "HALF_UP", "HALF_DOWN"] """ + def __init__(self, fixedpt_dtype, op_filter=default_op_filter, rounding_mode="ROUND"): super().__init__() if isinstance(fixedpt_dtype, str): diff --git a/src/qonnx/transformation/general.py b/src/qonnx/transformation/general.py index 654bee4e..5126bf27 100644 --- a/src/qonnx/transformation/general.py +++ b/src/qonnx/transformation/general.py @@ -261,8 +261,7 @@ def apply(self, model): # check if node inputs are connected to graph inputs or initializers # if so, we can keep the node in the graph for name in n.input: - if util.get_by_name(model.graph.initializer, name) or \ - util.get_by_name(model.graph.input, name): + if util.get_by_name(model.graph.initializer, name) or util.get_by_name(model.graph.input, name): # this node is connected to graph inputs or initializers # so we can keep it in the graph graph_dependencies[node_idx] = set() diff --git a/tests/core/test_datatypes.py b/tests/core/test_datatypes.py index 0fbd0dea..452c0611 100644 --- a/tests/core/test_datatypes.py +++ b/tests/core/test_datatypes.py @@ -27,6 +27,7 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. import pytest + import numpy as np from qonnx.core.datatype import DataType, resolve_datatype @@ -187,249 +188,218 @@ def test_resolve_datatype(input): test_resolve_datatype(DataType["INT32"]) test_resolve_datatype(DataType["FLOAT32"]) + vectorize_details = { "BIPOLAR": [ - np.array([ - [-1, +1, 0], - [ 0, +1, -1], - [+1, 0, -1] - ]), - np.array([ - [True, True, False], - [False, True, True], - [True, False, True] - ], dtype=bool) + np.array([[-1, +1, 0], [0, +1, -1], [+1, 0, -1]]), + np.array([[True, True, False], [False, True, True], [True, False, True]], dtype=bool), ], "BINARY": [ - np.array([ - [-1, +1, 0], - [ 0, +1, -1], - [+1, 0, -1] - ]), - np.array([ - [False, True, True], - [True, True, False], - [True, True, False] - ], dtype=bool) + np.array([[-1, +1, 0], [0, +1, -1], [+1, 0, -1]]), + np.array([[False, True, True], [True, True, False], [True, True, False]], dtype=bool), ], "TERNARY": [ - np.array([ - [-1, +2, +1, 0], - [ 0, +1, +2, -1], - [+2, +1, 0, -1] - ]), - np.array([ - [True, False, True, True], - [True, True, False, True], - [False, True, True, True] - ], dtype=bool) + np.array([[-1, +2, +1, 0], [0, +1, +2, -1], [+2, +1, 0, -1]]), + np.array([[True, False, True, True], [True, True, False, True], [False, True, True, True]], dtype=bool), ], "UINT2": [ - np.array([ - [[-1, +2, +1, 0], - [ 0, +1, +2, -1]], - [[+2, +1, 0, -1], - [+4, -1, -2, +3]], - ]), - np.array([ - [[False, True, True, True], - [True, True, True, False]], - [[True, True, True, False], - [False, False, False, True]], - ], dtype=bool) + np.array( + [ + [[-1, +2, +1, 0], [0, +1, +2, -1]], + [[+2, +1, 0, -1], [+4, -1, -2, +3]], + ] + ), + np.array( + [ + [[False, True, True, True], [True, True, True, False]], + [[True, True, True, False], [False, False, False, True]], + ], + dtype=bool, + ), ], "UINT3": [ - np.array([ - [[+9, -6, +3, 0], - [-4, +4, 0, +1]], - [[-1, +3, +10, +4], - [+2, +6, +7, +8]], - ]), - np.array([ - [[False, False, True, True], - [False, True, True, True]], - [[False, True, False, True], - [True, True, True, False]], - ], dtype=bool) + np.array( + [ + [[+9, -6, +3, 0], [-4, +4, 0, +1]], + [[-1, +3, +10, +4], [+2, +6, +7, +8]], + ] + ), + np.array( + [ + [[False, False, True, True], [False, True, True, True]], + [[False, True, False, True], [True, True, True, False]], + ], + dtype=bool, + ), ], "UINT4": [ - np.array([ - [[-10, -4, +9, +13], - [+1, +14, +11, +4]], - [[+18, -7, +1, +9], - [-1, -7, +1, -2]], - ]), - np.array([ - [[False, False, True, True], - [True, True, True, True]], - [[False, False, True, True], - [False, False, True, False]], - ], dtype=bool) + np.array( + [ + [[-10, -4, +9, +13], [+1, +14, +11, +4]], + [[+18, -7, +1, +9], [-1, -7, +1, -2]], + ] + ), + np.array( + [ + [[False, False, True, True], [True, True, True, True]], + [[False, False, True, True], [False, False, True, False]], + ], + dtype=bool, + ), ], "UINT8": [ - np.array([ - [[148, 61, 70, 29], - [244, 213, 10, 135]], - [[18, 25, 246, 137], - [236, -31, 220, 359]], - ]), - np.array([ - [[True, True, True, True], - [True, True, True, True]], - [[True, True, True, True], - [True, False, True, False]], - ], dtype=bool) + np.array( + [ + [[148, 61, 70, 29], [244, 213, 10, 135]], + [[18, 25, 246, 137], [236, -31, 220, 359]], + ] + ), + np.array( + [ + [[True, True, True, True], [True, True, True, True]], + [[True, True, True, True], [True, False, True, False]], + ], + dtype=bool, + ), ], "UINT16": [ - np.array([ - [[35261, 129491, 9136, 18643], - [128532, -597, 34768, 248]], - [[21646, 30778, 71076, 21224], - [60657, 52854, -5994, 17295]], - ]), - np.array([ - [[True, False, True, True], - [False, False, True, True]], - [[True, True, False, True], - [True, True, False, True]], - ], dtype=bool) + np.array( + [ + [[35261, 129491, 9136, 18643], [128532, -597, 34768, 248]], + [[21646, 30778, 71076, 21224], [60657, 52854, -5994, 17295]], + ] + ), + np.array( + [ + [[True, False, True, True], [False, False, True, True]], + [[True, True, False, True], [True, True, False, True]], + ], + dtype=bool, + ), ], "UINT32": [ - np.array([ - [[-417565331, 3488834022, -1757218812, 591311876], - [1842515574, 4131239283, 2022242400, 1240578991]], - [[609779043, 574774725, 4188472937, 3109757181], - [-767760560, -2100731532, 3794040092, 3223013612]], - ]), - np.array([ - [[False, True, False, True], - [True, True, True, True]], - [[True, True, True, True], - [False, False, True, True]], - ], dtype=bool) + np.array( + [ + [[-417565331, 3488834022, -1757218812, 591311876], [1842515574, 4131239283, 2022242400, 1240578991]], + [[609779043, 574774725, 4188472937, 3109757181], [-767760560, -2100731532, 3794040092, 3223013612]], + ] + ), + np.array( + [ + [[False, True, False, True], [True, True, True, True]], + [[True, True, True, True], [False, False, True, True]], + ], + dtype=bool, + ), ], "INT2": [ - np.array([ - [[ 0, 2, 2, 3], - [-4, 2, -1, 2]], - [[ 1, 2, -4, -1], - [ 2, -1, -1, -2]], - ]), - np.array([ - [[True, False, False, False], - [False, False, True, False]], - [[True, False, False, True], - [False, True, True, True]], - ], dtype=bool) + np.array( + [ + [[0, 2, 2, 3], [-4, 2, -1, 2]], + [[1, 2, -4, -1], [2, -1, -1, -2]], + ] + ), + np.array( + [ + [[True, False, False, False], [False, False, True, False]], + [[True, False, False, True], [False, True, True, True]], + ], + dtype=bool, + ), ], "INT3": [ - np.array([ - [[-4, -6, -7, 3], - [ 2, -8, -7, 3]], - [[-4, -4, 4, -4], - [ 1, -4, 1, -5]], - ]), - np.array([ - [[True, False, False, True], - [True, False, False, True]], - [[True, True, False, True], - [True, True, True, False]], - ], dtype=bool) + np.array( + [ + [[-4, -6, -7, 3], [2, -8, -7, 3]], + [[-4, -4, 4, -4], [1, -4, 1, -5]], + ] + ), + np.array( + [ + [[True, False, False, True], [True, False, False, True]], + [[True, True, False, True], [True, True, True, False]], + ], + dtype=bool, + ), ], "INT4": [ - np.array([ - [[ 5, 9, 3, -6], - [ 1, 5, 9, 10]], - [[ 10, 10, -3, 0], - [ -8, -5, -12, -5]], - ]), - np.array([ - [[True, False, True, True], - [True, True, False, False]], - [[False, False, True, True], - [True, True, False, True]], - ], dtype=bool) + np.array( + [ + [[5, 9, 3, -6], [1, 5, 9, 10]], + [[10, 10, -3, 0], [-8, -5, -12, -5]], + ] + ), + np.array( + [ + [[True, False, True, True], [True, True, False, False]], + [[False, False, True, True], [True, True, False, True]], + ], + dtype=bool, + ), ], "INT8": [ - np.array([ - [[-143, 140, 54, -217], - [ 22, 186, 72, -175]], - [[-126, -6, 115, 240], - [-87, -159, 128, -178]], - ]), - np.array([ - [[False, False, True, False], - [True, False, True, False]], - [[True, True, True, False], - [True, False, False, False]], - ], dtype=bool) + np.array( + [ + [[-143, 140, 54, -217], [22, 186, 72, -175]], + [[-126, -6, 115, 240], [-87, -159, 128, -178]], + ] + ), + np.array( + [ + [[False, False, True, False], [True, False, True, False]], + [[True, True, True, False], [True, False, False, False]], + ], + dtype=bool, + ), ], "INT16": [ - np.array([ - [[ 36863, 2676, 2728, -61500], - [ 24314, 18040, -39438, 64013]], - [[ 28824, -38855, 46308, -50728], - [-50275, -48853, -42034, -44384]], - ]), - np.array([ - [[False, True, True, False], - [True, True, False, False]], - [[True, False, False, False], - [False, False, False, False]], - ], dtype=bool) + np.array( + [ + [[36863, 2676, 2728, -61500], [24314, 18040, -39438, 64013]], + [[28824, -38855, 46308, -50728], [-50275, -48853, -42034, -44384]], + ] + ), + np.array( + [ + [[False, True, True, False], [True, True, False, False]], + [[True, False, False, False], [False, False, False, False]], + ], + dtype=bool, + ), ], "FIXED<4,2>": [ - np.array([ - [[1.8, 1.5, -0.25, 0], - [-1.1, -2, 1.75, 0.1]], - [[-1.5, 1.6, 0.5, 0.1], - [0.4, 0.001, 3.03, 1.75]], - ]), - np.array([ - [[False, True, True, True], - [False, True, True, False]], - [[True, False, True, False], - [False, False, False, True]], - ], dtype=bool) + np.array( + [ + [[1.8, 1.5, -0.25, 0], [-1.1, -2, 1.75, 0.1]], + [[-1.5, 1.6, 0.5, 0.1], [0.4, 0.001, 3.03, 1.75]], + ] + ), + np.array( + [ + [[False, True, True, True], [False, True, True, False]], + [[True, False, True, False], [False, False, False, True]], + ], + dtype=bool, + ), ], "FLOAT<4,3>": [ - np.array([ - [0.0, 0.5, 1.875, -1.5], - [1.8, -512.0, 0.013671875, 0.0087890625], - [0.001953125, 0.0009765625, 2.0, 1.25] - ]), - np.array([ - [True, True, True, True], - [False, False, True, False], - [True, False, True, True] - ]) + np.array( + [[0.0, 0.5, 1.875, -1.5], [1.8, -512.0, 0.013671875, 0.0087890625], [0.001953125, 0.0009765625, 2.0, 1.25]] + ), + np.array([[True, True, True, True], [False, False, True, False], [True, False, True, True]]), ], "FLOAT<4,0>": [ - np.array([ - [0.0, 0.5, 0.75], - [0.015625, 0.0078125, 0.0625] - ]), - np.array([ - [True, True, False], - [True, False, True] - ]) + np.array([[0.0, 0.5, 0.75], [0.015625, 0.0078125, 0.0625]]), + np.array([[True, True, False], [True, False, True]]), ], "FLOAT<4,3,5>": [ - np.array([ - [0.0, 0.5, 1.875], - [-1.5, 1.8, -512.0] - ]), - np.array([ - [True, True, True], - [True, False, True] - ]) + np.array([[0.0, 0.5, 1.875], [-1.5, 1.8, -512.0]]), + np.array([[True, True, True], [True, False, True]]), ], - "FLOAT<4,0,5>": [ - np.array([0.0, 0.0625, 0.03125]), - np.array([True, True, False]) - ] + "FLOAT<4,0,5>": [np.array([0.0, 0.0625, 0.03125]), np.array([True, True, False])], } + @pytest.mark.parametrize("datatype", vectorize_details.keys()) def test_vectorized_allowed(datatype): input_values, golden_out = vectorize_details[datatype] diff --git a/tests/core/test_subgraph_traversal.py b/tests/core/test_subgraph_traversal.py index 3e8121c2..15d7b1a5 100644 --- a/tests/core/test_subgraph_traversal.py +++ b/tests/core/test_subgraph_traversal.py @@ -1,12 +1,13 @@ import pytest + +import onnx from collections import Counter +from onnx import helper from qonnx.core.modelwrapper import ModelWrapper from qonnx.transformation.base import Transformation +from qonnx.util.basic import get_by_name, qonnx_make_model -from qonnx.util.basic import qonnx_make_model, get_by_name -import onnx -from onnx import helper # Helper to recursively build a graph with subgraphs attached to nodes def make_graph(tree): @@ -49,6 +50,7 @@ def make_graph(tree): return graph + def make_subgraph_model(tree): """ Build a ModelWrapper with a graph structure based on the provided tree. @@ -73,7 +75,9 @@ def apply(self, model_wrapper): dummy_name_in = f"{graph_name}_dummy_in" dummy_name_out = f"{graph_name}_dummy_out" model_wrapper.model.graph.input.append(helper.make_tensor_value_info(dummy_name_in, onnx.TensorProto.FLOAT, [4, 4])) - model_wrapper.model.graph.output.append(helper.make_tensor_value_info(dummy_name_out, onnx.TensorProto.FLOAT, [4, 4])) + model_wrapper.model.graph.output.append( + helper.make_tensor_value_info(dummy_name_out, onnx.TensorProto.FLOAT, [4, 4]) + ) model_wrapper.model.graph.node.append( helper.make_node( "DummyNode", # dummy op_type @@ -85,15 +89,18 @@ def apply(self, model_wrapper): # collect the name of the graph being transformed to check how many times each graph was visited self.visited.append(graph_name) - #import pdb; pdb.set_trace() + # import pdb; pdb.set_trace() return model_wrapper, False + class NestedTransform(Transformation): def __init__(self): self.dummy_transform = DummyTransform() + def apply(self, model_wrapper): return model_wrapper.transform(self.dummy_transform), False + def get_subgraph_names(tree): """ Recursively collect the names of all subgraphs in the tree structure. @@ -115,10 +122,11 @@ def check_all_visted_once(tree, transform): """ Check that all subgraphs in the tree structure were visited exactly once. """ - visited = transform.visited + visited = transform.visited expected = get_subgraph_names(tree) assert Counter(visited) == Counter(expected), f"Visited: {visited}, Expected: {expected}" + def check_visit_order(tree, transform, order): """ Check that the order of visited subgraphs matches the expected preorder or postorder traversal. @@ -127,6 +135,7 @@ def check_visit_order(tree, transform, order): expected = order(tree) assert visited == expected, f"Visited: {visited}, Expected: {expected}" + def check_all_subgraphs_transformed(graph): """ Check that all subgraphs in the tree structure have been transformed. @@ -149,20 +158,20 @@ def get_metadata_props(graph, key): else: return metadata_prop.value - assert(get_metadata_props(graph, graph.name) == "visited"), f"Metadata for {graph.name} not set correctly" - assert(get_metadata_props(graph, "opset_id") == "10"), "Metadata for opset_id not set correctly" + assert get_metadata_props(graph, graph.name) == "visited", f"Metadata for {graph.name} not set correctly" + assert get_metadata_props(graph, "opset_id") == "10", "Metadata for opset_id not set correctly" # recursively check all subgraphs for node in graph.node: - for attr in node.attribute: + for attr in node.attribute: if attr.type == onnx.AttributeProto.GRAPH: check_all_subgraphs_transformed(attr.g) + @pytest.mark.parametrize("cleanup", [False, True]) @pytest.mark.parametrize("make_deepcopy", [False, True]) -@pytest.mark.parametrize("tree, apply_to_subgraphs", - [(("top", []), True), - (("top", []), False), - (("top", [("sub1", [])]), False)]) +@pytest.mark.parametrize( + "tree, apply_to_subgraphs", [(("top", []), True), (("top", []), False), (("top", [("sub1", [])]), False)] +) def test_no_traversal(tree, cleanup, make_deepcopy, apply_to_subgraphs): # Check that the top-level model is transformed exactly once when there are no subgraphs. # Check that the top-level model is transformed exactly once when there are subgraphs, but apply_to_subgraphs is False. @@ -175,6 +184,7 @@ def test_no_traversal(tree, cleanup, make_deepcopy, apply_to_subgraphs): assert transform.visited == ["top"] assert t_model.get_metadata_prop("top") == "visited" + def build_preorder_traversal(tree): """ Build a preorder traversal of the tree structure. @@ -190,6 +200,7 @@ def traverse(node): traverse(tree) return traversal + def build_postorder_traversal(tree): """ Build a postorder traversal of the tree structure. @@ -205,10 +216,16 @@ def traverse(node): traverse(tree) return traversal + @pytest.mark.parametrize("cleanup", [False, True]) @pytest.mark.parametrize("make_deepcopy", [False, True]) -@pytest.mark.parametrize("tree", [("top", [("sub1", []), ("sub2", [])]), - ("top", [("sub1", [("sub1_1", []), ("sub1_2",[])]), ("sub2", [("sub2_1", [])])])]) +@pytest.mark.parametrize( + "tree", + [ + ("top", [("sub1", []), ("sub2", [])]), + ("top", [("sub1", [("sub1_1", []), ("sub1_2", [])]), ("sub2", [("sub2_1", [])])]), + ], +) @pytest.mark.parametrize("use_preorder_traversal", [True, False]) def test_traversal(tree, cleanup, make_deepcopy, use_preorder_traversal): # Check that the top-level model and all subgraphs are transformed when apply_to_subgraphs is True. @@ -216,7 +233,9 @@ def test_traversal(tree, cleanup, make_deepcopy, use_preorder_traversal): print(f"Testing tree: {tree}, cleanup: {cleanup}, make_deepcopy: {make_deepcopy}") model = make_subgraph_model(tree) transform = DummyTransform() - t_model = model.transform(transform, cleanup, make_deepcopy, apply_to_subgraphs=True, use_preorder_traversal=use_preorder_traversal) + t_model = model.transform( + transform, cleanup, make_deepcopy, apply_to_subgraphs=True, use_preorder_traversal=use_preorder_traversal + ) check_all_visted_once(tree, transform) check_all_subgraphs_transformed(t_model.model.graph) @@ -230,8 +249,13 @@ def test_traversal(tree, cleanup, make_deepcopy, use_preorder_traversal): @pytest.mark.parametrize("cleanup", [False, True]) @pytest.mark.parametrize("make_deepcopy", [False, True]) -@pytest.mark.parametrize("tree", [("top", [("sub1", []), ("sub2", [])]), - ("top", [("sub1", [("sub1_1", []), ("sub1_2",[])]), ("sub2", [("sub2_1", [])])])]) +@pytest.mark.parametrize( + "tree", + [ + ("top", [("sub1", []), ("sub2", [])]), + ("top", [("sub1", [("sub1_1", []), ("sub1_2", [])]), ("sub2", [("sub2_1", [])])]), + ], +) def test_traversal_nested(tree, cleanup, make_deepcopy): # Check that the top-level model and all subgraphs are transformed when apply_to_subgraphs is True. # This should always be done correctly regardless of cleanup and make_deepcopy. @@ -242,6 +266,7 @@ def test_traversal_nested(tree, cleanup, make_deepcopy): check_all_visted_once(tree, transform.dummy_transform) check_all_subgraphs_transformed(t_model.model.graph) + def dummy_analysis_fxn(model_wrapper): """ A dummy analysis function that simply returns the model wrapper. @@ -250,6 +275,7 @@ def dummy_analysis_fxn(model_wrapper): d = {} return d + @pytest.mark.xfail(reason="Analysis functions require apply_to_subgraphs when traversing subgraphs") def test_analysis_fxn_without_apply_to_subgraphs_fails(): # Check that an analysis function fails when apply_to_subgraphs is False diff --git a/tests/transformation/test_fixedpt_quantize.py b/tests/transformation/test_fixedpt_quantize.py index 285e87f8..2b60c735 100644 --- a/tests/transformation/test_fixedpt_quantize.py +++ b/tests/transformation/test_fixedpt_quantize.py @@ -28,16 +28,14 @@ import pytest -import numpy as np import os +from qonnx.core.datatype import DataType from qonnx.core.modelwrapper import ModelWrapper from qonnx.transformation.fixedpt_quantize import FixedPointQuantizeParams, FixedPointQuantizeParamsFromDict -from qonnx.core.datatype import DataType from qonnx.util.cleanup import cleanup_model from qonnx.util.test import download_model - fixedpt_dict_details = { "Conv_bias_example_round": { "test_model": "Conv_bias_example", @@ -47,9 +45,9 @@ "Conv_1_param0": "FIXED<8,1>", "Conv_1_param1": "FIXED<8,1>", "Gemm_0_param0": "FIXED<12,1>", - "Gemm_0_param1": "FIXED<12,1>" + "Gemm_0_param1": "FIXED<12,1>", }, - "rounding_mode": "ROUND" + "rounding_mode": "ROUND", }, "Conv_bias_example_floor": { "test_model": "Conv_bias_example", @@ -59,9 +57,9 @@ "Conv_1_param0": "FIXED<8,1>", "Conv_1_param1": "FIXED<8,1>", "Gemm_0_param0": "FIXED<12,1>", - "Gemm_0_param1": "FIXED<12,1>" + "Gemm_0_param1": "FIXED<12,1>", }, - "rounding_mode": "FLOOR" + "rounding_mode": "FLOOR", }, "FINN-CNV_W2A2_round": { "test_model": "FINN-CNV_W2A2", @@ -97,9 +95,9 @@ "BatchNormalization_7_param0": "FIXED<9,4>", "BatchNormalization_7_param1": "FIXED<10,3>", "BatchNormalization_7_param2": "FIXED<12,8>", - "BatchNormalization_7_param3": "FIXED<14,13>" + "BatchNormalization_7_param3": "FIXED<14,13>", }, - "rounding_mode": "ROUND" + "rounding_mode": "ROUND", }, "FINN-CNV_W2A2_floor": { "test_model": "FINN-CNV_W2A2", @@ -135,9 +133,9 @@ "BatchNormalization_7_param0": "FIXED<9,4>", "BatchNormalization_7_param1": "FIXED<10,3>", "BatchNormalization_7_param2": "FIXED<12,8>", - "BatchNormalization_7_param3": "FIXED<14,13>" + "BatchNormalization_7_param3": "FIXED<14,13>", }, - "rounding_mode": "FLOOR" + "rounding_mode": "FLOOR", }, "MobileNetv1-w4a4_round": { "test_model": "MobileNetv1-w4a4", @@ -249,9 +247,9 @@ "BatchNormalization_26_param0": "FIXED<10,3>", "BatchNormalization_26_param1": "FIXED<5,2>", "BatchNormalization_26_param2": "FIXED<4,2>", - "BatchNormalization_26_param3": "FIXED<11,1>" + "BatchNormalization_26_param3": "FIXED<11,1>", }, - "rounding_mode": "ROUND" + "rounding_mode": "ROUND", }, "MobileNetv1-w4a4_floor": { "test_model": "MobileNetv1-w4a4", @@ -363,10 +361,10 @@ "BatchNormalization_26_param0": "FIXED<10,3>", "BatchNormalization_26_param1": "FIXED<5,2>", "BatchNormalization_26_param2": "FIXED<4,2>", - "BatchNormalization_26_param3": "FIXED<11,1>" + "BatchNormalization_26_param3": "FIXED<11,1>", }, - "rounding_mode": "FLOOR" - } + "rounding_mode": "FLOOR", + }, } @@ -401,67 +399,44 @@ def test_fixedpt_quantize_from_dict(test_case): os.unlink(dl_file) + fixedpt_details = { "FINN-CNV_W2A2_round_0": { "test_model": "FINN-CNV_W2A2", "dtype": "FIXED<8,3>", "rounding_mode": "ROUND", - "quant_tensors": [ - "Mul_0_param0", - "Mul_1_param0", - "Add_0_param0" - ] + "quant_tensors": ["Mul_0_param0", "Mul_1_param0", "Add_0_param0"], }, "FINN-CNV_W2A2_floor_0": { "test_model": "FINN-CNV_W2A2", "dtype": "FIXED<8,3>", "rounding_mode": "FLOOR", - "quant_tensors": [ - "Mul_0_param0", - "Mul_1_param0", - "Add_0_param0" - ] + "quant_tensors": ["Mul_0_param0", "Mul_1_param0", "Add_0_param0"], }, "FINN-CNV_W2A2_round_1": { "test_model": "FINN-CNV_W2A2", "dtype": "FIXED<4,3>", "rounding_mode": "ROUND", - "quant_tensors": [ - "Mul_0_param0", - "Mul_1_param0", - "Add_0_param0" - ] + "quant_tensors": ["Mul_0_param0", "Mul_1_param0", "Add_0_param0"], }, "FINN-CNV_W2A2_floor_1": { "test_model": "FINN-CNV_W2A2", "dtype": "FIXED<4,3>", "rounding_mode": "FLOOR", - "quant_tensors": [ - "Mul_0_param0", - "Mul_1_param0", - "Add_0_param0" - ] + "quant_tensors": ["Mul_0_param0", "Mul_1_param0", "Add_0_param0"], }, "FINN-CNV_W2A2_round_2": { "test_model": "FINN-CNV_W2A2", "dtype": "FIXED<12,3>", "rounding_mode": "ROUND", - "quant_tensors": [ - "Mul_0_param0", - "Mul_1_param0", - "Add_0_param0" - ] + "quant_tensors": ["Mul_0_param0", "Mul_1_param0", "Add_0_param0"], }, "FINN-CNV_W2A2_floor_2": { "test_model": "FINN-CNV_W2A2", "dtype": "FIXED<12,3>", "rounding_mode": "FLOOR", - "quant_tensors": [ - "Mul_0_param0", - "Mul_1_param0", - "Add_0_param0" - ] - } + "quant_tensors": ["Mul_0_param0", "Mul_1_param0", "Add_0_param0"], + }, } diff --git a/tests/transformation/test_sort_graph.py b/tests/transformation/test_sort_graph.py index 876e2a4b..cb9fd072 100644 --- a/tests/transformation/test_sort_graph.py +++ b/tests/transformation/test_sort_graph.py @@ -167,6 +167,7 @@ def test_sort_nonlinear_graph(): # plt.plot(sizes,times,"--o") # plt.grid(True) + def test_sort_graph_node_only_connected_to_graphio(): """ Test that SortGraph does not remove nodes that are only connected to graph inputs/outputs. From 7456919c4e614919aa3003bab6ce1e9c55f1300f Mon Sep 17 00:00:00 2001 From: Yaman Umuroglu Date: Thu, 25 Sep 2025 11:32:08 +0200 Subject: [PATCH 02/20] [Core] add get_opset_imports utility fxn to ModelWrapper --- src/qonnx/core/modelwrapper.py | 4 ++++ tests/core/test_modelwrapper.py | 1 + 2 files changed, 5 insertions(+) diff --git a/src/qonnx/core/modelwrapper.py b/src/qonnx/core/modelwrapper.py index 77866f4c..a85a1cf0 100644 --- a/src/qonnx/core/modelwrapper.py +++ b/src/qonnx/core/modelwrapper.py @@ -737,3 +737,7 @@ def set_tensor_sparsity(self, tensor_name, sparsity_dict): qa.tensor_name = tensor_name qa.quant_parameter_tensor_names.append(dt) qnt_annotations.append(qa) + + def get_opset_imports(self): + """Returns a list of imported opsets as (domain, version) tuples.""" + return [(opset.domain, opset.version) for opset in self._model_proto.opset_import] diff --git a/tests/core/test_modelwrapper.py b/tests/core/test_modelwrapper.py index 722f0fb1..5ffabb3c 100644 --- a/tests/core/test_modelwrapper.py +++ b/tests/core/test_modelwrapper.py @@ -68,6 +68,7 @@ def test_modelwrapper(): inp_sparsity = {"dw": {"kernel_shape": [3, 3]}} model.set_tensor_sparsity(first_conv_iname, inp_sparsity) assert model.get_tensor_sparsity(first_conv_iname) == inp_sparsity + assert model.get_opset_imports() == [("", 8)] def test_modelwrapper_set_get_rm_initializer(): From 89396cde144cb09fab87b5e8f5f93e599c6e4bf3 Mon Sep 17 00:00:00 2001 From: Yaman Umuroglu Date: Thu, 25 Sep 2025 21:08:03 +0200 Subject: [PATCH 03/20] [Core] return dict from ModelWrapper.get_opset_imports --- src/qonnx/core/modelwrapper.py | 4 ++-- tests/core/test_modelwrapper.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/qonnx/core/modelwrapper.py b/src/qonnx/core/modelwrapper.py index a85a1cf0..6248ac6f 100644 --- a/src/qonnx/core/modelwrapper.py +++ b/src/qonnx/core/modelwrapper.py @@ -739,5 +739,5 @@ def set_tensor_sparsity(self, tensor_name, sparsity_dict): qnt_annotations.append(qa) def get_opset_imports(self): - """Returns a list of imported opsets as (domain, version) tuples.""" - return [(opset.domain, opset.version) for opset in self._model_proto.opset_import] + """Returns a list of imported opsets as a {domain, version} dictionary.""" + return {opset.domain: opset.version for opset in self._model_proto.opset_import} diff --git a/tests/core/test_modelwrapper.py b/tests/core/test_modelwrapper.py index 5ffabb3c..995bcb17 100644 --- a/tests/core/test_modelwrapper.py +++ b/tests/core/test_modelwrapper.py @@ -68,7 +68,7 @@ def test_modelwrapper(): inp_sparsity = {"dw": {"kernel_shape": [3, 3]}} model.set_tensor_sparsity(first_conv_iname, inp_sparsity) assert model.get_tensor_sparsity(first_conv_iname) == inp_sparsity - assert model.get_opset_imports() == [("", 8)] + assert model.get_opset_imports() == {"": 8} def test_modelwrapper_set_get_rm_initializer(): From db2994f82227f252f4470e4dbdcbd53bdf579fed Mon Sep 17 00:00:00 2001 From: Yaman Umuroglu Date: Thu, 25 Sep 2025 21:12:31 +0200 Subject: [PATCH 04/20] [Core] add versioned op to getCustomOp with fallback to old style --- src/qonnx/custom_op/registry.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/src/qonnx/custom_op/registry.py b/src/qonnx/custom_op/registry.py index 3540bb5a..5d1a52ca 100644 --- a/src/qonnx/custom_op/registry.py +++ b/src/qonnx/custom_op/registry.py @@ -41,7 +41,15 @@ def getCustomOp(node, onnx_opset_version=get_preferred_onnx_opset(), brevitas_ex try: opset_module = importlib.import_module(domain) assert type(opset_module.custom_op) is dict, "custom_op dict not found in Python module %s" % domain - inst_wrapper = opset_module.custom_op[op_type] + op_type_with_version = op_type + "_v" + str(onnx_opset_version) + # TODO version handling: use highest available version smaller than requested version + # when the exact match is not found + if op_type_with_version in opset_module.custom_op: + # priority: if it exists, load the versioned CustomOp wrapper + inst_wrapper = opset_module.custom_op[op_type_with_version] + else: + # otherwise use the default (non-suffixed) CustomOp wrapper + inst_wrapper = opset_module.custom_op[op_type] inst = inst_wrapper(node, onnx_opset_version=onnx_opset_version) return inst except ModuleNotFoundError: From 8a2db226a8a3efed8f07c41c9da0de8b943e0f7d Mon Sep 17 00:00:00 2001 From: Yaman Umuroglu Date: Thu, 25 Sep 2025 21:13:46 +0200 Subject: [PATCH 05/20] [Core] inrtoduce ModelWrapper.get_customop_wrapper grabs CustomOp instance with the right opset version from protobuf imported opsets --- src/qonnx/core/modelwrapper.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/qonnx/core/modelwrapper.py b/src/qonnx/core/modelwrapper.py index 6248ac6f..b2308a06 100644 --- a/src/qonnx/core/modelwrapper.py +++ b/src/qonnx/core/modelwrapper.py @@ -39,6 +39,7 @@ import qonnx.util.basic as util import qonnx.util.onnx as onnxutil from qonnx.core.datatype import DataType +from qonnx.custom_op.registry import getCustomOp from qonnx.transformation.double_to_single_float import DoubleToSingleFloat from qonnx.transformation.general import ( RemoveStaticGraphInputs, @@ -741,3 +742,10 @@ def set_tensor_sparsity(self, tensor_name, sparsity_dict): def get_opset_imports(self): """Returns a list of imported opsets as a {domain, version} dictionary.""" return {opset.domain: opset.version for opset in self._model_proto.opset_import} + + def get_customop_wrapper(self, node): + """Return CustomOp instance for given node, respecting the + imported opset version in the model protobuf.""" + opset_imports = self.get_opset_imports() + opset_import = opset_imports[node.domain] + return getCustomOp(node, onnx_opset_version=opset_import) From 402a58056f8af7784065c74cab7b6e58f0e44b4f Mon Sep 17 00:00:00 2001 From: Yaman Umuroglu Date: Thu, 25 Sep 2025 21:14:48 +0200 Subject: [PATCH 06/20] [Test] add basic unit tests for versioned custom op fetching --- tests/custom_op/test_customop_version.py | 106 +++++++++++++++++++++++ 1 file changed, 106 insertions(+) create mode 100644 tests/custom_op/test_customop_version.py diff --git a/tests/custom_op/test_customop_version.py b/tests/custom_op/test_customop_version.py new file mode 100644 index 00000000..bdc660f9 --- /dev/null +++ b/tests/custom_op/test_customop_version.py @@ -0,0 +1,106 @@ +# Copyright (c) 2025 Advanced Micro Devices, Inc. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of qonnx nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import onnx.parser as oprs + +import qonnx.custom_op.general as general +from qonnx.core.modelwrapper import ModelWrapper +from qonnx.custom_op.base import CustomOp +from qonnx.custom_op.registry import getCustomOp + + +class VerTestOp_v1(CustomOp): + def get_nodeattr_types(self): + my_attrs = {"v1_attr": ("i", True, 0)} + return my_attrs + + def make_shape_compatible_op(self, model): + ishape = model.get_tensor_shape(self.onnx_node.input[0]) + return super().make_const_shape_op(ishape) + + def infer_node_datatype(self, model): + node = self.onnx_node + # data type stays the same + dtype = model.get_tensor_datatype(node.input[0]) + model.set_tensor_datatype(node.output[0], dtype) + + def execute_node(self, context, graph): + node = self.onnx_node + context[node.output[0]] = context[node.input[0]] + + def verify_node(self): + pass + + +class VerTestOp_v2(VerTestOp_v1): + def get_nodeattr_types(self): + my_attrs = {"v2_attr": ("i", True, 0)} + return my_attrs + + +class VerTestOp_v3(VerTestOp_v2): + def get_nodeattr_types(self): + my_attrs = {"v3_attr": ("i", True, 0)} + return my_attrs + + +def make_vertest_model(vertest_ver): + ishp = (1, 10) + oshp = ishp + ishp_str = str(list(ishp)) + oshp_str = str(list(oshp)) + input = f""" + < + ir_version: 7, + opset_import: ["" : 9, "qonnx.custom_op.general" : {vertest_ver}] + > + agraph (float{ishp_str} in0) => (float{oshp_str} out0) + {{ + out0 = qonnx.custom_op.general.VerTestOp< + v{vertest_ver}_attr={vertest_ver} + >(in0) + }} + """ + model = oprs.parse_model(input) + model = ModelWrapper(model) + return model + + +def test_customop_version(): + general.custom_op["VerTestOp"] = VerTestOp_v1 + general.custom_op["VerTestOp_v2"] = VerTestOp_v2 + general.custom_op["VerTestOp_v3"] = VerTestOp_v3 + for ver in [1, 2, 3]: + model = make_vertest_model(ver) + # explicitly specify onnx_opset_version in getCustomOp + inst = getCustomOp(model.graph.node[0], onnx_opset_version=ver) + assert inst.get_nodeattr(f"v{ver}_attr") == ver + # now use ModelWrapper.get_customop_wrapper for implicit + # fetching of op version + inst = model.get_customop_wrapper(model.graph.node[0]) + assert inst.get_nodeattr(f"v{ver}_attr") == ver From 407fb13c438b47b95205f77101fec9455faaebaf Mon Sep 17 00:00:00 2001 From: Yaman Umuroglu Date: Thu, 2 Oct 2025 13:30:11 +0200 Subject: [PATCH 07/20] [Test] extend test_customop_version for default v handler --- tests/custom_op/test_customop_version.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/tests/custom_op/test_customop_version.py b/tests/custom_op/test_customop_version.py index bdc660f9..79974037 100644 --- a/tests/custom_op/test_customop_version.py +++ b/tests/custom_op/test_customop_version.py @@ -92,7 +92,10 @@ def make_vertest_model(vertest_ver): def test_customop_version(): + # unspecified version defaults to v1 implementation general.custom_op["VerTestOp"] = VerTestOp_v1 + # v1 version is also explicitly registered + general.custom_op["VerTestOp_v1"] = VerTestOp_v1 general.custom_op["VerTestOp_v2"] = VerTestOp_v2 general.custom_op["VerTestOp_v3"] = VerTestOp_v3 for ver in [1, 2, 3]: @@ -104,3 +107,7 @@ def test_customop_version(): # fetching of op version inst = model.get_customop_wrapper(model.graph.node[0]) assert inst.get_nodeattr(f"v{ver}_attr") == ver + # unspecified version getCustomOp should default to v1 handler + # (even though the node is actually v3 in this case) + inst = getCustomOp(model.graph.node[0]) + assert isinstance(inst, VerTestOp_v1) From feac9f09bcc5dce9137d64d6309504d3e8fd46d4 Mon Sep 17 00:00:00 2001 From: Yaman Umuroglu Date: Thu, 2 Oct 2025 14:27:10 +0200 Subject: [PATCH 08/20] [Core] opset ver. fallback for ModelWrapper.get_customop_wrapper --- src/qonnx/core/modelwrapper.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/src/qonnx/core/modelwrapper.py b/src/qonnx/core/modelwrapper.py index b2308a06..3d3ba0e9 100644 --- a/src/qonnx/core/modelwrapper.py +++ b/src/qonnx/core/modelwrapper.py @@ -743,9 +743,19 @@ def get_opset_imports(self): """Returns a list of imported opsets as a {domain, version} dictionary.""" return {opset.domain: opset.version for opset in self._model_proto.opset_import} - def get_customop_wrapper(self, node): + def get_customop_wrapper(self, node, fallback_customop_version=1): """Return CustomOp instance for given node, respecting the - imported opset version in the model protobuf.""" + imported opset version in the model protobuf. If the node's domain + is not found in the model's opset imports, fallback_customop_version + will be used.""" opset_imports = self.get_opset_imports() - opset_import = opset_imports[node.domain] - return getCustomOp(node, onnx_opset_version=opset_import) + try: + opset_import = opset_imports[node.domain] + return getCustomOp(node, onnx_opset_version=opset_import) + except KeyError: + # domain not found in imports, use fallback version + warnings.warn( + f"Domain {node.domain} not found in model opset imports, " + f"using fallback_customop_version={fallback_customop_version}" + ) + return getCustomOp(node, onnx_opset_version=fallback_customop_version) From 89eea4cfb7a8e421ed5c22ba814fc360f4da4b48 Mon Sep 17 00:00:00 2001 From: Yaman Umuroglu Date: Thu, 2 Oct 2025 14:28:01 +0200 Subject: [PATCH 09/20] [Core] getCustomOp: default v to None, fetch highest available v. --- src/qonnx/custom_op/registry.py | 39 ++++++++++++++++++++++----------- 1 file changed, 26 insertions(+), 13 deletions(-) diff --git a/src/qonnx/custom_op/registry.py b/src/qonnx/custom_op/registry.py index 5d1a52ca..825c7566 100644 --- a/src/qonnx/custom_op/registry.py +++ b/src/qonnx/custom_op/registry.py @@ -28,11 +28,12 @@ import importlib -from qonnx.util.basic import get_preferred_onnx_opset - -def getCustomOp(node, onnx_opset_version=get_preferred_onnx_opset(), brevitas_exception=True): - "Return a QONNX CustomOp instance for the given ONNX node, if it exists." +def getCustomOp(node, onnx_opset_version=None, brevitas_exception=True): + "Return a QONNX CustomOp wrapper for the given ONNX node and given opset version," + "if it exists. If opset version is None, the default handler for the op type will be used. " + "If version is specified but the exact version match isn't available, the highest available version " + "smaller than the requested version will be used." op_type = node.op_type domain = node.domain if brevitas_exception: @@ -41,18 +42,30 @@ def getCustomOp(node, onnx_opset_version=get_preferred_onnx_opset(), brevitas_ex try: opset_module = importlib.import_module(domain) assert type(opset_module.custom_op) is dict, "custom_op dict not found in Python module %s" % domain - op_type_with_version = op_type + "_v" + str(onnx_opset_version) - # TODO version handling: use highest available version smaller than requested version - # when the exact match is not found - if op_type_with_version in opset_module.custom_op: - # priority: if it exists, load the versioned CustomOp wrapper - inst_wrapper = opset_module.custom_op[op_type_with_version] - else: - # otherwise use the default (non-suffixed) CustomOp wrapper + if onnx_opset_version is None: inst_wrapper = opset_module.custom_op[op_type] + else: + op_type_with_version = op_type + "_v" + str(onnx_opset_version) + if op_type_with_version in opset_module.custom_op: + # priority: if it exists, load the versioned CustomOp wrapper + inst_wrapper = opset_module.custom_op[op_type_with_version] + else: + # when the exact version match is not found + # version handling: use highest available version smaller than requested version + available_versions = [ + int(k.split("_v")[-1]) for k in opset_module.custom_op.keys() if k.startswith(op_type + "_v") + ] + suitable_versions = [v for v in available_versions if v <= onnx_opset_version] + if suitable_versions: + highest_version = max(suitable_versions) + inst_wrapper = opset_module.custom_op[f"{op_type}_v{highest_version}"] + else: + raise Exception( + "Op %s version %s not found in custom opset %s" % (op_type, str(onnx_opset_version), domain) + ) inst = inst_wrapper(node, onnx_opset_version=onnx_opset_version) return inst except ModuleNotFoundError: raise Exception("Could not load custom opset %s, check your PYTHONPATH" % domain) except KeyError: - raise Exception("Op %s not found in custom opset %s" % (op_type, domain)) + raise Exception("Op %s version %s not found in custom opset %s" % (op_type, str(onnx_opset_version), domain)) From ec517b517a47db680180586c8063613431297745 Mon Sep 17 00:00:00 2001 From: Yaman Umuroglu Date: Thu, 2 Oct 2025 14:31:15 +0200 Subject: [PATCH 10/20] [Test] cover newly added opset ver behavior in test_customop_version --- tests/custom_op/test_customop_version.py | 37 +++++++++++++++++++----- 1 file changed, 29 insertions(+), 8 deletions(-) diff --git a/tests/custom_op/test_customop_version.py b/tests/custom_op/test_customop_version.py index 79974037..5364df61 100644 --- a/tests/custom_op/test_customop_version.py +++ b/tests/custom_op/test_customop_version.py @@ -69,15 +69,19 @@ def get_nodeattr_types(self): return my_attrs -def make_vertest_model(vertest_ver): +def make_vertest_model(vertest_ver, no_opset_import): ishp = (1, 10) oshp = ishp ishp_str = str(list(ishp)) oshp_str = str(list(oshp)) + if no_opset_import: + opset_import = "" + else: + opset_import = f', "qonnx.custom_op.general" : {vertest_ver}' input = f""" < ir_version: 7, - opset_import: ["" : 9, "qonnx.custom_op.general" : {vertest_ver}] + opset_import: ["" : 9{opset_import}] > agraph (float{ishp_str} in0) => (float{oshp_str} out0) {{ @@ -98,16 +102,33 @@ def test_customop_version(): general.custom_op["VerTestOp_v1"] = VerTestOp_v1 general.custom_op["VerTestOp_v2"] = VerTestOp_v2 general.custom_op["VerTestOp_v3"] = VerTestOp_v3 + + # if onnx is lacking the opset import, should default to v1 handler + # (since we set custom_op["VerTestOp"] = VerTestOp_v1) + model = make_vertest_model(1, True) + inst = getCustomOp(model.graph.node[0]) + assert isinstance(inst, VerTestOp_v1) + # alternatively, when using ModelWrapper.get_customop_wrapper and onnx is + # lacking the opset import, should fall back to the specified version + inst = model.get_customop_wrapper(model.graph.node[0], fallback_customop_version=2) + assert isinstance(inst, VerTestOp_v2) + for ver in [1, 2, 3]: - model = make_vertest_model(ver) - # explicitly specify onnx_opset_version in getCustomOp - inst = getCustomOp(model.graph.node[0], onnx_opset_version=ver) - assert inst.get_nodeattr(f"v{ver}_attr") == ver - # now use ModelWrapper.get_customop_wrapper for implicit + model = make_vertest_model(ver, False) + # use ModelWrapper.get_customop_wrapper for implicit # fetching of op version inst = model.get_customop_wrapper(model.graph.node[0]) assert inst.get_nodeattr(f"v{ver}_attr") == ver + # explicitly specify onnx_opset_version in getCustomOp + # note: new code should avoid calling getCustomOp directly like this + # and instead use ModelWrapper.get_customop_wrapper + inst = getCustomOp(model.graph.node[0], onnx_opset_version=ver) + assert inst.get_nodeattr(f"v{ver}_attr") == ver # unspecified version getCustomOp should default to v1 handler - # (even though the node is actually v3 in this case) + model = make_vertest_model(1, False) inst = getCustomOp(model.graph.node[0]) assert isinstance(inst, VerTestOp_v1) + # requesting v4 should return largest available version (v3 in this case) + model = make_vertest_model(3, False) + inst = getCustomOp(model.graph.node[0], onnx_opset_version=4) + assert isinstance(inst, VerTestOp_v3) From aeeff580d8bc1fa53f910474291997506bade759 Mon Sep 17 00:00:00 2001 From: Yaman Umuroglu Date: Thu, 2 Oct 2025 22:02:17 +0200 Subject: [PATCH 11/20] [Core, Util] distinguish preferred onnx opset from qonnx opset --- src/qonnx/core/modelwrapper.py | 2 +- src/qonnx/util/basic.py | 10 +++++++++- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/src/qonnx/core/modelwrapper.py b/src/qonnx/core/modelwrapper.py index fa5a968d..2ba2984a 100644 --- a/src/qonnx/core/modelwrapper.py +++ b/src/qonnx/core/modelwrapper.py @@ -743,7 +743,7 @@ def get_opset_imports(self): """Returns a list of imported opsets as a {domain, version} dictionary.""" return {opset.domain: opset.version for opset in self._model_proto.opset_import} - def get_customop_wrapper(self, node, fallback_customop_version=1): + def get_customop_wrapper(self, node, fallback_customop_version=util.get_preferred_qonnx_opset()): """Return CustomOp instance for given node, respecting the imported opset version in the model protobuf. If the node's domain is not found in the model's opset imports, fallback_customop_version diff --git a/src/qonnx/util/basic.py b/src/qonnx/util/basic.py index 3a3ce2af..e756366d 100644 --- a/src/qonnx/util/basic.py +++ b/src/qonnx/util/basic.py @@ -51,11 +51,19 @@ def get_preferred_onnx_opset(): return 11 +def get_preferred_qonnx_opset(): + "Return preferred ONNX opset version for QONNX" + return 1 + + def qonnx_make_model(graph_proto, **kwargs): "Wrapper around ONNX make_model with preferred qonnx opset version" opset_imports = kwargs.pop("opset_imports", None) if opset_imports is None: - opset_imports = [make_opsetid("", get_preferred_onnx_opset())] + opset_imports = [ + make_opsetid("", get_preferred_onnx_opset()), + make_opsetid("qonnx.custom_op.general", get_preferred_qonnx_opset()), + ] kwargs["opset_imports"] = opset_imports else: kwargs["opset_imports"] = opset_imports From 580150453f16bb8aa08ff4de1eee602f375b323e Mon Sep 17 00:00:00 2001 From: Yaman Umuroglu Date: Thu, 2 Oct 2025 22:03:21 +0200 Subject: [PATCH 12/20] [Core] respect selected opsets during execution --- src/qonnx/core/execute_custom_node.py | 3 +-- src/qonnx/core/onnx_exec.py | 12 ++++++++---- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/src/qonnx/core/execute_custom_node.py b/src/qonnx/core/execute_custom_node.py index 7acf3792..cd6bb605 100644 --- a/src/qonnx/core/execute_custom_node.py +++ b/src/qonnx/core/execute_custom_node.py @@ -27,10 +27,9 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. import qonnx.custom_op.registry as registry -from qonnx.util.basic import get_preferred_onnx_opset -def execute_custom_node(node, context, graph, onnx_opset_version=get_preferred_onnx_opset()): +def execute_custom_node(node, context, graph, onnx_opset_version): """Call custom implementation to execute a single custom node. Input/output provided via context.""" op_type = node.op_type diff --git a/src/qonnx/core/onnx_exec.py b/src/qonnx/core/onnx_exec.py index a8f4774c..ecb808be 100644 --- a/src/qonnx/core/onnx_exec.py +++ b/src/qonnx/core/onnx_exec.py @@ -36,7 +36,7 @@ import qonnx.analysis.topology as ta import qonnx.core.execute_custom_node as ex_cu_node from qonnx.util.basic import ( - get_preferred_onnx_opset, + get_preferred_qonnx_opset, get_sanitize_quant_tensors, is_finn_op, qonnx_make_model, @@ -44,7 +44,7 @@ ) -def execute_node(node, context, graph, return_full_exec_context=False, opset_version=get_preferred_onnx_opset()): +def execute_node(node, context, graph, opset_version, return_full_exec_context=False): """Executes a single node by using onnxruntime or with a custom function. Input/output provided via context.""" @@ -158,7 +158,7 @@ def execute_onnx(model, input_dict, return_full_exec_context=False, start_node=N model_exec_mode = model.get_metadata_prop("exec_mode") if (model_exec_mode is None) or (model_exec_mode == ""): # extract opset version for node-by-node execution - opset_version = model.model.opset_import[0].version + opset_imports = model.get_opset_imports() # execute the model node by node # we can simply walk down the list since the ONNX spec guarantees that it is # topologically sorted @@ -176,7 +176,11 @@ def execute_onnx(model, input_dict, return_full_exec_context=False, start_node=N if get_sanitize_quant_tensors() != 0: # round input values to match quantization annotation execution_context = sanitize_quant_values(model, node.input, execution_context) - execute_node(node, execution_context, graph, return_full_exec_context, opset_version) + if node.domain in opset_imports: + opset_version = opset_imports[node.domain] + else: + opset_version = get_preferred_qonnx_opset() + execute_node(node, execution_context, graph, opset_version, return_full_exec_context) if get_sanitize_quant_tensors() != 0: # round output values to quantization annotation execution_context = sanitize_quant_values(model, node.output, execution_context) From 35b8b12d919ae37fe94c317a5585e61fa705b5b6 Mon Sep 17 00:00:00 2001 From: Yaman Umuroglu Date: Thu, 2 Oct 2025 22:03:42 +0200 Subject: [PATCH 13/20] [CustomOp] alias all qonnx.custom_op.general as v1 --- src/qonnx/custom_op/general/__init__.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/src/qonnx/custom_op/general/__init__.py b/src/qonnx/custom_op/general/__init__.py index 9b14ea8a..e125cbf8 100644 --- a/src/qonnx/custom_op/general/__init__.py +++ b/src/qonnx/custom_op/general/__init__.py @@ -52,3 +52,16 @@ custom_op["Trunc"] = Trunc custom_op["BipolarQuant"] = BipolarQuant custom_op["FloatQuant"] = FloatQuant + +custom_op["DebugMarker_v1"] = DebugMarker +custom_op["QuantAvgPool2d_v1"] = QuantAvgPool2d +custom_op["MaxPoolNHWC_v1"] = MaxPoolNHWC +custom_op["GenericPartition_v1"] = GenericPartition +custom_op["MultiThreshold_v1"] = MultiThreshold +custom_op["XnorPopcountMatMul_v1"] = XnorPopcountMatMul +custom_op["Im2Col_v1"] = Im2Col +custom_op["IntQuant_v1"] = IntQuant +custom_op["Quant_v1"] = IntQuant +custom_op["Trunc_v1"] = Trunc +custom_op["BipolarQuant_v1"] = BipolarQuant +custom_op["FloatQuant_v1"] = FloatQuant From d190a69d813666c1bc55e42a895b84e026630e8a Mon Sep 17 00:00:00 2001 From: Yaman Umuroglu Date: Thu, 2 Oct 2025 22:04:42 +0200 Subject: [PATCH 14/20] [ChanLast] alias existing channels_last ops as v1 --- src/qonnx/custom_op/channels_last/__init__.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/qonnx/custom_op/channels_last/__init__.py b/src/qonnx/custom_op/channels_last/__init__.py index f1d7c39b..60aac1a4 100644 --- a/src/qonnx/custom_op/channels_last/__init__.py +++ b/src/qonnx/custom_op/channels_last/__init__.py @@ -7,3 +7,7 @@ custom_op["Conv"] = Conv custom_op["MaxPool"] = MaxPool custom_op["BatchNormalization"] = BatchNormalization + +custom_op["Conv_v1"] = Conv +custom_op["MaxPool_v1"] = MaxPool +custom_op["BatchNormalization_v1"] = BatchNormalization From 5f58f49dbc4110b80f013b919209c758f64dae1b Mon Sep 17 00:00:00 2001 From: Yaman Umuroglu Date: Thu, 2 Oct 2025 22:49:05 +0200 Subject: [PATCH 15/20] [Test] add opsets for test_custom_onnx_exec --- tests/core/test_custom_onnx_exec.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/tests/core/test_custom_onnx_exec.py b/tests/core/test_custom_onnx_exec.py index 8eec7156..54b71754 100644 --- a/tests/core/test_custom_onnx_exec.py +++ b/tests/core/test_custom_onnx_exec.py @@ -32,6 +32,8 @@ import qonnx.core.execute_custom_node as ex_cu_node from qonnx.custom_op.registry import getCustomOp +mt_node_version = 1 + def test_execute_custom_node_multithreshold(): inputs = np.ndarray( @@ -155,7 +157,7 @@ def test_execute_custom_node_multithreshold(): execution_context["v"] = inputs execution_context["thresholds"] = threshold_values - ex_cu_node.execute_custom_node(node_def, execution_context, graph_def) + ex_cu_node.execute_custom_node(node_def, execution_context, graph_def, mt_node_version) outputs = np.ndarray( shape=(6, 3, 2, 2), @@ -250,7 +252,7 @@ def test_execute_custom_node_multithreshold(): ) graph_def = helper.make_graph([node_def], "test_model", [v, thresholds], [out]) - ex_cu_node.execute_custom_node(node_def, execution_context, graph_def) + ex_cu_node.execute_custom_node(node_def, execution_context, graph_def, mt_node_version) outputs_scaled = 2.0 * outputs - 1.0 assert (execution_context["out"] == outputs_scaled).all() @@ -270,7 +272,7 @@ def test_execute_custom_node_multithreshold(): execution_context["v"] = inputs_nhwc graph_def = helper.make_graph([node_def], "test_model", [v_nhwc, thresholds], [out_nhwc]) - ex_cu_node.execute_custom_node(node_def, execution_context, graph_def) + ex_cu_node.execute_custom_node(node_def, execution_context, graph_def, mt_node_version) assert (execution_context["out"] == outputs_nhwc).all() # check the set of allowed values op_inst = getCustomOp(node_def) From db0b15a1f01bdddcb51eaec05131084d6ab9cc49 Mon Sep 17 00:00:00 2001 From: Yaman Umuroglu Date: Fri, 3 Oct 2025 16:34:44 +0200 Subject: [PATCH 16/20] [ChanLast] emulate op ver agnostic dict for channels last ops --- src/qonnx/custom_op/channels_last/__init__.py | 26 ++++++++++++++----- 1 file changed, 19 insertions(+), 7 deletions(-) diff --git a/src/qonnx/custom_op/channels_last/__init__.py b/src/qonnx/custom_op/channels_last/__init__.py index 60aac1a4..f5033e9b 100644 --- a/src/qonnx/custom_op/channels_last/__init__.py +++ b/src/qonnx/custom_op/channels_last/__init__.py @@ -2,12 +2,24 @@ from qonnx.custom_op.channels_last.conv import Conv from qonnx.custom_op.channels_last.max_pool import MaxPool -custom_op = dict() +# channels-last ops are defined by the underlying ONNX standard op +# thus, we can define them for any version of the original op +# so we emulate a custom op dictionary that mimics the support for any +# {ChannelsLastOp}_vX instead of hardcoding what versions are supported -custom_op["Conv"] = Conv -custom_op["MaxPool"] = MaxPool -custom_op["BatchNormalization"] = BatchNormalization -custom_op["Conv_v1"] = Conv -custom_op["MaxPool_v1"] = MaxPool -custom_op["BatchNormalization_v1"] = BatchNormalization +class ChannelsLastCustomOpDict: + def __init__(self): + self._custom_ops = {"Conv": Conv, "MaxPool": MaxPool, "BatchNormalization": BatchNormalization} + + def __getitem__(self, key): + base_key = key.split("_v")[0] # Extract base key (e.g., Conv from Conv_v13) + if base_key in self._custom_ops: + return self._custom_ops[base_key] + raise KeyError(f"Channels-last CustomOp '{key}' not found.") + + def keys(self): + return self._custom_ops.keys() + + +custom_op = ChannelsLastCustomOpDict() From 83c53aef80b42ca1528120a932eba64c5d78d3e0 Mon Sep 17 00:00:00 2001 From: Yaman Umuroglu Date: Fri, 3 Oct 2025 17:04:36 +0200 Subject: [PATCH 17/20] [Core] use isinstance instead of type check for custom_op --- src/qonnx/custom_op/registry.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/qonnx/custom_op/registry.py b/src/qonnx/custom_op/registry.py index 825c7566..442089c3 100644 --- a/src/qonnx/custom_op/registry.py +++ b/src/qonnx/custom_op/registry.py @@ -41,7 +41,7 @@ def getCustomOp(node, onnx_opset_version=None, brevitas_exception=True): domain = domain.replace("onnx.brevitas", "qonnx.custom_op.general") try: opset_module = importlib.import_module(domain) - assert type(opset_module.custom_op) is dict, "custom_op dict not found in Python module %s" % domain + assert isinstance(opset_module.custom_op, dict), "custom_op dict not found in Python module %s" % domain if onnx_opset_version is None: inst_wrapper = opset_module.custom_op[op_type] else: From 6bfc2a181b5fad34bf30696a15eef4abbb9f3e06 Mon Sep 17 00:00:00 2001 From: Yaman Umuroglu Date: Fri, 3 Oct 2025 17:05:20 +0200 Subject: [PATCH 18/20] [ChanLast] derive fake custom_op from dict, ensure domain import --- src/qonnx/custom_op/channels_last/__init__.py | 2 +- src/qonnx/transformation/channels_last.py | 7 ++++++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/src/qonnx/custom_op/channels_last/__init__.py b/src/qonnx/custom_op/channels_last/__init__.py index f5033e9b..02aa0d53 100644 --- a/src/qonnx/custom_op/channels_last/__init__.py +++ b/src/qonnx/custom_op/channels_last/__init__.py @@ -8,7 +8,7 @@ # {ChannelsLastOp}_vX instead of hardcoding what versions are supported -class ChannelsLastCustomOpDict: +class ChannelsLastCustomOpDict(dict): def __init__(self): self._custom_ops = {"Conv": Conv, "MaxPool": MaxPool, "BatchNormalization": BatchNormalization} diff --git a/src/qonnx/transformation/channels_last.py b/src/qonnx/transformation/channels_last.py index 175af058..c352238c 100644 --- a/src/qonnx/transformation/channels_last.py +++ b/src/qonnx/transformation/channels_last.py @@ -270,8 +270,13 @@ def apply(self, model): # Attach to original node n.output[i] = outp_trans_in - # Modify domain + # Modify node domain n.domain = "qonnx.custom_op.channels_last" + opset_imports = model.get_opset_imports() + # Ensure channels_last domain is imported in model + if "qonnx.custom_op.channels_last" not in opset_imports: + onnx_opset = opset_imports[""] + model.model.opset_import.append(helper.make_opsetid("qonnx.custom_op.channels_last", onnx_opset)) # Set modified flag graph_modified = True From c9811c5cb7aedea89d1712e00e4d4e64cfe9b1bc Mon Sep 17 00:00:00 2001 From: Yaman Umuroglu Date: Sat, 4 Oct 2025 00:19:31 +0200 Subject: [PATCH 19/20] [QuantAvgPool2d] use preferred ONNX opset for exec_node() impl --- src/qonnx/custom_op/general/quantavgpool2d.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/qonnx/custom_op/general/quantavgpool2d.py b/src/qonnx/custom_op/general/quantavgpool2d.py index c0e24071..00617dcf 100644 --- a/src/qonnx/custom_op/general/quantavgpool2d.py +++ b/src/qonnx/custom_op/general/quantavgpool2d.py @@ -33,7 +33,7 @@ from qonnx.core.datatype import DataType from qonnx.custom_op.base import CustomOp from qonnx.custom_op.general.maxpoolnhwc import compute_pool_output_dim -from qonnx.util.basic import qonnx_make_model +from qonnx.util.basic import get_preferred_onnx_opset, qonnx_make_model class QuantAvgPool2d(CustomOp): @@ -132,7 +132,7 @@ def execute_node(self, context, graph): outputs=[outp], ) - opset_version = self.onnx_opset_version + opset_version = get_preferred_onnx_opset() opset_imports = [helper.make_opsetid("", opset_version)] onnx_kwargs = {"opset_imports": opset_imports} model_avgpool = qonnx_make_model(graph_avgpool, **onnx_kwargs) From 073985d9c6e93ed1273cb7f089cfeaedbd17a5da Mon Sep 17 00:00:00 2001 From: Yaman Umuroglu Date: Sat, 4 Oct 2025 00:30:20 +0200 Subject: [PATCH 20/20] [ChanLast] implement __contains__ for op registration --- src/qonnx/custom_op/channels_last/__init__.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/qonnx/custom_op/channels_last/__init__.py b/src/qonnx/custom_op/channels_last/__init__.py index 02aa0d53..9ffd4e54 100644 --- a/src/qonnx/custom_op/channels_last/__init__.py +++ b/src/qonnx/custom_op/channels_last/__init__.py @@ -18,6 +18,10 @@ def __getitem__(self, key): return self._custom_ops[base_key] raise KeyError(f"Channels-last CustomOp '{key}' not found.") + def __contains__(self, key): + base_key = key.split("_v")[0] + return base_key in self._custom_ops + def keys(self): return self._custom_ops.keys()