@@ -563,33 +563,40 @@ def __init__(self,
563
563
self ._bias_attr = bias_attr
564
564
self ._use_global_stats = use_global_stats
565
565
566
- if get_default_dtype () == 'float16' :
567
- set_default_dtype ('float32' )
566
+ self ._dtype = 'float32'
568
567
569
568
param_shape = [num_features ]
570
569
571
570
# create parameter
572
571
if weight_attr == False :
573
572
self .weight = self .create_parameter (
574
- attr = None , shape = param_shape , default_initializer = Constant (1.0 ))
573
+ attr = None ,
574
+ shape = param_shape ,
575
+ dtype = self ._dtype ,
576
+ default_initializer = Constant (1.0 ))
575
577
self .weight .stop_gradient = True
576
578
else :
577
579
self .weight = self .create_parameter (
578
580
attr = self ._weight_attr ,
579
581
shape = param_shape ,
582
+ dtype = self ._dtype ,
580
583
default_initializer = Constant (1.0 ))
581
584
self .weight .stop_gradient = self ._weight_attr != None and self ._weight_attr .learning_rate == 0.
582
585
583
586
if bias_attr == False :
584
587
self .bias = self .create_parameter (
585
588
attr = None ,
586
589
shape = param_shape ,
590
+ dtype = self ._dtype ,
587
591
default_initializer = Constant (0.0 ),
588
592
is_bias = True )
589
593
self .bias .stop_gradient = True
590
594
else :
591
595
self .bias = self .create_parameter (
592
- attr = self ._bias_attr , shape = param_shape , is_bias = True )
596
+ attr = self ._bias_attr ,
597
+ shape = param_shape ,
598
+ dtype = self ._dtype ,
599
+ is_bias = True )
593
600
self .bias .stop_gradient = self ._bias_attr != None and self ._bias_attr .learning_rate == 0.
594
601
595
602
moving_mean_name = None
@@ -600,6 +607,7 @@ def __init__(self,
600
607
moving_variance_name = name + "_variance"
601
608
602
609
self ._mean = self .create_parameter (
610
+ dtype = self ._dtype ,
603
611
attr = ParamAttr (
604
612
name = moving_mean_name ,
605
613
initializer = Constant (0.0 ),
@@ -609,6 +617,7 @@ def __init__(self,
609
617
self ._mean .stop_gradient = True
610
618
611
619
self ._variance = self .create_parameter (
620
+ dtype = self ._dtype ,
612
621
attr = ParamAttr (
613
622
name = moving_variance_name ,
614
623
initializer = Constant (1.0 ),
0 commit comments