Skip to content

Commit 6fd300b

Browse files
authored
fix (#69671)
1 parent fa26020 commit 6fd300b

File tree

3 files changed

+72
-110
lines changed

3 files changed

+72
-110
lines changed

paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc

+10
Original file line numberDiff line numberDiff line change
@@ -1800,6 +1800,16 @@ class StridedSliceOpPattern
18001800
VLOG(3) << "pd_op.strided_slice must has starts,ends and strides input";
18011801
return false;
18021802
}
1803+
if (!pir::GetDefiningOpForInput(op, 1)
1804+
->isa<paddle::dialect::FullIntArrayOp>() ||
1805+
!pir::GetDefiningOpForInput(op, 2)
1806+
->isa<paddle::dialect::FullIntArrayOp>() ||
1807+
!pir::GetDefiningOpForInput(op, 3)
1808+
->isa<paddle::dialect::FullIntArrayOp>()) {
1809+
VLOG(3) << "pd_op.strided_slice's starts/ends/strides input must be "
1810+
"constant value";
1811+
return false;
1812+
}
18031813
op->set_attribute(kCanRunTrtAttr, rewriter.bool_attr(true));
18041814
return true;
18051815
}

python/paddle/tensorrt/impls/manipulation.py

+42-93
Original file line numberDiff line numberDiff line change
@@ -741,121 +741,70 @@ def strided_slice_converter(network, paddle_op, inputs):
741741
ends_op = paddle_op.operands()[2].source().get_defining_op()
742742
strides_op = paddle_op.operands()[3].source().get_defining_op()
743743

744-
starts = (
745-
starts_op.attrs()["value"]
746-
if starts_op.name() == "pd_op.full_int_array"
747-
else inputs[1]
748-
)
749-
ends = (
750-
ends_op.attrs()["value"]
751-
if ends_op.name() == "pd_op.full_int_array"
752-
else inputs[2]
753-
)
754-
strides = (
755-
strides_op.attrs()["value"]
756-
if strides_op.name() == "pd_op.full_int_array"
757-
else inputs[3]
758-
)
744+
if starts_op.name() == "pd_op.full_int_array":
745+
starts = starts_op.attrs()["value"]
746+
747+
if ends_op.name() == "pd_op.full_int_array":
748+
ends = ends_op.attrs()["value"]
749+
750+
if strides_op.name() == "pd_op.full_int_array":
751+
strides = strides_op.attrs()["value"]
759752

760753
input_shape = paddle_op.operands()[0].source().shape
761754
nchw_input_dims = len(input_shape)
762755

763756
trt_start_dims = [0] * nchw_input_dims
764-
trt_size_dims = [input_shape[i] for i in range(nchw_input_dims)]
757+
trt_end_dims = [0] * nchw_input_dims
758+
trt_size_dims = [0] * nchw_input_dims
765759
trt_step_dims = [1] * nchw_input_dims
766760

767761
has_neg_indices = False
768-
trt_start_tensors = []
769-
trt_end_tensors = []
770-
trt_stride_tensors = []
771-
772-
for i, axis in enumerate(axes):
773-
if isinstance(starts, trt.ITensor):
774-
start_tensor = get_shape_tensor_element(network, starts, i)
775-
else:
776-
start_tensor = add_1D_constant_layer(network, [starts[i]])
777-
778-
if isinstance(ends, trt.ITensor):
779-
end_tensor = get_shape_tensor_element(network, ends, i)
780-
else:
781-
end_tensor = add_1D_constant_layer(network, [ends[i]])
782-
783-
if isinstance(strides, trt.ITensor):
784-
stride_tensor = get_shape_tensor_element(network, strides, i)
785-
else:
786-
stride_tensor = add_1D_constant_layer(network, [strides[i]])
787762

