Skip to content

Commit 64573f9

Browse files
fix found_inf bug for custom optimizer (#50158)
1 parent 8031054 commit 64573f9

File tree

2 files changed

+30
-9
lines changed

2 files changed

+30
-9
lines changed

python/paddle/amp/grad_scaler.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -228,9 +228,16 @@ def minimize(self, optimizer, *args, **kwargs):
228228

229229
optimize_ops, params_grads = (None, None)
230230

231-
optimizer._set_auxiliary_var('found_inf', self._found_inf)
232-
optimize_ops, params_grads = optimizer.minimize(*args, **kwargs)
233-
self._cache_founf_inf = optimizer._get_auxiliary_var('found_inf')
231+
if hasattr(optimizer, "_set_auxiliary_var"):
232+
optimizer._set_auxiliary_var('found_inf', self._found_inf)
233+
optimize_ops, params_grads = optimizer.minimize(*args, **kwargs)
234+
self._cache_founf_inf = optimizer._get_auxiliary_var('found_inf')
235+
else:
236+
if self._found_inf:
237+
self._cache_founf_inf = True
238+
else:
239+
optimize_ops, params_grads = optimizer.minimize(*args, **kwargs)
240+
self._cache_founf_inf = False
234241

235242
if self._use_dynamic_loss_scaling:
236243
# uopdate the scale
@@ -771,9 +778,16 @@ def step(self, optimizer):
771778
if optimizer_state["state"] is OptimizerState.INIT:
772779
self._unscale(optimizer)
773780

774-
optimizer._set_auxiliary_var('found_inf', self._found_inf)
775-
optimizer.step()
776-
self._cache_founf_inf = optimizer._get_auxiliary_var('found_inf')
781+
if hasattr(optimizer, "_set_auxiliary_var"):
782+
optimizer._set_auxiliary_var('found_inf', self._found_inf)
783+
optimizer.step()
784+
self._cache_founf_inf = optimizer._get_auxiliary_var('found_inf')
785+
else:
786+
if self._found_inf:
787+
self._cache_founf_inf = True
788+
else:
789+
optimizer.step()
790+
self._cache_founf_inf = False
777791

778792
optimizer_state["state"] = OptimizerState.STEPPED
779793

python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_gradscaler.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,16 @@ def minimize(self, optimizer, *args, **kwargs):
4141

4242
optimize_ops, params_grads = (None, None)
4343

44-
optimizer._set_auxiliary_var('found_inf', self._found_inf)
45-
optimize_ops, params_grads = optimizer.minimize(*args, **kwargs)
46-
self._cache_founf_inf = optimizer._get_auxiliary_var('found_inf')
44+
if hasattr(optimizer, "_set_auxiliary_var"):
45+
optimizer._set_auxiliary_var('found_inf', self._found_inf)
46+
optimize_ops, params_grads = optimizer.minimize(*args, **kwargs)
47+
self._cache_founf_inf = optimizer._get_auxiliary_var('found_inf')
48+
else:
49+
if self._found_inf:
50+
self._cache_founf_inf = True
51+
else:
52+
optimize_ops, params_grads = optimizer.minimize(*args, **kwargs)
53+
self._cache_founf_inf = False
4754

4855
if self._use_dynamic_loss_scaling:
4956
self._update()

0 commit comments

Comments
 (0)