Skip to content

QAT int8 MKL-DNN transformation pass with MUL #18322

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@ class TransformForMkldnnPass(object):
1. Convert int8 range weights with float32 data type, which are generated by
the QuantizationFreezePass, to float32 range weights with float32 data type
by using the corresponding scales. This conversion is because MKL-DNN INT8
conv2d kernel now only supports float32 weights input, will do weights
quantization inside the conv2d kernel.
2. Create the new conv2d op with the converted weights and link its output
conv2d kernel and mul kernel now only support float32 weights input, hence
weights quantization will happen inside the conv2d and mul INT8 kernel.
2. Create the new conv2d or mul op with the converted weights and link its output
to fake_dequantize_abs_max op's output and set conv2d's attribute "force_fp32
_output" as true
3. Transform fake_quantize_xx op to quantize op
Expand Down Expand Up @@ -73,13 +73,8 @@ def __init__(self, scope=None, place=None):

self.InScale = {}
self.max_range = {}
self.conv_new_output = {}
self.new_output = {}
self.s8_max = 127
# Temporary code for keeping the mul op as fake quantization
#TODO Intel: Remove the following code when mul int8 mkldnn
# kernel enabled
self.mul_input_id = []
self.mul_output_id = []

def apply(self, graph):
"""
Expand All @@ -97,28 +92,22 @@ def apply(self, graph):

persistable_vars = [p.name() for p in graph.all_persistable_nodes()]
# Collect the InScales and max_range to calculate the new scales for MKL-DNN
# INT8 conv2d
# INT8 conv2d and mul
for op_node in ops:
if op_node.name() in self.dequantize_type:
input_name = op_node.input("X")[0]
scale_name = op_node.input("Scale")[0]
self.InScale[input_name] = self._load_param(self._scope,
scale_name)[0]
self.max_range[input_name] = op_node.op().attr("max_range")
self.conv_new_output[input_name] = op_node.output("Out")[0]
# Temporary graph transform on keeping the mul op
# TODO Intel: Remove following code
elif op_node.name() in ['mul']:
input_node = graph._find_node_by_name(op_node.inputs,
op_node.input('X')[0])
output_node = graph._find_node_by_name(op_node.outputs,
op_node.output('Out')[0])
self.mul_input_id.append(input_node.id())
self.mul_output_id.append(output_node.id())
self.new_output[input_name] = op_node.output("Out")[0]

for op_node in ops:
if op_node.name() in self._conv_ops:
self._transform_to_conv_mkldnn(graph, op_node)
if op_node.name() in self._quantizable_ops:
if op_node.name() in self._conv_ops:
self._transform_to_conv_mkldnn(graph, op_node)
else:
self._transform_to_mul_mkldnn(graph, op_node)
elif op_node.name() in self.quantize_type:
self._transform_to_quantize_mkldnn(graph, op_node)
elif op_node.name() in self.dequantize_type:
Expand All @@ -132,16 +121,16 @@ def _transform_to_conv_mkldnn(self, graph, op_node):
# Convert int8 range weights to fp32 range weights
weight = self._load_param(self._scope, weight_name)
w_fp32 = np.divide(
np.multiply(weight, 127), self.max_range[output_name])
np.multiply(weight, self.s8_max), self.max_range[output_name])
w_fp32 = w_fp32.reshape(weight.shape)
self._restore_var(weight_name, w_fp32)
input_var_node = graph._find_node_by_name(op_node.inputs,
op_node.input("Input")[0])
weight_var_node = graph._find_node_by_name(op_node.inputs, weight_name)

# Set fake_dequantize_abs_max's output as new output of conv2d
output_var_node = graph._find_node_by_name(
graph.all_var_nodes(), self.conv_new_output[output_name])
output_var_node = graph._find_node_by_name(graph.all_var_nodes(),
self.new_output[output_name])
attrs = {
name: op_node.op().attr(name)
for name in op_node.op().attr_names()
Expand All @@ -157,7 +146,7 @@ def _transform_to_conv_mkldnn(self, graph, op_node):
# Based on the QAT's scales to calculate the scales of MKL-DNN INT8 conv2d
scale_in = self.s8_max / self.InScale[output_name]
scale_w = []
scale_w.append(self.max_range[output_name] / self.s8_max)
scale_w = [self.max_range[output_name] / self.s8_max]

conv_op_node.set_attr("Scale_weights", scale_w)
conv_op_node.set_attr("Scale_in", scale_in)
Expand All @@ -169,6 +158,50 @@ def _transform_to_conv_mkldnn(self, graph, op_node):
graph.link_to(conv_op_node, output_var_node)
graph.safe_remove_nodes(op_node)

def _transform_to_mul_mkldnn(self, graph, op_node):
# For MKL-DNN INT8 mul, input Y should be the weights
weight_name = op_node.input("Y")[0]
output_name = op_node.output("Out")[0]
# Convert int8 range weights to fp32 range weights
weight = self._load_param(self._scope, weight_name)
w_fp32 = np.divide(
np.multiply(weight, self.s8_max), self.max_range[output_name])
w_fp32 = w_fp32.reshape(weight.shape)
self._restore_var(weight_name, w_fp32)
input_var_node = graph._find_node_by_name(op_node.inputs,
op_node.input("X")[0])
weight_var_node = graph._find_node_by_name(op_node.inputs, weight_name)

# Set fake_dequantize_abs_max's output as new output of mul
output_var_node = graph._find_node_by_name(graph.all_var_nodes(),
self.new_output[output_name])
attrs = {
name: op_node.op().attr(name)
for name in op_node.op().attr_names()
}

mul_op_node = graph.create_op_node(
op_type='mul',
attrs=attrs,
inputs={'X': input_var_node,
'Y': weight_var_node},
outputs={'Out': output_var_node})

# Based on the QAT's scales to calculate MKL-DNN INT8 mul's scales
scale_in = self.s8_max / self.InScale[output_name]
scale_w = []
scale_w = [self.max_range[output_name] / self.s8_max]

mul_op_node.set_attr("scale_y", scale_w)
mul_op_node.set_attr("scale_x", scale_in)
mul_op_node.set_attr("scale_out", 1.0)
mul_op_node.set_attr("use_mkldnn", 1)
mul_op_node.set_attr("force_fp32_output", 1)
graph.link_to(input_var_node, mul_op_node)
graph.link_to(weight_var_node, mul_op_node)
graph.link_to(mul_op_node, output_var_node)
graph.safe_remove_nodes(op_node)

def _transform_to_quantize_mkldnn(self, graph, op_node):
"""
Transform fake_quantize_xx op to quantize mkldnn op in the graph.
Expand All @@ -177,32 +210,26 @@ def _transform_to_quantize_mkldnn(self, graph, op_node):
op_node.input("X")[0])
output_var_node = graph._find_node_by_name(op_node.outputs,
op_node.output("Out")[0])
if output_var_node.id() in self.mul_input_id:
return
else:
scale_in = self.s8_max / self._load_param(
self._scope, op_node.input("InScale")[0])[0]
quant_op_node = graph.create_op_node(
op_type='quantize',
attrs={
'data_format': 'MKLDNNLAYOUT',
'use_mkldnn': 1,
'Scale': scale_in,
'is_negative_input': 1
},
inputs={'Input': input_var_node},
outputs={'Output': output_var_node})
graph.link_to(input_var_node, quant_op_node)
graph.link_to(quant_op_node, output_var_node)
graph.safe_remove_nodes(op_node)
scale_in = self.s8_max / self._load_param(
self._scope, op_node.input("InScale")[0])[0]
quant_op_node = graph.create_op_node(
op_type='quantize',
attrs={
'data_format': 'MKLDNNLAYOUT',
'use_mkldnn': 1,
'Scale': scale_in,
'is_negative_input': 1
},
inputs={'Input': input_var_node},
outputs={'Output': output_var_node})
graph.link_to(input_var_node, quant_op_node)
graph.link_to(quant_op_node, output_var_node)
graph.safe_remove_nodes(op_node)

def _remove_fake_dequantize_op(self, graph, op_node):
input_var_node = graph._find_node_by_name(op_node.inputs,
op_node.input("X")[0])
if input_var_node.id() in self.mul_output_id:
return
else:
graph.safe_remove_nodes(op_node)
graph.safe_remove_nodes(op_node)

def _load_param(self, scope, param_name):
return np.array(scope.find_var(param_name).get_tensor())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,7 @@ def setUp(self):
self.quantizable_op_and_inputs = {
'conv2d': ['Input', 'Filter'],
'depthwise_conv2d': ['Input', 'Filter'],
# Mul int8 op is under internal test
# TODO Update this when mul op is merged
#'mul': ['X', 'Y']
'mul': ['X', 'Y']
}

def check_program(self, program):
Expand Down Expand Up @@ -162,11 +160,15 @@ def mkldnn_based_freeze_graph(self,
activation_quant_type + '_' + weight_quant_type,
marked_nodes)
mkldnn_program = test_graph.to_program()
w_mkldnn = np.array(scope.find_var('conv2d_1.w_0').get_tensor())

# Check the transformation weights of conv2d and mul
conv_w_mkldnn = np.array(scope.find_var('conv2d_1.w_0').get_tensor())
mul_w_mkldnn = np.array(scope.find_var('fc_0.w_0').get_tensor())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I still think its unsafe to refer directly to the variable names here, but currently I don't posses enough knowledge on how this could be done differently. We should probably think how to adjust it in the future.

# Check if weights are still integer
self.assertFalse(self.isinteger(np.sum(w_mkldnn)))
self.assertFalse(self.isinteger(np.sum(conv_w_mkldnn)))
self.assertFalse(self.isinteger(np.sum(mul_w_mkldnn)))

# Check if the conv2d output is rightly linked to fake_dequantize's
# Check if the conv2d output and mul output are correctly linked to fake_dequantize's
# output
self.check_program(mkldnn_program)
if not for_ci:
Expand Down