@@ -21,8 +21,6 @@ namespace paddle {
21
21
22
22
REGISTER_LAYER (mkldnn_batch_norm, MKLDNNBatchNormLayer);
23
23
24
- const real MKLDNNBatchNormLayer::EPS = 1E-5 ;
25
-
26
24
bool MKLDNNBatchNormLayer::init (const LayerMap& layerMap,
27
25
const ParameterMap& parameterMap) {
28
26
if (!MKLDNNLayer::init (layerMap, parameterMap)) {
@@ -50,6 +48,8 @@ bool MKLDNNBatchNormLayer::init(const LayerMap& layerMap,
50
48
useGlobalStats_ = config_.use_global_stats ();
51
49
}
52
50
movingAvgFraction_ = config_.moving_average_fraction ();
51
+ epsilon_ = config_.epsilon ();
52
+
53
53
VLOG (MKLDNN_BASE) << " --- " << (useGlobalStats_ ? " use" : " do not use" )
54
54
<< " --- global stats" ;
55
55
VLOG (MKLDNN_BASE) << " Moving average fraction: " << movingAvgFraction_;
@@ -210,7 +210,7 @@ void MKLDNNBatchNormLayer::resetFwdPD(
210
210
if (wgt) {
211
211
flags_ = (flags_ | batch_normalization_flag::use_scale_shift);
212
212
}
213
- auto fwdDesc = bn_fwd::desc (pk, in->getMemoryDesc (), EPS , flags_);
213
+ auto fwdDesc = bn_fwd::desc (pk, in->getMemoryDesc (), epsilon_ , flags_);
214
214
pd.reset (new bn_fwd::primitive_desc (fwdDesc, engine_));
215
215
CHECK_PRIMITIVE_DESC_EQ (out, pd->dst_primitive_desc ());
216
216
if (wgt) {
@@ -277,7 +277,7 @@ void MKLDNNBatchNormLayer::resetBwdPD(
277
277
}
278
278
CHECK_PRIMITIVE_DESC_EQ (out, in->getPrimitiveDesc ());
279
279
auto md = in->getMemoryDesc ();
280
- auto bwdDesc = bn_bwd::desc (prop_kind::backward, md, md, EPS , flags_);
280
+ auto bwdDesc = bn_bwd::desc (prop_kind::backward, md, md, epsilon_ , flags_);
281
281
pd.reset (new bn_bwd::primitive_desc (bwdDesc, engine_, *fwdPD_));
282
282
CHECK (pd->weights_primitive_desc () == fwdPD_->weights_primitive_desc ());
283
283
CHECK_PRIMITIVE_DESC_EQ (wgt, pd->diff_weights_primitive_desc ());
0 commit comments