From 1372a6d758c24aeef4dfe77355b079023f7715f8 Mon Sep 17 00:00:00 2001 From: Channingss Date: Mon, 12 Oct 2020 09:42:27 +0000 Subject: [PATCH 1/2] [ONNX] fix bug of prelu --- x2paddle/convert.py | 1 + x2paddle/op_mapper/onnx2paddle/opset9/opset.py | 11 ++++++++++- 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/x2paddle/convert.py b/x2paddle/convert.py index d5c78e761..86f62104e 100644 --- a/x2paddle/convert.py +++ b/x2paddle/convert.py @@ -190,6 +190,7 @@ def onnx2paddle(model_path, save_dir, params_merge=False): mapper = ONNXOpMapper(model) print("Model optimizing ...") optimizer = ONNXOptimizer(mapper) + optimizer.delete_redundance_code() print("Model optimized.") print("Paddle model and code generating ...") diff --git a/x2paddle/op_mapper/onnx2paddle/opset9/opset.py b/x2paddle/op_mapper/onnx2paddle/opset9/opset.py index 38087558f..b36560c18 100644 --- a/x2paddle/op_mapper/onnx2paddle/opset9/opset.py +++ b/x2paddle/op_mapper/onnx2paddle/opset9/opset.py @@ -1256,10 +1256,19 @@ def PRelu(self, node): mode = 'channel' shape_slope = val_slope.out_shapes[0] - if len(shape_slope) == 1: + + if shape_slope == [1]: mode = 'all' elif len(shape_slope) > 2: mode = 'element' + + if mode == 'channel' and len(shape_slope) == 1: + # paddle params shape need be [1, channel] + slope_data = _const_weight_or_none(val_slope) + slope_data = np.reshape(slope_data, [1] + shape_slope) + self.weights[val_slope.layer_name] = slope_data + + self.omit_nodes.append(val_slope.layer_name) attr = { "param_attr": string(val_slope.layer_name), 'mode': string(mode) From 683dc396513847cceba522fbf126cf7a244d9aeb Mon Sep 17 00:00:00 2001 From: Channingss Date: Tue, 13 Oct 2020 02:21:21 +0000 Subject: [PATCH 2/2] delete redundant code --- x2paddle/op_mapper/onnx2paddle/opset9/opset.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/x2paddle/op_mapper/onnx2paddle/opset9/opset.py b/x2paddle/op_mapper/onnx2paddle/opset9/opset.py index b36560c18..ead2370ef 100644 --- a/x2paddle/op_mapper/onnx2paddle/opset9/opset.py +++ b/x2paddle/op_mapper/onnx2paddle/opset9/opset.py @@ -550,8 +550,6 @@ def InstanceNormalization(self, node): def Expand(self, node): val_x = self.graph.get_input_node(node, idx=0, copy=True) val_shape = self.graph.get_input_node(node, idx=1, copy=True) - if len(val_shape.outputs) == 1: - self.omit_nodes.append(val_shape.layer_name) val_x_dtype = val_x.dtype name_ones = node.layer_name + '_ones' attr_ones = {