From 1bffba68600215800e77e8dc19c0bb6b23fe0ee2 Mon Sep 17 00:00:00 2001 From: bingyanghuang Date: Tue, 25 Jun 2019 16:20:47 +0800 Subject: [PATCH] modify the qat pass to add mul, test=develop --- .../quantization/quantization_mkldnn_pass.py | 123 +++++++++++------- .../tests/test_quantization_mkldnn_pass.py | 14 +- 2 files changed, 83 insertions(+), 54 deletions(-) diff --git a/python/paddle/fluid/contrib/slim/quantization/quantization_mkldnn_pass.py b/python/paddle/fluid/contrib/slim/quantization/quantization_mkldnn_pass.py index 2fc9dfac8e7bbd..293951200678bf 100644 --- a/python/paddle/fluid/contrib/slim/quantization/quantization_mkldnn_pass.py +++ b/python/paddle/fluid/contrib/slim/quantization/quantization_mkldnn_pass.py @@ -27,9 +27,9 @@ class TransformForMkldnnPass(object): 1. Convert int8 range weights with float32 data type, which are generated by the QuantizationFreezePass, to float32 range weights with float32 data type by using the corresponding scales. This conversion is because MKL-DNN INT8 - conv2d kernel now only supports float32 weights input, will do weights - quantization inside the conv2d kernel. - 2. Create the new conv2d op with the converted weights and link its output + conv2d kernel and mul kernel now only support float32 weights input, hence + weights quantization will happen inside the conv2d and mul INT8 kernel. + 2. Create the new conv2d or mul op with the converted weights and link its output to fake_dequantize_abs_max op's output and set conv2d's attribute "force_fp32 _output" as true 3. Transform fake_quantize_xx op to quantize op @@ -73,13 +73,8 @@ def __init__(self, scope=None, place=None): self.InScale = {} self.max_range = {} - self.conv_new_output = {} + self.new_output = {} self.s8_max = 127 - # Temporary code for keeping the mul op as fake quantization - #TODO Intel: Remove the following code when mul int8 mkldnn - # kernel enabled - self.mul_input_id = [] - self.mul_output_id = [] def apply(self, graph): """ @@ -97,7 +92,7 @@ def apply(self, graph): persistable_vars = [p.name() for p in graph.all_persistable_nodes()] # Collect the InScales and max_range to calculate the new scales for MKL-DNN - # INT8 conv2d + # INT8 conv2d and mul for op_node in ops: if op_node.name() in self.dequantize_type: input_name = op_node.input("X")[0] @@ -105,20 +100,14 @@ def apply(self, graph): self.InScale[input_name] = self._load_param(self._scope, scale_name)[0] self.max_range[input_name] = op_node.op().attr("max_range") - self.conv_new_output[input_name] = op_node.output("Out")[0] - # Temporary graph transform on keeping the mul op - # TODO Intel: Remove following code - elif op_node.name() in ['mul']: - input_node = graph._find_node_by_name(op_node.inputs, - op_node.input('X')[0]) - output_node = graph._find_node_by_name(op_node.outputs, - op_node.output('Out')[0]) - self.mul_input_id.append(input_node.id()) - self.mul_output_id.append(output_node.id()) + self.new_output[input_name] = op_node.output("Out")[0] for op_node in ops: - if op_node.name() in self._conv_ops: - self._transform_to_conv_mkldnn(graph, op_node) + if op_node.name() in self._quantizable_ops: + if op_node.name() in self._conv_ops: + self._transform_to_conv_mkldnn(graph, op_node) + else: + self._transform_to_mul_mkldnn(graph, op_node) elif op_node.name() in self.quantize_type: self._transform_to_quantize_mkldnn(graph, op_node) elif op_node.name() in self.dequantize_type: @@ -132,7 +121,7 @@ def _transform_to_conv_mkldnn(self, graph, op_node): # Convert int8 range weights to fp32 range weights weight = self._load_param(self._scope, weight_name) w_fp32 = np.divide( - np.multiply(weight, 127), self.max_range[output_name]) + np.multiply(weight, self.s8_max), self.max_range[output_name]) w_fp32 = w_fp32.reshape(weight.shape) self._restore_var(weight_name, w_fp32) input_var_node = graph._find_node_by_name(op_node.inputs, @@ -140,8 +129,8 @@ def _transform_to_conv_mkldnn(self, graph, op_node): weight_var_node = graph._find_node_by_name(op_node.inputs, weight_name) # Set fake_dequantize_abs_max's output as new output of conv2d - output_var_node = graph._find_node_by_name( - graph.all_var_nodes(), self.conv_new_output[output_name]) + output_var_node = graph._find_node_by_name(graph.all_var_nodes(), + self.new_output[output_name]) attrs = { name: op_node.op().attr(name) for name in op_node.op().attr_names() @@ -157,7 +146,7 @@ def _transform_to_conv_mkldnn(self, graph, op_node): # Based on the QAT's scales to calculate the scales of MKL-DNN INT8 conv2d scale_in = self.s8_max / self.InScale[output_name] scale_w = [] - scale_w.append(self.max_range[output_name] / self.s8_max) + scale_w = [self.max_range[output_name] / self.s8_max] conv_op_node.set_attr("Scale_weights", scale_w) conv_op_node.set_attr("Scale_in", scale_in) @@ -169,6 +158,50 @@ def _transform_to_conv_mkldnn(self, graph, op_node): graph.link_to(conv_op_node, output_var_node) graph.safe_remove_nodes(op_node) + def _transform_to_mul_mkldnn(self, graph, op_node): + # For MKL-DNN INT8 mul, input Y should be the weights + weight_name = op_node.input("Y")[0] + output_name = op_node.output("Out")[0] + # Convert int8 range weights to fp32 range weights + weight = self._load_param(self._scope, weight_name) + w_fp32 = np.divide( + np.multiply(weight, self.s8_max), self.max_range[output_name]) + w_fp32 = w_fp32.reshape(weight.shape) + self._restore_var(weight_name, w_fp32) + input_var_node = graph._find_node_by_name(op_node.inputs, + op_node.input("X")[0]) + weight_var_node = graph._find_node_by_name(op_node.inputs, weight_name) + + # Set fake_dequantize_abs_max's output as new output of mul + output_var_node = graph._find_node_by_name(graph.all_var_nodes(), + self.new_output[output_name]) + attrs = { + name: op_node.op().attr(name) + for name in op_node.op().attr_names() + } + + mul_op_node = graph.create_op_node( + op_type='mul', + attrs=attrs, + inputs={'X': input_var_node, + 'Y': weight_var_node}, + outputs={'Out': output_var_node}) + + # Based on the QAT's scales to calculate MKL-DNN INT8 mul's scales + scale_in = self.s8_max / self.InScale[output_name] + scale_w = [] + scale_w = [self.max_range[output_name] / self.s8_max] + + mul_op_node.set_attr("scale_y", scale_w) + mul_op_node.set_attr("scale_x", scale_in) + mul_op_node.set_attr("scale_out", 1.0) + mul_op_node.set_attr("use_mkldnn", 1) + mul_op_node.set_attr("force_fp32_output", 1) + graph.link_to(input_var_node, mul_op_node) + graph.link_to(weight_var_node, mul_op_node) + graph.link_to(mul_op_node, output_var_node) + graph.safe_remove_nodes(op_node) + def _transform_to_quantize_mkldnn(self, graph, op_node): """ Transform fake_quantize_xx op to quantize mkldnn op in the graph. @@ -177,32 +210,26 @@ def _transform_to_quantize_mkldnn(self, graph, op_node): op_node.input("X")[0]) output_var_node = graph._find_node_by_name(op_node.outputs, op_node.output("Out")[0]) - if output_var_node.id() in self.mul_input_id: - return - else: - scale_in = self.s8_max / self._load_param( - self._scope, op_node.input("InScale")[0])[0] - quant_op_node = graph.create_op_node( - op_type='quantize', - attrs={ - 'data_format': 'MKLDNNLAYOUT', - 'use_mkldnn': 1, - 'Scale': scale_in, - 'is_negative_input': 1 - }, - inputs={'Input': input_var_node}, - outputs={'Output': output_var_node}) - graph.link_to(input_var_node, quant_op_node) - graph.link_to(quant_op_node, output_var_node) - graph.safe_remove_nodes(op_node) + scale_in = self.s8_max / self._load_param( + self._scope, op_node.input("InScale")[0])[0] + quant_op_node = graph.create_op_node( + op_type='quantize', + attrs={ + 'data_format': 'MKLDNNLAYOUT', + 'use_mkldnn': 1, + 'Scale': scale_in, + 'is_negative_input': 1 + }, + inputs={'Input': input_var_node}, + outputs={'Output': output_var_node}) + graph.link_to(input_var_node, quant_op_node) + graph.link_to(quant_op_node, output_var_node) + graph.safe_remove_nodes(op_node) def _remove_fake_dequantize_op(self, graph, op_node): input_var_node = graph._find_node_by_name(op_node.inputs, op_node.input("X")[0]) - if input_var_node.id() in self.mul_output_id: - return - else: - graph.safe_remove_nodes(op_node) + graph.safe_remove_nodes(op_node) def _load_param(self, scope, param_name): return np.array(scope.find_var(param_name).get_tensor()) diff --git a/python/paddle/fluid/contrib/slim/tests/test_quantization_mkldnn_pass.py b/python/paddle/fluid/contrib/slim/tests/test_quantization_mkldnn_pass.py index 90cc28b3aaf95f..81a31ba7d2efb3 100644 --- a/python/paddle/fluid/contrib/slim/tests/test_quantization_mkldnn_pass.py +++ b/python/paddle/fluid/contrib/slim/tests/test_quantization_mkldnn_pass.py @@ -55,9 +55,7 @@ def setUp(self): self.quantizable_op_and_inputs = { 'conv2d': ['Input', 'Filter'], 'depthwise_conv2d': ['Input', 'Filter'], - # Mul int8 op is under internal test - # TODO Update this when mul op is merged - #'mul': ['X', 'Y'] + 'mul': ['X', 'Y'] } def check_program(self, program): @@ -162,11 +160,15 @@ def mkldnn_based_freeze_graph(self, activation_quant_type + '_' + weight_quant_type, marked_nodes) mkldnn_program = test_graph.to_program() - w_mkldnn = np.array(scope.find_var('conv2d_1.w_0').get_tensor()) + + # Check the transformation weights of conv2d and mul + conv_w_mkldnn = np.array(scope.find_var('conv2d_1.w_0').get_tensor()) + mul_w_mkldnn = np.array(scope.find_var('fc_0.w_0').get_tensor()) # Check if weights are still integer - self.assertFalse(self.isinteger(np.sum(w_mkldnn))) + self.assertFalse(self.isinteger(np.sum(conv_w_mkldnn))) + self.assertFalse(self.isinteger(np.sum(mul_w_mkldnn))) - # Check if the conv2d output is rightly linked to fake_dequantize's + # Check if the conv2d output and mul output are correctly linked to fake_dequantize's # output self.check_program(mkldnn_program) if not for_ci: