Skip to content

Commit 1b9e246

Browse files
authored
【TRT_Converter】add where TRT_Converter (#68876)
* TRT_Converter * where * add test_converter_search * add int32 test & LT * merge * update where * update where converter * updata where del non_zero * updata where del non_zero * update where int64
1 parent a5cf82a commit 1b9e246

File tree

4 files changed

+71
-0
lines changed

4 files changed

+71
-0
lines changed

paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1475,6 +1475,31 @@ class TanhOpPattern : public pir::OpRewritePattern<paddle::dialect::TanhOp> {
14751475
}
14761476
};
14771477

1478+
class WherePattern : public pir::OpRewritePattern<paddle::dialect::WhereOp> {
1479+
public:
1480+
using pir::OpRewritePattern<paddle::dialect::WhereOp>::OpRewritePattern;
1481+
bool MatchAndRewrite(paddle::dialect::WhereOp op,
1482+
pir::PatternRewriter &rewriter) const override {
1483+
if (op->HasAttribute(kCanRunTrtAttr) &&
1484+
op.attribute<pir::BoolAttribute>(kCanRunTrtAttr).data()) {
1485+
return false;
1486+
}
1487+
pir::Value x = op.operand_source(1);
1488+
pir::Value y = op.operand_source(2);
1489+
if (x == nullptr || y == nullptr) {
1490+
VLOG(3) << "pd_op.where x or y tensor value is null";
1491+
return false;
1492+
}
1493+
#if IS_TRT_VERSION_LT(8400)
1494+
VLOG(3) << "where is not supported when TensorRT < 8.4";
1495+
return false;
1496+
#endif
1497+
1498+
op->set_attribute(kCanRunTrtAttr, rewriter.bool_attr(true));
1499+
return true;
1500+
}
1501+
};
1502+
14781503
class FullWithTensorPattern
14791504
: public pir::OpRewritePattern<paddle::dialect::FullWithTensorOp> {
14801505
public:
@@ -1666,6 +1691,7 @@ class TrtOpMarkerPass : public pir::PatternRewritePass {
16661691
ps.Add(std::make_unique<NearestInterV2Pattern>(context));
16671692
ps.Add(std::make_unique<StackOpPattern>(context));
16681693
ps.Add(std::make_unique<TanhOpPattern>(context));
1694+
ps.Add(std::make_unique<WherePattern>(context));
16691695
ps.Add(std::make_unique<FullWithTensorPattern>(context));
16701696
ps.Add(std::make_unique<StridedSliceOpPattern>(context));
16711697
ps.Add(std::make_unique<TopkOpPattern>(context));

python/paddle/tensorrt/impls/search.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,17 @@ def argmax_converter(network, paddle_op, inputs):
6666
return squeeze_layer.get_output(0)
6767

6868

69+
@converter_registry.register("pd_op.where", trt_version="8.x")
70+
def where_converter(network, paddle_op, inputs):
71+
condition = inputs[0]
72+
x = inputs[1]
73+
y = inputs[2]
74+
75+
select_layer = network.add_select(condition, x, y)
76+
77+
return select_layer.get_output(0)
78+
79+
6980
@converter_registry.register("pd_op.topk", trt_version="8.x")
7081
def topk_converter(network, paddle_op, inputs):
7182
input_tensor = inputs[0]

test/tensorrt/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,5 +23,7 @@ if(NOT WIN32 AND TENSORRT_FOUND)
2323
set_tests_properties(test_converter_creation PROPERTIES TIMEOUT "300")
2424
set_tests_properties(test_converter_attribute PROPERTIES TIMEOUT "300")
2525
set_tests_properties(test_converter_common PROPERTIES TIMEOUT "300")
26+
set_tests_properties(test_converter_search PROPERTIES TIMEOUT "300")
2627
set_tests_properties(test_converter_logic PROPERTIES TIMEOUT "300")
28+
2729
endif()

test/tensorrt/test_converter_search.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,38 @@ def test_trt_result(self):
3535
self.check_trt_result()
3636

3737

38+
class TestWhereTRTPatternCase1(TensorRTBaseTest):
39+
def setUp(self):
40+
self.python_api = paddle.where
41+
self.api_args = {
42+
"condition": np.random.choice([True, False], size=(2, 3)),
43+
"x": np.random.randn(2, 3).astype("float32"),
44+
"y": np.random.randn(2, 3).astype("float32"),
45+
}
46+
self.program_config = {"feed_list": ["condition", "x", "y"]}
47+
self.min_shape = {"condition": [1, 3], "x": [1, 3], "y": [1, 3]}
48+
self.max_shape = {"condition": [5, 3], "x": [5, 3], "y": [5, 3]}
49+
50+
def test_trt_result(self):
51+
self.check_trt_result()
52+
53+
54+
class TestWhereTRTPatternCase2(TensorRTBaseTest):
55+
def setUp(self):
56+
self.python_api = paddle.where
57+
self.api_args = {
58+
"condition": np.random.choice([True, False], size=(2, 3)),
59+
"x": np.random.randn(2, 3).astype("int64"),
60+
"y": np.random.randn(2, 3).astype("int64"),
61+
}
62+
self.program_config = {"feed_list": ["condition", "x", "y"]}
63+
self.min_shape = {"condition": [1, 3], "x": [1, 3], "y": [1, 3]}
64+
self.max_shape = {"condition": [5, 3], "x": [5, 3], "y": [5, 3]}
65+
66+
def test_trt_result(self):
67+
self.check_trt_result()
68+
69+
3870
class TestTopkCase1TRTPattern(TensorRTBaseTest):
3971
def setUp(self):
4072
self.python_api = paddle.topk

0 commit comments

Comments
 (0)