Skip to content

Commit c36eb75

Browse files
committed
fix BatchNorm for fp16
1 parent ec148ca commit c36eb75

File tree

1 file changed

+13
-4
lines changed

1 file changed

+13
-4
lines changed

python/paddle/nn/layer/norm.py

+13-4
Original file line numberDiff line numberDiff line change
@@ -563,33 +563,40 @@ def __init__(self,
563563
self._bias_attr = bias_attr
564564
self._use_global_stats = use_global_stats
565565

566-
if get_default_dtype() == 'float16':
567-
set_default_dtype('float32')
566+
self._dtype = 'float32'
568567

569568
param_shape = [num_features]
570569

571570
# create parameter
572571
if weight_attr == False:
573572
self.weight = self.create_parameter(
574-
attr=None, shape=param_shape, default_initializer=Constant(1.0))
573+
attr=None,
574+
shape=param_shape,
575+
dtype=self._dtype,
576+
default_initializer=Constant(1.0))
575577
self.weight.stop_gradient = True
576578
else:
577579
self.weight = self.create_parameter(
578580
attr=self._weight_attr,
579581
shape=param_shape,
582+
dtype=self._dtype,
580583
default_initializer=Constant(1.0))
581584
self.weight.stop_gradient = self._weight_attr != None and self._weight_attr.learning_rate == 0.
582585

583586
if bias_attr == False:
584587
self.bias = self.create_parameter(
585588
attr=None,
586589
shape=param_shape,
590+
dtype=self._dtype,
587591
default_initializer=Constant(0.0),
588592
is_bias=True)
589593
self.bias.stop_gradient = True
590594
else:
591595
self.bias = self.create_parameter(
592-
attr=self._bias_attr, shape=param_shape, is_bias=True)
596+
attr=self._bias_attr,
597+
shape=param_shape,
598+
dtype=self._dtype,
599+
is_bias=True)
593600
self.bias.stop_gradient = self._bias_attr != None and self._bias_attr.learning_rate == 0.
594601

595602
moving_mean_name = None
@@ -600,6 +607,7 @@ def __init__(self,
600607
moving_variance_name = name + "_variance"
601608

602609
self._mean = self.create_parameter(
610+
dtype=self._dtype,
603611
attr=ParamAttr(
604612
name=moving_mean_name,
605613
initializer=Constant(0.0),
@@ -609,6 +617,7 @@ def __init__(self,
609617
self._mean.stop_gradient = True
610618

611619
self._variance = self.create_parameter(
620+
dtype=self._dtype,
612621
attr=ParamAttr(
613622
name=moving_variance_name,
614623
initializer=Constant(1.0),

0 commit comments

Comments
 (0)