Skip to content

Commit 8e3a499

Browse files
committed
resolve conflict, test=develop, test=document_fix
2 parents 3147d31 + eb05db7 commit 8e3a499

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

55 files changed

+4111
-1071
lines changed

cmake/inference_lib.cmake

Lines changed: 40 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
if(WIN32)
1818
if(NOT PYTHON_EXECUTABLE)
19-
FIND_PACKAGE(PythonInterp REQUIRED)
19+
FIND_PACKAGE(PythonInterp REQUIRED)
2020
endif()
2121
endif()
2222

@@ -78,26 +78,26 @@ add_custom_target(inference_lib_dist DEPENDS ${inference_lib_deps})
7878

7979
set(dst_dir "${FLUID_INFERENCE_INSTALL_DIR}/third_party/eigen3")
8080
copy(inference_lib_dist
81-
SRCS ${EIGEN_INCLUDE_DIR}/Eigen/Core ${EIGEN_INCLUDE_DIR}/Eigen/src ${EIGEN_INCLUDE_DIR}/unsupported/Eigen
82-
DSTS ${dst_dir}/Eigen ${dst_dir}/Eigen ${dst_dir}/unsupported)
81+
SRCS ${EIGEN_INCLUDE_DIR}/Eigen/Core ${EIGEN_INCLUDE_DIR}/Eigen/src ${EIGEN_INCLUDE_DIR}/unsupported/Eigen
82+
DSTS ${dst_dir}/Eigen ${dst_dir}/Eigen ${dst_dir}/unsupported)
8383

8484
set(dst_dir "${FLUID_INFERENCE_INSTALL_DIR}/third_party/boost")
8585
copy(inference_lib_dist
86-
SRCS ${BOOST_INCLUDE_DIR}/boost
87-
DSTS ${dst_dir})
86+
SRCS ${BOOST_INCLUDE_DIR}/boost
87+
DSTS ${dst_dir})
8888

8989
if(WITH_MKLML)
9090
set(dst_dir "${FLUID_INFERENCE_INSTALL_DIR}/third_party/install/mklml")
9191
if(WIN32)
9292
copy(inference_lib_dist
93-
SRCS ${MKLML_LIB} ${MKLML_IOMP_LIB} ${MKLML_SHARED_LIB}
93+
SRCS ${MKLML_LIB} ${MKLML_IOMP_LIB} ${MKLML_SHARED_LIB}
9494
${MKLML_SHARED_LIB_DEPS} ${MKLML_SHARED_IOMP_LIB} ${MKLML_INC_DIR}
95-
DSTS ${dst_dir}/lib ${dst_dir}/lib ${dst_dir}/lib
95+
DSTS ${dst_dir}/lib ${dst_dir}/lib ${dst_dir}/lib
9696
${dst_dir}/lib ${dst_dir}/lib ${dst_dir})
9797
else()
9898
copy(inference_lib_dist
99-
SRCS ${MKLML_LIB} ${MKLML_IOMP_LIB} ${MKLML_INC_DIR}
100-
DSTS ${dst_dir}/lib ${dst_dir}/lib ${dst_dir})
99+
SRCS ${MKLML_LIB} ${MKLML_IOMP_LIB} ${MKLML_INC_DIR}
100+
DSTS ${dst_dir}/lib ${dst_dir}/lib ${dst_dir})
101101
endif()
102102
elseif (NOT CBLAS_FOUND OR WIN32)
103103
set(dst_dir "${FLUID_INFERENCE_INSTALL_DIR}/third_party/install/openblas")
@@ -107,16 +107,16 @@ elseif (NOT CBLAS_FOUND OR WIN32)
107107
endif ()
108108

