Skip to content

Commit dfed246

Browse files
[Fix] fix recompute to save random and numpy status (PaddlePaddle#72364)
* [Fix] fix recompute to save random and numpy status * fix code style
1 parent 4fce57b commit dfed246

File tree

2 files changed

+33
-4
lines changed

2 files changed

+33
-4
lines changed

python/paddle/distributed/fleet/recompute/recompute.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,12 @@
1717
import contextlib
1818
import copy
1919
import inspect
20+
import random
2021
import weakref
2122
from typing import TYPE_CHECKING, Any, TypedDict
2223

24+
import numpy as np
25+
2326
import paddle
2427
from paddle import framework
2528
from paddle.autograd import PyLayer
@@ -113,16 +116,23 @@ def check_recompute_necessary(inputs):
113116

114117

115118
@contextlib.contextmanager
116-
def switch_rng_state_tracker(rng_state, tracker):
119+
def switch_rng_state_tracker(rng_state, tracker, numpy_state, random_state):
117120
orig_rng_state = paddle.get_rng_state()
118121
orig_rng_tracker = get_rng_state_tracker().get_states_tracker()
119122
paddle.set_rng_state(rng_state)
120123
get_rng_state_tracker().set_states_tracker(tracker)
124+
125+
orig_numpy_state = np.random.get_state()
126+
orig_random_state = random.getstate()
127+
np.random.set_state(numpy_state)
128+
random.setstate(random_state)
121129
try:
122130
yield
123131
finally:
124132
paddle.set_rng_state(orig_rng_state)
125133
get_rng_state_tracker().set_states_tracker(orig_rng_tracker)
134+
np.random.set_state(orig_numpy_state)
135+
random.setstate(orig_random_state)
126136

127137

128138
class RecomputeFunction(PyLayer):
@@ -147,6 +157,8 @@ def forward(
147157
ctx.fwd_rng_state_tracker = (
148158
get_rng_state_tracker().get_states_tracker()
149159
)
160+
ctx.fwd_numpy_state = np.random.get_state()
161+
ctx.fwd_random_state = random.getstate()
150162

151163
# TODO support AMP
152164
tracer = framework._dygraph_tracer()
@@ -252,7 +264,10 @@ def backward(ctx, *args):
252264
# need restore auto_cast state as well as w/b list
253265
if ctx.preserve_rng_state:
254266
with switch_rng_state_tracker(
255-
ctx.fw_rng_state, ctx.fwd_rng_state_tracker
267+
ctx.fw_rng_state,
268+
ctx.fwd_rng_state_tracker,
269+
ctx.fwd_numpy_state,
270+
ctx.fwd_random_state,
256271
):
257272
with paddle.amp.auto_cast(
258273
enable=ctx.is_fw_autocast,
@@ -353,6 +368,9 @@ def _recompute_without_reentrant(
353368
fwd_cuda_rng_state_tracker = (
354369
get_rng_state_tracker().get_states_tracker()
355370
)
371+
fwd_numpy_state = np.random.get_state()
372+
fwd_random_state = random.getstate()
373+
356374
tracer = framework._dygraph_tracer()
357375
is_fw_autocast = False if tracer._amp_level == core.AmpLevel.O0 else True
358376
if tracer._amp_level == core.AmpLevel.O2:
@@ -423,7 +441,10 @@ def inner_unpack(inner_x):
423441

424442
if preserve_rng_state:
425443
with switch_rng_state_tracker(
426-
fw_cuda_rng_state, fwd_cuda_rng_state_tracker
444+
fw_cuda_rng_state,
445+
fwd_cuda_rng_state_tracker,
446+
fwd_numpy_state,
447+
fwd_random_state,
427448
):
428449
with paddle.set_grad_enabled(True):
429450
with paddle.amp.auto_cast(

python/paddle/distributed/fleet/recompute/recompute_hybrid.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,11 @@
1313
# limitations under the License.
1414
from __future__ import annotations
1515

16+
import random
1617
from typing import TYPE_CHECKING, Any, TypedDict
1718

19+
import numpy as np
20+
1821
import paddle
1922
from paddle import framework
2023
from paddle.autograd import PyLayer
@@ -109,6 +112,8 @@ def forward(
109112
# store the rng states
110113
ctx.fwd_rng_state = paddle.get_rng_state()
111114
ctx.fwd_rng_state_tracker = get_rng_state_tracker().get_states_tracker()
115+
ctx.fwd_numpy_state = np.random.get_state()
116+
ctx.fwd_random_state = random.getstate()
112117

113118
# save config info
114119
ctx.mp_group = mp_group
@@ -214,7 +219,10 @@ def backward(ctx, *args):
214219

215220
# need restore auto_cast state as well as w/b list
216221
with switch_rng_state_tracker(
217-
ctx.fwd_rng_state, ctx.fwd_rng_state_tracker
222+
ctx.fwd_rng_state,
223+
ctx.fwd_rng_state_tracker,
224+
ctx.fwd_numpy_state,
225+
ctx.fwd_random_state,
218226
):
219227
if ctx.is_fw_autocast:
220228
with paddle.amp.auto_cast(

0 commit comments

Comments
 (0)