Skip to content

Commit 4dfbdb0

Browse files
authored
add paddle-trt convert op: greater_equal (#52000)
1 parent 978d544 commit 4dfbdb0

File tree

4 files changed

+192
-2
lines changed

4 files changed

+192
-2
lines changed

paddle/fluid/inference/api/analysis_predictor.cc

+1
Original file line numberDiff line numberDiff line change
@@ -2404,6 +2404,7 @@ USE_TRT_CONVERTER(logical_or);
24042404
USE_TRT_CONVERTER(logical_xor);
24052405
USE_TRT_CONVERTER(logical_and);
24062406
USE_TRT_CONVERTER(less_equal);
2407+
USE_TRT_CONVERTER(greater_equal);
24072408
USE_TRT_CONVERTER(transpose);
24082409
USE_TRT_CONVERTER(transpose2);
24092410
USE_TRT_CONVERTER(flatten);

paddle/fluid/inference/tensorrt/convert/elementwise_op.cc

100755100644
+27
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,26 @@ class ElementwiseTensorOpConverter : public OpConverter {
162162
*(equal_layer->getOutput(0)),
163163
nvinfer1::ElementWiseOperation::kOR);
164164

165+
RreplenishLayerAndOutput(layer, "elementwise", {output_name}, test_mode);
166+
} else if (op_type_ == "greater_equal") {
167+
auto* greater_layer =
168+
TRT_ENGINE_ADD_LAYER(engine_,
169+
ElementWise,
170+
*X,
171+
*reshape_y_tensor,
172+
nvinfer1::ElementWiseOperation::kGREATER);
173+
auto* equal_layer =
174+
TRT_ENGINE_ADD_LAYER(engine_,
175+
ElementWise,
176+
*X,
177+
*reshape_y_tensor,
178+
nvinfer1::ElementWiseOperation::kEQUAL);
179+
auto* layer = TRT_ENGINE_ADD_LAYER(engine_,
180+
ElementWise,
181+
*(greater_layer->getOutput(0)),
182+
*(equal_layer->getOutput(0)),
183+
nvinfer1::ElementWiseOperation::kOR);
184+
165185
RreplenishLayerAndOutput(layer, "elementwise", {output_name}, test_mode);
166186
} else if (op_type_ == "mod") {
167187
auto* div_layer =
@@ -290,6 +310,11 @@ class ElementwiseTensorLessEqualOpConverter
290310
public:
291311
ElementwiseTensorLessEqualOpConverter() { op_type_ = "less_equal"; }
292312
};
313+
class ElementwiseTensorGreaterEqualOpConverter
314+
: public ElementwiseTensorOpConverter {
315+
public:
316+
ElementwiseTensorGreaterEqualOpConverter() { op_type_ = "greater_equal"; }
317+
};
293318
class ElementwiseTensorModOpConverter : public ElementwiseTensorOpConverter {
294319
public:
295320
ElementwiseTensorModOpConverter() { op_type_ = "mod"; }
@@ -342,3 +367,5 @@ REGISTER_TRT_OP_CONVERTER(logical_or, ElementwiseTensorLogicalOrOpConverter);
342367
REGISTER_TRT_OP_CONVERTER(logical_xor, ElementwiseTensorLogicalXorOpConverter);
343368
REGISTER_TRT_OP_CONVERTER(logical_and, ElementwiseTensorLogicalAndOpConverter);
344369
REGISTER_TRT_OP_CONVERTER(less_equal, ElementwiseTensorLessEqualOpConverter);
370+
REGISTER_TRT_OP_CONVERTER(greater_equal,
371+
ElementwiseTensorGreaterEqualOpConverter);

paddle/fluid/inference/tensorrt/op_teller.cc

+5-2
Original file line numberDiff line numberDiff line change
@@ -1427,7 +1427,8 @@ struct SimpleOpTypeSetTeller : public Teller {
14271427

14281428
if (op_type == "less_than" || op_type == "greater_than" ||
14291429
op_type == "logical_or" || op_type == "logical_xor" ||
1430-
op_type == "logical_and" || op_type == "less_equal") {
1430+
op_type == "logical_and" || op_type == "less_equal" ||
1431+
op_type == "greater_equal") {
14311432
#if IS_TRT_VERSION_GE(8400)
14321433
// TRT does not support kEQUAL/kGREATER/kLESS work with implicit batch
14331434
if (!with_dynamic_shape) {
@@ -1448,7 +1449,7 @@ struct SimpleOpTypeSetTeller : public Teller {
14481449
}
14491450
}
14501451
if (op_type == "less_than" || op_type == "greater_than" ||
1451-
op_type == "less_equal") {
1452+
op_type == "less_equal" || op_type == "greater_equal") {
14521453
if (x_dtype == framework::proto::VarType::BOOL ||
14531454
y_dtype == framework::proto::VarType::BOOL) {
14541455
VLOG(3)
@@ -2767,6 +2768,7 @@ struct SimpleOpTypeSetTeller : public Teller {
27672768
"logical_xor",
27682769
"logical_and",
27692770
"less_equal",
2771+
"greater_equal",
27702772
"dropout",
27712773
"fill_any_like",
27722774
"prelu",
@@ -2923,6 +2925,7 @@ struct SimpleOpTypeSetTeller : public Teller {
29232925
"logical_xor",
29242926
"logical_and",
29252927
"less_equal",
2928+
"greater_equal",
29262929
"dropout",
29272930
"fill_any_like",
29282931
"prelu",

python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_compare_and_logical.py

+159
Original file line numberDiff line numberDiff line change
@@ -481,6 +481,165 @@ def test(self):
481481
self.run_test()
482482

483483

484+
class TrtConvertGreaterEqualTest(TrtLayerAutoScanTest):
485+
def is_program_valid(self, program_config: ProgramConfig) -> bool:
486+
return True
487+
488+
def sample_program_configs(self):
489+
def generate_input(shape):
490+
return np.random.random(shape).astype(np.float32)
491+
492+
for shape in [[2, 16], [2, 16, 32], [1, 32, 16, 32]]:
493+
for op_type in ["greater_equal"]:
494+
for axis in [-1]:
495+
self.dims = len(shape)
496+
dics = [
497+
{"axis": axis},
498+
{"in_dtype": 5, "out_dtype": 2},
499+
{"in_dtype": 0, "out_dtype": 5},
500+
]
501+
ops_config = [
502+
{
503+
"op_type": "cast",
504+
"op_inputs": {"X": ["input_data1"]},
505+
"op_outputs": {"Out": ["cast_output_data1"]},
506+
"op_attrs": dics[1],
507+
"outputs_dtype": {"cast_output_data1": np.int32},
508+
},
509+
{
510+
"op_type": "cast",
511+
"op_inputs": {"X": ["input_data2"]},
512+
"op_outputs": {"Out": ["cast_output_data2"]},
513+
"op_attrs": dics[1],
514+
"outputs_dtype": {"cast_output_data2": np.int32},
515+
},
516+
{
517+
"op_type": op_type,
518+
"op_inputs": {
519+
"X": ["cast_output_data1"],
520+
"Y": ["cast_output_data2"],
521+
},
522+
"op_outputs": {"Out": ["cast_output_data0"]},
523+
"op_attrs": dics[0],
524+
},
525+
{
526+
"op_type": "cast",
527+
"op_inputs": {"X": ["cast_output_data0"]},
528+
"op_outputs": {"Out": ["output_data"]},
529+
"op_attrs": dics[2],
530+
},
531+
]
532+
ops = self.generate_op_config(ops_config)
533+
534+
program_config = ProgramConfig(
535+
ops=ops,
536+
weights={},
537+
inputs={
538+
"input_data1": TensorConfig(
539+
data_gen=partial(generate_input, shape)
540+
),
541+
"input_data2": TensorConfig(
542+
data_gen=partial(generate_input, shape)
543+
),
544+
},
545+
outputs=["output_data"],
546+
)
547+
548+
yield program_config
549+
550+
def sample_predictor_configs(
551+
self, program_config
552+
) -> (paddle_infer.Config, List[int], float):
553+
def generate_dynamic_shape(attrs):
554+
if self.dims == 2:
555+
self.dynamic_shape.min_input_shape = {
556+
"input_data1": [2, 16],
557+
"input_data2": [2, 16],
558+
}
559+
self.dynamic_shape.max_input_shape = {
560+
"input_data1": [2, 16],
561+
"input_data2": [2, 16],
562+
}
563+
self.dynamic_shape.opt_input_shape = {
564+
"input_data1": [2, 16],
565+
"input_data2": [2, 16],
566+
}
567+
if self.dims == 3:
568+
self.dynamic_shape.min_input_shape = {
569+
"input_data1": [2, 16, 32],
570+
"input_data2": [2, 16, 32],
571+
}
572+
self.dynamic_shape.max_input_shape = {
573+
"input_data1": [2, 16, 32],
574+
"input_data2": [2, 16, 32],
575+
}
576+
self.dynamic_shape.opt_input_shape = {
577+
"input_data1": [2, 16, 32],
578+
"input_data2": [2, 16, 32],
579+
}
580+
if self.dims == 4:
581+
self.dynamic_shape.min_input_shape = {
582+
"input_data1": [1, 32, 16, 32],
583+
"input_data2": [1, 32, 16, 32],
584+
}
585+
self.dynamic_shape.max_input_shape = {
586+
"input_data1": [1, 32, 16, 32],
587+
"input_data2": [1, 32, 16, 32],
588+
}
589+
self.dynamic_shape.opt_input_shape = {
590+
"input_data1": [1, 32, 16, 32],
591+
"input_data2": [1, 32, 16, 32],
592+
}
593+
594+
def clear_dynamic_shape():
595+
self.dynamic_shape.max_input_shape = {}
596+
self.dynamic_shape.min_input_shape = {}
597+
self.dynamic_shape.opt_input_shape = {}
598+
599+
def generate_trt_nodes_num(attrs, dynamic_shape):
600+
ver = paddle_infer.get_trt_compile_version()
601+
if (
602+
ver[0] * 1000 + ver[1] * 100 + ver[2] * 10 < 8400
603+
or not dynamic_shape
604+
):
605+
return 2, 5
606+
else:
607+
return 1, 3
608+
609+
attrs = [
610+
program_config.ops[i].attrs for i in range(len(program_config.ops))
611+
]
612+
613+
# for static_shape
614+
clear_dynamic_shape()
615+
self.trt_param.precision = paddle_infer.PrecisionType.Float32
616+
yield self.create_inference_config(), generate_trt_nodes_num(
617+
attrs, False
618+
), 1e-5
619+
self.trt_param.precision = paddle_infer.PrecisionType.Half
620+
yield self.create_inference_config(), generate_trt_nodes_num(
621+
attrs, False
622+
), (1e-3, 1e-3)
623+
624+
# for dynamic_shape
625+
generate_dynamic_shape(attrs)
626+
self.trt_param.precision = paddle_infer.PrecisionType.Float32
627+
yield self.create_inference_config(), generate_trt_nodes_num(
628+
attrs, True
629+
), 1e-5
630+
self.trt_param.precision = paddle_infer.PrecisionType.Half
631+
yield self.create_inference_config(), generate_trt_nodes_num(
632+
attrs, True
633+
), (1e-3, 1e-3)
634+
635+
def add_skip_trt_case(self):
636+
pass
637+
638+
def test(self):
639+
self.add_skip_trt_case()
640+
self.run_test()
641+
642+
484643
class TrtConvertCompareSkipTest(TrtLayerAutoScanTest):
485644
def is_program_valid(self, program_config: ProgramConfig) -> bool:
486645
return True

0 commit comments

Comments
 (0)