109109
if(WITH_MKLDNN)
110-
set(dst_dir "${FLUID_INFERENCE_INSTALL_DIR}/third_party/install/mkldnn")
111-
if(WIN32)
112-
copy(inference_lib_dist
113-
SRCS ${MKLDNN_INC_DIR} ${MKLDNN_SHARED_LIB} ${MKLDNN_LIB}
114-
DSTS ${dst_dir} ${dst_dir}/lib ${dst_dir}/lib)
115-
else()
116-
copy(inference_lib_dist
117-
SRCS ${MKLDNN_INC_DIR} ${MKLDNN_SHARED_LIB}
118-
DSTS ${dst_dir} ${dst_dir}/lib)
119-
endif()
110+
set(dst_dir "${FLUID_INFERENCE_INSTALL_DIR}/third_party/install/mkldnn")
111+
if(WIN32)
112+
copy(inference_lib_dist
113+
SRCS ${MKLDNN_INC_DIR} ${MKLDNN_SHARED_LIB} ${MKLDNN_LIB}
114+
DSTS ${dst_dir} ${dst_dir}/lib ${dst_dir}/lib)
115+
else()
116+
copy(inference_lib_dist
117+
SRCS ${MKLDNN_INC_DIR} ${MKLDNN_SHARED_LIB}
118+
DSTS ${dst_dir} ${dst_dir}/lib)
119+
endif()
120120
endif()
121121

