@@ -741,121 +741,70 @@ def strided_slice_converter(network, paddle_op, inputs):
741
741
ends_op = paddle_op .operands ()[2 ].source ().get_defining_op ()
742
742
strides_op = paddle_op .operands ()[3 ].source ().get_defining_op ()
743
743
744
- starts = (
745
- starts_op .attrs ()["value" ]
746
- if starts_op .name () == "pd_op.full_int_array"
747
- else inputs [1 ]
748
- )
749
- ends = (
750
- ends_op .attrs ()["value" ]
751
- if ends_op .name () == "pd_op.full_int_array"
752
- else inputs [2 ]
753
- )
754
- strides = (
755
- strides_op .attrs ()["value" ]
756
- if strides_op .name () == "pd_op.full_int_array"
757
- else inputs [3 ]
758
- )
744
+ if starts_op .name () == "pd_op.full_int_array" :
745
+ starts = starts_op .attrs ()["value" ]
746
+
747
+ if ends_op .name () == "pd_op.full_int_array" :
748
+ ends = ends_op .attrs ()["value" ]
749
+
750
+ if strides_op .name () == "pd_op.full_int_array" :
751
+ strides = strides_op .attrs ()["value" ]
759
752
760
753
input_shape = paddle_op .operands ()[0 ].source ().shape
761
754
nchw_input_dims = len (input_shape )
762
755
763
756
trt_start_dims = [0 ] * nchw_input_dims
764
- trt_size_dims = [input_shape [i ] for i in range (nchw_input_dims )]
757
+ trt_end_dims = [0 ] * nchw_input_dims
758
+ trt_size_dims = [0 ] * nchw_input_dims
765
759
trt_step_dims = [1 ] * nchw_input_dims
766
760
767
761
has_neg_indices = False
768
- trt_start_tensors = []
769
- trt_end_tensors = []
770
- trt_stride_tensors = []
771
-
772
- for i , axis in enumerate (axes ):
773
- if isinstance (starts , trt .ITensor ):
774
- start_tensor = get_shape_tensor_element (network , starts , i )
775
- else :
776
- start_tensor = add_1D_constant_layer (network , [starts [i ]])
777
-
778
- if isinstance (ends , trt .ITensor ):
779
- end_tensor = get_shape_tensor_element (network , ends , i )
780
- else :
781
- end_tensor = add_1D_constant_layer (network , [ends [i ]])
782
-
783
- if isinstance (strides , trt .ITensor ):
784
- stride_tensor = get_shape_tensor_element (network , strides , i )
785
- else :
786
- stride_tensor = add_1D_constant_layer (network , [strides [i ]])
787
762
788
- zero_tensor = add_1D_constant_layer (network , [0 ])
763
+ for i , trt_axis in enumerate (axes ):
764
+ trt_start_dims [trt_axis ] = starts [i ]
765
+ trt_end_dims [trt_axis ] = ends [i ]
766
+ trt_step_dims [trt_axis ] = strides [i ]
767
+ if starts [i ] < 0 or ends [i ] < 0 :
768
+ has_neg_indices = True
789
769
790
- if isinstance (starts , trt .ITensor ) or isinstance (ends , trt .ITensor ):
791
- is_start_neg = trt_less (network , start_tensor , zero_tensor )
792
- is_end_neg = trt_less (network , end_tensor , zero_tensor )
793
- temp_has_neg = network .add_elementwise (
794
- is_start_neg , is_end_neg , trt .ElementWiseOperation .OR
795
- ).get_output (0 )
796
- if not has_neg_indices :
797
- has_neg_indices = temp_has_neg
798
- else :
799
- has_neg_indices = network .add_elementwise (
800
- has_neg_indices , temp_has_neg , trt .ElementWiseOperation .OR
801
- ).get_output (0 )
802
- else :
803
- if starts [i ] < 0 or ends [i ] < 0 :
804
- has_neg_indices = True
805
-
806
- trt_start_tensors .append (start_tensor )
807
- trt_end_tensors .append (end_tensor )
808
- trt_stride_tensors .append (stride_tensor )
809
-
810
- # Concatenate the tensors for start, end, and strides
811
- start_tensor = network .add_concatenation (trt_start_tensors ).get_output (0 )
812
- end_tensor = network .add_concatenation (trt_end_tensors ).get_output (0 )
813
- step_tensor = network .add_concatenation (trt_stride_tensors ).get_output (0 )
814
-
815
- shape_tensor = network .add_shape (input_tensor ).get_output (0 )
816
-
817
- if has_neg_indices is True :
770
+ shape_tensor = trt_shape (network , input_tensor )
771
+ start_tensor = add_1D_constant_layer (network , trt_start_dims )
772
+ if has_neg_indices :
818
773
start_tensor = fix_negative_indices (network , shape_tensor , start_tensor )
819
- elif isinstance (has_neg_indices , trt .ITensor ):
820
- fixed_start_tensor = fix_negative_indices (
821
- network , shape_tensor , start_tensor
822
- )
823
- start_tensor = network .add_select (
824
- condition = has_neg_indices ,
825
- then_input = fixed_start_tensor ,
826
- else_input = start_tensor ,
827
- ).get_output (0 )
828
774
829
- # Process end_tensor similarly to handle negative indices
830
- if has_neg_indices is True :
831
- end_tensor = fix_negative_indices (network , shape_tensor , end_tensor )
832
- elif isinstance (has_neg_indices , trt .ITensor ):
833
- fixed_end_tensor = fix_negative_indices (
834
- network , shape_tensor , end_tensor
775
+ end_vec_tensor = []
776
+ for i in range (len (trt_end_dims )):
777
+ end_vec_tensor .append (
778
+ get_shape_tensor_element (network , shape_tensor , i )
835
779
)
836
- end_tensor = network .add_select (
837
- condition = has_neg_indices ,
838
- then_input = fixed_end_tensor ,
839
- else_input = end_tensor ,
840
- ).get_output (0 )
841
780
842
- # Compute min_tensor
843
- min_tensor = trt_min (network , end_tensor , shape_tensor )
844
- # Correct size_tensor calculation
845
- size_tensor = trt_sub (network , start_tensor , min_tensor )
781
+ for i , trt_axis in enumerate (axes ):
782
+ if ends [i ] >= 0 :
783
+ end_vec_tensor [trt_axis ] = add_1D_constant_layer (network , ends [i ])
784
+ else :
785
+ end_vec_tensor [trt_axis ] = trt_sum (
786
+ network ,
787
+ end_vec_tensor [trt_axis ],
788
+ add_1D_constant_layer (network , ends [i ]),
789
+ )
846
790
847
- # floor_div_tensor computation
848
- floor_div_tensor = trt_floor_div (network , size_tensor , step_tensor )
849
- size_tensor = trt_sub (network , zero_tensor , floor_div_tensor )
791
+ size_tensor = trt_sub (
792
+ network ,
793
+ start_tensor ,
794
+ trt_min (network , trt_concat (network , end_vec_tensor ), shape_tensor ),
795
+ )
796
+ zero_t = add_1D_constant_layer (network , 0 )
797
+ step_tensor = add_1D_constant_layer (network , trt_step_dims )
798
+ size_tensor = trt_sub (
799
+ network , zero_t , trt_floor_div (network , size_tensor , step_tensor )
800
+ )
850
801
851
- # Create the slice layer
852
802
layer = network .add_slice (
853
803
input_tensor , trt_start_dims , trt_size_dims , trt_step_dims
854
804
)
855
805
layer .set_input (1 , start_tensor )
856
806
layer .set_input (2 , size_tensor )
857
807
layer .set_input (3 , step_tensor )
858
-
859
808
return layer .get_output (0 )
860
809
861
810
0 commit comments