@@ -228,9 +228,16 @@ def minimize(self, optimizer, *args, **kwargs):
228
228
229
229
optimize_ops , params_grads = (None , None )
230
230
231
- optimizer ._set_auxiliary_var ('found_inf' , self ._found_inf )
232
- optimize_ops , params_grads = optimizer .minimize (* args , ** kwargs )
233
- self ._cache_founf_inf = optimizer ._get_auxiliary_var ('found_inf' )
231
+ if hasattr (optimizer , "_set_auxiliary_var" ):
232
+ optimizer ._set_auxiliary_var ('found_inf' , self ._found_inf )
233
+ optimize_ops , params_grads = optimizer .minimize (* args , ** kwargs )
234
+ self ._cache_founf_inf = optimizer ._get_auxiliary_var ('found_inf' )
235
+ else :
236
+ if self ._found_inf :
237
+ self ._cache_founf_inf = True
238
+ else :
239
+ optimize_ops , params_grads = optimizer .minimize (* args , ** kwargs )
240
+ self ._cache_founf_inf = False
234
241
235
242
if self ._use_dynamic_loss_scaling :
236
243
# uopdate the scale
@@ -771,9 +778,16 @@ def step(self, optimizer):
771
778
if optimizer_state ["state" ] is OptimizerState .INIT :
772
779
self ._unscale (optimizer )
773
780
774
- optimizer ._set_auxiliary_var ('found_inf' , self ._found_inf )
775
- optimizer .step ()
776
- self ._cache_founf_inf = optimizer ._get_auxiliary_var ('found_inf' )
781
+ if hasattr (optimizer , "_set_auxiliary_var" ):
782
+ optimizer ._set_auxiliary_var ('found_inf' , self ._found_inf )
783
+ optimizer .step ()
784
+ self ._cache_founf_inf = optimizer ._get_auxiliary_var ('found_inf' )
785
+ else :
786
+ if self ._found_inf :
787
+ self ._cache_founf_inf = True
788
+ else :
789
+ optimizer .step ()
790
+ self ._cache_founf_inf = False
777
791
778
792
optimizer_state ["state" ] = OptimizerState .STEPPED
779
793
0 commit comments