@@ -28,10 +28,17 @@ using string::Style;
28
28
29
29
size_t PDPattern::id_ = 0UL ;
30
30
31
+ #ifdef PADDLE_WITH_TENSORRT
32
+ namespace patterns {
33
+ thread_local std::unordered_map<std::string, size_t > KeyCounter::dic_;
34
+ }
35
+ #endif
36
+
31
37
PDNode *PDPattern::NewNode (const std::string &name) {
32
38
if (!name.empty ()) {
33
39
PADDLE_ENFORCE_EQ (
34
- node_map_.count (name), 0UL ,
40
+ node_map_.count (name),
41
+ 0UL ,
35
42
platform::errors::PreconditionNotMet (
36
43
" PDNode's name should be unique, get duplicate [%s]" , name));
37
44
}
@@ -45,7 +52,8 @@ PDNode *PDPattern::NewNode(const std::string &name) {
45
52
PDNode *PDPattern::NewNode (PDNode::teller_t &&teller, const std::string &name) {
46
53
if (!name.empty ()) {
47
54
PADDLE_ENFORCE_EQ (
48
- node_map_.count (name), 0UL ,
55
+ node_map_.count (name),
56
+ 0UL ,
49
57
platform::errors::PreconditionNotMet (
50
58
" PDNode's name should be unique, get duplicate [%s]" , name));
51
59
}
@@ -70,8 +78,10 @@ void PDPattern::AddEdge(PDNode *a, PDNode *b) {
70
78
a, platform::errors::NotFound (" PDNode %s is not found." , a->name ()));
71
79
PADDLE_ENFORCE_NOT_NULL (
72
80
b, platform::errors::NotFound (" PDNode %s is not found." , b->name ()));
73
- PADDLE_ENFORCE_NE (a, b, platform::errors::PermissionDenied (
74
- " Cannot connect the same node in the graph." ));
81
+ PADDLE_ENFORCE_NE (a,
82
+ b,
83
+ platform::errors::PermissionDenied (
84
+ " Cannot connect the same node in the graph." ));
75
85
edges_.emplace_back (a, b);
76
86
}
77
87
@@ -128,7 +138,8 @@ void GraphPatternDetector::ValidateByNodeRole(
128
138
129
139
subgraphs->erase (
130
140
std::remove_if (
131
- subgraphs->begin (), subgraphs->end (),
141
+ subgraphs->begin (),
142
+ subgraphs->end (),
132
143
[](const GraphPatternDetector::subgraph_t &subgraph) -> bool {
133
144
// Collect the inputs and outputs.
134
145
std::set<Node *> ios;
@@ -310,7 +321,8 @@ void GraphPatternDetector::SortSubgraphs(
310
321
}
311
322
312
323
std::sort (
313
- subgraphs->begin (), subgraphs->end (),
324
+ subgraphs->begin (),
325
+ subgraphs->end (),
314
326
[](const GraphPatternDetector::subgraph_t &a,
315
327
const GraphPatternDetector::subgraph_t &b) {
316
328
for (auto &item : a) {
@@ -438,7 +450,8 @@ PDNode *PDNode::assert_is_persistable_var() {
438
450
}
439
451
440
452
PDNode *PDNode::assert_is_op_nth_input (const std::string &op_type,
441
- const std::string &argument, int nth) {
453
+ const std::string &argument,
454
+ int nth) {
442
455
assert_is_var ();
443
456
assert_is_op_input (op_type);
444
457
asserts_.emplace_back ([=](Node *x) {
@@ -453,7 +466,8 @@ PDNode *PDNode::assert_is_op_nth_input(const std::string &op_type,
453
466
}
454
467
455
468
PDNode *PDNode::assert_is_op_nth_output (const std::string &op_type,
456
- const std::string &argument, int nth) {
469
+ const std::string &argument,
470
+ int nth) {
457
471
assert_is_var ();
458
472
asserts_.emplace_back ([=](Node *x) {
459
473
for (auto *op : x->inputs ) {
@@ -580,7 +594,8 @@ PDNode *PDNode::assert_is_ops(const std::unordered_set<std::string> &op_types) {
580
594
581
595
PDNode *PDNode::assert_is_ops_nth_input (
582
596
const std::unordered_set<std::string> &op_types,
583
- const std::string &argument, int nth) {
597
+ const std::string &argument,
598
+ int nth) {
584
599
assert_is_var ();
585
600
assert_is_ops_input (op_types);
586
601
asserts_.emplace_back ([=](Node *x) {
@@ -596,7 +611,8 @@ PDNode *PDNode::assert_is_ops_nth_input(
596
611
597
612
PDNode *PDNode::assert_is_ops_nth_output (
598
613
const std::unordered_set<std::string> &op_types,
599
- const std::string &argument, int nth) {
614
+ const std::string &argument,
615
+ int nth) {
600
616
assert_is_var ();
601
617
asserts_.emplace_back ([=](Node *x) {
602
618
for (auto *op : x->inputs ) {
@@ -693,11 +709,13 @@ bool VarLinksToOp(Node *node, const std::string &op_type) {
693
709
694
710
bool IsNthInput (Node *var, Node *op, const std::string &argument, size_t nth) {
695
711
PADDLE_ENFORCE_EQ (
696
- var->IsVar (), true ,
712
+ var->IsVar (),
713
+ true ,
697
714
platform::errors::InvalidArgument (
698
715
" First parameter of function IsNthInput must be Node::Var" ));
699
716
PADDLE_ENFORCE_EQ (
700
- op->IsOp (), true ,
717
+ op->IsOp (),
718
+ true ,
701
719
platform::errors::InvalidArgument (
702
720
" Second parameter of function IsNthInput must be Node::Op" ));
703
721
if (!HasInput (op, argument) || op->Op ()->Input (argument).size () <= nth)
@@ -707,7 +725,8 @@ bool IsNthInput(Node *var, Node *op, const std::string &argument, size_t nth) {
707
725
708
726
bool HasInput (Node *op, const std::string &argument) {
709
727
PADDLE_ENFORCE_EQ (
710
- op->IsOp (), true ,
728
+ op->IsOp (),
729
+ true ,
711
730
platform::errors::InvalidArgument (
712
731
" First parameter of function HasInput must be Node::Op" ));
713
732
auto const &names = op->Op ()->InputNames ();
@@ -718,7 +737,8 @@ bool HasInput(Node *op, const std::string &argument) {
718
737
719
738
bool HasOutput (Node *op, const std::string &argument) {
720
739
PADDLE_ENFORCE_EQ (
721
- op->IsOp (), true ,
740
+ op->IsOp (),
741
+ true ,
722
742
platform::errors::InvalidArgument (
723
743
" First parameter of function HasOuput must be Node::Op" ));
724
744
auto const &names = op->Op ()->OutputNames ();
@@ -729,11 +749,13 @@ bool HasOutput(Node *op, const std::string &argument) {
729
749
730
750
bool IsNthOutput (Node *var, Node *op, const std::string &argument, size_t nth) {
731
751
PADDLE_ENFORCE_EQ (
732
- var->IsVar (), true ,
752
+ var->IsVar (),
753
+ true ,
733
754
platform::errors::InvalidArgument (
734
755
" First parameter of function IsNthOutput must be Node::Var" ));
735
756
PADDLE_ENFORCE_EQ (
736
- op->IsOp (), true ,
757
+ op->IsOp (),
758
+ true ,
737
759
platform::errors::InvalidArgument (
738
760
" Second parameter of function IsNthOutput must be Node::Op" ));
739
761
if (!HasOutput (op, argument) || op->Op ()->Output (argument).size () <= nth)
@@ -875,22 +897,35 @@ PDNode *patterns::ConvBN::operator()(paddle::framework::ir::PDNode *conv_input,
875
897
eltwise_op->LinksFrom ({conv_out_var, eltwise_y_in_var})
876
898
.LinksTo ({eltwise_out_var});
877
899
batch_norm_op
878
- ->LinksFrom ({eltwise_out_var, bn_scale_var, bn_bias_var, bn_mean_var,
900
+ ->LinksFrom ({eltwise_out_var,
901
+ bn_scale_var,
902
+ bn_bias_var,
903
+ bn_mean_var,
879
904
bn_variance_var})
880
- .LinksTo ({bn_out_var, bn_mean_out_var, bn_variance_out_var,
881
- bn_saved_mean_var, bn_saved_variance_var});
905
+ .LinksTo ({bn_out_var,
906
+ bn_mean_out_var,
907
+ bn_variance_out_var,
908
+ bn_saved_mean_var,
909
+ bn_saved_variance_var});
882
910
} else {
883
911
batch_norm_op
884
- ->LinksFrom ({conv_out_var, bn_scale_var, bn_bias_var, bn_mean_var,
912
+ ->LinksFrom ({conv_out_var,
913
+ bn_scale_var,
914
+ bn_bias_var,
915
+ bn_mean_var,
885
916
bn_variance_var})
886
- .LinksTo ({bn_out_var, bn_mean_out_var, bn_variance_out_var,
887
- bn_saved_mean_var, bn_saved_variance_var});
917
+ .LinksTo ({bn_out_var,
918
+ bn_mean_out_var,
919
+ bn_variance_out_var,
920
+ bn_saved_mean_var,
921
+ bn_saved_variance_var});
888
922
}
889
923
return bn_out_var;
890
924
}
891
925
892
926
PDNode *patterns::ConvActivation::operator ()(
893
- paddle::framework::ir::PDNode *conv_input, std::string conv_type,
927
+ paddle::framework::ir::PDNode *conv_input,
928
+ std::string conv_type,
894
929
std::string activation_type) {
895
930
// Create Operators
896
931
conv_input->assert_is_op_input (conv_type, " Input" );
@@ -920,7 +955,8 @@ PDNode *patterns::ConvActivation::operator()(
920
955
921
956
PDNode *patterns::ElementwiseActivation::operator ()(
922
957
paddle::framework::ir::PDNode *elementwise_a,
923
- const std::string &elementwise_type, const std::string &activation_type) {
958
+ const std::string &elementwise_type,
959
+ const std::string &activation_type) {
924
960
// Create Operators
925
961
elementwise_a->assert_is_op_input (elementwise_type, " X" );
926
962
auto *elementwise_op =
@@ -995,7 +1031,8 @@ PDNode *patterns::SeqConvEltAddRelu::operator()(
995
1031
}
996
1032
997
1033
PDNode *patterns::FC::operator ()(paddle::framework::ir::PDNode *x,
998
- bool with_bias, bool with_relu) {
1034
+ bool with_bias,
1035
+ bool with_relu) {
999
1036
// Create shared nodes.
1000
1037
x->assert_is_op_input (" mul" , " X" );
1001
1038
auto *mul = pattern->NewNode (mul_repr ())->assert_is_op (" mul" );
@@ -1261,8 +1298,12 @@ PDNode *patterns::BatchNormAct::operator()(
1261
1298
1262
1299
bn->LinksFrom (
1263
1300
{bn_x_var, bn_scale_var, bn_bias_var, bn_variance_var, bn_mean_var})
1264
- .LinksTo ({bn_mean_out_var, bn_variance_out_var, bn_saved_variance_var,
1265
- bn_saved_mean_var, bn_reserve_space, bn_out_var});
1301
+ .LinksTo ({bn_mean_out_var,
1302
+ bn_variance_out_var,
1303
+ bn_saved_variance_var,
1304
+ bn_saved_mean_var,
1305
+ bn_reserve_space,
1306
+ bn_out_var});
1266
1307
act->LinksFrom ({bn_out_var}).LinksTo ({act_out_var});
1267
1308
1268
1309
return act_out_var;
@@ -1319,8 +1360,13 @@ PDNode *patterns::BatchNormActGrad::operator()(
1319
1360
.LinksTo ({d_intermediate_var});
1320
1361
1321
1362
bn_grad
1322
- ->LinksFrom ({bn_x_var, d_intermediate_var, bn_scale_var, bn_bias_var,
1323
- bn_saved_mean_var, bn_saved_variance_var, bn_reserve_space})
1363
+ ->LinksFrom ({bn_x_var,
1364
+ d_intermediate_var,
1365
+ bn_scale_var,
1366
+ bn_bias_var,
1367
+ bn_saved_mean_var,
1368
+ bn_saved_variance_var,
1369
+ bn_reserve_space})
1324
1370
.LinksTo ({d_bn_x_var, d_bn_scale_var, d_bn_bias_var});
1325
1371
1326
1372
return bn_grad;
@@ -1404,8 +1450,12 @@ PDNode *patterns::BatchNormAddAct::operator()(
1404
1450
pattern->NewNode (act_out_repr ())->assert_is_ops_output (act_types, " Out" );
1405
1451
1406
1452
bn->LinksFrom ({bn_x_var, bn_scale_var, bn_bias_var})
1407
- .LinksTo ({bn_mean_out_var, bn_variance_out_var, bn_saved_variance_var,
1408
- bn_saved_mean_var, bn_reserve_space, bn_out_var});
1453
+ .LinksTo ({bn_mean_out_var,
1454
+ bn_variance_out_var,
1455
+ bn_saved_variance_var,
1456
+ bn_saved_mean_var,
1457
+ bn_reserve_space,
1458
+ bn_out_var});
1409
1459
elewise_add->LinksFrom ({elewise_add_in_var, bn_out_var})
1410
1460
.LinksTo ({elewise_add_out_var});
1411
1461
act->LinksFrom ({elewise_add_out_var}).LinksTo ({act_out_var});
@@ -1484,8 +1534,13 @@ PDNode *patterns::BatchNormAddActGrad::operator()(
1484
1534
.LinksTo ({d_elewise_add_in_var, d_bn_out_var});
1485
1535
1486
1536
bn_grad
1487
- ->LinksFrom ({bn_x_var, d_bn_out_var, bn_scale_var, bn_bias_var,
1488
- bn_saved_mean_var, bn_saved_variance_var, bn_reserve_space})
1537
+ ->LinksFrom ({bn_x_var,
1538
+ d_bn_out_var,
1539
+ bn_scale_var,
1540
+ bn_bias_var,
1541
+ bn_saved_mean_var,
1542
+ bn_saved_variance_var,
1543
+ bn_reserve_space})
1489
1544
.LinksTo ({d_bn_x_var, d_bn_scale_var, d_bn_bias_var});
1490
1545
1491
1546
return bn_grad;
@@ -1558,7 +1613,8 @@ PDNode *patterns::ElewiseAddAct::operator()(
1558
1613
1559
1614
PDNode *patterns::LinearAct::operator ()(
1560
1615
paddle::framework::ir::PDNode *linear_x_var,
1561
- const std::unordered_set<std::string> &act_types, bool with_grad_link,
1616
+ const std::unordered_set<std::string> &act_types,
1617
+ bool with_grad_link,
1562
1618
bool is_act_grad_x_from_act) {
1563
1619
auto *matmul_w_var =
1564
1620
pattern->NewNode (matmul_w_repr ())->assert_is_op_input (" matmul_v2" , " Y" );
@@ -1621,7 +1677,8 @@ PDNode *patterns::LinearAct::operator()(
1621
1677
PDNode *patterns::ElewiseAddMatmulAct::operator ()(
1622
1678
paddle::framework::ir::PDNode *dout_var,
1623
1679
const std::unordered_set<std::string> &act_grad_types,
1624
- bool without_x_gradient, bool is_act_grad_x_from_act) {
1680
+ bool without_x_gradient,
1681
+ bool is_act_grad_x_from_act) {
1625
1682
auto *ele_grad_bias_var =
1626
1683
pattern->NewNode (ele_grad_bias_repr ())
1627
1684
->assert_is_op_input (" elementwise_add_grad" , " Y" );
@@ -2052,7 +2109,8 @@ PDNode *patterns::Pool::operator()() {
2052
2109
return output_var;
2053
2110
}
2054
2111
2055
- PDNode *patterns::Elementwise::operator ()(PDNode *x_var, PDNode *y_var,
2112
+ PDNode *patterns::Elementwise::operator ()(PDNode *x_var,
2113
+ PDNode *y_var,
2056
2114
const std::string elementwise_type) {
2057
2115
auto elementwise_op =
2058
2116
pattern->NewNode (elementwise_op_repr ())->assert_is_op (elementwise_type);
@@ -2084,7 +2142,9 @@ PDNode *patterns::ElementwiseOp::operator()(
2084
2142
}
2085
2143
2086
2144
PDNode *patterns::ResidualElementwise::operator ()(
2087
- PDNode *op_var, PDNode *residual_var, const std::string elementwise_type,
2145
+ PDNode *op_var,
2146
+ PDNode *residual_var,
2147
+ const std::string elementwise_type,
2088
2148
bool as_x) {
2089
2149
auto elementwise_op =
2090
2150
pattern->NewNode (elementwise_op_repr ())->assert_is_op (elementwise_type);
@@ -3065,7 +3125,8 @@ void patterns::DeleteQuantDequantLinearOpPattern::operator()() {
3065
3125
}
3066
3126
3067
3127
PDNode *patterns::ReshapeTransposeMatmulPattern::operator ()(
3068
- const std::string &op_name, bool with_reshape_xshape,
3128
+ const std::string &op_name,
3129
+ bool with_reshape_xshape,
3069
3130
bool with_transpose_xshape) {
3070
3131
auto reshape_op =
3071
3132
pattern->NewNode (reshape_op_repr ())->assert_is_op (" reshape2" );
@@ -3098,11 +3159,10 @@ PDNode *patterns::ReshapeTransposeMatmulPattern::operator()(
3098
3159
transpose_out->assert_is_only_output_of_op (" transpose2" );
3099
3160
3100
3161
auto transpose_xshape =
3101
- with_transpose_xshape
3102
- ? pattern->NewNode (transpose_xshape_repr ())
3103
- ->AsIntermediate ()
3104
- ->assert_is_op_output (" transpose2" , " XShape" )
3105
- : nullptr ;
3162
+ with_transpose_xshape ? pattern->NewNode (transpose_xshape_repr ())
3163
+ ->AsIntermediate ()
3164
+ ->assert_is_op_output (" transpose2" , " XShape" )
3165
+ : nullptr ;
3106
3166
3107
3167
auto matmul_out = pattern->NewNode (matmul_out_repr ())
3108
3168
->AsOutput ()
0 commit comments