Skip to content

Commit 5b2a3c4

Browse files
sfraczekluotao1
authored andcommitted
Conv concat relu quantization (#17466)
* add conv_concat_relu fuse test=develop * add test code test=develop * added missing include with unordered_map test=develop * review fixes for wojtuss test=develop * remove 'should (not) be fused' comment statements one of them was invalid anyway test=develop
1 parent bccb0ba commit 5b2a3c4

7 files changed

+405
-0
lines changed

paddle/fluid/framework/ir/CMakeLists.txt

+2
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ if(WITH_MKLDNN)
8686
pass_library(conv_bias_mkldnn_fuse_pass inference mkldnn)
8787
pass_library(conv_relu_mkldnn_fuse_pass inference mkldnn)
8888
pass_library(conv_brelu_mkldnn_fuse_pass inference mkldnn)
89+
pass_library(conv_concat_relu_mkldnn_fuse_pass inference mkldnn)
8990
pass_library(conv_elementwise_add_mkldnn_fuse_pass inference mkldnn)
9091
pass_library(cpu_quantize_placement_pass base mkldnn)
9192
pass_library(cpu_quantize_pass inference mkldnn)
@@ -116,6 +117,7 @@ if (WITH_MKLDNN)
116117
cc_test(test_conv_bias_mkldnn_fuse_pass SRCS mkldnn/conv_bias_mkldnn_fuse_pass_tester.cc DEPS conv_bias_mkldnn_fuse_pass naive_executor)
117118
cc_test(test_conv_relu_mkldnn_fuse_pass SRCS mkldnn/conv_relu_mkldnn_fuse_pass_tester.cc DEPS conv_relu_mkldnn_fuse_pass)
118119
cc_test(test_conv_brelu_mkldnn_fuse_pass SRCS mkldnn/conv_brelu_mkldnn_fuse_pass_tester.cc DEPS conv_brelu_mkldnn_fuse_pass)
120+
cc_test(test_conv_concat_relu_mkldnn_fuse_pass SRCS mkldnn/conv_concat_relu_mkldnn_fuse_pass_tester.cc DEPS conv_concat_relu_mkldnn_fuse_pass)
119121
cc_test(test_conv_elementwise_add_mkldnn_fuse_pass SRCS mkldnn/conv_elementwise_add_mkldnn_fuse_pass_tester.cc DEPS conv_elementwise_add_mkldnn_fuse_pass)
120122
cc_test(test_mkldnn_placement_pass SRCS mkldnn/mkldnn_placement_pass_tester.cc DEPS mkldnn_placement_pass)
121123
cc_test(test_cpu_quantize_placement_pass SRCS mkldnn/cpu_quantize_placement_pass_tester.cc DEPS cpu_quantize_placement_pass)

paddle/fluid/framework/ir/graph_pattern_detector.cc

+40
Original file line numberDiff line numberDiff line change
@@ -1184,6 +1184,46 @@ PDNode *patterns::ElementwiseAdd::operator()(PDNode *x_var, PDNode *y_var) {
11841184
return out_var;
11851185
}
11861186

1187+
PDNode *patterns::ConcatReLU::operator()() {
1188+
auto concat_op = pattern->NewNode(concat_op_repr())->assert_is_op("concat");
1189+
auto relu_op = pattern->NewNode(relu_op_repr())->assert_is_op("relu");
1190+
1191+
auto concat_out =
1192+
pattern->NewNode(concat_out_repr())->assert_is_op_output("concat", "Out");
1193+
1194+
auto relu_out = pattern->NewNode(relu_out_repr())
1195+
->AsOutput()
1196+
->assert_is_op_output("relu", "Out");
1197+
1198+
concat_op->LinksTo({concat_out});
1199+
relu_op->LinksFrom({concat_out}).LinksTo({relu_out});
1200+
1201+
return relu_out;
1202+
}
1203+
1204+
PDNode *patterns::ConvConcatReLU::operator()() {
1205+
auto conv_op = pattern->NewNode(conv_op_repr())->assert_is_op("conv2d");
1206+
auto concat_op = pattern->NewNode(concat_op_repr())->assert_is_op("concat");
1207+
auto relu_op = pattern->NewNode(relu_op_repr())->assert_is_op("relu");
1208+
1209+
auto conv_out = pattern->NewNode(conv_out_repr())
1210+
->assert_is_op_output("conv2d", "Output");
1211+
1212+
auto concat_out = pattern->NewNode(concat_out_repr())
1213+
->assert_is_op_output("concat", "Out")
1214+
->assert_is_op_input("relu", "X");
1215+
1216+
auto relu_out = pattern->NewNode(relu_out_repr())
1217+
->AsOutput()
1218+
->assert_is_op_output("relu", "Out");
1219+
1220+
conv_op->LinksTo({conv_out});
1221+
concat_op->LinksFrom({conv_out}).LinksTo({concat_out});
1222+
relu_op->LinksFrom({concat_out}).LinksTo({relu_out});
1223+
1224+
return relu_out;
1225+
}
1226+
11871227
std::unordered_set<std::string> conv_act_set({"identity", "relu"});
11881228

11891229
PDNode *patterns::ConvElementwiseaddAct::operator()(PDNode *conv_in) {

paddle/fluid/framework/ir/graph_pattern_detector.h

+33
Original file line numberDiff line numberDiff line change
@@ -728,6 +728,39 @@ struct ElementwiseAdd : public PatternBase {
728728
PATTERN_DECL_NODE(elementwise_add_out);
729729
};
730730

731+
// Concat + ReLU
732+
// named nodes:
733+
// concat_op, concat_out, relu_op, relu_out
734+
struct ConcatReLU : public PatternBase {
735+
ConcatReLU(PDPattern* pattern, const std::string& name_scope)
736+
: PatternBase(pattern, name_scope, "concat_relu") {}
737+
738+
PDNode* operator()();
739+
740+
PATTERN_DECL_NODE(concat_op);
741+
PATTERN_DECL_NODE(concat_out);
742+
PATTERN_DECL_NODE(relu_op);
743+
PATTERN_DECL_NODE(relu_out);
744+
};
745+
746+
// Conv + Concat + ReLU
747+
// named nodes:
748+
// conv_op, conv_out
749+
// concat_op, concat_out, relu_op, relu_out
750+
struct ConvConcatReLU : public PatternBase {
751+
ConvConcatReLU(PDPattern* pattern, const std::string& name_scope)
752+
: PatternBase(pattern, name_scope, "conv_concat_relu") {}
753+
754+
PDNode* operator()();
755+
756+
PATTERN_DECL_NODE(conv_op);
757+
PATTERN_DECL_NODE(conv_out);
758+
PATTERN_DECL_NODE(concat_op);
759+
PATTERN_DECL_NODE(concat_out);
760+
PATTERN_DECL_NODE(relu_op);
761+
PATTERN_DECL_NODE(relu_out);
762+
};
763+
731764
// Conv + ElementwiseAdd + an activation
732765
// This pattern can futher fuse the conv related ops after the conv+bn fusion.
733766
struct ConvElementwiseaddAct : public PatternBase {
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/fluid/framework/ir/mkldnn/conv_concat_relu_mkldnn_fuse_pass.h"
16+
#include <vector>
17+
#include "paddle/fluid/platform/enforce.h"
18+
19+
namespace paddle {
20+
namespace framework {
21+
namespace ir {
22+
23+
void ConvConcatReLUFusePass::FindConcatWithConvs(
24+
ir::Graph* graph,
25+
std::unordered_map<const Node*, int>* concat_with_convs_counter) const {
26+
GraphPatternDetector gpd;
27+
patterns::ConcatReLU concat_relu_pattern{gpd.mutable_pattern(),
28+
"concat_relu"};
29+
concat_relu_pattern();
30+
31+
int found_count = 0;
32+
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
33+
Graph* g) {
34+
VLOG(4) << "Find Concats with Convs";
35+
GET_IR_NODE_FROM_SUBGRAPH(concat_op, concat_op, concat_relu_pattern);
36+
GET_IR_NODE_FROM_SUBGRAPH(relu_op, relu_op, concat_relu_pattern);
37+
38+
auto concat_inputs = concat_op->inputs;
39+
40+
for (auto node : concat_inputs) {
41+
auto prev_op_node = node->inputs;
42+
PADDLE_ENFORCE_EQ(prev_op_node.size(), 1);
43+
auto* conv_op = prev_op_node[0];
44+
if (conv_op->Op()->Type() != "conv2d") return;
45+
46+
FuseOptions fuse_option = FindFuseOption(*conv_op, *relu_op);
47+
if (fuse_option == DO_NOT_FUSE) {
48+
return;
49+
}
50+
}
51+
52+
(*concat_with_convs_counter)[concat_op] = concat_inputs.size();
53+
found_count++;
54+
};
55+
gpd(graph, handler);
56+
AddStatis(found_count);
57+
}
58+
59+
void ConvConcatReLUFusePass::FuseConvConcatReLU(
60+
ir::Graph* graph,
61+
std::unordered_map<const Node*, int>* concat_with_convs_counter) const {
62+
GraphPatternDetector gpd;
63+
auto pattern = gpd.mutable_pattern();
64+
patterns::ConvConcatReLU conv_concat_relu(pattern, name_scope_);
65+
conv_concat_relu();
66+
67+
int found_count = 0;
68+
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
69+
Graph* g) {
70+
VLOG(4) << "handle ConvConcatReLU fuse";
71+
72+
GET_IR_NODE_FROM_SUBGRAPH(conv_op, conv_op, conv_concat_relu);
73+
GET_IR_NODE_FROM_SUBGRAPH(conv_out, conv_out, conv_concat_relu);
74+
GET_IR_NODE_FROM_SUBGRAPH(concat_op, concat_op, conv_concat_relu);
75+
GET_IR_NODE_FROM_SUBGRAPH(concat_out, concat_out, conv_concat_relu);
76+
GET_IR_NODE_FROM_SUBGRAPH(relu_op, relu_op, conv_concat_relu);
77+
GET_IR_NODE_FROM_SUBGRAPH(relu_out, relu_out, conv_concat_relu);
78+
79+
if (!concat_with_convs_counter->count(concat_op)) {
80+
VLOG(4) << "this concat has input from non-conv2d operator";
81+
return;
82+
}
83+
84+
// Transform Conv node into ConvReLU node.
85+
OpDesc* conv_desc = conv_op->Op();
86+
conv_desc->SetAttr("fuse_relu", true);
87+
88+
// Remove ReLU when all Convs were transformed.
89+
auto number_of_unfused_convs_left =
90+
--(*concat_with_convs_counter)[concat_op];
91+
if (number_of_unfused_convs_left == 0) {
92+
OpDesc* concat_desc = concat_op->Op();
93+
concat_desc->SetOutput("Out",
94+
std::vector<std::string>({relu_out->Name()}));
95+
GraphSafeRemoveNodes(graph, {relu_op, concat_out});
96+
IR_NODE_LINK_TO(concat_op, relu_out);
97+
}
98+
99+
found_count++;
100+
};
101+
gpd(graph, handler);
102+
AddStatis(found_count);
103+
}
104+
105+
void ConvConcatReLUFusePass::ApplyImpl(ir::Graph* graph) const {
106+
PADDLE_ENFORCE(graph);
107+
FusePassBase::Init(name_scope_, graph);
108+
109+
std::unordered_map<const Node*, int> concat_with_convs_counter;
110+
FindConcatWithConvs(graph, &concat_with_convs_counter);
111+
FuseConvConcatReLU(graph, &concat_with_convs_counter);
112+
}
113+
114+
} // namespace ir
115+
} // namespace framework
116+
} // namespace paddle
117+
118+
REGISTER_PASS(conv_concat_relu_mkldnn_fuse_pass,
119+
paddle::framework::ir::ConvConcatReLUFusePass);
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
17+
#include <string>
18+
#include <unordered_map>
19+
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
20+
#include "paddle/fluid/framework/ir/graph.h"
21+
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
22+
#include "paddle/fluid/framework/ir/pass.h"
23+
24+
namespace paddle {
25+
namespace framework {
26+
namespace ir {
27+
28+
/*
29+
* Fuse the (multi conv) -> Concat -> ReLU -> next_op
30+
* to a:
31+
* (multi ConvReLU) -> Concat -> next_op.
32+
*/
33+
class ConvConcatReLUFusePass : public FusePassBase {
34+
public:
35+
virtual ~ConvConcatReLUFusePass() {}
36+
37+
protected:
38+
void ApplyImpl(ir::Graph* graph) const override;
39+
40+
void FindConcatWithConvs(
41+
Graph* graph,
42+
std::unordered_map<const Node*, int>* concat_with_convs_counter) const;
43+
44+
void FuseConvConcatReLU(
45+
Graph* graph,
46+
std::unordered_map<const Node*, int>* concat_with_convs_counter) const;
47+
48+
const std::string name_scope_{"conv_concat_relu_mkldnn_fuse"};
49+
};
50+
51+
} // namespace ir
52+
} // namespace framework
53+
} // namespace paddle

0 commit comments

Comments
 (0)