122122
set(dst_dir "${FLUID_INFERENCE_INSTALL_DIR}/third_party/install/gflags")
@@ -156,20 +156,20 @@ endif ()
156156
if (TENSORRT_FOUND)
157157
set(dst_dir "${FLUID_INFERENCE_INSTALL_DIR}/third_party/install/tensorrt")
158158
copy(inference_lib_dist
159-
SRCS ${TENSORRT_ROOT}/include/Nv*.h ${TENSORRT_ROOT}/lib/*nvinfer*
160-
DSTS ${dst_dir}/include ${dst_dir}/lib)
159+
SRCS ${TENSORRT_ROOT}/include/Nv*.h ${TENSORRT_ROOT}/lib/*nvinfer*
160+
DSTS ${dst_dir}/include ${dst_dir}/lib)
161161
endif ()
162162

163163
if (ANAKIN_FOUND)
164164
set(dst_dir "${FLUID_INFERENCE_INSTALL_DIR}/third_party/install/anakin")
165165
copy(inference_lib_dist
166-
SRCS ${ANAKIN_ROOT}/*
167-
DSTS ${dst_dir})
166+
SRCS ${ANAKIN_ROOT}/*
167+
DSTS ${dst_dir})
168168
endif ()
169169

170170
copy(inference_lib_dist
171-
SRCS ${CMAKE_CURRENT_BINARY_DIR}/CMakeCache.txt
172-
DSTS ${FLUID_INFERENCE_INSTALL_DIR})
171+
SRCS ${CMAKE_CURRENT_BINARY_DIR}/CMakeCache.txt
172+
DSTS ${FLUID_INFERENCE_INSTALL_DIR})
173173

174174
set(src_dir "${PADDLE_SOURCE_DIR}/paddle/fluid")
175175
if(WIN32)
@@ -179,8 +179,8 @@ else(WIN32)
179179
endif(WIN32)
180180

181181
copy(inference_lib_dist
182-
SRCS ${src_dir}/inference/api/paddle_*.h ${paddle_fluid_lib}
183-
DSTS ${FLUID_INFERENCE_INSTALL_DIR}/paddle/include ${FLUID_INFERENCE_INSTALL_DIR}/paddle/lib)
182+
SRCS ${src_dir}/inference/api/paddle_*.h ${paddle_fluid_lib}
183+
DSTS ${FLUID_INFERENCE_INSTALL_DIR}/paddle/include ${FLUID_INFERENCE_INSTALL_DIR}/paddle/lib)
184184

185185

186186
# fluid library for both train and inference
@@ -190,17 +190,23 @@ add_custom_target(fluid_lib_dist ALL DEPENDS ${fluid_lib_deps})
190190
set(dst_dir "${FLUID_INSTALL_DIR}/paddle/fluid")
191191
set(module "inference")
192192
copy(fluid_lib_dist
193-
SRCS ${src_dir}/${module}/*.h ${src_dir}/${module}/api/paddle_*.h ${paddle_fluid_lib}
194-
DSTS ${dst_dir}/${module} ${dst_dir}/${module} ${dst_dir}/${module}
195-
)
193+
SRCS ${src_dir}/${module}/*.h ${src_dir}/${module}/api/paddle_*.h ${paddle_fluid_lib}
194+
DSTS ${dst_dir}/${module} ${dst_dir}/${module} ${dst_dir}/${module}
195+
)
196196

197197
set(module "framework")
198198
set(framework_lib_deps framework_proto)
199199
add_dependencies(fluid_lib_dist ${framework_lib_deps})
200200
copy(fluid_lib_dist
201-
SRCS ${src_dir}/${module}/*.h ${src_dir}/${module}/details/*.h ${PADDLE_BINARY_DIR}/paddle/fluid/framework/framework.pb.h ${PADDLE_BINARY_DIR}/paddle/fluid/framework/data_feed.pb.h ${src_dir}/${module}/ir/memory_optimize_pass/*.h
202-
${src_dir}/${module}/ir/*.h ${src_dir}/${module}/fleet/*.h
203-
DSTS ${dst_dir}/${module} ${dst_dir}/${module}/details ${dst_dir}/${module} ${dst_dir}/${module} ${dst_dir}/${module}/ir/memory_optimize_pass ${dst_dir}/${module}/ir ${dst_dir}/${module}/fleet)
201+
SRCS ${src_dir}/${module}/*.h ${src_dir}/${module}/details/*.h ${PADDLE_BINARY_DIR}/paddle/fluid/framework/trainer_desc.pb.h ${PADDLE_BINARY_DIR}/paddle/fluid/framework/framework.pb.h ${PADDLE_BINARY_DIR}/paddle/fluid/framework/data_feed.pb.h ${src_dir}/${module}/ir/memory_optimize_pass/*.h
202+
${src_dir}/${module}/ir/*.h ${src_dir}/${module}/fleet/*.h
203+
DSTS ${dst_dir}/${module} ${dst_dir}/${module}/details ${dst_dir}/${module} ${dst_dir}/${module} ${dst_dir}/${module} ${dst_dir}/${module}/ir/memory_optimize_pass ${dst_dir}/${module}/ir ${dst_dir}/${module}/fleet)
204+
205+
set(module "operators")
206+
copy(fluid_lib_dist
207+
SRCS ${src_dir}/${module}/reader/blocking_queue.h
208+
DSTS ${dst_dir}/${module}/reader/
209+
)
204210

205211
set(module "memory")
206212
copy(fluid_lib_dist
@@ -252,4 +258,4 @@ function(version version_file)
252258
endif ()
253259
endfunction()
254260
version(${FLUID_INSTALL_DIR}/version.txt)
255-
version(${FLUID_INFERENCE_INSTALL_DIR}/version.txt)
261+
version(${FLUID_INFERENCE_INSTALL_DIR}/version.txt)

paddle/fluid/API.spec

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,9 @@ paddle.fluid.load_op_library (ArgSpec(args=['lib_filename'], varargs=None, keywo
3030
paddle.fluid.Executor ('paddle.fluid.executor.Executor', ('document', '34e8c1769313fbeff7817212dda6259e'))
3131
paddle.fluid.Executor.__init__ (ArgSpec(args=['self', 'place'], varargs=None, keywords=None, defaults=None), ('document', '6adf97f83acf6453d4a6a4b1070f3754'))
3232
paddle.fluid.Executor.close (ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None), ('document', '3a584496aa1343f36eebf3c46b323a74'))
33-
paddle.fluid.Executor.infer_from_dataset (ArgSpec(args=['self', 'program', 'dataset', 'scope', 'thread', 'debug', 'fetch_list', 'fetch_info', 'print_period'], varargs=None, keywords=None, defaults=(None, None, None, 0, False, None, None, 100)), ('document', 'bedc29ad01c1b911e99032ee1e19ac59'))
33+
paddle.fluid.Executor.infer_from_dataset (ArgSpec(args=['self', 'program', 'dataset', 'scope', 'thread', 'debug', 'fetch_list', 'fetch_info', 'print_period', 'fetch_handler'], varargs=None, keywords=None, defaults=(None, None, None, 0, False, None, None, 100, None)), ('document', '4ff256774ecaeee01c840a5fb5de8f7a'))
3434
paddle.fluid.Executor.run (ArgSpec(args=['self', 'program', 'feed', 'fetch_list', 'feed_var_name', 'fetch_var_name', 'scope', 'return_numpy', 'use_program_cache'], varargs=None, keywords=None, defaults=(None, None, None, 'feed', 'fetch', None, True, False)), ('document', '4cfcd9c15b766a51b584cc46d38f1ad8'))
35-
paddle.fluid.Executor.train_from_dataset (ArgSpec(args=['self', 'program', 'dataset', 'scope', 'thread', 'debug', 'fetch_list', 'fetch_info', 'print_period'], varargs=None, keywords=None, defaults=(None, None, None, 0, False, None, None, 100)), ('document', '28f50904a0213f110947a30e0438529c'))
35+
paddle.fluid.Executor.train_from_dataset (ArgSpec(args=['self', 'program', 'dataset', 'scope', 'thread', 'debug', 'fetch_list', 'fetch_info', 'print_period', 'fetch_handler'], varargs=None, keywords=None, defaults=(None, None, None, 0, False, None, None, 100, None)), ('document', '73024c79f46b4f14f1060edeaa4919c8'))
3636
paddle.fluid.global_scope (ArgSpec(args=[], varargs=None, keywords=None, defaults=None), ('document', 'f65788d9ead293ada47551339df12203'))
3737
paddle.fluid.scope_guard (ArgSpec(args=['scope'], varargs=None, keywords=None, defaults=None), ('document', 'e6c073ed237001aaba7bff976b62b122'))
3838
paddle.fluid.DistributeTranspiler ('paddle.fluid.transpiler.distribute_transpiler.DistributeTranspiler', ('document', 'b2b19821c5dffcd11473d6a4eef089af'))
@@ -153,8 +153,8 @@ paddle.fluid.layers.batch_norm (ArgSpec(args=['input', 'act', 'is_test', 'moment
153153
paddle.fluid.layers.instance_norm (ArgSpec(args=['input', 'epsilon', 'param_attr', 'bias_attr', 'name'], varargs=None, keywords=None, defaults=(1e-05, None, None, None)), ('document', '5e2d18e85599ede7e71b06ed64d0f69e'))
154154
paddle.fluid.layers.data_norm (ArgSpec(args=['input', 'act', 'epsilon', 'param_attr', 'data_layout', 'in_place', 'name', 'moving_mean_name', 'moving_variance_name', 'do_model_average_for_mean_and_var'], varargs=None, keywords=None, defaults=(None, 1e-05, None, 'NCHW', False, None, None, None, False)), ('document', '2460b30fb87037555208fa8ac6fc1787'))
155155
paddle.fluid.layers.beam_search_decode (ArgSpec(args=['ids', 'scores', 'beam_size', 'end_id', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '83e08f21af41ac8bac37aeab1f86fdd0'))
156-
paddle.fluid.layers.conv2d_transpose (ArgSpec(args=['input', 'num_filters', 'output_size', 'filter_size', 'padding', 'stride', 'dilation', 'groups', 'param_attr', 'bias_attr', 'use_cudnn', 'act', 'name'], varargs=None, keywords=None, defaults=(None, None, 0, 1, 1, None, None, None, True, None, None)), ('document', 'de4f1dffa8245f010a5f7e8f3952e90c'))
157-
paddle.fluid.layers.conv3d_transpose (ArgSpec(args=['input', 'num_filters', 'output_size', 'filter_size', 'padding', 'stride', 'dilation', 'groups', 'param_attr', 'bias_attr', 'use_cudnn', 'act', 'name'], varargs=None, keywords=None, defaults=(None, None, 0, 1, 1, None, None, None, True, None, None)), ('document', 'fbbd2eab215f00a6ccd51d54d30dba87'))
156+
paddle.fluid.layers.conv2d_transpose (ArgSpec(args=['input', 'num_filters', 'output_size', 'filter_size', 'padding', 'stride', 'dilation', 'groups', 'param_attr', 'bias_attr', 'use_cudnn', 'act', 'name', 'data_format'], varargs=None, keywords=None, defaults=(None, None, 0, 1, 1, None, None, None, True, None, None, 'NCHW')), ('document', '0ca6c549ac2b63096bdc7832a08b4431'))
157+
paddle.fluid.layers.conv3d_transpose (ArgSpec(args=['input', 'num_filters', 'output_size', 'filter_size', 'padding', 'stride', 'dilation', 'groups', 'param_attr', 'bias_attr', 'use_cudnn', 'act', 'name', 'data_format'], varargs=None, keywords=None, defaults=(None, None, 0, 1, 1, None, None, None, True, None, None, 'NCDHW')), ('document', '3393afeec7bf6fb6ebff086eecbc244a'))
158158
paddle.fluid.layers.sequence_expand (ArgSpec(args=['x', 'y', 'ref_level', 'name'], varargs=None, keywords=None, defaults=(-1, None)), ('document', '10e122eb755c2bd1f78ef2332b28f1a0'))
159159
paddle.fluid.layers.sequence_expand_as (ArgSpec(args=['x', 'y', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '858c432e7cbd8bb952cc2eb555457d50'))
160160
paddle.fluid.layers.sequence_pad (ArgSpec(args=['x', 'pad_value', 'maxlen', 'name'], varargs=None, keywords=None, defaults=(None, None)), ('document', 'df08b9c499ab3a90f95d08ab5b6c6c62'))

paddle/fluid/framework/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ cc_library(shape_inference SRCS shape_inference.cc DEPS ddim attribute device_co
123123

124124
cc_library(transfer_scope_cache SRCS transfer_scope_cache.cc DEPS scope framework_proto device_context)
125125
cc_library(op_kernel_type SRCS op_kernel_type.cc DEPS device_context place)
126-
cc_library(operator SRCS operator.cc DEPS op_info device_context tensor scope glog data_feed_proto
126+
cc_library(operator SRCS operator.cc DEPS op_info device_context tensor scope glog trainer_desc_proto data_feed_proto
127127
shape_inference data_transform lod_tensor profiler transfer_scope_cache op_kernel_type op_call_stack)
128128

129129
cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry device_context)

paddle/fluid/framework/executor.cc

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -140,14 +140,14 @@ void Executor::CreateVariables(const ProgramDesc& pdesc, Scope* scope,
140140
}
141141
}
142142

143-
void Executor::RunFromDataset(const ProgramDesc& main_program, Scope* scope,
144-
Dataset* dataset,
145-
const std::string& trainer_desc_str) {
143+
std::shared_ptr<TrainerBase> Executor::InitForDataset(
144+
const ProgramDesc& main_program, const std::string& trainer_desc_str,
145+
Scope* scope, Dataset* dataset) {
146146
VLOG(3) << "Start to RunFromDataset in executor";
147147
TrainerDesc trainer_desc;
148148
bool success = trainer_desc.ParseFromString(trainer_desc_str);
149-
PADDLE_ENFORCE(success, "Fail to parse TrainerDesc from string:\n%s",
150-
trainer_desc_str.c_str());
149+
PADDLE_ENFORCE_EQ(success, true, "Fail to parse TrainerDesc from string:\n%s",
150+
trainer_desc_str.c_str());
151151
VLOG(3) << "Going to create trainer, trainer class is "
152152
<< trainer_desc.class_name();
153153
std::shared_ptr<TrainerBase> trainer;
@@ -162,12 +162,17 @@ void Executor::RunFromDataset(const ProgramDesc& main_program, Scope* scope,
162162
trainer->InitTrainerEnv(main_program, place_);
163163
VLOG(3) << "Try to init other environment";
164164
trainer->InitOtherEnv(main_program);
165+
return trainer;
166+
}
167+
168+
void Executor::RunFromDataset(std::shared_ptr<TrainerBase> trainer) {
169+
PADDLE_ENFORCE_NE(trainer, nullptr,
170+
"Trainer is nullptr, invoke InitForDataset first");
165171
// training and finalize training
166172
VLOG(3) << "Trainer starts to run";
167173
trainer->Run();
168174
VLOG(3) << "Trainer going to finalize";
169175
trainer->Finalize();
170-
return;
171176
}
172177

173178
void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id,

paddle/fluid/framework/executor.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ limitations under the License. */
2626
#include "paddle/fluid/framework/program_desc.h"
2727
#include "paddle/fluid/framework/scope.h"
2828
#include "paddle/fluid/framework/tensor.h"
29+
#include "paddle/fluid/framework/trainer.h"
2930
#include "paddle/fluid/platform/device_context.h"
3031

3132
namespace paddle {
@@ -119,8 +120,10 @@ class Executor {
119120

120121
void EnableMKLDNN(const ProgramDesc& program);
121122

122-
void RunFromDataset(const ProgramDesc& main_program, Scope* scope,
123-
Dataset* dataset, const std::string& trainer_desc_str);
123+
std::shared_ptr<TrainerBase> InitForDataset(
124+
const ProgramDesc& main_program, const std::string& trainer_desc_str,
125+
Scope* scope, Dataset* dataset);
126+
void RunFromDataset(std::shared_ptr<TrainerBase> trainer);
124127

125128
private:
126129
const platform::Place place_;

paddle/fluid/framework/multi_trainer.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,10 @@ void MultiTrainer::InitTrainerEnv(const ProgramDesc& main_program,
6262
}
6363
}
6464

65+
Scope* MultiTrainer::GetWorkerScope(int thread_id) {
66+
return workers_[thread_id]->GetThreadScope();
67+
}
68+
6569
void MultiTrainer::Run() {
6670
VLOG(3) << "Going to run";
6771
for (int thidx = 0; thidx < thread_num_; ++thidx) {

paddle/fluid/framework/pipeline_trainer.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,10 @@ void PipelineTrainer::Finalize() {
261261
root_scope_->DropKids();
262262
}
263263

264+
Scope* PipelineTrainer::GetWorkerScope(int thread_id) {
265+
return pipeline_scopes_[thread_id];
266+
}
267+
264268
} // end namespace framework
265269
} // end namespace paddle
266270
#endif

paddle/fluid/framework/trainer.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ class TrainerBase {
5050
virtual void InitOtherEnv(const ProgramDesc& main_program) = 0;
5151
virtual void Run() = 0;
5252
virtual void Finalize() = 0;
53+
virtual Scope* GetWorkerScope(int thread_id) = 0;
5354

5455
protected:
5556
Scope* root_scope_;
@@ -70,6 +71,7 @@ class MultiTrainer : public TrainerBase {
7071
virtual void InitOtherEnv(const ProgramDesc& main_program) {}
7172
virtual void Run();
7273
virtual void Finalize();
74+
virtual Scope* GetWorkerScope(int thread_id);
7375

7476
protected:
7577
int thread_num_;
@@ -92,6 +94,7 @@ class DistMultiTrainer : public MultiTrainer {
9294
virtual void FinalizeDumpEnv();
9395
virtual void InitDumpEnv();
9496
virtual void DumpWork();
97+
virtual Scope* GetWorkerScope(int thread_id) { return root_scope_; }
9598

9699
protected:
97100
std::shared_ptr<paddle::framework::PullDenseWorker> pull_dense_worker_;
@@ -117,6 +120,7 @@ class PipelineTrainer : public TrainerBase {
117120
void InitOtherEnv(const ProgramDesc& main_program) override {}
118121
void Run() override;
119122
void Finalize() override;
123+
virtual Scope* GetWorkerScope(int thread_id);
120124

121125
protected:
122126
int section_num_;

paddle/fluid/operators/batch_norm_op.cc

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,18 @@ void BatchNormOp::InferShape(framework::InferShapeContext *ctx) const {
5858
const DataLayout data_layout = framework::StringToDataLayout(
5959
ctx->Attrs().Get<std::string>("data_layout"));
6060

61-
PADDLE_ENFORCE(x_dims.size() >= 2 && x_dims.size() <= 5,
62-
"Input X must have 2 to 5 dimensions.");
61+
PADDLE_ENFORCE_GE(
62+
x_dims.size(), 2,
63+
"ShapeError: the dimension of input X must greater than or equal to 2."
64+
"But received: the shape of input X = [%s], the dimension of input X ="
65+
"[%d]",
66+
x_dims, x_dims.size());
67+
PADDLE_ENFORCE_LE(
68+
x_dims.size(), 5,
69+
"ShapeError: the dimension of input X must smaller than or equal to 5."
70+
"But received: the shape of input X = [%s], the dimension of input X ="
71+
"[%d]",
72+
x_dims, x_dims.size());
6373

6474
const int64_t C =
6575
(data_layout == DataLayout::kNCHW ? x_dims[1]
@@ -68,8 +78,16 @@ void BatchNormOp::InferShape(framework::InferShapeContext *ctx) const {
6878
auto scale_dim = ctx->GetInputDim("Scale");
6979
auto bias_dim = ctx->GetInputDim("Bias");
7080

71-
PADDLE_ENFORCE_EQ(scale_dim.size(), 1UL);
72-
PADDLE_ENFORCE_EQ(scale_dim.size(), 1UL);
81+
PADDLE_ENFORCE_EQ(scale_dim.size(), 1UL,
82+
"ShapeError: the dimension of scale must equal to 1."
83+
"But received: the shape of scale is [%s], the dimension "
84+
"of scale is [%d]",
85+
scale_dim, scale_dim.size());
86+
PADDLE_ENFORCE_EQ(
87+
bias_dim.size(), 1UL,
88+
"ShapeError: the dimension of bias must equal to 1."
89+
"But received: the shape of bias is [%s],the dimension of bias is [%d]",
90+
bias_dim, bias_dim.size());
7391

7492
bool check = true;
7593
if ((!ctx->IsRuntime()) && (framework::product(scale_dim) <= 0 ||
@@ -78,8 +96,14 @@ void BatchNormOp::InferShape(framework::InferShapeContext *ctx) const {
7896
}
7997

8098
if (check) {
81-
PADDLE_ENFORCE_EQ(scale_dim[0], C);
82-
PADDLE_ENFORCE_EQ(scale_dim[0], C);
99+
PADDLE_ENFORCE_EQ(scale_dim[0], C,
100+
"ShapeError: the shape of scale must equal to [%d]"
101+
"But received: the shape of scale is [%d]",
102+
C, scale_dim[0]);
103+
PADDLE_ENFORCE_EQ(bias_dim[0], C,
104+
"ShapeError: the shape of bias must equal to [%d]"
105+
"But received: the shape of bias is [%d]",
106+
C, bias_dim[0]);
83107
}
84108
ctx->SetOutputDim("Y", x_dims);
85109
ctx->SetOutputDim("MeanOut", {C});

0 commit comments

Comments
 (0)