Skip to content

Commit 38f6a5a

Browse files
authored
[Inference]Improve pir-trt performance Part-2 (#71712)
* fix performance * add meshgrid * fix ci * fix ci * add coverage * del code * fix coverage
1 parent 65f2c27 commit 38f6a5a

11 files changed

+239
-35
lines changed

paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc

+28-10
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@ DEFINE_GENERAL_PATTERN(Asin, paddle::dialect::AsinOp)
126126
DEFINE_GENERAL_PATTERN(Acos, paddle::dialect::AcosOp)
127127
DEFINE_GENERAL_PATTERN(Atan, paddle::dialect::AtanOp)
128128
DEFINE_GENERAL_PATTERN(ShuffleChannel, paddle::dialect::ShuffleChannelOp)
129+
DEFINE_GENERAL_PATTERN(Meshgrid, paddle::dialect::MeshgridOp)
129130

130131
#undef DEFINE_GENERAL_PATTERN
131132

@@ -927,10 +928,16 @@ class UnsqueezeOpPattern
927928
dynamic_dims.push_back(i);
928929
}
929930
}
930-
if (dynamic_dims.size() > 1) {
931-
VLOG(3) << "Currently we don't support unsqueeze with more than one "
932-
"dynamic dims";
933-
return false;
931+
if (dynamic_dims.size() == 0) {
932+
std::vector<int64_t> axes;
933+
for (auto &axis_ele : axis.AsVector()) {
934+
axes.push_back(axis_ele.dyn_cast<pir::Int64Attribute>().data());
935+
}
936+
if (std::find(axes.begin(), axes.end(), 0) != axes.end()) {
937+
VLOG(3) << "Invalid squeeze axes. Axes having batch axis is not "
938+
"supported in static shape";
939+
return false;
940+
}
934941
}
935942

936943
op->set_attribute(kCanRunTrtAttr, rewriter.bool_attr(true));
@@ -967,10 +974,16 @@ class Unsqueeze_OpPattern
967974
dynamic_dims.push_back(i);
968975
}
969976
}
970-
if (dynamic_dims.size() > 1) {
971-
VLOG(3) << "Currently we don't support unsqueeze with more than one "
972-
"dynamic dims";
973-
return false;
977+
if (dynamic_dims.size() == 0) {
978+
std::vector<int64_t> axes;
979+
for (auto &axis_ele : axis.AsVector()) {
980+
axes.push_back(axis_ele.dyn_cast<pir::Int64Attribute>().data());
981+
}
982+
if (std::find(axes.begin(), axes.end(), 0) != axes.end()) {
983+
VLOG(3) << "Invalid squeeze axes. Axes having batch axis is not "
984+
"supported in static shape";
985+
return false;
986+
}
974987
}
975988

976989
op->set_attribute(kCanRunTrtAttr, rewriter.bool_attr(true));
@@ -1953,8 +1966,12 @@ class StackOpPattern : public pir::OpRewritePattern<paddle::dialect::StackOp> {
19531966
pir::Value x = op.operand_source(0);
19541967
int rank = 1;
19551968
auto x_type = x.type();
1956-
if (x_type.isa<pir::VectorType>()) {
1957-
rank = x_type.dyn_cast<pir::VectorType>().size();
1969+
if (x_type.isa<pir::VectorType>() &&
1970+
x_type.dyn_cast<pir::VectorType>().size() > 0) {
1971+
auto vec_type = x_type.dyn_cast<pir::VectorType>();
1972+
auto tensor_element =
1973+
vec_type.data()[0].dyn_cast<paddle::dialect::DenseTensorType>();
1974+
rank = tensor_element.dims().size();
19581975
} else {
19591976
auto x_shape = pir::GetShapeFromValue(x);
19601977
rank = x_shape.size();
@@ -3004,6 +3021,7 @@ class TrtOpMarkerPass : public pir::PatternRewritePass {
30043021
ADD_PATTERN(Acos)
30053022
ADD_PATTERN(Atan)
30063023
ADD_PATTERN(ShuffleChannel)
3024+
ADD_PATTERN(Meshgrid)
30073025
#if IS_TRT_VERSION_GE(8600)
30083026
ADD_PATTERN(Layer_norm)
30093027
#endif

python/paddle/tensorrt/converter.py

+34
Original file line numberDiff line numberDiff line change
@@ -586,8 +586,42 @@ def convert_program_to_trt(self):
586586
if op.results()[0].use_empty():
587587
self.program.global_block().remove_op(op)
588588
if op.name() == "builtin.constant":
589+
# builtin.constant can't be saved/loaded, we need del it
589590
if op.results()[0].use_empty():
590591
self.program.global_block().remove_op(op)
592+
else:
593+
constant_result = op.results()[0]
594+
constant_value_name = op.attrs()["value"]
595+
out_dtype = np.dtype(
596+
paddle.pir.core._PADDLE_PIR_DTYPE_2_NUMPY_DTYPE[
597+
constant_result.dtype
598+
]
599+
)
600+
tensor_data = self.scope.var(
601+
constant_value_name
602+
).get_tensor()
603+
constant_array = np.array(
604+
tensor_data, dtype=out_dtype
605+
).tolist()
606+
607+
# convert builtin.constant to pd_op.full_int_array/full and then delete it
608+
with paddle.pir.core.program_guard(self.program):
609+
paddle.base.libpaddle.pir.reset_insertion_point_to_start()
610+
if len(constant_array) == 1:
611+
full_value = paddle._C_ops.full(
612+
[1],
613+
constant_array[0],
614+
constant_result.dtype,
615+
paddle.CUDAPlace(0),
616+
)
617+
else:
618+
full_value = paddle._C_ops.full_int_array(
619+
constant_array,
620+
constant_result.dtype,
621+
paddle.CUDAPlace(0),
622+
)
623+
op.replace_all_uses_with([full_value])
624+
self.program.global_block().remove_op(op)
591625

592626
# Call clear_shape_info to clear the previous shape information
593627
clear_shape_info()

python/paddle/tensorrt/impls/creation.py

+76
Original file line numberDiff line numberDiff line change
@@ -371,3 +371,79 @@ def full_with_tensor_converter(network, paddle_op, inputs):
371371
set_layer_name(fill_layer, paddle_op)
372372
output_tensor = fill_layer.get_output(0)
373373
return output_tensor
374+
375+
376+
@converter_registry.register("pd_op.meshgrid", trt_version="8.x")
377+
def meshgrid_converter(network, paddle_op, vec_inputs):
378+
inputs = vec_inputs[0]
379+
n = len(inputs)
380+
outputs = []
381+
382+
# get all input dims (all input is 1-dim)
383+
input_dims = [network.add_shape(inp).get_output(0) for inp in inputs]
384+
385+
for k in range(n):
386+
# --------------------------------
387+
# step1:reshape k input as [1,..,Dk,..,1]
388+
# --------------------------------
389+
x = inputs[k]
390+
reshape_dims = [] # init dims as 1
391+
for i in range(n):
392+
one = add_1D_constant_layer(
393+
network,
394+
1,
395+
dtype=np.int32,
396+
is_scalar=False,
397+
name=[paddle_op.name(), f'one_{k}'],
398+
)
399+
reshape_dims.append(one)
400+
# replace k-th input dim as Dk
401+
reshape_dims[k] = input_dims[k]
402+
403+
dim_concat = network.add_concatenation(reshape_dims)
404+
set_layer_name(dim_concat, paddle_op)
405+
x_reshaped = network.add_shuffle(x)
406+
x_reshaped.set_input(1, dim_concat.get_output(0))
407+
408+
# --------------------------------
409+
# step2: create tensor([D1, D2, ..., 1, ..., Dn]) that filled with 1
410+
# --------------------------------
411+
ones_shape = []
412+
for i in range(n):
413+
ones_shape.append(input_dims[i])
414+
ones_shape[k] = add_1D_constant_layer(
415+
network,
416+
1,
417+
dtype=np.int32,
418+
is_scalar=False,
419+
name=[paddle_op.name(), f'ones_shape_{k}'],
420+
)
421+
dim_concat = network.add_concatenation(ones_shape)
422+
set_layer_name(dim_concat, paddle_op)
423+
424+
# Fill constant 1
425+
fill_layer = network.add_fill(shape=(), op=trt.FillOperation.LINSPACE)
426+
fill_layer.set_input(0, dim_concat.get_output(0))
427+
value_input = add_1D_constant_layer(
428+
network,
429+
1,
430+
dtype=np.float32,
431+
is_scalar=True,
432+
name=[paddle_op.name(), 'one_for_fill'],
433+
)
434+
fill_layer.set_input(1, value_input)
435+
beta_vec = [0] * n
436+
fill_layer.set_input(
437+
2, add_1D_constant_layer(network, beta_vec, np.float32)
438+
)
439+
440+
# --------------------------------
441+
# step3: element wise multiplication
442+
# --------------------------------
443+
grid = network.add_elementwise(
444+
x_reshaped.get_output(0),
445+
fill_layer.get_output(0),
446+
trt.ElementWiseOperation.PROD,
447+
).get_output(0)
448+
outputs.append(grid)
449+
return outputs

python/paddle/tensorrt/util.py

+12-19
Original file line numberDiff line numberDiff line change
@@ -52,19 +52,10 @@ def map_dtype(pd_dtype):
5252
raise TypeError(f"Unsupported dtype: {pd_dtype}")
5353

5454

55-
def all_ops_into_trt(program):
55+
def support_constant_folding_pass(program):
5656
for op in program.global_block().ops:
57-
if (
58-
op.name() == "pd_op.fetch"
59-
or op.name() == "pd_op.data"
60-
or op.name().split('.')[0] == "builtin"
61-
):
62-
continue
63-
if op.has_attr("__l_trt__") is False:
57+
if op.name() == "pd_op.while" or op.name() == "pd_op.if":
6458
return False
65-
if op.attrs()["__l_trt__"] is False:
66-
return False
67-
_logger.info("All ops convert to trt.")
6859
return True
6960

7061

@@ -107,7 +98,7 @@ def _add_pass_(pm, passes, disable_passes):
10798
# run other passes
10899
pm.clear()
109100
passes = []
110-
if all_ops_into_trt(program):
101+
if support_constant_folding_pass(program):
111102
# only run constant_folding_pass when all ops into trt
112103
passes.append(
113104
{
@@ -117,18 +108,19 @@ def _add_pass_(pm, passes, disable_passes):
117108
}
118109
}
119110
)
120-
111+
passes.append(
112+
{
113+
'dead_code_elimination_pass': {
114+
"__place__": place,
115+
"__param_scope__": scope,
116+
}
117+
}
118+
)
121119
passes.append({'conv2d_add_fuse_pass': {}})
122120
passes.append({'trt_op_marker_pass': {}}) # for op that created by pass
123121
_add_pass_(pm, passes, disable_passes)
124122
pm.run(program)
125123

126-
# delete unused op
127-
for op in program.global_block().ops:
128-
if op.name() == "builtin.constant" or op.name() == "builtin.parameter":
129-
if op.results()[0].use_empty():
130-
program.global_block().remove_op(op)
131-
132124
return program
133125

134126

@@ -282,6 +274,7 @@ def weight_to_tensor(network, paddle_value, trt_tensor, use_op_name):
282274
# the following op needn't cast trt.Weight to ITensor, because the layer need weight as input
283275
forbid_cast_op = [
284276
"pd_op.depthwise_conv2d",
277+
"pd_op.conv2d",
285278
"pd_op.conv2d_transpose",
286279
"pd_op.conv3d",
287280
"pd_op.conv3d_transpose",

test/tensorrt/tensorrt_test_base.py

+9
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ def __init__(self, methodName='runTest'):
4646
self.dynamic_shape_data = {}
4747
self.disable_passes = [
4848
"constant_folding_pass",
49+
"dead_code_elimination_pass",
4950
]
5051

5152
def create_fake_program(self):
@@ -267,6 +268,14 @@ def check_trt_result(self, rtol=1e-5, atol=1e-5, precision_mode="fp32"):
267268
main_program,
268269
disable_passes=self.disable_passes,
269270
)
271+
# delete unused op
272+
for op in main_program.global_block().ops:
273+
if (
274+
op.name() == "builtin.constant"
275+
or op.name() == "builtin.parameter"
276+
):
277+
if op.results()[0].use_empty():
278+
main_program.global_block().remove_op(op)
270279

271280
scope = paddle.static.global_scope()
272281
main_program = warmup_shape_infer(

test/tensorrt/test_converter_conv.py

+16-4
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,11 @@ def setUp(self):
4141
self.min_shape = {"x": [1, 3, 8, 8]}
4242
self.opt_shape = {"x": [2, 3, 8, 8]}
4343
self.max_shape = {"x": [10, 3, 8, 8]}
44-
self.disable_passes = ['constant_folding_pass', 'conv2d_add_fuse_pass']
44+
self.disable_passes = [
45+
'constant_folding_pass',
46+
'conv2d_add_fuse_pass',
47+
'dead_code_elimination_pass',
48+
]
4549

4650
def test_trt_result_fp16(self):
4751
self.check_trt_result(precision_mode="fp16")
@@ -62,7 +66,11 @@ def setUp(self):
6266
self.min_shape = {"x": [1, 3, 8, 8]}
6367
self.opt_shape = {"x": [2, 3, 8, 8]}
6468
self.max_shape = {"x": [10, 3, 8, 8]}
65-
self.disable_passes = ['constant_folding_pass', 'conv2d_add_fuse_pass']
69+
self.disable_passes = [
70+
'constant_folding_pass',
71+
'conv2d_add_fuse_pass',
72+
'dead_code_elimination_pass',
73+
]
6674

6775
def test_trt_result(self):
6876
self.check_trt_result()
@@ -81,7 +89,11 @@ def setUp(self):
8189
self.min_shape = {"x": [1, 3, 8, 8]}
8290
self.opt_shape = {"x": [2, 3, 8, 8]}
8391
self.max_shape = {"x": [10, 3, 8, 8]}
84-
self.disable_passes = ['constant_folding_pass', 'conv2d_add_fuse_pass']
92+
self.disable_passes = [
93+
'constant_folding_pass',
94+
'conv2d_add_fuse_pass',
95+
'dead_code_elimination_pass',
96+
]
8597

8698
def test_trt_result(self):
8799
self.check_trt_result()
@@ -489,7 +501,7 @@ def setUp(self):
489501
self.min_shape = {"x": [1, 3, 8, 8]}
490502
self.opt_shape = {"x": [2, 3, 8, 8]}
491503
self.max_shape = {"x": [10, 3, 8, 8]}
492-
self.disable_passes = []
504+
self.disable_passes = ['dead_code_elimination_pass']
493505

494506
def test_trt_result_fp16(self):
495507
self.check_trt_result(precision_mode="fp16")

test/tensorrt/test_converter_creation.py

+18
Original file line numberDiff line numberDiff line change
@@ -251,5 +251,23 @@ def test_trt_result(self):
251251
self.check_trt_result()
252252

253253

254+
class TestMeshgridTRTPattern(TensorRTBaseTest):
255+
def setUp(self):
256+
self.python_api = paddle.meshgrid
257+
self.api_args = {
258+
"x": [
259+
np.random.random([20]).astype("float32"),
260+
np.random.random([30]).astype("float32"),
261+
],
262+
}
263+
self.program_config = {"feed_list": ["x"]}
264+
self.min_shape = {"x": [[10], [20]]}
265+
self.opt_shape = {"x": [[20], [30]]}
266+
self.max_shape = {"x": [[30], [40]]}
267+
268+
def test_trt_result(self):
269+
self.check_trt_result()
270+
271+
254272
if __name__ == "__main__":
255273
unittest.main()

test/tensorrt/test_converter_manipulation.py

+36
Original file line numberDiff line numberDiff line change
@@ -1101,5 +1101,41 @@ def test_fp16_result(self):
11011101
self.check_trt_result(precision_mode="fp16")
11021102

11031103

1104+
class TestUnsqueezeTRTPattern(TensorRTBaseTest):
1105+
def setUp(self):
1106+
self.python_api = paddle.unsqueeze
1107+
self.api_args = {
1108+
"x": np.random.random([5, 10]).astype("float32"),
1109+
"axis": 0,
1110+
}
1111+
self.program_config = {"feed_list": ["x"]}
1112+
self.min_shape = {}
1113+
self.opt_shape = {}
1114+
self.max_shape = {}
1115+
1116+
def test_trt_result(self):
1117+
self.check_marker(expected_result=False)
1118+
1119+
1120+
def unsqueeze_inplace_wrapper(x, axis):
1121+
return _C_ops.unsqueeze_(x, axis)
1122+
1123+
1124+
class TestUnsqueeze_TRTPattern(TensorRTBaseTest):
1125+
def setUp(self):
1126+
self.python_api = unsqueeze_inplace_wrapper
1127+
self.api_args = {
1128+
"x": np.random.random([5, 10]).astype("float32"),
1129+
"axis": 0,
1130+
}
1131+
self.program_config = {"feed_list": ["x"]}
1132+
self.min_shape = {}
1133+
self.opt_shape = {}
1134+
self.max_shape = {}
1135+
1136+
def test_trt_result(self):
1137+
self.check_marker(expected_result=False)
1138+
1139+
11041140
if __name__ == '__main__':
11051141
unittest.main()

0 commit comments

Comments
 (0)