Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
7dfc4b8
[Lint] rerun linter, fix errors
maltanar Sep 25, 2025
7456919
[Core] add get_opset_imports utility fxn to ModelWrapper
maltanar Sep 25, 2025
89396cd
[Core] return dict from ModelWrapper.get_opset_imports
maltanar Sep 25, 2025
db2994f
[Core] add versioned op to getCustomOp with fallback to old style
maltanar Sep 25, 2025
8a2db22
[Core] inrtoduce ModelWrapper.get_customop_wrapper
maltanar Sep 25, 2025
402a580
[Test] add basic unit tests for versioned custom op fetching
maltanar Sep 25, 2025
ad80561
Merge branch 'main' into feature/op_version
maltanar Oct 2, 2025
407fb13
[Test] extend test_customop_version for default v handler
maltanar Oct 2, 2025
feac9f0
[Core] opset ver. fallback for ModelWrapper.get_customop_wrapper
maltanar Oct 2, 2025
89eea4c
[Core] getCustomOp: default v to None, fetch highest available v.
maltanar Oct 2, 2025
ec517b5
[Test] cover newly added opset ver behavior in test_customop_version
maltanar Oct 2, 2025
7406dcf
Merge branch 'main' into feature/op_version
maltanar Oct 2, 2025
aeeff58
[Core, Util] distinguish preferred onnx opset from qonnx opset
maltanar Oct 2, 2025
5801504
[Core] respect selected opsets during execution
maltanar Oct 2, 2025
35b8b12
[CustomOp] alias all qonnx.custom_op.general as v1
maltanar Oct 2, 2025
d190a69
[ChanLast] alias existing channels_last ops as v1
maltanar Oct 2, 2025
5f58f49
[Test] add opsets for test_custom_onnx_exec
maltanar Oct 2, 2025
db0b15a
[ChanLast] emulate op ver agnostic dict for channels last ops
maltanar Oct 3, 2025
83c53ae
[Core] use isinstance instead of type check for custom_op
maltanar Oct 3, 2025
6bfc2a1
[ChanLast] derive fake custom_op from dict, ensure domain import
maltanar Oct 3, 2025
c9811c5
[QuantAvgPool2d] use preferred ONNX opset for exec_node() impl
maltanar Oct 3, 2025
073985d
[ChanLast] implement __contains__ for op registration
maltanar Oct 3, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions src/qonnx/core/execute_custom_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,9 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import qonnx.custom_op.registry as registry
from qonnx.util.basic import get_preferred_onnx_opset


