17
17
import contextlib
18
18
import copy
19
19
import inspect
20
+ import random
20
21
import weakref
21
22
from typing import TYPE_CHECKING , Any , TypedDict
22
23
24
+ import numpy as np
25
+
23
26
import paddle
24
27
from paddle import framework
25
28
from paddle .autograd import PyLayer
@@ -113,16 +116,23 @@ def check_recompute_necessary(inputs):
113
116
114
117
115
118
@contextlib .contextmanager
116
- def switch_rng_state_tracker (rng_state , tracker ):
119
+ def switch_rng_state_tracker (rng_state , tracker , numpy_state , random_state ):
117
120
orig_rng_state = paddle .get_rng_state ()
118
121
orig_rng_tracker = get_rng_state_tracker ().get_states_tracker ()
119
122
paddle .set_rng_state (rng_state )
120
123
get_rng_state_tracker ().set_states_tracker (tracker )
124
+
125
+ orig_numpy_state = np .random .get_state ()
126
+ orig_random_state = random .getstate ()
127
+ np .random .set_state (numpy_state )
128
+ random .setstate (random_state )
121
129
try :
122
130
yield
123
131
finally :
124
132
paddle .set_rng_state (orig_rng_state )
125
133
get_rng_state_tracker ().set_states_tracker (orig_rng_tracker )
134
+ np .random .set_state (orig_numpy_state )
135
+ random .setstate (orig_random_state )
126
136
127
137
128
138
class RecomputeFunction (PyLayer ):
@@ -147,6 +157,8 @@ def forward(
147
157
ctx .fwd_rng_state_tracker = (
148
158
get_rng_state_tracker ().get_states_tracker ()
149
159
)
160
+ ctx .fwd_numpy_state = np .random .get_state ()
161
+ ctx .fwd_random_state = random .getstate ()
150
162
151
163
# TODO support AMP
152
164
tracer = framework ._dygraph_tracer ()
@@ -252,7 +264,10 @@ def backward(ctx, *args):
252
264
# need restore auto_cast state as well as w/b list
253
265
if ctx .preserve_rng_state :
254
266
with switch_rng_state_tracker (
255
- ctx .fw_rng_state , ctx .fwd_rng_state_tracker
267
+ ctx .fw_rng_state ,
268
+ ctx .fwd_rng_state_tracker ,
269
+ ctx .fwd_numpy_state ,
270
+ ctx .fwd_random_state ,
256
271
):
257
272
with paddle .amp .auto_cast (
258
273
enable = ctx .is_fw_autocast ,
@@ -353,6 +368,9 @@ def _recompute_without_reentrant(
353
368
fwd_cuda_rng_state_tracker = (
354
369
get_rng_state_tracker ().get_states_tracker ()
355
370
)
371
+ fwd_numpy_state = np .random .get_state ()
372
+ fwd_random_state = random .getstate ()
373
+
356
374
tracer = framework ._dygraph_tracer ()
357
375
is_fw_autocast = False if tracer ._amp_level == core .AmpLevel .O0 else True
358
376
if tracer ._amp_level == core .AmpLevel .O2 :
@@ -423,7 +441,10 @@ def inner_unpack(inner_x):
423
441
424
442
if preserve_rng_state :
425
443
with switch_rng_state_tracker (
426
- fw_cuda_rng_state , fwd_cuda_rng_state_tracker
444
+ fw_cuda_rng_state ,
445
+ fwd_cuda_rng_state_tracker ,
446
+ fwd_numpy_state ,
447
+ fwd_random_state ,
427
448
):
428
449
with paddle .set_grad_enabled (True ):
429
450
with paddle .amp .auto_cast (
0 commit comments