788-
zero_tensor = add_1D_constant_layer(network, [0])
763+
for i, trt_axis in enumerate(axes):
764+
trt_start_dims[trt_axis] = starts[i]
765+
trt_end_dims[trt_axis] = ends[i]
766+
trt_step_dims[trt_axis] = strides[i]
767+
if starts[i] < 0 or ends[i] < 0:
768+
has_neg_indices = True
789769

790-
if isinstance(starts, trt.ITensor) or isinstance(ends, trt.ITensor):
791-
is_start_neg = trt_less(network, start_tensor, zero_tensor)
792-
is_end_neg = trt_less(network, end_tensor, zero_tensor)
793-
temp_has_neg = network.add_elementwise(
794-
is_start_neg, is_end_neg, trt.ElementWiseOperation.OR
795-
).get_output(0)
796-
if not has_neg_indices:
797-
has_neg_indices = temp_has_neg
798-
else:
799-
has_neg_indices = network.add_elementwise(
800-
has_neg_indices, temp_has_neg, trt.ElementWiseOperation.OR
801-
).get_output(0)
802-
else:
803-
if starts[i] < 0 or ends[i] < 0:
804-
has_neg_indices = True
805-
806-
trt_start_tensors.append(start_tensor)
807-
trt_end_tensors.append(end_tensor)
808-
trt_stride_tensors.append(stride_tensor)
809-
810-
# Concatenate the tensors for start, end, and strides
811-
start_tensor = network.add_concatenation(trt_start_tensors).get_output(0)
812-
end_tensor = network.add_concatenation(trt_end_tensors).get_output(0)
813-
step_tensor = network.add_concatenation(trt_stride_tensors).get_output(0)
814-
815-
shape_tensor = network.add_shape(input_tensor).get_output(0)
816-
817-
if has_neg_indices is True:
770+
shape_tensor = trt_shape(network, input_tensor)
771+
start_tensor = add_1D_constant_layer(network, trt_start_dims)
772+
if has_neg_indices:
818773
start_tensor = fix_negative_indices(network, shape_tensor, start_tensor)
819-
elif isinstance(has_neg_indices, trt.ITensor):
820-
fixed_start_tensor = fix_negative_indices(
821-
network, shape_tensor, start_tensor
822-
)
823-
start_tensor = network.add_select(
824-
condition=has_neg_indices,
825-
then_input=fixed_start_tensor,
826-
else_input=start_tensor,
827-
).get_output(0)
828774

