Skip to content

Commit 0c39b97

Browse files
Sand3r-luotao1
authored andcommitted
[MKL-DNN] Add Fully Connected Op for inference only(#15226)
* fuse mul and elementwise add to fc * Reimplement the FC forward operator * Fix FC MKLDNN integration by transposing weights * Add FC MKLDNN Pass test=develop * FC MKLDNN Pass: change memcpy to std::copy * Fix MKLDNN FC handling of mismatch input and weights dims * Lower tolerance for MKL-DNN in resnet50 test test=develop * Adjust FC to support MKLDNN Op placement test=develop * Adjust Placement Op to set use_mkldnn attribute for graph test=develop * MKLDNN FC: fix weights format so that gemm version is called test=develop * FC MKLDNN: Remove tolerance decrease from tester_helper * FC MKL-DNN: Refactor the code, change input reorder to weight reorder * MKL-DNN FC: Introduce operator caching test=develop * FC MKL-DNN: Fix the tensor type in ExpectedKernelType test=develop * FC MKL-DNN: fix style changes test=develop * FC MKL-DNN: fallback to native on non-supported dim sizes test=develop * FC MKLDNN: fix CMake paths test=develop * FC MKLDNN: Refine placement pass graph mkldnn attribute test=develop * Fix Transpiler error for fuse_conv_eltwise test=develop * Fix missing STL includes in files test=develop * FC MKL-DNN: Enable new output size computation Also, refine pass to comply with newest interface. test=develop * FC MKL-DNN: enable only when fc_mkldnn_pass is enabled * FC MKL-DNN: Allow Weights to use oi or io format * FC MKL-DNN: Adjust UT to work with correct dims test=develop * Enable MKL DEBUG for resnet50 analyzer test=develop * FC MKL-DNN: Improve Hashing function test=develop * FC MKL-DNN: Fix shape for fc weights in transpiler * FC MKL-DNN: Update input pointer in re-used fc primitive * Add log for not handling fc fuse for unsupported dims test=develop * FC MKL-DNN: Move transpose from pass to Op Kernel test=develop * FC MKL-DNN: Disable transpose in unit test test=develop * FC MKL-DNN: Remove fc_mkldnn_pass from default list * Correct Flag for fake data analyzer tests test=develop * FC MKL-DNN: Add comment about fc mkldnn pass disablement test=develop * FC MKL-DNN: Disable fc in int8 tests test=develop
1 parent 21138eb commit 0c39b97

20 files changed

+502
-286
lines changed

cmake/generic.cmake

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -385,7 +385,7 @@ function(cc_test TARGET_NAME)
385385
set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_cpu_deterministic=true)
386386
set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_init_allocated_mem=true)
387387
set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_limit_of_tmp_allocation=4294967296) # 4G
388-
set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_cudnn_deterministic=true)
388+
set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_cudnn_deterministic=true ${MKL_DEBUG_FLAG})
389389
# No unit test should exceed 10 minutes.
390390
set_tests_properties(${TARGET_NAME} PROPERTIES TIMEOUT 600)
391391
endif()

paddle/fluid/framework/ir/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ if(WITH_MKLDNN)
8888
pass_library(conv_brelu_mkldnn_fuse_pass inference mkldnn)
8989
pass_library(conv_concat_relu_mkldnn_fuse_pass inference mkldnn)
9090
pass_library(conv_elementwise_add_mkldnn_fuse_pass inference mkldnn)
91+
pass_library(fc_mkldnn_pass inference mkldnn)
9192
pass_library(cpu_quantize_placement_pass base mkldnn)
9293
pass_library(cpu_quantize_pass inference mkldnn)
9394
pass_library(cpu_quantize_squash_pass inference mkldnn)

paddle/fluid/framework/ir/fc_fuse_pass.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
// limitations under the License.
1414

