Skip to content

Commit f6f8749

Browse files
authored
fix bug of clip and prelu (#680)
* fix bug of clip and prelu * update * update code * add get_input_index
1 parent 57786e3 commit f6f8749

File tree

2 files changed

+53
-14
lines changed

2 files changed

+53
-14
lines changed

x2paddle/decoder/onnx_decoder.py

100644100755
+12-1
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,18 @@ def __init__(self, layer, layer_name=None):
4848
self.dtype = None
4949
self.which_child = {}
5050

51+
def get_input_index(self, input_name):
52+
"""
53+
get the index of input_name in layer.input
54+
-1 means input_name is not in the input
55+
"""
56+
index = -1
57+
for i in range(len(self.layer.input)):
58+
if input_name == self.layer.input[i]:
59+
index = i
60+
break
61+
return index
62+
5163
def get_attr_map(self):
5264
"""
5365
convert ONNX node attributes to dict
@@ -294,7 +306,6 @@ def build(self):
294306
for layer_name, node in self.node_map.items():
295307
if isinstance(node, ONNXGraphNode):
296308
self.build_connection(layer_name, node)
297-
298309
#generate topo
299310
super(ONNXGraph, self).build()
300311

x2paddle/op_mapper/onnx2paddle/opset9/opset.py

+41-13
Original file line numberDiff line numberDiff line change
@@ -1150,6 +1150,17 @@ def ConstantOfShape(self, node):
11501150
outputs=[node.name],
11511151
**layer_attrs)
11521152

1153+
@print_mapping_info
1154+
def GatherND(self, node):
1155+
print(len(node.inputs), node.inputs)
1156+
val_x = self.graph.get_input_node(node, idx=0, copy=True)
1157+
val_y = self.graph.get_input_node(node, idx=1, copy=True)
1158+
self.paddle_graph.add_layer(
1159+
"paddle.gather_nd",
1160+
inputs={"x": val_x.name,
1161+
"index": val_y.name},
1162+
outputs=[node.name])
1163+
11531164
@print_mapping_info
11541165
def Clip(self, node):
11551166
val_x = self.graph.get_input_node(node, idx=0, copy=True)
@@ -1169,23 +1180,40 @@ def Clip(self, node):
11691180
outputs=[node.name],
11701181
**layer_attrs)
11711182
else:
1172-
min_ipt = self.graph.get_input_node(node, idx=1, copy=True)
1173-
max_ipt = self.graph.get_input_node(node, idx=2, copy=True)
1174-
min_value = _const_weight_or_none(min_ipt)
1175-
max_value = _const_weight_or_none(max_ipt)
1176-
if max_value.shape == (1, ):
1177-
max_value = max_value[0]
1178-
if min_value.shape == (1, ):
1179-
min_value = min_value[0]
1180-
if max_value is not None and min_value is not None:
1181-
layer_attrs = {'max': max_value, 'min': min_value}
1183+
if len(node.inputs) == 2:
1184+
val_ipt = self.graph.get_input_node(node, idx=1, copy=True)
1185+
1186+
index = node.get_input_index(val_ipt.name)
1187+
1188+
val_value = _const_weight_or_none(val_ipt)
1189+
if val_value.shape == (1, ):
1190+
val_value = val_value[0]
1191+
1192+
if index == 1:
1193+
layer_attrs = {'min': val_value}
1194+
1195+
if index == 2:
1196+
layer_attrs = {'max': val_value}
1197+
11821198
self.paddle_graph.add_layer(
11831199
'paddle.clip',
11841200
inputs={"x": val_x.name},
11851201
outputs=[node.name],
11861202
**layer_attrs)
11871203
else:
1188-
raise Exception("max_value or min_value can't be None")
1204+
if len(node.inputs) == 3:
1205+
min_ipt = self.graph.get_input_node(node, idx=1, copy=True)
1206+
max_ipt = self.graph.get_input_node(node, idx=2, copy=True)
1207+
self.paddle_graph.add_layer(
1208+
'paddle.clip',
1209+
inputs={
1210+
"x": val_x.name,
1211+
"min": min_ipt.name,
1212+
"max": max_ipt.name
1213+
},
1214+
outputs=[node.name])
1215+
else:
1216+
raise Exception("max_value or min_value can't be None")
11891217

11901218
@print_mapping_info
11911219
def ReduceSum(self, node):
@@ -1681,9 +1709,9 @@ def PRelu(self, node):
16811709
num_parameters = val_x.out_shapes[0][1]
16821710
else:
16831711
num_parameters = 1
1712+
slope_data = self.weights[val_slope.name]
16841713
_rename_or_remove_weight(self.weights, val_slope.name)
1685-
self.weights[op_name + '._weight'] = np.reshape(
1686-
self.weights[val_slope.name], [1])
1714+
self.weights[op_name + '._weight'] = np.reshape(slope_data, [1])
16871715
self.paddle_graph.add_layer(
16881716
"paddle.nn.PReLU",
16891717
inputs={"x": val_x.name},

0 commit comments

Comments
 (0)