Skip to content

Commit 3aa8e57

Browse files
committed
fix bug of shape infer
1 parent 09d3558 commit 3aa8e57

File tree

2 files changed

+61
-26
lines changed

2 files changed

+61
-26
lines changed

x2paddle/decoder/onnx_shape_inference.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1419,7 +1419,7 @@ def _infer_impl(self, in_mp, start_sympy_data={}):
14191419
if self.verbose_ > 2:
14201420
print(node.op_type + ': ' + node.name)
14211421
for i, name in enumerate(node.input):
1422-
print(' Input {}: {} {}€5€5€5€5€5'.format(
1422+
print(' Input {}: {} {}'.format(
14231423
i, name, 'initializer'
14241424
if name in self.initializers_ else ''))
14251425

@@ -1544,7 +1544,7 @@ def _infer_impl(self, in_mp, start_sympy_data={}):
15441544
continue # continue the inference after guess, no need to stop as no merge is needed
15451545

15461546
if self.verbose_ > 0 or not self.auto_merge_ or out_type_undefined:
1547-
print('Stopping at incomplete shape inference at ' +
1547+
print('Stopping at incomplete symbolic shape inference at ' +
15481548
node.op_type + ': ' + node.name)
15491549
print('node inputs:')
15501550
for i in node.input:
@@ -1579,6 +1579,7 @@ def infer_shapes(in_mp,
15791579
all_shapes_inferred = False
15801580
symbolic_shape_inference._preprocess(
15811581
in_mp, input_shapes=fixed_input_shape)
1582+
15821583
try:
15831584
while symbolic_shape_inference.run_:
15841585
all_shapes_inferred = symbolic_shape_inference._infer_impl(
@@ -1588,9 +1589,8 @@ def infer_shapes(in_mp,
15881589
print('!' * 10)
15891590
symbolic_shape_inference.out_mp_ = shape_inference.infer_shapes(
15901591
symbolic_shape_inference.out_mp_)
1591-
#onnx.save(symbolic_shape_inference.out_mp_, 'tmp.onnx')
15921592
except:
1593-
print('Stopping at incomplete shape inference')
1593+
print('Stopping at incomplete symbolic shape inference')
15941594
symbolic_shape_inference.out_mp_ = shape_inference.infer_shapes(
1595-
symbolic_shape_inference.out_mp_)
1595+
in_mp)
15961596
return symbolic_shape_inference.out_mp_.graph

x2paddle/op_mapper/onnx2paddle/opset9/opset.py

Lines changed: 56 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,6 @@ def _is_static_shape(shape):
5757
return False
5858
return True
5959

60-
6160
def _get_same_padding(in_size, kernel_size, stride):
6261
new_size = int(math.ceil(in_size * 1.0 / stride))
6362
pad_size = (new_size - 1) * stride + kernel_size - in_size
@@ -104,14 +103,6 @@ class OpSet9():
104103

105104
default_op_mapping = {
106105
'Shape': ['shape', ['X'], ['Out']],
107-
'Clip': [
108-
'clip', ['X'], ['Out'], dict(), dict(
109-
min=(np.asarray(
110-
[255, 255, 127, 255], dtype=np.uint8).view(np.float32)[0]),
111-
max=(np.asarray(
112-
[255, 255, 127, 127], dtype=np.uint8).view(np.float32)[0]),
113-
)
114-
],
115106
'Erf': ['erf', ['X'], ['Out']],
116107
'Ceil': ['ceil', ['X'], ['Out']],
117108
'ReduceMean': [
@@ -831,27 +822,31 @@ def Slice(self, node):
831822
if len(node.inputs) > 1:
832823
starts = self.graph.get_input_node(node, idx=1, copy=True)
833824
ends = self.graph.get_input_node(node, idx=2, copy=True)
825+
starts_value = _const_weight_or_none(starts)
826+
ends_value = _const_weight_or_none(ends)
827+
834828
if len(node.inputs) > 3:
835829
axes = self.graph.get_input_node(node, idx=3, copy=True)
836830
axes = _const_weight_or_none(axes, necessary=True)
837831
if len(node.inputs) > 4:
838832
steps = self.graph.get_input_node(node, idx=4, copy=True)
839833
steps = _const_weight_or_none(steps)
840-
if steps is not None:
841-
assert steps == 1, "Only support convert op:Slice, which attribute:steps == 1"
842834
attr = {
843835
"axes": axes,
844836
"starts": starts.layer_name,
845837
"ends": ends.layer_name
846838
}
847-
starts_value = _const_weight_or_none(starts)
848-
ends_value = _const_weight_or_none(ends)
849839
if starts_value is not None and ends_value is not None:
850840
self.omit_nodes.append(starts.layer_name)
851841
self.omit_nodes.append(ends.layer_name)
842+
starts_value = starts_value.copy()
852843
ends_value = ends_value.copy()
853844
for idx in range(len(ends_value)):
854-
if ends_value[idx] > 2**31 - 1:
845+
if starts_value[idx] > val_x.out_shapes[0][axes[idx]]:
846+
starts_value[idx] = val_x.out_shapes[0][axes[idx]]-1
847+
ends_value[idx] = val_x.out_shapes[0][axes[idx]]
848+
starts_value[idx] = val_x.out_shapes[0][axes[idx]]-1
849+
elif ends_value[idx] > 2**31 - 1:
855850
ends_value[idx] = 2**31 - 1
856851
attr = {
857852
"axes": axes,
@@ -869,12 +864,12 @@ def Slice(self, node):
869864
attr['starts'] = starts_cast
870865
if ends.dtype != 'int32':
871866
ends_cast = ends.layer_name + '_cast'
872-
node.fluid_code.add_layer(
873-
'cast',
874-
inputs=ends,
875-
output=ends_cast,
876-
param_attr={'dtype': string('int32')})
877-
attr['ends'] = ends_cast
867+
node.fluid_code.add_layer(
868+
'cast',
869+
inputs=ends,
870+
output=ends_cast,
871+
param_attr={'dtype': string('int32')})
872+
attr['ends'] = ends_cast
878873
else:
879874
starts = node.get_attr('starts')
880875
ends = node.get_attr('ends')
@@ -884,7 +879,12 @@ def Slice(self, node):
884879
ends[idx] = 2**31 - 1
885880
attr = {"axes": axes, "starts": starts, "ends": ends}
886881

887-
node.fluid_code.add_layer(
882+
if steps is not None:
883+
attr['strides'] = steps
884+
node.fluid_code.add_layer(
885+
'strided_slice', inputs=val_x, output=node, param_attr=attr)
886+
else:
887+
node.fluid_code.add_layer(
888888
'slice', inputs=val_x, output=node, param_attr=attr)
889889

890890
@print_mapping_info
@@ -907,6 +907,41 @@ def ConstantOfShape(self, node):
907907
node.fluid_code.add_layer(
908908
'fill_constant', inputs=None, output=node, param_attr=attr)
909909

910+
@print_mapping_info
911+
def Clip(self, node):
912+
val_x = self.graph.get_input_node(node, idx=0, copy=True)
913+
val_y = self.graph.get_node(node.layer.output[0], copy=True)
914+
max_value, min_value = None, None
915+
if len(node.inputs) == 1:
916+
max_value = node.get_attr('max')
917+
min_value = node.get_attr('min')
918+
attr = {
919+
'max': max_value,
920+
'min': min_value,
921+
}
922+
node.fluid_code.add_layer(
923+
'clip', inputs=val_x, output=node, param_attr=attr)
924+
else:
925+
max_ipt = self.graph.get_input_node(node, idx=1, copy=True)
926+
min_ipt = self.graph.get_input_node(node, idx=2, copy=True)
927+
max_value = _const_weight_or_none(max_ipt)
928+
min_value = _const_weight_or_none(min_ipt)
929+
self.omit_nodes.append(max_ipt.layer_name)
930+
self.omit_nodes.append(min_ipt.layer_name)
931+
if max_value.shape == (1,):
932+
max_value = max_value[0]
933+
if min_value.shape == (1,):
934+
min_value = min_value[0]
935+
if max_value is not None and min_value is not None:
936+
attr = {
937+
'max': max_value,
938+
'min': min_value
939+
}
940+
node.fluid_code.add_layer(
941+
'clip', inputs=val_x, output=node, param_attr=attr)
942+
else:
943+
raise
944+
910945
@print_mapping_info
911946
def Split(self, node):
912947
val_x = self.graph.get_input_node(node, idx=0, copy=True)

0 commit comments

Comments
 (0)