@@ -959,6 +959,47 @@ def __init__(self,
959
959
super (DGCMomentumOptimizer , self ).__init__ (
960
960
learning_rate , momentum , use_nesterov , regularization , name )
961
961
962
+ def _is_use_dgc (self , param_var , grad_var ):
963
+ var_numel = abs (reduce (lambda x , y : x * y , param_var .shape ))
964
+ if var_numel < 16384 or \
965
+ param_var .type == core .VarDesc .VarType .SELECTED_ROWS or \
966
+ grad_var .type == core .VarDesc .VarType .SELECTED_ROWS or \
967
+ param_var .dtype != core .VarDesc .VarType .FP32 :
968
+ return False
969
+ return True
970
+
971
+ def _append_optimize_op (self , block , param_and_grad ):
972
+ assert isinstance (block , framework .Block )
973
+
974
+ if not self ._is_use_dgc (param_and_grad [0 ], param_and_grad [1 ]):
975
+ return super (DGCMomentumOptimizer , self )._append_optimize_op (
976
+ block , param_and_grad )
977
+
978
+ velocity_acc = self ._get_accumulator (self ._velocity_acc_str ,
979
+ param_and_grad [0 ])
980
+ # create the dgc momentum optimize op
981
+ dgc_momentum_op = block .append_op (
982
+ type = "dgc_momentum" ,
983
+ inputs = {
984
+ "Param" : param_and_grad [0 ],
985
+ "Grad" : param_and_grad [1 ],
986
+ "Velocity" : velocity_acc ,
987
+ "LearningRate" : self ._create_param_lr (param_and_grad ),
988
+ "current_step" : self ._global_step_var ,
989
+ },
990
+ outputs = {
991
+ "ParamOut" : param_and_grad [0 ],
992
+ "VelocityOut" : velocity_acc
993
+ },
994
+ attrs = {
995
+ "mu" : self ._momentum ,
996
+ "use_nesterov" : self ._use_nesterov ,
997
+ "rampup_begin_step" : float (self ._rampup_begin_step )
998
+ },
999
+ stop_gradient = True )
1000
+
1001
+ return dgc_momentum_op
1002
+
962
1003
def _add_auto_increment_var (self , counter_name , begin , step = 1 ):
963
1004
helper = LayerHelper ('global_step_counter' )
964
1005
counter , is_new_var = helper .create_or_get_global_variable (
@@ -997,11 +1038,7 @@ def _append_dgc_ops(self, param_and_grads):
997
1038
force_cpu = True )
998
1039
999
1040
for param_var , grad_var in param_and_grads :
1000
- var_numel = abs (reduce (lambda x , y : x * y , param_var .shape ))
1001
- if var_numel < 16384 or \
1002
- param_var .type == core .VarDesc .VarType .SELECTED_ROWS or \
1003
- grad_var .type == core .VarDesc .VarType .SELECTED_ROWS or \
1004
- param_var .dtype != core .VarDesc .VarType .FP32 :
1041
+ if not self ._is_use_dgc (param_var , grad_var ):
1005
1042
continue
1006
1043
1007
1044
u_var = tensor .create_global_var (
0 commit comments