829-
# Process end_tensor similarly to handle negative indices
830-
if has_neg_indices is True:
831-
end_tensor = fix_negative_indices(network, shape_tensor, end_tensor)
832-
elif isinstance(has_neg_indices, trt.ITensor):
833-
fixed_end_tensor = fix_negative_indices(
834-
network, shape_tensor, end_tensor
775+
end_vec_tensor = []
776+
for i in range(len(trt_end_dims)):
777+
end_vec_tensor.append(
778+
get_shape_tensor_element(network, shape_tensor, i)
835779
)
836-
end_tensor = network.add_select(
837-
condition=has_neg_indices,
838-
then_input=fixed_end_tensor,
839-
else_input=end_tensor,
840-
).get_output(0)
841780

842-
# Compute min_tensor
843-
min_tensor = trt_min(network, end_tensor, shape_tensor)
844-
# Correct size_tensor calculation
845-
size_tensor = trt_sub(network, start_tensor, min_tensor)
781+
for i, trt_axis in enumerate(axes):
782+
if ends[i] >= 0:
783+
end_vec_tensor[trt_axis] = add_1D_constant_layer(network, ends[i])
784+
else:
785+
end_vec_tensor[trt_axis] = trt_sum(
786+
network,
787+
end_vec_tensor[trt_axis],
788+
add_1D_constant_layer(network, ends[i]),
789+
)
846790

847-
# floor_div_tensor computation
848-
floor_div_tensor = trt_floor_div(network, size_tensor, step_tensor)
849-
size_tensor = trt_sub(network, zero_tensor, floor_div_tensor)
791+
size_tensor = trt_sub(
792+
network,
793+
start_tensor,
794+
trt_min(network, trt_concat(network, end_vec_tensor), shape_tensor),
795+
)
796+
zero_t = add_1D_constant_layer(network, 0)
797+
step_tensor = add_1D_constant_layer(network, trt_step_dims)
798+
size_tensor = trt_sub(
799+
network, zero_t, trt_floor_div(network, size_tensor, step_tensor)
800+
)
850801

851-
# Create the slice layer
852802
layer = network.add_slice(
853803
input_tensor, trt_start_dims, trt_size_dims, trt_step_dims
854804
)
855805
layer.set_input(1, start_tensor)
856806
layer.set_input(2, size_tensor)
857807
layer.set_input(3, step_tensor)
858-
859808
return layer.get_output(0)
860809

861810

test/tensorrt/test_converter_manipulation.py

+20-17
Original file line numberDiff line numberDiff line change
@@ -448,7 +448,7 @@ class TestStrideSliceCase2TRTPattern(TensorRTBaseTest):
448448
def setUp(self):
449449
self.python_api = paddle.strided_slice
450450
self.api_args = {
451-
"x": np.random.random([3, 4, 10]).astype("int32"),
451+
"x": np.random.random([3, 4, 10]).astype("int64"),
452452
"axes": [0, 1, 2],
453453
"starts": [1, 0, 2],
454454
"ends": [2, 3, 4],
@@ -484,15 +484,15 @@ class TestStrideSliceCase4TRTPattern(TensorRTBaseTest):
484484
def setUp(self):
485485
self.python_api = paddle.strided_slice
486486
self.api_args = {
487-
"x": np.random.random([5, 5, 5]).astype("float32"),
488-
"axes": [0, 1, 2],
489-
"starts": np.array([1, 0, 0]).astype("int32"),
490-
"ends": np.array([2, 1, 3]).astype("int32"),
491-
"strides": np.array([1, 1, 1]).astype("int32"),
487+
"x": np.random.random([1, 56, 56, 128]).astype("float32"),
488+
"axes": [1, 2],
489+
"starts": [0, 0],
490+
"ends": [6, 6],
491+
"strides": [2, 2],
492492
}
493-
self.program_config = {"feed_list": ["x", "starts", "ends", "strides"]}
494-
self.min_shape = {"x": [1, 5, 5]}
495-
self.max_shape = {"x": [6, 5, 5]}
493+
self.program_config = {"feed_list": ["x"]}
494+
self.min_shape = {"x": [1, 56, 56, 128]}
495+
self.max_shape = {"x": [1, 56, 56, 128]}
496496

497497
def test_trt_result(self):
498498
self.check_trt_result()
@@ -502,15 +502,18 @@ class TestStrideSliceCase5TRTPattern(TensorRTBaseTest):
502502
def setUp(self):
503503
self.python_api = paddle.strided_slice
504504
self.api_args = {
505-
"x": np.random.random([3, 4, 10]).astype("float32"),
506-
"axes": [0, 1, 2],
507-
"starts": np.array([0, -1, 0]).astype("int32"),
508-
"ends": np.array([2, -3, 5]).astype("int32"),
509-
"strides": np.array([1, -1, 1]).astype("int32"),
505+
"x": np.random.random([1, 56, 56, 128]).astype("float32"),
506+
"axes": [1, 2],
507+
"starts": [
508+
1,
509+
1,
510+
],
511+
"ends": [10000, 10000],
512+
"strides": [2, 2],
510513
}
511-
self.program_config = {"feed_list": ["x", "starts", "ends", "strides"]}
512-
self.min_shape = {"x": [1, 4, 10]}
513-
self.max_shape = {"x": [5, 4, 10]}
514+
self.program_config = {"feed_list": ["x"]}
515+
self.min_shape = {"x": [1, 56, 56, 128]}
516+
self.max_shape = {"x": [1, 56, 56, 128]}
514517

515518
def test_trt_result(self):
516519
self.check_trt_result()

0 commit comments

Comments
 (0)