def execute_custom_node(node, context, graph, onnx_opset_version=get_preferred_onnx_opset()):
def execute_custom_node(node, context, graph, onnx_opset_version):
"""Call custom implementation to execute a single custom node.
Input/output provided via context."""
op_type = node.op_type
Expand Down
22 changes: 22 additions & 0 deletions src/qonnx/core/modelwrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
import qonnx.util.basic as util
import qonnx.util.onnx as onnxutil
from qonnx.core.datatype import DataType
from qonnx.custom_op.registry import getCustomOp
from qonnx.transformation.double_to_single_float import DoubleToSingleFloat
from qonnx.transformation.general import (
RemoveStaticGraphInputs,
Expand Down Expand Up @@ -737,3 +738,24 @@ def set_tensor_sparsity(self, tensor_name, sparsity_dict):
qa.tensor_name = tensor_name
qa.quant_parameter_tensor_names.append(dt)
qnt_annotations.append(qa)

def get_opset_imports(self):
"""Returns a list of imported opsets as a {domain, version} dictionary."""
return {opset.domain: opset.version for opset in self._model_proto.opset_import}

def get_customop_wrapper(self, node, fallback_customop_version=util.get_preferred_qonnx_opset()):
"""Return CustomOp instance for given node, respecting the
imported opset version in the model protobuf. If the node's domain
is not found in the model's opset imports, fallback_customop_version
will be used."""
opset_imports = self.get_opset_imports()
try:
opset_import = opset_imports[node.domain]
return getCustomOp(node, onnx_opset_version=opset_import)
except KeyError:
# domain not found in imports, use fallback version
warnings.warn(
f"Domain {node.domain} not found in model opset imports, "
f"using fallback_customop_version={fallback_customop_version}"
)
return getCustomOp(node, onnx_opset_version=fallback_customop_version)
12 changes: 8 additions & 4 deletions src/qonnx/core/onnx_exec.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,15 @@
import qonnx.analysis.topology as ta
import qonnx.core.execute_custom_node as ex_cu_node
from qonnx.util.basic import (
get_preferred_onnx_opset,
get_preferred_qonnx_opset,
get_sanitize_quant_tensors,
is_finn_op,
qonnx_make_model,
sanitize_quant_values,
)


def execute_node(node, context, graph, return_full_exec_context=False, opset_version=get_preferred_onnx_opset()):
def execute_node(node, context, graph, opset_version, return_full_exec_context=False):
"""Executes a single node by using onnxruntime or with a custom function.

Input/output provided via context."""
Expand Down Expand Up @@ -158,7 +158,7 @@ def execute_onnx(model, input_dict, return_full_exec_context=False, start_node=N
model_exec_mode = model.get_metadata_prop("exec_mode")
if (model_exec_mode is None) or (model_exec_mode == ""):
# extract opset version for node-by-node execution
opset_version = model.model.opset_import[0].version
opset_imports = model.get_opset_imports()
# execute the model node by node
# we can simply walk down the list since the ONNX spec guarantees that it is
# topologically sorted
Expand All @@ -176,7 +176,11 @@ def execute_onnx(model, input_dict, return_full_exec_context=False, start_node=N
if get_sanitize_quant_tensors() != 0:
# round input values to match quantization annotation
execution_context = sanitize_quant_values(model, node.input, execution_context)
execute_node(node, execution_context, graph, return_full_exec_context, opset_version)
if node.domain in opset_imports:
opset_version = opset_imports[node.domain]
else:
opset_version = get_preferred_qonnx_opset()
execute_node(node, execution_context, graph, opset_version, return_full_exec_context)
if get_sanitize_quant_tensors() != 0:
# round output values to quantization annotation
execution_context = sanitize_quant_values(model, node.output, execution_context)
Expand Down
28 changes: 24 additions & 4 deletions src/qonnx/custom_op/channels_last/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,28 @@
from qonnx.custom_op.channels_last.conv import Conv
from qonnx.custom_op.channels_last.max_pool import MaxPool

custom_op = dict()
# channels-last ops are defined by the underlying ONNX standard op
# thus, we can define them for any version of the original op
# so we emulate a custom op dictionary that mimics the support for any
# {ChannelsLastOp}_vX instead of hardcoding what versions are supported

custom_op["Conv"] = Conv
custom_op["MaxPool"] = MaxPool
custom_op["BatchNormalization"] = BatchNormalization

class ChannelsLastCustomOpDict(dict):
def __init__(self):
self._custom_ops = {"Conv": Conv, "MaxPool": MaxPool, "BatchNormalization": BatchNormalization}

def __getitem__(self, key):
base_key = key.split("_v")[0] # Extract base key (e.g., Conv from Conv_v13)
if base_key in self._custom_ops:
return self._custom_ops[base_key]
raise KeyError(f"Channels-last CustomOp '{key}' not found.")

def __contains__(self, key):
base_key = key.split("_v")[0]
return base_key in self._custom_ops

def keys(self):
return self._custom_ops.keys()


custom_op = ChannelsLastCustomOpDict()
13 changes: 13 additions & 0 deletions src/qonnx/custom_op/general/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,16 @@
custom_op["Trunc"] = Trunc
custom_op["BipolarQuant"] = BipolarQuant
custom_op["FloatQuant"] = FloatQuant

custom_op["DebugMarker_v1"] = DebugMarker
custom_op["QuantAvgPool2d_v1"] = QuantAvgPool2d
custom_op["MaxPoolNHWC_v1"] = MaxPoolNHWC
custom_op["GenericPartition_v1"] = GenericPartition
custom_op["MultiThreshold_v1"] = MultiThreshold
custom_op["XnorPopcountMatMul_v1"] = XnorPopcountMatMul
custom_op["Im2Col_v1"] = Im2Col
custom_op["IntQuant_v1"] = IntQuant
custom_op["Quant_v1"] = IntQuant
custom_op["Trunc_v1"] = Trunc
custom_op["BipolarQuant_v1"] = BipolarQuant
custom_op["FloatQuant_v1"] = FloatQuant
4 changes: 2 additions & 2 deletions src/qonnx/custom_op/general/quantavgpool2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from qonnx.core.datatype import DataType
from qonnx.custom_op.base import CustomOp
from qonnx.custom_op.general.maxpoolnhwc import compute_pool_output_dim
from qonnx.util.basic import qonnx_make_model
from qonnx.util.basic import get_preferred_onnx_opset, qonnx_make_model


class QuantAvgPool2d(CustomOp):
Expand Down Expand Up @@ -132,7 +132,7 @@ def execute_node(self, context, graph):
outputs=[outp],
)

opset_version = self.onnx_opset_version
opset_version = get_preferred_onnx_opset()
opset_imports = [helper.make_opsetid("", opset_version)]
onnx_kwargs = {"opset_imports": opset_imports}
model_avgpool = qonnx_make_model(graph_avgpool, **onnx_kwargs)
Expand Down
35 changes: 28 additions & 7 deletions src/qonnx/custom_op/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,23 +28,44 @@

import importlib

from qonnx.util.basic import get_preferred_onnx_opset


def getCustomOp(node, onnx_opset_version=get_preferred_onnx_opset(), brevitas_exception=True):
"Return a QONNX CustomOp instance for the given ONNX node, if it exists."
def getCustomOp(node, onnx_opset_version=None, brevitas_exception=True):
"Return a QONNX CustomOp wrapper for the given ONNX node and given opset version,"
"if it exists. If opset version is None, the default handler for the op type will be used. "
"If version is specified but the exact version match isn't available, the highest available version "
"smaller than the requested version will be used."
op_type = node.op_type
domain = node.domain
if brevitas_exception:
# transparently resolve Brevitas domain ops to qonnx ones
domain = domain.replace("onnx.brevitas", "qonnx.custom_op.general")
try:
opset_module = importlib.import_module(domain)
assert type(opset_module.custom_op) is dict, "custom_op dict not found in Python module %s" % domain
inst_wrapper = opset_module.custom_op[op_type]
assert isinstance(opset_module.custom_op, dict), "custom_op dict not found in Python module %s" % domain
if onnx_opset_version is None:
inst_wrapper = opset_module.custom_op[op_type]
else:
op_type_with_version = op_type + "_v" + str(onnx_opset_version)
if op_type_with_version in opset_module.custom_op:
# priority: if it exists, load the versioned CustomOp wrapper
inst_wrapper = opset_module.custom_op[op_type_with_version]
else:
# when the exact version match is not found
# version handling: use highest available version smaller than requested version
available_versions = [
int(k.split("_v")[-1]) for k in opset_module.custom_op.keys() if k.startswith(op_type + "_v")
]
suitable_versions = [v for v in available_versions if v <= onnx_opset_version]
if suitable_versions:
highest_version = max(suitable_versions)
inst_wrapper = opset_module.custom_op[f"{op_type}_v{highest_version}"]
else:
raise Exception(
"Op %s version %s not found in custom opset %s" % (op_type, str(onnx_opset_version), domain)
)
inst = inst_wrapper(node, onnx_opset_version=onnx_opset_version)
return inst
except ModuleNotFoundError:
raise Exception("Could not load custom opset %s, check your PYTHONPATH" % domain)
except KeyError:
raise Exception("Op %s not found in custom opset %s" % (op_type, domain))
raise Exception("Op %s version %s not found in custom opset %s" % (op_type, str(onnx_opset_version), domain))
7 changes: 6 additions & 1 deletion src/qonnx/transformation/channels_last.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,8 +270,13 @@ def apply(self, model):
# Attach to original node
n.output[i] = outp_trans_in

# Modify domain
# Modify node domain
n.domain = "qonnx.custom_op.channels_last"
opset_imports = model.get_opset_imports()
# Ensure channels_last domain is imported in model
if "qonnx.custom_op.channels_last" not in opset_imports:
onnx_opset = opset_imports[""]
model.model.opset_import.append(helper.make_opsetid("qonnx.custom_op.channels_last", onnx_opset))
# Set modified flag
graph_modified = True

Expand Down
3 changes: 2 additions & 1 deletion src/qonnx/transformation/fixedpt_quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ class FixedPointQuantizeParamsFromDict(Transformation):
data type or its canonical name
rounding_mode: Rounding mode used for conversion into fixed point.
Default is "ROUND",
possible values: ["ROUND", "HALF_EVEN", "CEIL", "FLOOR", "UP", "DOWN", "HALF_UP", "HALF_DOWN"]
possible values: ["ROUND", "HALF_EVEN", "CEIL", "FLOOR", "UP", "DOWN",
"HALF_UP", "HALF_DOWN"]
"""

def __init__(self, fixedpt_dict, rounding_mode="ROUND"):
Expand Down
10 changes: 9 additions & 1 deletion src/qonnx/util/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,19 @@ def get_preferred_onnx_opset():
return 11


def get_preferred_qonnx_opset():
"Return preferred ONNX opset version for QONNX"
return 1


def qonnx_make_model(graph_proto, **kwargs):
"Wrapper around ONNX make_model with preferred qonnx opset version"
opset_imports = kwargs.pop("opset_imports", None)
if opset_imports is None:
opset_imports = [make_opsetid("", get_preferred_onnx_opset())]
opset_imports = [
make_opsetid("", get_preferred_onnx_opset()),
make_opsetid("qonnx.custom_op.general", get_preferred_qonnx_opset()),
]
kwargs["opset_imports"] = opset_imports
else:
kwargs["opset_imports"] = opset_imports
Expand Down
8 changes: 5 additions & 3 deletions tests/core/test_custom_onnx_exec.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@
import qonnx.core.execute_custom_node as ex_cu_node
from qonnx.custom_op.registry import getCustomOp

mt_node_version = 1


def test_execute_custom_node_multithreshold():
inputs = np.ndarray(
Expand Down Expand Up @@ -155,7 +157,7 @@ def test_execute_custom_node_multithreshold():
execution_context["v"] = inputs
execution_context["thresholds"] = threshold_values

ex_cu_node.execute_custom_node(node_def, execution_context, graph_def)
ex_cu_node.execute_custom_node(node_def, execution_context, graph_def, mt_node_version)

outputs = np.ndarray(
shape=(6, 3, 2, 2),
Expand Down Expand Up @@ -250,7 +252,7 @@ def test_execute_custom_node_multithreshold():
)

graph_def = helper.make_graph([node_def], "test_model", [v, thresholds], [out])
ex_cu_node.execute_custom_node(node_def, execution_context, graph_def)
ex_cu_node.execute_custom_node(node_def, execution_context, graph_def, mt_node_version)
outputs_scaled = 2.0 * outputs - 1.0
assert (execution_context["out"] == outputs_scaled).all()

Expand All @@ -270,7 +272,7 @@ def test_execute_custom_node_multithreshold():
execution_context["v"] = inputs_nhwc

graph_def = helper.make_graph([node_def], "test_model", [v_nhwc, thresholds], [out_nhwc])
ex_cu_node.execute_custom_node(node_def, execution_context, graph_def)
ex_cu_node.execute_custom_node(node_def, execution_context, graph_def, mt_node_version)
assert (execution_context["out"] == outputs_nhwc).all()
# check the set of allowed values
op_inst = getCustomOp(node_def)
Expand Down
1 change: 1 addition & 0 deletions tests/core/test_modelwrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def test_modelwrapper():
inp_sparsity = {"dw": {"kernel_shape": [3, 3]}}
model.set_tensor_sparsity(first_conv_iname, inp_sparsity)
assert model.get_tensor_sparsity(first_conv_iname) == inp_sparsity
assert model.get_opset_imports() == {"": 8}


def test_modelwrapper_set_get_rm_initializer():
Expand Down
Loading
Loading