Skip to content

Commit 6577760

Browse files
authored
Merge pull request #5692 from peterzhang2029/add_bn_eq
Make epsilon in BatchNormLayer a configurable variable.
2 parents 6ab78ae + 90e05a4 commit 6577760

14 files changed

+41
-24
lines changed

paddle/gserver/layers/BatchNormBaseLayer.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ bool BatchNormBaseLayer::init(const LayerMap& layerMap,
4141
useGlobalStats_ = config_.use_global_stats();
4242
}
4343
movingAvgFraction_ = config_.moving_average_fraction();
44+
epsilon_ = config_.epsilon();
4445

4546
weight_.reset(new Weight(1, channels_, parameters_[0]));
4647
movingMean_.reset(new Weight(1, channels_, parameters_[1]));

paddle/gserver/layers/BatchNormBaseLayer.h

+2
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,8 @@ class BatchNormBaseLayer : public Layer {
9494
bool useGlobalStats_;
9595
// use to compute moving mean and variance.
9696
real movingAvgFraction_;
97+
// Epsilon is a small random noise used in batch normalization for stability.
98+
real epsilon_;
9799
};
98100

99101
} // namespace paddle

paddle/gserver/layers/BatchNormalizationLayer.cpp

+2-4
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,6 @@ namespace paddle {
2222

2323
REGISTER_LAYER(batch_norm, BatchNormalizationLayer);
2424

25-
const real BatchNormalizationLayer::EPS = 1E-5;
26-
2725
bool BatchNormalizationLayer::init(const LayerMap& layerMap,
2826
const ParameterMap& parameterMap) {
2927
/* Initialize the basic parent class */
@@ -53,7 +51,7 @@ void BatchNormalizationLayer::calMeanAndStd(const MatrixPtr& mat) {
5351

5452
calMovingMeanAndVar();
5553

56-
savedInvVar_->subScalar(-EPS);
54+
savedInvVar_->subScalar(-epsilon_);
5755
savedInvVar_->sqrt2(*savedInvVar_);
5856
}
5957

@@ -74,7 +72,7 @@ void BatchNormalizationLayer::setMeanAndStd() {
7472
savedInvVar_->copyFrom(*(movingVar_->getW()));
7573
savedInvVar_->downClip(real(0.0));
7674

77-
savedInvVar_->subScalar(-EPS);
75+
savedInvVar_->subScalar(-epsilon_);
7876
savedInvVar_->sqrt2(*savedInvVar_);
7977
}
8078

paddle/gserver/layers/BatchNormalizationLayer.h

-3
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,6 @@ class BatchNormalizationLayer : public BatchNormBaseLayer {
3939
void backward(const UpdateCallback& callback = nullptr) override;
4040

4141
protected:
42-
/// Epsilon value used in the batch normalization formula.
43-
static const real EPS;
44-
4542
/// Load pre-calculated mean and std.
4643
void setMeanAndStd();
4744

paddle/gserver/layers/CudnnBatchNormLayer.cpp

+10-6
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,6 @@ namespace paddle {
2121

2222
REGISTER_LAYER(cudnn_batch_norm, CudnnBatchNormLayer);
2323

24-
const double CudnnBatchNormLayer::EPS = 1E-5;
25-
2624
bool CudnnBatchNormLayer::init(const LayerMap& layerMap,
2725
const ParameterMap& parameterMap) {
2826
/* Initialize the basic parent class */
@@ -61,6 +59,9 @@ void CudnnBatchNormLayer::forward(PassType passType) {
6159
real* movingMean = movingMean_->getW()->getData();
6260
real* movingVar = movingVar_->getW()->getData();
6361

62+
// cuDNN does not allow an epsilon value less than CUDNN_BN_MIN_EPSILON.
63+
eps_ = std::max(CUDNN_BN_MIN_EPSILON, static_cast<double>(epsilon_));
64+
6465
if (!useGlobalStats_) {
6566
REGISTER_TIMER_INFO("CudnnBatchFwTimer", getName().c_str());
6667
real* savedMean = savedMean_->getData();
@@ -75,7 +76,7 @@ void CudnnBatchNormLayer::forward(PassType passType) {
7576
1.0 - movingAvgFraction_,
7677
movingMean,
7778
movingVar,
78-
EPS,
79+
eps_,
7980
savedMean,
8081
savedInvVar);
8182
} else {
@@ -90,7 +91,7 @@ void CudnnBatchNormLayer::forward(PassType passType) {
9091
beta,
9192
movingMean,
9293
movingVar,
93-
EPS);
94+
eps_);
9495
} else {
9596
// There is a limitation in cudnn library.
9697
// When the batch size is larger than 1024 in cuDNN v5.1,
@@ -101,7 +102,7 @@ void CudnnBatchNormLayer::forward(PassType passType) {
101102
beta,
102103
movingMean,
103104
movingVar,
104-
EPS,
105+
eps_,
105106
batchSize,
106107
channels_,
107108
imageH_ * imageD_,
@@ -128,6 +129,9 @@ void CudnnBatchNormLayer::backward(const UpdateCallback& callback) {
128129
real* savedMean = savedMean_->getData();
129130
real* savedInvVar = savedInvVar_->getData();
130131

132+
// cuDNN does not allow an epsilon value less than CUDNN_BN_MIN_EPSILON.
133+
eps_ = std::max(CUDNN_BN_MIN_EPSILON, static_cast<double>(epsilon_));
134+
131135
auto create = [](MatrixPtr& m, size_t h, size_t w, real** p) {
132136
Matrix::resizeOrCreate(m, h, w, false, true);
133137
m->zeroMem();
@@ -157,7 +161,7 @@ void CudnnBatchNormLayer::backward(const UpdateCallback& callback) {
157161
gamma,
158162
gammaGrad,
159163
betaGrad,
160-
EPS,
164+
eps_,
161165
savedMean,
162166
savedInvVar);
163167

paddle/gserver/layers/CudnnBatchNormLayer.h

+4-6
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ limitations under the License. */
1414

1515
#pragma once
1616

17+
#include <cudnn.h>
1718
#include "BatchNormBaseLayer.h"
1819
#include "Layer.h"
1920
#include "paddle/utils/Stat.h"
@@ -46,12 +47,9 @@ class CudnnBatchNormLayer : public BatchNormBaseLayer {
4647
void backward(const UpdateCallback& callback = nullptr) override;
4748

4849
protected:
49-
/**
50-
* Epsilon value used in the batch normalization formula.
51-
* Minimum allowed value is CUDNN_BN_MIN_EPSILON defined in cudnn.h.
52-
* Same epsilon value should be used in forward and backward functions.
53-
*/
54-
static const double EPS;
50+
/// Epsilon value used in the batch normalization formula.
51+
/// Same epsilon value should be used in forward and backward functions.
52+
double eps_;
5553

5654
/// Input/output tensor descriptor desc
5755
hl_tensor_descriptor ioDesc_;

paddle/gserver/layers/MKLDNNBatchNormLayer.cpp

+4-4
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,6 @@ namespace paddle {
2121

2222
REGISTER_LAYER(mkldnn_batch_norm, MKLDNNBatchNormLayer);
2323

24-
const real MKLDNNBatchNormLayer::EPS = 1E-5;
25-
2624
bool MKLDNNBatchNormLayer::init(const LayerMap& layerMap,
2725
const ParameterMap& parameterMap) {
2826
if (!MKLDNNLayer::init(layerMap, parameterMap)) {
@@ -50,6 +48,8 @@ bool MKLDNNBatchNormLayer::init(const LayerMap& layerMap,
5048
useGlobalStats_ = config_.use_global_stats();
5149
}
5250
movingAvgFraction_ = config_.moving_average_fraction();
51+
epsilon_ = config_.epsilon();
52+
5353
VLOG(MKLDNN_BASE) << "--- " << (useGlobalStats_ ? "use" : "do not use")
5454
<< " --- global stats";
5555
VLOG(MKLDNN_BASE) << "Moving average fraction: " << movingAvgFraction_;
@@ -210,7 +210,7 @@ void MKLDNNBatchNormLayer::resetFwdPD(
210210
if (wgt) {
211211
flags_ = (flags_ | batch_normalization_flag::use_scale_shift);
212212
}
213-
auto fwdDesc = bn_fwd::desc(pk, in->getMemoryDesc(), EPS, flags_);
213+
auto fwdDesc = bn_fwd::desc(pk, in->getMemoryDesc(), epsilon_, flags_);
214214
pd.reset(new bn_fwd::primitive_desc(fwdDesc, engine_));
215215
CHECK_PRIMITIVE_DESC_EQ(out, pd->dst_primitive_desc());
216216
if (wgt) {
@@ -277,7 +277,7 @@ void MKLDNNBatchNormLayer::resetBwdPD(
277277
}
278278
CHECK_PRIMITIVE_DESC_EQ(out, in->getPrimitiveDesc());
279279
auto md = in->getMemoryDesc();
280-
auto bwdDesc = bn_bwd::desc(prop_kind::backward, md, md, EPS, flags_);
280+
auto bwdDesc = bn_bwd::desc(prop_kind::backward, md, md, epsilon_, flags_);
281281
pd.reset(new bn_bwd::primitive_desc(bwdDesc, engine_, *fwdPD_));
282282
CHECK(pd->weights_primitive_desc() == fwdPD_->weights_primitive_desc());
283283
CHECK_PRIMITIVE_DESC_EQ(wgt, pd->diff_weights_primitive_desc());

paddle/gserver/layers/MKLDNNBatchNormLayer.h

+2-1
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,8 @@ class MKLDNNBatchNormLayer : public MKLDNNLayer {
3232
std::shared_ptr<bn_fwd::primitive_desc> fwdPD_;
3333

3434
// Epsilon value used in the batch normalization formula.
35-
static const real EPS;
35+
real epsilon_;
36+
3637
// weight and bias in paddle
3738
std::unique_ptr<Weight> weight_;
3839
std::unique_ptr<Weight> biases_;

proto/ModelConfig.proto

+4
Original file line numberDiff line numberDiff line change
@@ -540,6 +540,10 @@ message LayerConfig {
540540

541541
// for switch order layer
542542
optional ReshapeConfig reshape_conf = 59;
543+
544+
// for batch normalization layer
545+
// The small constant added to the variance to improve numeric stability.
546+
optional double epsilon = 60 [ default = 0.00001 ];
543547
}
544548

545549
message EvaluatorConfig {

python/paddle/trainer/config_parser.py

+4
Original file line numberDiff line numberDiff line change
@@ -2412,6 +2412,7 @@ def __init__(self,
24122412
bias=True,
24132413
img3D=False,
24142414
use_global_stats=True,
2415+
epsilon=1e-5,
24152416
moving_average_fraction=0.9,
24162417
batch_norm_type=None,
24172418
mean_var_names=None,
@@ -2460,6 +2461,9 @@ def __init__(self,
24602461
self.config.use_global_stats = use_global_stats
24612462
if moving_average_fraction is not None:
24622463
self.config.moving_average_fraction = moving_average_fraction
2464+
if epsilon is not None:
2465+
assert epsilon >= 1e-5, "epsilon must be no less than 1e-5."
2466+
self.config.epsilon = epsilon
24632467

24642468
input_layer = self.get_input_layer(0)
24652469
image_conf = self.config.inputs[0].image_conf

python/paddle/trainer_config_helpers/layers.py

+5
Original file line numberDiff line numberDiff line change
@@ -3118,6 +3118,7 @@ def batch_norm_layer(input,
31183118
param_attr=None,
31193119
layer_attr=None,
31203120
batch_norm_type=None,
3121+
epsilon=1e-5,
31213122
moving_average_fraction=0.9,
31223123
use_global_stats=None,
31233124
mean_var_names=None):
@@ -3188,6 +3189,8 @@ def batch_norm_layer(input,
31883189
will use the mean and variance of the current batch
31893190
of test data.
31903191
:type use_global_stats: bool | None.
3192+
:param epsilon: The small constant added to the variance to improve numeric stability.
3193+
:type epsilon: float.
31913194
:param moving_average_fraction: Factor used in the moving average computation.
31923195
:math:`runningMean = newMean*(1-factor) + runningMean*factor`
31933196
:type moving_average_fraction: float.
@@ -3205,6 +3208,7 @@ def batch_norm_layer(input,
32053208
assert (batch_norm_type is None) or (batch_norm_type == "batch_norm") or \
32063209
(batch_norm_type == "mkldnn_batch_norm") or \
32073210
(batch_norm_type == "cudnn_batch_norm")
3211+
32083212
l = Layer(
32093213
name=name,
32103214
img3D=img3D,
@@ -3214,6 +3218,7 @@ def batch_norm_layer(input,
32143218
type=LayerType.BATCH_NORM_LAYER,
32153219
batch_norm_type=batch_norm_type,
32163220
bias=ParamAttr.to_bias(bias_attr),
3221+
epsilon=epsilon,
32173222
moving_average_fraction=moving_average_fraction,
32183223
use_global_stats=use_global_stats,
32193224
mean_var_names=mean_var_names,

python/paddle/trainer_config_helpers/tests/configs/protostr/img_layers.protostr

+1
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ layers {
6565
height: 227
6666
width: 227
6767
depth: 1
68+
epsilon: 1e-05
6869
}
6970
layers {
7071
name: "__crmnorm_0__"

python/paddle/trainer_config_helpers/tests/configs/protostr/img_trans_layers.protostr

+1
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ layers {
6565
height: 256
6666
width: 256
6767
depth: 1
68+
epsilon: 1e-05
6869
}
6970
layers {
7071
name: "__crmnorm_0__"

python/paddle/trainer_config_helpers/tests/configs/protostr/test_BatchNorm3D.protostr

+1
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ layers {
3636
height: 6
3737
width: 20
3838
depth: 3
39+
epsilon: 1e-05
3940
}
4041
parameters {
4142
name: "___batch_norm_0__.w0"

0 commit comments

Comments
 (0)