1515
#include "paddle/fluid/framework/ir/fc_fuse_pass.h"
16+
#include <memory>
1617
#include <string>
1718
#include <unordered_set>
1819
#include <vector>
@@ -80,6 +81,7 @@ void FCFusePass::ApplyImpl(ir::Graph* graph) const {
8081
}
8182

8283
desc.SetType("fc");
84+
8385
auto fc_node = g->CreateOpNode(&desc); // OpDesc will be copied.
8486
GraphSafeRemoveNodes(graph, {mul, elementwise_add, mul_out});
8587

paddle/fluid/framework/ir/graph_pattern_detector.cc

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,10 @@
1414

1515
#include <algorithm>
1616
#include <array>
17+
#include <memory>
1718
#include <string>
19+
#include <unordered_map>
20+
#include <unordered_set>
1821
#include <vector>
1922

2023
#include "paddle/fluid/framework/ir/graph_helper.h"
@@ -896,6 +899,33 @@ PDNode *patterns::FC::operator()(paddle::framework::ir::PDNode *x,
896899
}
897900
}
898901

902+
PDNode *patterns::FCMKLDNN::operator()(paddle::framework::ir::PDNode *x,
903+
bool with_bias) {
904+
// Create shared nodes.
905+
x->assert_is_op_input("fc", "Input");
906+
907+
auto *fc_op = pattern->NewNode(fc_repr())->assert_is_op("fc");
908+
// Create variables
909+
// Filter
910+
auto *fc_weight_var = pattern->NewNode(weights_repr())
911+
->AsInput()
912+
->assert_is_persistable_var()
913+
->assert_is_op_input("fc", "W");
914+
// Bias
915+
auto *fc_bias_var = pattern->NewNode(bias_repr())
916+
->AsInput()
917+
->assert_is_persistable_var()
918+
->assert_is_op_input("fc", "Bias");
919+
// Output
920+
auto *fc_out_var = pattern->NewNode(output_repr())
921+
->AsOutput()
922+
->assert_is_op_output("fc", "Out")
923+
->assert_is_only_output_of_op("fc");
924+
925+
fc_op->LinksFrom({x, fc_weight_var, fc_bias_var}).LinksTo({fc_out_var});
926+
return fc_out_var;
927+
}
928+
899929
PDNode *patterns::Embedding::operator()(PDNode *x) {
900930
x->assert_is_op_input("lookup_table", "Ids");
901931
auto *lookup_table_op =

paddle/fluid/framework/ir/graph_pattern_detector.h

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -517,6 +517,25 @@ struct FC : public PatternBase {
517517
PATTERN_DECL_NODE(Out);
518518
};
519519

520+
// MKL-DNN's FC with bias
521+
// op: fc
522+
// named node:
523+
// fc
524+
// w, bias, output
525+
struct FCMKLDNN : public PatternBase {
526+
FCMKLDNN(PDPattern* pattern, const std::string& name_scope)
527+
: PatternBase(pattern, name_scope, "fc_mkldnn") {}
528+
529+
PDNode* operator()(PDNode* x, bool with_bias);
530+
531+
// declare operator node's name
532+
PATTERN_DECL_NODE(fc);
533+
// declare variable node's name
534+
PATTERN_DECL_NODE(weights);
535+
PATTERN_DECL_NODE(bias);
536+
PATTERN_DECL_NODE(output);
537+
};
538+
520539
// Embedding
521540
struct Embedding : public PatternBase {
522541
Embedding(PDPattern* pattern, const std::string& name_scope)
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/fluid/framework/ir/mkldnn/fc_mkldnn_pass.h"
16+
#include <algorithm>
17+
#include <memory>
18+
#include <string>
19+
#include <vector>
20+
#include "paddle/fluid/framework/eigen.h"
21+
#include "paddle/fluid/framework/lod_tensor.h"
22+
#include "paddle/fluid/platform/enforce.h"
23+
24+
namespace paddle {
25+
namespace framework {
26+
namespace ir {
27+
28+
void FCMKLDNNPass::ApplyImpl(ir::Graph* graph) const {
29+
PADDLE_ENFORCE(graph);
30+
Init("fc_mkldnn_pass", graph);
31+
32+
auto* scope = param_scope();
33+
PADDLE_ENFORCE(scope);
34+
35+
GraphPatternDetector gpd;
36+
auto* x = gpd.mutable_pattern()
37+
->NewNode("fc_mkldnn_pass/x")
38+
->AsInput()
39+
->assert_is_op_input("fc", "Input");
40+
patterns::FCMKLDNN fc_pattern(gpd.mutable_pattern(), "fc_mkldnn_pass");
41+
fc_pattern(x, true /*with bias*/);
42+
43+
int found_fc_count = 0;
44+
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
45+
Graph* g) {
46+
VLOG(4) << "Handle FC MKL-DNN pass";
47+
if (!(graph->Has("use_mkldnn") && graph->Get<bool>("use_mkldnn"))) {
48+
VLOG(3) << "do not perform fc fuse";
49+
return;
50+
}
51+
GET_IR_NODE_FROM_SUBGRAPH(fc, fc, fc_pattern);
52+
GET_IR_NODE_FROM_SUBGRAPH(weights, weights, fc_pattern);
53+
GET_IR_NODE_FROM_SUBGRAPH(bias, bias, fc_pattern);
54+
GET_IR_NODE_FROM_SUBGRAPH(output, output, fc_pattern);
55+
56+
OpDesc* desc = fc->Op();
57+
auto in_size = fc->inputs[0]->Var()->GetShape().size();
58+
if (in_size != 2 && in_size != 4) {
59+
VLOG(3) << "Do not enable FC MKL-DNN for dimensions different than 2 & 4";
60+
return;
61+
}
62+
desc->SetAttr("use_mkldnn", true);
63+
PADDLE_ENFORCE(subgraph.count(x));
64+
65+
found_fc_count++;
66+
};
67+
68+
gpd(graph, handler);
69+
70+
AddStatis(found_fc_count);
71+
}
72+
73+
} // namespace ir
74+
} // namespace framework
75+
} // namespace paddle
76+
77+
REGISTER_PASS(fc_mkldnn_pass, paddle::framework::ir::FCMKLDNNPass);
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
#pragma once
15+
#include <memory>
16+
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
17+
#include "paddle/fluid/framework/ir/graph.h"
18+
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
19+
#include "paddle/fluid/framework/ir/pass.h"
20+
21+
namespace paddle {
22+
namespace framework {
23+
namespace ir {
24+
25+
/*
26+
* Transpose weights of FC to comply with MKL-DNN interface
27+
*/
28+
class FCMKLDNNPass : public FusePassBase {
29+
public:
30+
virtual ~FCMKLDNNPass() {}
31+
32+
protected:
33+
void ApplyImpl(ir::Graph* graph) const;
34+
};
35+
36+
} // namespace ir
37+
} // namespace framework
38+
} // namespace paddle

paddle/fluid/framework/ir/mkldnn/mkldnn_placement_pass.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

1515
#include "paddle/fluid/framework/ir/mkldnn/mkldnn_placement_pass.h"
16+
#include <memory>
1617
#include <string>
1718
#include <unordered_set>
1819

@@ -24,6 +25,9 @@ void MKLDNNPlacementPass::ApplyImpl(ir::Graph* graph) const {
2425
VLOG(3) << "Applies MKL-DNN placement strategy.";
2526
const auto& op_types_list =
2627
Get<std::unordered_set<std::string>>("mkldnn_enabled_op_types");
28+
if (!graph->Has("use_mkldnn")) {
29+
graph->Set<bool>("use_mkldnn", new bool(true));
30+
}
2731
for (const Node* n : graph->Nodes()) {
2832
if (n->IsOp()) {
2933
auto* op = n->Op();

paddle/fluid/inference/api/paddle_pass_builder.cc

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -146,16 +146,19 @@ void CpuPassStrategy::EnableMKLDNN() {
146146
if (!use_mkldnn_) {
147147
passes_.insert(passes_.begin(), "mkldnn_placement_pass");
148148

149-
for (auto &pass : std::vector<std::string>(
150-
{"depthwise_conv_mkldnn_pass", //
151-
"conv_bn_fuse_pass", // Execute BN passes again to
152-
"conv_eltwiseadd_bn_fuse_pass", // preserve correct pass order
153-
"conv_bias_mkldnn_fuse_pass", //
154-
"conv3d_bias_mkldnn_fuse_pass", //
155-
"conv_elementwise_add_mkldnn_fuse_pass",
156-
"conv_concat_relu_mkldnn_fuse_pass",
157-
"conv_relu_mkldnn_fuse_pass", //
158-
"conv_brelu_mkldnn_fuse_pass"})) {
149+
for (auto &pass : std::vector<std::string>({
150+
"depthwise_conv_mkldnn_pass", //
151+
"conv_bn_fuse_pass", // Execute BN passes again to
152+
"conv_eltwiseadd_bn_fuse_pass", // preserve correct pass order
153+
"conv_bias_mkldnn_fuse_pass", //
154+
"conv3d_bias_mkldnn_fuse_pass", //
155+
"conv_elementwise_add_mkldnn_fuse_pass",
156+
"conv_concat_relu_mkldnn_fuse_pass",
157+
"conv_relu_mkldnn_fuse_pass", //
158+
"conv_brelu_mkldnn_fuse_pass", //
159+
// Disabled due to topology-dependent speed-up
160+
// "fc_mkldnn_pass"
161+
})) {
159162
passes_.push_back(pass);
160163
}
161164
}

paddle/fluid/inference/tests/api/CMakeLists.txt

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,10 @@ function(inference_analysis_api_int8_test target model_dir data_dir filename)
3333
--paddle_num_threads=${CPU_NUM_THREADS_ON_CI}
3434
--iterations=2)
3535
endfunction()
36-
37-
function(inference_analysis_api_test_with_fake_data target install_dir filename model_name)
36+
function(inference_analysis_api_test_with_fake_data target install_dir filename model_name mkl_debug)
37+
if(mkl_debug)
38+
set(MKL_DEBUG_FLAG MKL_DEBUG_CPU_TYPE=7)
39+
endif()
3840
download_model(${install_dir} ${model_name})
3941
inference_analysis_test(${target} SRCS ${filename}
4042
EXTRA_DEPS ${INFERENCE_EXTRA_DEPS}
@@ -143,15 +145,15 @@ inference_analysis_api_test_with_refer_result(test_analyzer_mobilenet_transpose
143145

144146
# googlenet
145147
inference_analysis_api_test_with_fake_data(test_analyzer_googlenet
146-
"${INFERENCE_DEMO_INSTALL_DIR}/googlenet" analyzer_resnet50_tester.cc "googlenet.tar.gz")
148+
"${INFERENCE_DEMO_INSTALL_DIR}/googlenet" analyzer_resnet50_tester.cc "googlenet.tar.gz" false)
147149

148150
# resnet50
149151
inference_analysis_api_test_with_fake_data(test_analyzer_resnet50
150-
"${INFERENCE_DEMO_INSTALL_DIR}/resnet50" analyzer_resnet50_tester.cc "resnet50_model.tar.gz")
152+
"${INFERENCE_DEMO_INSTALL_DIR}/resnet50" analyzer_resnet50_tester.cc "resnet50_model.tar.gz" true)
151153

152154
# mobilenet with depthwise_conv op
153155
inference_analysis_api_test_with_fake_data(test_analyzer_mobilenet_depthwise_conv
154-
"${INFERENCE_DEMO_INSTALL_DIR}/mobilenet_depthwise_conv" analyzer_resnet50_tester.cc "mobilenet_model.tar.gz")
156+
"${INFERENCE_DEMO_INSTALL_DIR}/mobilenet_depthwise_conv" analyzer_resnet50_tester.cc "mobilenet_model.tar.gz" false)
155157

156158
# int8 image classification tests
157159
if(WITH_MKLDNN)

paddle/fluid/inference/tests/api/analyzer_bert_tester.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,7 @@ void profile(bool use_mkldnn = false) {
152152

153153
if (use_mkldnn) {
154154
config.EnableMKLDNN();
155+
config.pass_builder()->AppendPass("fc_mkldnn_pass");
155156
}
156157

157158
std::vector<std::vector<PaddleTensor>> outputs;

paddle/fluid/inference/tests/api/analyzer_dam_tester.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,8 +200,9 @@ void profile(bool use_mkldnn = false) {
200200
cfg.EnableMKLDNN();
201201
// Enable all the mkldnn supported ops except conv3d in dam
202202
std::unordered_set<std::string> op_list = {"softmax", "elementwise_add",
203-
"relu"};
203+
"relu", "fc"};
204204
cfg.SetMKLDNNOp(op_list);
205+
cfg.pass_builder()->AppendPass("fc_mkldnn_pass");
205206
}
206207

207208
std::vector<std::vector<PaddleTensor>> outputs;

paddle/fluid/inference/tests/api/analyzer_mm_dnn_tester.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ void profile(bool use_mkldnn = false) {
100100

101101
if (use_mkldnn) {
102102
cfg.EnableMKLDNN();
103+
cfg.pass_builder()->AppendPass("fc_mkldnn_pass");
103104
}
104105

105106
std::vector<std::vector<PaddleTensor>> input_slots_all;

paddle/fluid/inference/tests/api/analyzer_resnet50_tester.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ void profile(bool use_mkldnn = false) {
4848

4949
if (use_mkldnn) {
5050
cfg.EnableMKLDNN();
51+
cfg.pass_builder()->AppendPass("fc_mkldnn_pass");
5152
}
5253
std::vector<std::vector<PaddleTensor>> outputs;
5354

@@ -79,6 +80,7 @@ void compare(bool use_mkldnn = false) {
7980
SetConfig(&cfg);
8081
if (use_mkldnn) {
8182
cfg.EnableMKLDNN();
83+
cfg.pass_builder()->AppendPass("fc_mkldnn_pass");
8284
}
8385

8486
std::vector<std::vector<PaddleTensor>> input_slots_all;

paddle/fluid/inference/tests/api/analyzer_seq_pool1_tester.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,7 @@ void SetConfig(AnalysisConfig *cfg, bool use_mkldnn = false) {
149149
}
150150
if (use_mkldnn) {
151151
cfg->EnableMKLDNN();
152+
cfg->pass_builder()->AppendPass("fc_mkldnn_pass");
152153
}
153154
// Enable seqpool_concat_fuse_pass, disabled by default since it takes much
154155
// time

paddle/fluid/inference/tests/api/analyzer_transformer_tester.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,7 @@ void profile(bool use_mkldnn = false) {
189189
std::vector<std::vector<PaddleTensor>> outputs;
190190
if (use_mkldnn) {
191191
cfg.EnableMKLDNN();
192+
cfg.pass_builder()->AppendPass("fc_mkldnn_pass");
192193
}
193194

194195
std::vector<std::vector<PaddleTensor>> input_slots_all;

paddle/fluid/inference/tests/api/analyzer_vis_tester.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ void profile(bool use_mkldnn = false) {
8585
SetConfig(&cfg);
8686
if (use_mkldnn) {
8787
cfg.EnableMKLDNN();
88+
cfg.pass_builder()->AppendPass("fc_mkldnn_pass");
8889
}
8990
// cfg.pass_builder()->TurnOnDebug();
9091
std::vector<std::vector<PaddleTensor>> outputs;

0 commit comments

Comments
 (0)