@@ -57,7 +57,6 @@ def _is_static_shape(shape):
57
57
return False
58
58
return True
59
59
60
-
61
60
def _get_same_padding (in_size , kernel_size , stride ):
62
61
new_size = int (math .ceil (in_size * 1.0 / stride ))
63
62
pad_size = (new_size - 1 ) * stride + kernel_size - in_size
@@ -104,14 +103,6 @@ class OpSet9():
104
103
105
104
default_op_mapping = {
106
105
'Shape' : ['shape' , ['X' ], ['Out' ]],
107
- 'Clip' : [
108
- 'clip' , ['X' ], ['Out' ], dict (), dict (
109
- min = (np .asarray (
110
- [255 , 255 , 127 , 255 ], dtype = np .uint8 ).view (np .float32 )[0 ]),
111
- max = (np .asarray (
112
- [255 , 255 , 127 , 127 ], dtype = np .uint8 ).view (np .float32 )[0 ]),
113
- )
114
- ],
115
106
'Erf' : ['erf' , ['X' ], ['Out' ]],
116
107
'Ceil' : ['ceil' , ['X' ], ['Out' ]],
117
108
'ReduceMean' : [
@@ -831,27 +822,31 @@ def Slice(self, node):
831
822
if len (node .inputs ) > 1 :
832
823
starts = self .graph .get_input_node (node , idx = 1 , copy = True )
833
824
ends = self .graph .get_input_node (node , idx = 2 , copy = True )
825
+ starts_value = _const_weight_or_none (starts )
826
+ ends_value = _const_weight_or_none (ends )
827
+
834
828
if len (node .inputs ) > 3 :
835
829
axes = self .graph .get_input_node (node , idx = 3 , copy = True )
836
830
axes = _const_weight_or_none (axes , necessary = True )
837
831
if len (node .inputs ) > 4 :
838
832
steps = self .graph .get_input_node (node , idx = 4 , copy = True )
839
833
steps = _const_weight_or_none (steps )
840
- if steps is not None :
841
- assert steps == 1 , "Only support convert op:Slice, which attribute:steps == 1"
842
834
attr = {
843
835
"axes" : axes ,
844
836
"starts" : starts .layer_name ,
845
837
"ends" : ends .layer_name
846
838
}
847
- starts_value = _const_weight_or_none (starts )
848
- ends_value = _const_weight_or_none (ends )
849
839
if starts_value is not None and ends_value is not None :
850
840
self .omit_nodes .append (starts .layer_name )
851
841
self .omit_nodes .append (ends .layer_name )
842
+ starts_value = starts_value .copy ()
852
843
ends_value = ends_value .copy ()
853
844
for idx in range (len (ends_value )):
854
- if ends_value [idx ] > 2 ** 31 - 1 :
845
+ if starts_value [idx ] > val_x .out_shapes [0 ][axes [idx ]]:
846
+ starts_value [idx ] = val_x .out_shapes [0 ][axes [idx ]]- 1
847
+ ends_value [idx ] = val_x .out_shapes [0 ][axes [idx ]]
848
+ starts_value [idx ] = val_x .out_shapes [0 ][axes [idx ]]- 1
849
+ elif ends_value [idx ] > 2 ** 31 - 1 :
855
850
ends_value [idx ] = 2 ** 31 - 1
856
851
attr = {
857
852
"axes" : axes ,
@@ -869,12 +864,12 @@ def Slice(self, node):
869
864
attr ['starts' ] = starts_cast
870
865
if ends .dtype != 'int32' :
871
866
ends_cast = ends .layer_name + '_cast'
872
- node .fluid_code .add_layer (
873
- 'cast' ,
874
- inputs = ends ,
875
- output = ends_cast ,
876
- param_attr = {'dtype' : string ('int32' )})
877
- attr ['ends' ] = ends_cast
867
+ node .fluid_code .add_layer (
868
+ 'cast' ,
869
+ inputs = ends ,
870
+ output = ends_cast ,
871
+ param_attr = {'dtype' : string ('int32' )})
872
+ attr ['ends' ] = ends_cast
878
873
else :
879
874
starts = node .get_attr ('starts' )
880
875
ends = node .get_attr ('ends' )
@@ -884,7 +879,12 @@ def Slice(self, node):
884
879
ends [idx ] = 2 ** 31 - 1
885
880
attr = {"axes" : axes , "starts" : starts , "ends" : ends }
886
881
887
- node .fluid_code .add_layer (
882
+ if steps is not None :
883
+ attr ['strides' ] = steps
884
+ node .fluid_code .add_layer (
885
+ 'strided_slice' , inputs = val_x , output = node , param_attr = attr )
886
+ else :
887
+ node .fluid_code .add_layer (
888
888
'slice' , inputs = val_x , output = node , param_attr = attr )
889
889
890
890
@print_mapping_info
@@ -907,6 +907,41 @@ def ConstantOfShape(self, node):
907
907
node .fluid_code .add_layer (
908
908
'fill_constant' , inputs = None , output = node , param_attr = attr )
909
909
910
+ @print_mapping_info
911
+ def Clip (self , node ):
912
+ val_x = self .graph .get_input_node (node , idx = 0 , copy = True )
913
+ val_y = self .graph .get_node (node .layer .output [0 ], copy = True )
914
+ max_value , min_value = None , None
915
+ if len (node .inputs ) == 1 :
916
+ max_value = node .get_attr ('max' )
917
+ min_value = node .get_attr ('min' )
918
+ attr = {
919
+ 'max' : max_value ,
920
+ 'min' : min_value ,
921
+ }
922
+ node .fluid_code .add_layer (
923
+ 'clip' , inputs = val_x , output = node , param_attr = attr )
924
+ else :
925
+ max_ipt = self .graph .get_input_node (node , idx = 1 , copy = True )
926
+ min_ipt = self .graph .get_input_node (node , idx = 2 , copy = True )
927
+ max_value = _const_weight_or_none (max_ipt )
928
+ min_value = _const_weight_or_none (min_ipt )
929
+ self .omit_nodes .append (max_ipt .layer_name )
930
+ self .omit_nodes .append (min_ipt .layer_name )
931
+ if max_value .shape == (1 ,):
932
+ max_value = max_value [0 ]
933
+ if min_value .shape == (1 ,):
934
+ min_value = min_value [0 ]
935
+ if max_value is not None and min_value is not None :
936
+ attr = {
937
+ 'max' : max_value ,
938
+ 'min' : min_value
939
+ }
940
+ node .fluid_code .add_layer (
941
+ 'clip' , inputs = val_x , output = node , param_attr = attr )
942
+ else :
943
+ raise
944
+
910
945
@print_mapping_info
911
946
def Split (self , node ):
912
947
val_x = self .graph .get_input_node (node , idx = 0 , copy = True )
0 commit comments