18
18
19
19
import numpy as np
20
20
21
- from paddle import _C_ops , _legacy_C_ops
21
+ from paddle import _legacy_C_ops
22
22
from paddle .fluid import core , in_dygraph_mode
23
23
from paddle .fluid .data_feeder import check_type
24
24
from paddle .fluid .dygraph import to_variable
@@ -228,16 +228,11 @@ def minimize(self, optimizer, *args, **kwargs):
228
228
229
229
optimize_ops , params_grads = (None , None )
230
230
231
- if hasattr (optimizer , "_set_auxiliary_var" ):
232
- optimizer ._set_auxiliary_var ('found_inf' , self ._found_inf )
233
- optimize_ops , params_grads = optimizer .minimize (* args , ** kwargs )
234
- self ._cache_founf_inf = optimizer ._get_auxiliary_var ('found_inf' )
231
+ if self ._found_inf :
232
+ self ._cache_founf_inf = True
235
233
else :
236
- if self ._found_inf :
237
- self ._cache_founf_inf = True
238
- else :
239
- optimize_ops , params_grads = optimizer .minimize (* args , ** kwargs )
240
- self ._cache_founf_inf = False
234
+ optimize_ops , params_grads = optimizer .minimize (* args , ** kwargs )
235
+ self ._cache_founf_inf = False
241
236
242
237
if self ._use_dynamic_loss_scaling :
243
238
# uopdate the scale
@@ -335,9 +330,6 @@ def _unscale(self, optimizer):
335
330
param_grads_fp16 ,
336
331
self ._temp_found_inf_fp16 ,
337
332
)
338
- self ._found_inf = _C_ops .bitwise_or (
339
- self ._found_inf , self ._temp_found_inf_fp16
340
- )
341
333
if len (param_grads_bf16 ):
342
334
_legacy_C_ops .check_finite_and_unscale (
343
335
param_grads_bf16 ,
@@ -346,9 +338,6 @@ def _unscale(self, optimizer):
346
338
param_grads_bf16 ,
347
339
self ._temp_found_inf_bf16 ,
348
340
)
349
- self ._found_inf = _C_ops .bitwise_or (
350
- self ._found_inf , self ._temp_found_inf_bf16
351
- )
352
341
if len (param_grads_fp32 ):
353
342
_legacy_C_ops .check_finite_and_unscale (
354
343
param_grads_fp32 ,
@@ -357,9 +346,6 @@ def _unscale(self, optimizer):
357
346
param_grads_fp32 ,
358
347
self ._temp_found_inf_fp32 ,
359
348
)
360
- self ._found_inf = _C_ops .bitwise_or (
361
- self ._found_inf , self ._temp_found_inf_fp32
362
- )
363
349
else :
364
350
if len (param_grads_fp16 ):
365
351
_legacy_C_ops .check_finite_and_unscale (
@@ -368,29 +354,26 @@ def _unscale(self, optimizer):
368
354
param_grads_fp16 ,
369
355
self ._temp_found_inf_fp16 ,
370
356
)
371
- self ._found_inf = _C_ops .bitwise_or (
372
- self ._found_inf , self ._temp_found_inf_fp16
373
- )
374
357
if len (param_grads_bf16 ):
375
358
_legacy_C_ops .check_finite_and_unscale (
376
359
param_grads_bf16 ,
377
360
self ._scale ,
378
361
param_grads_bf16 ,
379
362
self ._temp_found_inf_bf16 ,
380
363
)
381
- self ._found_inf = _C_ops .bitwise_or (
382
- self ._found_inf , self ._temp_found_inf_bf16
383
- )
384
364
if len (param_grads_fp32 ):
385
365
_legacy_C_ops .check_finite_and_unscale (
386
366
param_grads_fp32 ,
387
367
self ._scale ,
388
368
param_grads_fp32 ,
389
369
self ._temp_found_inf_fp32 ,
390
370
)
391
- self ._found_inf = _C_ops .bitwise_or (
392
- self ._found_inf , self ._temp_found_inf_fp32
393
- )
371
+
372
+ self ._found_inf = (
373
+ self ._temp_found_inf_fp16
374
+ or self ._temp_found_inf_bf16
375
+ or self ._temp_found_inf_fp32
376
+ )
394
377
395
378
optimizer_state ["state" ] = OptimizerState .UNSCALED
396
379
@@ -778,16 +761,11 @@ def step(self, optimizer):
778
761
if optimizer_state ["state" ] is OptimizerState .INIT :
779
762
self ._unscale (optimizer )
780
763
781
- if hasattr (optimizer , "_set_auxiliary_var" ):
782
- optimizer ._set_auxiliary_var ('found_inf' , self ._found_inf )
783
- optimizer .step ()
784
- self ._cache_founf_inf = optimizer ._get_auxiliary_var ('found_inf' )
764
+ if self ._found_inf :
765
+ self ._cache_founf_inf = True
785
766
else :
786
- if self ._found_inf :
787
- self ._cache_founf_inf = True
788
- else :
789
- optimizer .step ()
790
- self ._cache_founf_inf = False
767
+ optimizer .step ()
768
+ self ._cache_founf_inf = False
791
769
792
770
optimizer_state ["state" ] = OptimizerState .STEPPED
793
771
0 commit comments