Skip to content

Commit d33a367

Browse files
authored
【Paddle TensorRT】Resolved the precision issue of pd_op.slice (#71655) (#71715)
* Resolved the precision issue of pd_op.slice
1 parent 9931609 commit d33a367

File tree

2 files changed

+13
-39
lines changed

2 files changed

+13
-39
lines changed

paddle/fluid/inference/tensorrt/convert/slice_op.cc

+2-7
Original file line numberDiff line numberDiff line change
@@ -99,13 +99,8 @@ class SliceOpConverter : public OpConverter {
9999
if (slice_inputs.find("EndsTensor") != slice_inputs.end() &&
100100
!op_desc.Input("EndsTensor").empty()) { // has EndsTensor input
101101
for (size_t i = 0; i < axes.size(); ++i) {
102-
auto axis = axes[i];
103-
auto input_dim = GetEleTensorOfShape(shape_tensor, axis);
104-
ends_tensor[axes[i]] =
105-
Min(Max(GetEleTensorOfShape(
106-
engine_->GetITensor(op_desc.Input("EndsTensor")[0]), i),
107-
Add1DConstantLayer(0)),
108-
input_dim);
102+
ends_tensor[axes[i]] = GetEleTensorOfShape(
103+
engine_->GetITensor(op_desc.Input("EndsTensor")[0]), i);
109104
}
110105
} else if (slice_inputs.find("EndsTensorList") != slice_inputs.end() &&
111106
!op_desc.Input("EndsTensorList").empty()) {

python/paddle/tensorrt/impls/manipulation.py

+11-32
Original file line numberDiff line numberDiff line change
@@ -517,7 +517,6 @@ def slice_converter(network, paddle_op, inputs):
517517
idx,
518518
name=[paddle_op.name(), f'starts_tensor_{idx}'],
519519
)
520-
start_tensor = trt_concat(network, starts_tensor)
521520

522521
ends = get_input_constant_value(paddle_op, inputs, 2)
523522
if ends is not None:
@@ -569,36 +568,21 @@ def slice_converter(network, paddle_op, inputs):
569568
else:
570569
ends = inputs[2]
571570
for idx in range(len(axes)):
572-
axis = axes[idx]
573-
input_dim = get_shape_tensor_element(
574-
network,
575-
input_shape_tensor,
576-
axis,
577-
name=[paddle_op.name(), f'input_dim_{idx}'],
578-
)
579-
end_element = get_shape_tensor_element(
571+
ends_tensor[axes[idx]] = get_shape_tensor_element(
580572
network,
581573
ends,
582574
idx,
583-
name=[paddle_op.name(), f'end_element_{idx}'],
575+
name=[paddle_op.name(), f'ends_tensor_{idx}'],
584576
)
585577

586-
ends_tensor[axes[idx]] = trt_min(
587-
network,
588-
trt_max(
589-
network,
590-
end_element,
591-
add_1D_constant_layer(
592-
network, 0, name=[paddle_op.name(), 'zero_tensor_{idx}']
593-
),
594-
name=[paddle_op.name(), 'trt_max_{idx}'],
595-
),
596-
input_dim,
597-
name=[paddle_op.name(), 'trt_min_{idx}'],
598-
)
599-
end_tensor = trt_concat(
600-
network, ends_tensor, name=[paddle_op.name(), 'end_tensor']
601-
)
578+
start_tensor_layer = network.add_concatenation(starts_tensor)
579+
start_tensor_layer.axis = 0
580+
set_layer_name(start_tensor_layer, paddle_op)
581+
start_tensor = start_tensor_layer.get_output(0)
582+
end_tensor_layer = network.add_concatenation(ends_tensor)
583+
end_tensor_layer.axis = 0
584+
set_layer_name(end_tensor_layer, paddle_op)
585+
end_tensor = end_tensor_layer.get_output(0)
602586
size_tensor = trt_sub(
603587
network,
604588
end_tensor,
@@ -629,12 +613,7 @@ def slice_converter(network, paddle_op, inputs):
629613
shuffle_layer = network.add_shuffle(output_tensor)
630614
shuffle_layer.reshape_dims = ()
631615
else:
632-
real_size_tensor = trt_gather(
633-
network,
634-
size_tensor,
635-
gather_indices,
636-
name=[paddle_op.name(), 'real_size_tensor'],
637-
)
616+
real_size_tensor = trt_gather(network, size_tensor, gather_indices)
638617
shuffle_layer = network.add_shuffle(output_tensor)
639618
shuffle_layer.set_input(1, real_size_tensor)
640619

0 commit comments

Comments
 (0)