Skip to content

Commit 8462e2b

Browse files
Sand3r-luotao1
authored andcommitted
Disable MKLDNN FC in Resnet50 test (#18030)
1 parent 78e9328 commit 8462e2b

File tree

3 files changed

+10
-8
lines changed

3 files changed

+10
-8
lines changed

cmake/generic.cmake

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -385,7 +385,7 @@ function(cc_test TARGET_NAME)
385385
set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_cpu_deterministic=true)
386386
set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_init_allocated_mem=true)
387387
set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_limit_of_tmp_allocation=4294967296) # 4G
388-
set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_cudnn_deterministic=true ${MKL_DEBUG_FLAG})
388+
set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_cudnn_deterministic=true)
389389
# No unit test should exceed 10 minutes.
390390
set_tests_properties(${TARGET_NAME} PROPERTIES TIMEOUT 600)
391391
endif()

paddle/fluid/inference/tests/api/CMakeLists.txt

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,14 +33,12 @@ function(inference_analysis_api_int8_test target model_dir data_dir filename)
3333
--paddle_num_threads=${CPU_NUM_THREADS_ON_CI}
3434
--iterations=2)
3535
endfunction()
36-
function(inference_analysis_api_test_with_fake_data target install_dir filename model_name mkl_debug)
37-
if(mkl_debug)
38-
set(MKL_DEBUG_FLAG MKL_DEBUG_CPU_TYPE=7)
39-
endif()
36+
function(inference_analysis_api_test_with_fake_data target install_dir filename model_name disable_fc)
4037
download_model(${install_dir} ${model_name})
4138
inference_analysis_test(${target} SRCS ${filename}
4239
EXTRA_DEPS ${INFERENCE_EXTRA_DEPS}
43-
ARGS --infer_model=${install_dir}/model)
40+
ARGS --infer_model=${install_dir}/model
41+
--disable_mkldnn_fc=${disable_fc})
4442
endfunction()
4543

4644
function(inference_analysis_api_test_with_refer_result target install_dir filename)

paddle/fluid/inference/tests/api/analyzer_resnet50_tester.cc

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ limitations under the License. */
1616
#include <iostream>
1717
#include "paddle/fluid/inference/tests/api/tester_helper.h"
1818

19+
DEFINE_bool(disable_mkldnn_fc, false, "Disable usage of MKL-DNN's FC op");
20+
1921
namespace paddle {
2022
namespace inference {
2123
namespace analysis {
@@ -48,7 +50,8 @@ void profile(bool use_mkldnn = false) {
4850

4951
if (use_mkldnn) {
5052
cfg.EnableMKLDNN();
51-
cfg.pass_builder()->AppendPass("fc_mkldnn_pass");
53+
if (!FLAGS_disable_mkldnn_fc)
54+
cfg.pass_builder()->AppendPass("fc_mkldnn_pass");
5255
}
5356
std::vector<std::vector<PaddleTensor>> outputs;
5457

@@ -80,7 +83,8 @@ void compare(bool use_mkldnn = false) {
8083
SetConfig(&cfg);
8184
if (use_mkldnn) {
8285
cfg.EnableMKLDNN();
83-
cfg.pass_builder()->AppendPass("fc_mkldnn_pass");
86+
if (!FLAGS_disable_mkldnn_fc)
87+
cfg.pass_builder()->AppendPass("fc_mkldnn_pass");
8488
}
8589

8690
std::vector<std::vector<PaddleTensor>> input_slots_all;

0 commit comments

Comments
 (0)