@@ -4884,48 +4884,52 @@ def minimize(self, loss, startup_program=None):
4884
4884
inputs = {"X" : fast_var },
4885
4885
outputs = {"Out" : slow_var })
4886
4886
4887
- # Add Var k to main prog and startup prog
4888
- k = layers .create_global_var (
4889
- name = "lookahead_k" ,
4890
- shape = [1 ],
4891
- value = int (self .k ),
4892
- dtype = 'int32' ,
4893
- persistable = True )
4887
+ with framework .program_guard (main_block .program , startup_program ):
4888
+ # Add Var k to main prog and startup prog
4889
+ k = layers .create_global_var (
4890
+ name = "lookahead_k" ,
4891
+ shape = [1 ],
4892
+ value = int (self .k ),
4893
+ dtype = 'int32' ,
4894
+ persistable = True )
4894
4895
4895
- # Add Var alpha to main prog and startup prog
4896
- alpha = layers .create_global_var (
4897
- name = "lookahead_alpha" ,
4898
- shape = [1 ],
4899
- value = float (self .alpha ),
4900
- dtype = 'float32' ,
4901
- persistable = True )
4896
+ # Add Var alpha to main prog and startup prog
4897
+ alpha = layers .create_global_var (
4898
+ name = "lookahead_alpha" ,
4899
+ shape = [1 ],
4900
+ value = float (self .alpha ),
4901
+ dtype = 'float32' ,
4902
+ persistable = True )
4902
4903
4903
- # Add Var step
4904
- step = layers .create_global_var (
4905
- name = "lookahead_step" ,
4906
- shape = [1 ],
4907
- value = int (0 ),
4908
- dtype = 'int32' ,
4909
- persistable = True )
4910
- layers .increment (x = step , value = 1.0 , in_place = True )
4911
-
4912
- # lookahead
4913
- zero_var = layers .fill_constant (shape = [1 ], dtype = 'float32' , value = 0.0 )
4914
-
4915
- one_var = layers .fill_constant (shape = [1 ], dtype = 'float32' , value = 1.0 )
4916
-
4917
- mod = layers .elementwise_mod (step , k )
4918
- with layers .control_flow .Switch () as switch :
4919
- with switch .case (mod == zero_var ):
4920
- for param_name in params :
4921
- fast_var = main_block .var (param_name )
4922
- slow_var = param_to_slow [param_name ]
4923
- tmp_var = layers .elementwise_add (
4924
- layers .elementwise_mul (fast_var , alpha ),
4925
- layers .elementwise_mul (
4926
- slow_var , layers .elementwise_sub (one_var , alpha )))
4927
- layers .assign (input = tmp_var , output = slow_var )
4928
- layers .assign (input = tmp_var , output = fast_var )
4929
- with switch .default ():
4930
- pass
4904
+ # Add Var step
4905
+ step = layers .create_global_var (
4906
+ name = "lookahead_step" ,
4907
+ shape = [1 ],
4908
+ value = int (0 ),
4909
+ dtype = 'int32' ,
4910
+ persistable = True )
4911
+ layers .increment (x = step , value = 1.0 , in_place = True )
4912
+
4913
+ # lookahead
4914
+ zero_var = layers .fill_constant (
4915
+ shape = [1 ], dtype = 'float32' , value = 0.0 )
4916
+
4917
+ one_var = layers .fill_constant (
4918
+ shape = [1 ], dtype = 'float32' , value = 1.0 )
4919
+
4920
+ mod = layers .elementwise_mod (step , k )
4921
+ with layers .control_flow .Switch () as switch :
4922
+ with switch .case (mod == zero_var ):
4923
+ for param_name in params :
4924
+ fast_var = main_block .var (param_name )
4925
+ slow_var = param_to_slow [param_name ]
4926
+ tmp_var = layers .elementwise_add (
4927
+ layers .elementwise_mul (fast_var , alpha ),
4928
+ layers .elementwise_mul (
4929
+ slow_var ,
4930
+ layers .elementwise_sub (one_var , alpha )))
4931
+ layers .assign (input = tmp_var , output = slow_var )
4932
+ layers .assign (input = tmp_var , output = fast_var )
4933
+ with switch .default ():
4934
+ pass
4931
4935
return mini_out
0 commit comments