Skip to content

Commit 70ee2f0

Browse files
committed
modify graph_pattern to thread_local
1 parent 26187c2 commit 70ee2f0

File tree

3 files changed

+268
-129
lines changed

3 files changed

+268
-129
lines changed

paddle/fluid/framework/ir/graph_pattern_detector.cc

+103-43
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,17 @@ using string::Style;
2828

2929
size_t PDPattern::id_ = 0UL;
3030

31+
#ifdef PADDLE_WITH_TENSORRT
32+
namespace patterns {
33+
thread_local std::unordered_map<std::string, size_t> KeyCounter::dic_;
34+
}
35+
#endif
36+
3137
PDNode *PDPattern::NewNode(const std::string &name) {
3238
if (!name.empty()) {
3339
PADDLE_ENFORCE_EQ(
34-
node_map_.count(name), 0UL,
40+
node_map_.count(name),
41+
0UL,
3542
platform::errors::PreconditionNotMet(
3643
"PDNode's name should be unique, get duplicate [%s]", name));
3744
}
@@ -45,7 +52,8 @@ PDNode *PDPattern::NewNode(const std::string &name) {
4552
PDNode *PDPattern::NewNode(PDNode::teller_t &&teller, const std::string &name) {
4653
if (!name.empty()) {
4754
PADDLE_ENFORCE_EQ(
48-
node_map_.count(name), 0UL,
55+
node_map_.count(name),
56+
0UL,
4957
platform::errors::PreconditionNotMet(
5058
"PDNode's name should be unique, get duplicate [%s]", name));
5159
}
@@ -70,8 +78,10 @@ void PDPattern::AddEdge(PDNode *a, PDNode *b) {
7078
a, platform::errors::NotFound("PDNode %s is not found.", a->name()));
7179
PADDLE_ENFORCE_NOT_NULL(
7280
b, platform::errors::NotFound("PDNode %s is not found.", b->name()));
73-
PADDLE_ENFORCE_NE(a, b, platform::errors::PermissionDenied(
74-
"Cannot connect the same node in the graph."));
81+
PADDLE_ENFORCE_NE(a,
82+
b,
83+
platform::errors::PermissionDenied(
84+
"Cannot connect the same node in the graph."));
7585
edges_.emplace_back(a, b);
7686
}
7787

@@ -128,7 +138,8 @@ void GraphPatternDetector::ValidateByNodeRole(
128138

129139
subgraphs->erase(
130140
std::remove_if(
131-
subgraphs->begin(), subgraphs->end(),
141+
subgraphs->begin(),
142+
subgraphs->end(),
132143
[](const GraphPatternDetector::subgraph_t &subgraph) -> bool {
133144
// Collect the inputs and outputs.
134145
std::set<Node *> ios;
@@ -310,7 +321,8 @@ void GraphPatternDetector::SortSubgraphs(
310321
}
311322

312323
std::sort(
313-
subgraphs->begin(), subgraphs->end(),
324+
subgraphs->begin(),
325+
subgraphs->end(),
314326
[](const GraphPatternDetector::subgraph_t &a,
315327
const GraphPatternDetector::subgraph_t &b) {
316328
for (auto &item : a) {
@@ -438,7 +450,8 @@ PDNode *PDNode::assert_is_persistable_var() {
438450
}
439451

440452
PDNode *PDNode::assert_is_op_nth_input(const std::string &op_type,
441-
const std::string &argument, int nth) {
453+
const std::string &argument,
454+
int nth) {
442455
assert_is_var();
443456
assert_is_op_input(op_type);
444457
asserts_.emplace_back([=](Node *x) {
@@ -453,7 +466,8 @@ PDNode *PDNode::assert_is_op_nth_input(const std::string &op_type,
453466
}
454467

455468
PDNode *PDNode::assert_is_op_nth_output(const std::string &op_type,
456-
const std::string &argument, int nth) {
469+
const std::string &argument,
470+
int nth) {
457471
assert_is_var();
458472
asserts_.emplace_back([=](Node *x) {
459473
for (auto *op : x->inputs) {
@@ -580,7 +594,8 @@ PDNode *PDNode::assert_is_ops(const std::unordered_set<std::string> &op_types) {
580594

581595
PDNode *PDNode::assert_is_ops_nth_input(
582596
const std::unordered_set<std::string> &op_types,
583-
const std::string &argument, int nth) {
597+
const std::string &argument,
598+
int nth) {
584599
assert_is_var();
585600
assert_is_ops_input(op_types);
586601
asserts_.emplace_back([=](Node *x) {
@@ -596,7 +611,8 @@ PDNode *PDNode::assert_is_ops_nth_input(
596611

597612
PDNode *PDNode::assert_is_ops_nth_output(
598613
const std::unordered_set<std::string> &op_types,
599-
const std::string &argument, int nth) {
614+
const std::string &argument,
615+
int nth) {
600616
assert_is_var();
601617
asserts_.emplace_back([=](Node *x) {
602618
for (auto *op : x->inputs) {
@@ -693,11 +709,13 @@ bool VarLinksToOp(Node *node, const std::string &op_type) {
693709

694710
bool IsNthInput(Node *var, Node *op, const std::string &argument, size_t nth) {
695711
PADDLE_ENFORCE_EQ(
696-
var->IsVar(), true,
712+
var->IsVar(),
713+
true,
697714
platform::errors::InvalidArgument(
698715
"First parameter of function IsNthInput must be Node::Var"));
699716
PADDLE_ENFORCE_EQ(
700-
op->IsOp(), true,
717+
op->IsOp(),
718+
true,
701719
platform::errors::InvalidArgument(
702720
"Second parameter of function IsNthInput must be Node::Op"));
703721
if (!HasInput(op, argument) || op->Op()->Input(argument).size() <= nth)
@@ -707,7 +725,8 @@ bool IsNthInput(Node *var, Node *op, const std::string &argument, size_t nth) {
707725

708726
bool HasInput(Node *op, const std::string &argument) {
709727
PADDLE_ENFORCE_EQ(
710-
op->IsOp(), true,
728+
op->IsOp(),
729+
true,
711730
platform::errors::InvalidArgument(
712731
"First parameter of function HasInput must be Node::Op"));
713732
auto const &names = op->Op()->InputNames();
@@ -718,7 +737,8 @@ bool HasInput(Node *op, const std::string &argument) {
718737

719738
bool HasOutput(Node *op, const std::string &argument) {
720739
PADDLE_ENFORCE_EQ(
721-
op->IsOp(), true,
740+
op->IsOp(),
741+
true,
722742
platform::errors::InvalidArgument(
723743
"First parameter of function HasOuput must be Node::Op"));
724744
auto const &names = op->Op()->OutputNames();
@@ -729,11 +749,13 @@ bool HasOutput(Node *op, const std::string &argument) {
729749

730750
bool IsNthOutput(Node *var, Node *op, const std::string &argument, size_t nth) {
731751
PADDLE_ENFORCE_EQ(
732-
var->IsVar(), true,
752+
var->IsVar(),
753+
true,
733754
platform::errors::InvalidArgument(
734755
"First parameter of function IsNthOutput must be Node::Var"));
735756
PADDLE_ENFORCE_EQ(
736-
op->IsOp(), true,
757+
op->IsOp(),
758+
true,
737759
platform::errors::InvalidArgument(
738760
"Second parameter of function IsNthOutput must be Node::Op"));
739761
if (!HasOutput(op, argument) || op->Op()->Output(argument).size() <= nth)
@@ -875,22 +897,35 @@ PDNode *patterns::ConvBN::operator()(paddle::framework::ir::PDNode *conv_input,
875897
eltwise_op->LinksFrom({conv_out_var, eltwise_y_in_var})
876898
.LinksTo({eltwise_out_var});
877899
batch_norm_op
878-
->LinksFrom({eltwise_out_var, bn_scale_var, bn_bias_var, bn_mean_var,
900+
->LinksFrom({eltwise_out_var,
901+
bn_scale_var,
902+
bn_bias_var,
903+
bn_mean_var,
879904
bn_variance_var})
880-
.LinksTo({bn_out_var, bn_mean_out_var, bn_variance_out_var,
881-
bn_saved_mean_var, bn_saved_variance_var});
905+
.LinksTo({bn_out_var,
906+
bn_mean_out_var,
907+
bn_variance_out_var,
908+
bn_saved_mean_var,
909+
bn_saved_variance_var});
882910
} else {
883911
batch_norm_op
884-
->LinksFrom({conv_out_var, bn_scale_var, bn_bias_var, bn_mean_var,
912+
->LinksFrom({conv_out_var,
913+
bn_scale_var,
914+
bn_bias_var,
915+
bn_mean_var,
885916
bn_variance_var})
886-
.LinksTo({bn_out_var, bn_mean_out_var, bn_variance_out_var,
887-
bn_saved_mean_var, bn_saved_variance_var});
917+
.LinksTo({bn_out_var,
918+
bn_mean_out_var,
919+
bn_variance_out_var,
920+
bn_saved_mean_var,
921+
bn_saved_variance_var});
888922
}
889923
return bn_out_var;
890924
}
891925

892926
PDNode *patterns::ConvActivation::operator()(
893-
paddle::framework::ir::PDNode *conv_input, std::string conv_type,
927+
paddle::framework::ir::PDNode *conv_input,
928+
std::string conv_type,
894929
std::string activation_type) {
895930
// Create Operators
896931
conv_input->assert_is_op_input(conv_type, "Input");
@@ -920,7 +955,8 @@ PDNode *patterns::ConvActivation::operator()(
920955

921956
PDNode *patterns::ElementwiseActivation::operator()(
922957
paddle::framework::ir::PDNode *elementwise_a,
923-
const std::string &elementwise_type, const std::string &activation_type) {
958+
const std::string &elementwise_type,
959+
const std::string &activation_type) {
924960
// Create Operators
925961
elementwise_a->assert_is_op_input(elementwise_type, "X");
926962
auto *elementwise_op =
@@ -995,7 +1031,8 @@ PDNode *patterns::SeqConvEltAddRelu::operator()(
9951031
}
9961032

9971033
PDNode *patterns::FC::operator()(paddle::framework::ir::PDNode *x,
998-
bool with_bias, bool with_relu) {
1034+
bool with_bias,
1035+
bool with_relu) {
9991036
// Create shared nodes.
10001037
x->assert_is_op_input("mul", "X");
10011038
auto *mul = pattern->NewNode(mul_repr())->assert_is_op("mul");
@@ -1261,8 +1298,12 @@ PDNode *patterns::BatchNormAct::operator()(
12611298

12621299
bn->LinksFrom(
12631300
{bn_x_var, bn_scale_var, bn_bias_var, bn_variance_var, bn_mean_var})
1264-
.LinksTo({bn_mean_out_var, bn_variance_out_var, bn_saved_variance_var,
1265-
bn_saved_mean_var, bn_reserve_space, bn_out_var});
1301+
.LinksTo({bn_mean_out_var,
1302+
bn_variance_out_var,
1303+
bn_saved_variance_var,
1304+
bn_saved_mean_var,
1305+
bn_reserve_space,
1306+
bn_out_var});
12661307
act->LinksFrom({bn_out_var}).LinksTo({act_out_var});
12671308

12681309
return act_out_var;
@@ -1319,8 +1360,13 @@ PDNode *patterns::BatchNormActGrad::operator()(
13191360
.LinksTo({d_intermediate_var});
13201361

13211362
bn_grad
1322-
->LinksFrom({bn_x_var, d_intermediate_var, bn_scale_var, bn_bias_var,
1323-
bn_saved_mean_var, bn_saved_variance_var, bn_reserve_space})
1363+
->LinksFrom({bn_x_var,
1364+
d_intermediate_var,
1365+
bn_scale_var,
1366+
bn_bias_var,
1367+
bn_saved_mean_var,
1368+
bn_saved_variance_var,
1369+
bn_reserve_space})
13241370
.LinksTo({d_bn_x_var, d_bn_scale_var, d_bn_bias_var});
13251371

13261372
return bn_grad;
@@ -1404,8 +1450,12 @@ PDNode *patterns::BatchNormAddAct::operator()(
14041450
pattern->NewNode(act_out_repr())->assert_is_ops_output(act_types, "Out");
14051451

14061452
bn->LinksFrom({bn_x_var, bn_scale_var, bn_bias_var})
1407-
.LinksTo({bn_mean_out_var, bn_variance_out_var, bn_saved_variance_var,
1408-
bn_saved_mean_var, bn_reserve_space, bn_out_var});
1453+
.LinksTo({bn_mean_out_var,
1454+
bn_variance_out_var,
1455+
bn_saved_variance_var,
1456+
bn_saved_mean_var,
1457+
bn_reserve_space,
1458+
bn_out_var});
14091459
elewise_add->LinksFrom({elewise_add_in_var, bn_out_var})
14101460
.LinksTo({elewise_add_out_var});
14111461
act->LinksFrom({elewise_add_out_var}).LinksTo({act_out_var});
@@ -1484,8 +1534,13 @@ PDNode *patterns::BatchNormAddActGrad::operator()(
14841534
.LinksTo({d_elewise_add_in_var, d_bn_out_var});
14851535

14861536
bn_grad
1487-
->LinksFrom({bn_x_var, d_bn_out_var, bn_scale_var, bn_bias_var,
1488-
bn_saved_mean_var, bn_saved_variance_var, bn_reserve_space})
1537+
->LinksFrom({bn_x_var,
1538+
d_bn_out_var,
1539+
bn_scale_var,
1540+
bn_bias_var,
1541+
bn_saved_mean_var,
1542+
bn_saved_variance_var,
1543+
bn_reserve_space})
14891544
.LinksTo({d_bn_x_var, d_bn_scale_var, d_bn_bias_var});
14901545

14911546
return bn_grad;
@@ -1558,7 +1613,8 @@ PDNode *patterns::ElewiseAddAct::operator()(
15581613

15591614
PDNode *patterns::LinearAct::operator()(
15601615
paddle::framework::ir::PDNode *linear_x_var,
1561-
const std::unordered_set<std::string> &act_types, bool with_grad_link,
1616+
const std::unordered_set<std::string> &act_types,
1617+
bool with_grad_link,
15621618
bool is_act_grad_x_from_act) {
15631619
auto *matmul_w_var =
15641620
pattern->NewNode(matmul_w_repr())->assert_is_op_input("matmul_v2", "Y");
@@ -1621,7 +1677,8 @@ PDNode *patterns::LinearAct::operator()(
16211677
PDNode *patterns::ElewiseAddMatmulAct::operator()(
16221678
paddle::framework::ir::PDNode *dout_var,
16231679
const std::unordered_set<std::string> &act_grad_types,
1624-
bool without_x_gradient, bool is_act_grad_x_from_act) {
1680+
bool without_x_gradient,
1681+
bool is_act_grad_x_from_act) {
16251682
auto *ele_grad_bias_var =
16261683
pattern->NewNode(ele_grad_bias_repr())
16271684
->assert_is_op_input("elementwise_add_grad", "Y");
@@ -2052,7 +2109,8 @@ PDNode *patterns::Pool::operator()() {
20522109
return output_var;
20532110
}
20542111

2055-
PDNode *patterns::Elementwise::operator()(PDNode *x_var, PDNode *y_var,
2112+
PDNode *patterns::Elementwise::operator()(PDNode *x_var,
2113+
PDNode *y_var,
20562114
const std::string elementwise_type) {
20572115
auto elementwise_op =
20582116
pattern->NewNode(elementwise_op_repr())->assert_is_op(elementwise_type);
@@ -2084,7 +2142,9 @@ PDNode *patterns::ElementwiseOp::operator()(
20842142
}
20852143

20862144
PDNode *patterns::ResidualElementwise::operator()(
2087-
PDNode *op_var, PDNode *residual_var, const std::string elementwise_type,
2145+
PDNode *op_var,
2146+
PDNode *residual_var,
2147+
const std::string elementwise_type,
20882148
bool as_x) {
20892149
auto elementwise_op =
20902150
pattern->NewNode(elementwise_op_repr())->assert_is_op(elementwise_type);
@@ -3065,7 +3125,8 @@ void patterns::DeleteQuantDequantLinearOpPattern::operator()() {
30653125
}
30663126

30673127
PDNode *patterns::ReshapeTransposeMatmulPattern::operator()(
3068-
const std::string &op_name, bool with_reshape_xshape,
3128+
const std::string &op_name,
3129+
bool with_reshape_xshape,
30693130
bool with_transpose_xshape) {
30703131
auto reshape_op =
30713132
pattern->NewNode(reshape_op_repr())->assert_is_op("reshape2");
@@ -3098,11 +3159,10 @@ PDNode *patterns::ReshapeTransposeMatmulPattern::operator()(
30983159
transpose_out->assert_is_only_output_of_op("transpose2");
30993160

31003161
auto transpose_xshape =
3101-
with_transpose_xshape
3102-
? pattern->NewNode(transpose_xshape_repr())
3103-
->AsIntermediate()
3104-
->assert_is_op_output("transpose2", "XShape")
3105-
: nullptr;
3162+
with_transpose_xshape ? pattern->NewNode(transpose_xshape_repr())
3163+
->AsIntermediate()
3164+
->assert_is_op_output("transpose2", "XShape")
3165+
: nullptr;
31063166

31073167
auto matmul_out = pattern->NewNode(matmul_out_repr())
31083168
->AsOutput()

0 commit comments

Comments
 (0)